""" Model downloader for BackgroundFX Pro. Handles downloading, caching, and verification of models. """ import os import shutil import tempfile import hashlib import requests from pathlib import Path from typing import Optional, Callable, Dict, Any, List from dataclasses import dataclass from enum import Enum import time import threading from urllib.parse import urlparse from concurrent.futures import ThreadPoolExecutor, Future import logging from .registry import ModelInfo, ModelStatus, ModelRegistry logger = logging.getLogger(__name__) class DownloadStatus(Enum): """Download status.""" PENDING = "pending" DOWNLOADING = "downloading" VERIFYING = "verifying" EXTRACTING = "extracting" COMPLETED = "completed" FAILED = "failed" CANCELLED = "cancelled" @dataclass class DownloadProgress: """Download progress information.""" model_id: str status: DownloadStatus current_bytes: int = 0 total_bytes: int = 0 speed_mbps: float = 0.0 eta_seconds: float = 0.0 error: Optional[str] = None @property def progress(self) -> float: """Get progress percentage.""" if self.total_bytes > 0: return (self.current_bytes / self.total_bytes) * 100 return 0.0 class ModelDownloader: """Handle model downloading with progress tracking and resume support.""" def __init__(self, registry: ModelRegistry, max_workers: int = 3, chunk_size: int = 8192, timeout: int = 30, max_retries: int = 3): """ Initialize model downloader. Args: registry: Model registry instance max_workers: Maximum concurrent downloads chunk_size: Download chunk size in bytes timeout: Request timeout in seconds max_retries: Maximum retry attempts """ self.registry = registry self.max_workers = max_workers self.chunk_size = chunk_size self.timeout = timeout self.max_retries = max_retries # Download management self.downloads: Dict[str, DownloadProgress] = {} self.executor = ThreadPoolExecutor(max_workers=max_workers) self.futures: Dict[str, Future] = {} self._stop_events: Dict[str, threading.Event] = {} # Cache directory self.cache_dir = registry.models_dir / ".cache" self.cache_dir.mkdir(exist_ok=True) def download_model(self, model_id: str, progress_callback: Optional[Callable[[DownloadProgress], None]] = None, force: bool = False) -> bool: """ Download a model. Args: model_id: Model ID to download progress_callback: Optional progress callback force: Force re-download even if exists Returns: True if download successful """ # Get model info model = self.registry.get_model(model_id) if not model: logger.error(f"Model not found: {model_id}") return False # Check if already downloaded if not force and model.status == ModelStatus.AVAILABLE: logger.info(f"Model already available: {model_id}") return True # Initialize progress progress = DownloadProgress( model_id=model_id, status=DownloadStatus.PENDING, total_bytes=model.file_size ) self.downloads[model_id] = progress # Create stop event self._stop_events[model_id] = threading.Event() # Submit download task future = self.executor.submit( self._download_model_task, model, progress, progress_callback, force ) self.futures[model_id] = future # Wait for completion try: return future.result() except Exception as e: logger.error(f"Download failed for {model_id}: {e}") return False def download_models_async(self, model_ids: List[str], progress_callback: Optional[Callable[[str, DownloadProgress], None]] = None, force: bool = False) -> Dict[str, Future]: """ Download multiple models asynchronously. Args: model_ids: List of model IDs progress_callback: Optional progress callback with model_id force: Force re-download Returns: Dictionary of futures """ futures = {} for model_id in model_ids: model = self.registry.get_model(model_id) if not model: logger.warning(f"Model not found: {model_id}") continue # Skip if already available if not force and model.status == ModelStatus.AVAILABLE: continue # Initialize progress progress = DownloadProgress( model_id=model_id, status=DownloadStatus.PENDING, total_bytes=model.file_size ) self.downloads[model_id] = progress # Create stop event self._stop_events[model_id] = threading.Event() # Wrapper for progress callback def progress_wrapper(p): if progress_callback: progress_callback(model_id, p) # Submit download task future = self.executor.submit( self._download_model_task, model, progress, progress_wrapper, force ) futures[model_id] = future self.futures[model_id] = future return futures def _download_model_task(self, model: ModelInfo, progress: DownloadProgress, progress_callback: Optional[Callable], force: bool) -> bool: """ Download model task. Args: model: Model information progress: Progress tracker progress_callback: Progress callback force: Force re-download Returns: True if successful """ try: # Update status progress.status = DownloadStatus.DOWNLOADING self._notify_progress(progress, progress_callback) # Try primary URL first, then mirrors urls = [model.url] + model.mirror_urls for url in urls: if self._stop_events[model.model_id].is_set(): progress.status = DownloadStatus.CANCELLED self._notify_progress(progress, progress_callback) return False try: # Download file output_path = self.registry.models_dir / model.filename success = self._download_file( url, output_path, progress, progress_callback, model.model_id ) if success: # Verify file progress.status = DownloadStatus.VERIFYING self._notify_progress(progress, progress_callback) if self._verify_download(output_path, model): # Update registry model.status = ModelStatus.AVAILABLE model.local_path = str(output_path) model.download_date = time.time() self.registry._save_registry() progress.status = DownloadStatus.COMPLETED self._notify_progress(progress, progress_callback) logger.info(f"Successfully downloaded: {model.model_id}") return True else: # Verification failed output_path.unlink(missing_ok=True) logger.warning(f"Verification failed for {model.model_id}") except Exception as e: logger.warning(f"Download failed from {url}: {e}") continue # All attempts failed progress.status = DownloadStatus.FAILED progress.error = "All download attempts failed" self._notify_progress(progress, progress_callback) return False except Exception as e: progress.status = DownloadStatus.FAILED progress.error = str(e) self._notify_progress(progress, progress_callback) logger.error(f"Download task failed: {e}") return False def _download_file(self, url: str, output_path: Path, progress: DownloadProgress, progress_callback: Optional[Callable], model_id: str) -> bool: """ Download file with resume support. Args: url: Download URL output_path: Output file path progress: Progress tracker progress_callback: Progress callback model_id: Model ID for stop event Returns: True if successful """ # Check for partial download temp_path = output_path.with_suffix('.part') resume_pos = 0 if temp_path.exists(): resume_pos = temp_path.stat().st_size logger.info(f"Resuming download from {resume_pos} bytes") # Prepare headers for resume headers = {} if resume_pos > 0: headers['Range'] = f'bytes={resume_pos}-' # Start download start_time = time.time() bytes_downloaded = resume_pos try: response = requests.get( url, headers=headers, stream=True, timeout=self.timeout ) response.raise_for_status() # Get total size if 'content-length' in response.headers: total_size = int(response.headers['content-length']) + resume_pos progress.total_bytes = total_size else: total_size = None # Download with progress mode = 'ab' if resume_pos > 0 else 'wb' with open(temp_path, mode) as f: for chunk in response.iter_content(chunk_size=self.chunk_size): # Check for cancellation if self._stop_events[model_id].is_set(): logger.info(f"Download cancelled: {model_id}") return False if chunk: f.write(chunk) bytes_downloaded += len(chunk) # Update progress progress.current_bytes = bytes_downloaded # Calculate speed and ETA elapsed = time.time() - start_time if elapsed > 0: speed_bps = (bytes_downloaded - resume_pos) / elapsed progress.speed_mbps = (speed_bps * 8) / 1_000_000 if total_size and speed_bps > 0: remaining = total_size - bytes_downloaded progress.eta_seconds = remaining / speed_bps self._notify_progress(progress, progress_callback) # Move to final location shutil.move(str(temp_path), str(output_path)) return True except requests.exceptions.RequestException as e: logger.error(f"Download error: {e}") return False except Exception as e: logger.error(f"File write error: {e}") return False def _verify_download(self, file_path: Path, model: ModelInfo) -> bool: """ Verify downloaded file. Args: file_path: Downloaded file path model: Model information Returns: True if verification passed """ # Check file exists if not file_path.exists(): return False # Check file size actual_size = file_path.stat().st_size if model.file_size > 0: size_diff = abs(actual_size - model.file_size) if size_diff > 1000: # Allow 1KB difference logger.warning(f"Size mismatch: expected {model.file_size}, got {actual_size}") return False # Check SHA256 if available if model.sha256: try: sha256 = self._calculate_sha256(file_path) if sha256 != model.sha256: logger.warning(f"SHA256 mismatch for {model.model_id}") return False except Exception as e: logger.error(f"SHA256 calculation failed: {e}") return False return True def _calculate_sha256(self, file_path: Path) -> str: """Calculate SHA256 hash of file.""" sha256_hash = hashlib.sha256() with open(file_path, "rb") as f: for byte_block in iter(lambda: f.read(self.chunk_size), b""): sha256_hash.update(byte_block) return sha256_hash.hexdigest() def _notify_progress(self, progress: DownloadProgress, callback: Optional[Callable]): """Notify progress callback.""" if callback: try: callback(progress) except Exception as e: logger.error(f"Progress callback error: {e}") def cancel_download(self, model_id: str) -> bool: """ Cancel ongoing download. Args: model_id: Model ID to cancel Returns: True if cancelled """ if model_id in self._stop_events: self._stop_events[model_id].set() # Wait for cancellation if model_id in self.futures: try: self.futures[model_id].result(timeout=5) except: pass del self.futures[model_id] # Update progress if model_id in self.downloads: self.downloads[model_id].status = DownloadStatus.CANCELLED logger.info(f"Download cancelled: {model_id}") return True return False def get_progress(self, model_id: str) -> Optional[DownloadProgress]: """Get download progress for model.""" return self.downloads.get(model_id) def get_all_progress(self) -> Dict[str, DownloadProgress]: """Get all download progress.""" return self.downloads.copy() def cleanup_partial_downloads(self): """Clean up partial download files.""" for file in self.registry.models_dir.glob("*.part"): try: file.unlink() logger.info(f"Removed partial download: {file.name}") except Exception as e: logger.error(f"Failed to remove {file}: {e}") def download_required_models(self, task: str = None, gpu_available: bool = True) -> bool: """ Download all required models for a task. Args: task: Optional task filter gpu_available: GPU availability Returns: True if all downloads successful """ # Get required models required = [] if task: # Get best model for task from .registry import ModelTask task_enum = ModelTask(task) model = self.registry.get_best_model( task_enum, require_gpu=gpu_available if gpu_available else False ) if model: required.append(model.model_id) else: # Get all essential models essential = ['rmbg-1.4', 'u2netp', 'modnet'] for model_id in essential: if self.registry.get_model(model_id): required.append(model_id) # Download models if required: logger.info(f"Downloading required models: {required}") futures = self.download_models_async(required) # Wait for completion success = True for model_id, future in futures.items(): try: if not future.result(): success = False except Exception: success = False return success return True