Detect PyTorch models and provide clear error message - Repository contains generator.pt (PyTorch) not FastAI model
0454a91
| """ | |
| Colorize model wrapper using FastAI GAN Colorization Model | |
| Hammad712/GAN-Colorization-Model | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| import os | |
| from typing import Tuple | |
| # Ensure cache directory is set before any HF imports | |
| # (main.py should have set these, but ensure they're set here too) | |
| cache_dir = os.environ.get("HF_HOME", "/tmp/hf_cache") | |
| os.environ["HF_HOME"] = cache_dir | |
| os.environ["TRANSFORMERS_CACHE"] = cache_dir | |
| os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir | |
| os.environ["HF_HUB_CACHE"] = cache_dir | |
| os.environ["XDG_CACHE_HOME"] = cache_dir | |
| import torch | |
| from PIL import Image | |
| from fastai.vision.all import * | |
| from huggingface_hub import from_pretrained_fastai, hf_hub_download, list_repo_files | |
| from app.config import settings | |
| logger = logging.getLogger(__name__) | |
| def _ensure_cache_dir() -> str: | |
| cache_dir = os.environ.get("HF_HOME", "/tmp/hf_cache") | |
| try: | |
| os.makedirs(cache_dir, exist_ok=True) | |
| except Exception as exc: | |
| logger.warning("Could not create cache directory %s: %s", cache_dir, exc) | |
| # Ensure all cache env vars point to this directory | |
| os.environ["HF_HOME"] = cache_dir | |
| os.environ["TRANSFORMERS_CACHE"] = cache_dir | |
| os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir | |
| os.environ["HF_HUB_CACHE"] = cache_dir | |
| os.environ["XDG_CACHE_HOME"] = cache_dir | |
| return cache_dir | |
| class ColorizeModel: | |
| """Colorization model using FastAI GAN model.""" | |
| def __init__(self, model_id: str | None = None) -> None: | |
| self.cache_dir = _ensure_cache_dir() | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| os.environ.setdefault("OMP_NUM_THREADS", "1") | |
| # Use FastAI model ID from config or default | |
| self.model_id = model_id or settings.MODEL_ID | |
| self.output_caption = getattr(settings, "FASTAI_OUTPUT_CAPTION", "Colorized using GAN-Colorization-Model") | |
| logger.info("Loading FastAI GAN Colorization model: %s", self.model_id) | |
| try: | |
| # Try using from_pretrained_fastai first | |
| try: | |
| self.learn = from_pretrained_fastai(self.model_id) | |
| logger.info("FastAI GAN Colorization model loaded successfully via from_pretrained_fastai") | |
| except Exception as e1: | |
| logger.warning("from_pretrained_fastai failed: %s. Trying manual download...", str(e1)) | |
| # Fallback: manually download and load the model file | |
| # First, list files in the repository to find the actual model file | |
| hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") | |
| try: | |
| repo_files = list_repo_files(repo_id=self.model_id, token=hf_token) | |
| logger.info("Repository files: %s", repo_files) | |
| # Look for .pkl files (FastAI) or .pt files (PyTorch) | |
| pkl_files = [f for f in repo_files if f.endswith('.pkl')] | |
| pt_files = [f for f in repo_files if f.endswith('.pt')] | |
| if pkl_files: | |
| model_filenames = pkl_files | |
| logger.info("Found .pkl files in repository: %s", pkl_files) | |
| model_type = "fastai" | |
| elif pt_files: | |
| model_filenames = pt_files | |
| logger.info("Found .pt files in repository: %s", pt_files) | |
| model_type = "pytorch" | |
| else: | |
| # Fallback to common filenames | |
| model_filenames = ["model.pkl", "export.pkl", "learner.pkl", "model_export.pkl", "generator.pt"] | |
| model_type = "fastai" # Default assumption | |
| except Exception as list_err: | |
| logger.warning("Could not list repository files: %s. Trying common filenames...", str(list_err)) | |
| # Fallback to common filenames | |
| model_filenames = ["model.pkl", "export.pkl", "learner.pkl", "model_export.pkl", "generator.pt"] | |
| model_type = "fastai" | |
| model_path = None | |
| for filename in model_filenames: | |
| try: | |
| model_path = hf_hub_download( | |
| repo_id=self.model_id, | |
| filename=filename, | |
| cache_dir=self.cache_dir, | |
| token=hf_token | |
| ) | |
| logger.info("Found model file: %s", filename) | |
| # Determine model type from extension | |
| if filename.endswith('.pt'): | |
| model_type = "pytorch" | |
| elif filename.endswith('.pkl'): | |
| model_type = "fastai" | |
| break | |
| except Exception as dl_err: | |
| logger.debug("Failed to download %s: %s", filename, str(dl_err)) | |
| continue | |
| if model_path and os.path.exists(model_path): | |
| if model_type == "pytorch": | |
| # Load PyTorch model - this is a GAN generator | |
| logger.info("Loading PyTorch model from: %s", model_path) | |
| # Note: This requires knowing the model architecture | |
| # For now, we'll try to load it and see if it works | |
| logger.warning("PyTorch model loading not fully implemented. This model may not work correctly.") | |
| raise RuntimeError( | |
| f"Repository '{self.model_id}' contains a PyTorch model (generator.pt), " | |
| f"not a FastAI model. FastAI models must be .pkl files created with FastAI's export. " | |
| f"Please use a FastAI-compatible colorization model, or switch to a different model backend." | |
| ) | |
| else: | |
| # Load the model using FastAI's load_learner | |
| logger.info("Loading FastAI model from: %s", model_path) | |
| self.learn = load_learner(model_path) | |
| logger.info("FastAI GAN Colorization model loaded successfully from %s", model_path) | |
| else: | |
| # If no model file found, raise error with more details | |
| raise RuntimeError( | |
| f"Could not find model file in repository '{self.model_id}'. " | |
| f"Tried: {', '.join(model_filenames)}. " | |
| f"Original error: {str(e1)}" | |
| ) | |
| except Exception as e: | |
| error_msg = ( | |
| f"Failed to load FastAI model '{self.model_id}'. " | |
| f"Error: {str(e)}\n" | |
| f"Please check the MODEL_ID environment variable. " | |
| f"Default model: 'Hammad712/GAN-Colorization-Model'" | |
| ) | |
| logger.error(error_msg) | |
| raise RuntimeError(error_msg) from e | |
| def colorize(self, image: Image.Image, num_inference_steps: int | None = None) -> Tuple[Image.Image, str]: | |
| """ | |
| Colorize a grayscale or color image using FastAI GAN model. | |
| Args: | |
| image: PIL Image (grayscale or color) | |
| num_inference_steps: Ignored for FastAI model (kept for API compatibility) | |
| Returns: | |
| Tuple of (colorized PIL Image, caption string) | |
| """ | |
| try: | |
| original_size = image.size | |
| # Ensure image is RGB | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| # FastAI predict expects a PIL Image | |
| logger.info("Running FastAI GAN colorization...") | |
| # Use the model's predict method | |
| # FastAI predict for image models typically returns the output image directly | |
| # or as the first element of a tuple | |
| prediction = self.learn.predict(image) | |
| # Extract the colorized image from prediction | |
| # Handle different return types from FastAI | |
| if isinstance(prediction, (list, tuple)): | |
| # If tuple/list, first element is usually the prediction | |
| colorized = prediction[0] if len(prediction) > 0 else image | |
| else: | |
| # Direct return | |
| colorized = prediction | |
| # Ensure we have a PIL Image | |
| if not isinstance(colorized, Image.Image): | |
| # If it's a tensor, convert to PIL | |
| if isinstance(colorized, torch.Tensor): | |
| # Handle tensor conversion | |
| if colorized.dim() == 4: | |
| colorized = colorized[0] # Remove batch dimension | |
| if colorized.dim() == 3: | |
| # Convert CHW to HWC and denormalize if needed | |
| colorized = colorized.permute(1, 2, 0).cpu() | |
| # Clamp values to [0, 1] if float, or [0, 255] if uint8 | |
| if colorized.dtype == torch.float32 or colorized.dtype == torch.float16: | |
| colorized = torch.clamp(colorized, 0, 1) | |
| colorized = (colorized * 255).byte() | |
| colorized = Image.fromarray(colorized.numpy(), 'RGB') | |
| else: | |
| raise ValueError(f"Unexpected tensor shape: {colorized.shape}") | |
| else: | |
| raise ValueError(f"Unexpected prediction type: {type(colorized)}") | |
| # Ensure RGB mode | |
| if colorized.mode != "RGB": | |
| colorized = colorized.convert("RGB") | |
| # Resize back to original size if needed | |
| if colorized.size != original_size: | |
| colorized = colorized.resize(original_size, Image.Resampling.LANCZOS) | |
| logger.info("Colorization completed successfully") | |
| return colorized, self.output_caption | |
| except Exception as e: | |
| logger.error("Error during colorization: %s", str(e)) | |
| raise RuntimeError(f"Colorization failed: {str(e)}") from e | |