PYTHONPython

s3

real world projects / file sync / sync / backends

PYTHON
s3.py🐍
"""
AWS S3 backend for cloud storage.
"""

import asyncio
from pathlib import Path
from typing import List, Optional, Any
from datetime import datetime
import logging

try:
    import boto3
    from botocore.exceptions import ClientError, NoCredentialsError
    HAS_BOTO3 = True
except ImportError:
    HAS_BOTO3 = False
    ClientError = Exception  # type: ignore
    NoCredentialsError = Exception  # type: ignore

from .base import BaseBackend, RemoteFile

logger = logging.getLogger(__name__)


class S3Backend(BaseBackend):
    """AWS S3 storage backend."""
    
    def __init__(
        self,
        bucket: str,
        region: str = "us-east-1",
        prefix: str = "",
        access_key: Optional[str] = None,
        secret_key: Optional[str] = None
    ):
        if not HAS_BOTO3:
            raise ImportError("boto3 is required for S3Backend. Install with: pip install boto3")
        
        self.bucket = bucket
        self.region = region
        self.prefix = prefix.rstrip("/")
        self.access_key = access_key
        self.secret_key = secret_key
        
        self._client: Any = None
        self._resource: Any = None
        self._connected = False
    
    def _get_s3_key(self, path: str) -> str:
        """Get full S3 key including prefix."""
        if self.prefix:
            return f"{self.prefix}/{path}"
        return path
    
    async def connect(self):
        """Initialize S3 client."""
        try:
            session_kwargs: dict[str, Any] = {"region_name": self.region}
            
            if self.access_key and self.secret_key:
                session_kwargs["aws_access_key_id"] = self.access_key
                session_kwargs["aws_secret_access_key"] = self.secret_key
            
            # Create client (boto3 is sync, we'll use run_in_executor for I/O)
            self._client = boto3.client("s3", **session_kwargs)
            self._resource = boto3.resource("s3", **session_kwargs)
            
            # Verify connection by checking bucket exists
            loop = asyncio.get_event_loop()
            await loop.run_in_executor(
                None,
                lambda: self._client.head_bucket(Bucket=self.bucket)
            )
            
            self._connected = True
            logger.info(f"S3Backend connected: s3://{self.bucket}/{self.prefix}")
            
        except NoCredentialsError:
            logger.error("AWS credentials not found")
            raise
        except ClientError as e:
            logger.error(f"Failed to connect to S3: {e}")
            raise
    
    async def disconnect(self):
        """Cleanup S3 client."""
        self._client = None
        self._resource = None
        self._connected = False
        logger.info("S3Backend disconnected")
    
    async def upload(self, local_path: Path, remote_path: str) -> bool:
        """Upload file to S3."""
        if not self._client:
            return False
        
        try:
            s3_key = self._get_s3_key(remote_path)
            client = self._client
            
            loop = asyncio.get_event_loop()
            await loop.run_in_executor(
                None,
                lambda: client.upload_file(
                    str(local_path),
                    self.bucket,
                    s3_key
                )
            )
            
            logger.debug(f"Uploaded to S3: {s3_key}")
            return True
            
        except ClientError as e:
            logger.error(f"Error uploading to S3 {remote_path}: {e}")
            return False
    
    async def download(self, remote_path: str, local_path: Path) -> bool:
        """Download file from S3."""
        if not self._client:
            return False
        
        try:
            s3_key = self._get_s3_key(remote_path)
            local_path.parent.mkdir(parents=True, exist_ok=True)
            client = self._client
            
            loop = asyncio.get_event_loop()
            await loop.run_in_executor(
                None,
                lambda: client.download_file(
                    self.bucket,
                    s3_key,
                    str(local_path)
                )
            )
            
            logger.debug(f"Downloaded from S3: {s3_key}")
            return True
            
        except ClientError as e:
            logger.error(f"Error downloading from S3 {remote_path}: {e}")
            return False
    
    async def delete(self, remote_path: str) -> bool:
        """Delete file from S3."""
        if not self._client:
            return False
        
        try:
            s3_key = self._get_s3_key(remote_path)
            client = self._client
            
            loop = asyncio.get_event_loop()
            await loop.run_in_executor(
                None,
                lambda: client.delete_object(
                    Bucket=self.bucket,
                    Key=s3_key
                )
            )
            
            logger.debug(f"Deleted from S3: {s3_key}")
            return True
            
        except ClientError as e:
            logger.error(f"Error deleting from S3 {remote_path}: {e}")
            return False
    
    async def exists(self, remote_path: str) -> bool:
        """Check if file exists in S3."""
        if not self._client:
            return False
        
        try:
            s3_key = self._get_s3_key(remote_path)
            client = self._client
            
            loop = asyncio.get_event_loop()
            await loop.run_in_executor(
                None,
                lambda: client.head_object(
                    Bucket=self.bucket,
                    Key=s3_key
                )
            )
            
            return True
            
        except ClientError:
            return False
    
    async def list_files(self, prefix: str = "") -> List[RemoteFile]:
        """List files in S3 bucket."""
        if not self._client:
            return []
        
        files: List[RemoteFile] = []
        
        try:
            s3_prefix = self._get_s3_key(prefix) if prefix else self.prefix
            client = self._client
            stored_prefix = self.prefix
            
            loop = asyncio.get_event_loop()
            paginator = client.get_paginator("list_objects_v2")
            
            def list_objects() -> List[RemoteFile]:
                result: List[RemoteFile] = []
                for page in paginator.paginate(Bucket=self.bucket, Prefix=s3_prefix):
                    for obj in page.get("Contents", []):
                        # Remove prefix to get relative path
                        rel_path = obj["Key"]
                        if stored_prefix:
                            rel_path = rel_path[len(stored_prefix) + 1:]
                        
                        result.append(RemoteFile(
                            path=rel_path,
                            size=obj["Size"],
                            modified_time=obj["LastModified"],
                            checksum=obj.get("ETag", "").strip('"')
                        ))
                return result
            
            files = await loop.run_in_executor(None, list_objects)
            
        except ClientError as e:
            logger.error(f"Error listing S3 files: {e}")
        
        return files
    
    async def get_file_info(self, remote_path: str) -> Optional[RemoteFile]:
        """Get info about a file in S3."""
        if not self._client:
            return None
        
        try:
            s3_key = self._get_s3_key(remote_path)
            client = self._client
            
            loop = asyncio.get_event_loop()
            response = await loop.run_in_executor(
                None,
                lambda: client.head_object(
                    Bucket=self.bucket,
                    Key=s3_key
                )
            )
            
            return RemoteFile(
                path=remote_path,
                size=response["ContentLength"],
                modified_time=response["LastModified"],
                checksum=response.get("ETag", "").strip('"')
            )
            
        except ClientError:
            return None
PreviousNext