from fastapi import FastAPI, HTTPException from pydantic import BaseModel import torch import open_clip from mobileclip.modules.common.mobileone import reparameterize_model from PIL import Image import requests from io import BytesIO import logging try: import numpy as np print("✅ NumPy imported successfully:", np.__version__) except ImportError as e: print("❌ NumPy failed to import:", str(e)) import os # Set cache directories os.environ['HF_HOME'] = '/app/.cache' os.environ['TORCH_HOME'] = '/app/.cache/torch' os.environ['TRANSFORMERS_CACHE'] = '/app/.cache/transformers' # Create cache directories if they don't exist os.makedirs('/app/.cache', exist_ok=True) os.makedirs('/app/.cache/torch', exist_ok=True) os.makedirs('/app/.cache/transformers', exist_ok=True) # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI( title="MobileCLIP API", description="API for MobileCLIP image and text embeddings", version="1.0.0" ) # Global variables for model model = None preprocess = None tokenizer = None class TextRequest(BaseModel): text: str class ImageRequest(BaseModel): image_url: str class SimilarityRequest(BaseModel): image_url: str text: str class EmbeddingResponse(BaseModel): embedding: list class SimilarityResponse(BaseModel): similarity: float def load_model(): """Load and initialize the MobileCLIP model""" global model, preprocess, tokenizer try: logger.info("📥 Downloading MobileCLIP-S2 model...") # Explicitly set cache directory model, _, preprocess = open_clip.create_model_and_transforms( 'MobileCLIP-S2', pretrained='datacompdr', cache_dir='/app/.cache' ) logger.info("🔧 Loading tokenizer...") tokenizer = open_clip.get_tokenizer('MobileCLIP-S2') # Reparameterize for inference logger.info("⚡ Reparameterizing model for inference...") model.eval() model = reparameterize_model(model) logger.info("✅ Model loaded and optimized successfully!") except Exception as e: logger.error(f"❌ Failed to load model: {str(e)}") raise e def download_image(url: str) -> Image.Image: """Download image from URL""" try: response = requests.get(url, timeout=10) response.raise_for_status() image = Image.open(BytesIO(response.content)) return image.convert('RGB') except Exception as e: raise HTTPException(status_code=400, detail=f"Failed to download image: {str(e)}") def get_image_embedding(image: Image.Image): """Get embedding for an image""" try: image_tensor = preprocess(image).unsqueeze(0) with torch.no_grad(): image_features = model.encode_image(image_tensor) # Normalize the embedding image_features = image_features / image_features.norm(dim=-1, keepdim=True) return image_features.squeeze().cpu() except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to process image: {str(e)}") def get_text_embedding(text: str): """Get embedding for text""" try: text_tokens = tokenizer([text]) with torch.no_grad(): text_features = model.encode_text(text_tokens) # Normalize the embedding text_features = text_features / text_features.norm(dim=-1, keepdim=True) return text_features.squeeze().cpu() except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to process text: {str(e)}") def calculate_similarity(embedding1: np.ndarray, embedding2: np.ndarray) -> float: """Calculate cosine similarity between two embeddings""" return float(np.dot(embedding1, embedding2)) @app.on_event("startup") async def startup_event(): """Load model on startup""" logger.info("🚀 Starting MobileCLIP API...") logger.info("📦 Loading model - this may take 2-5 minutes...") load_model() logger.info("✅ Model loaded successfully! API is ready.") @app.get("/") async def root(): """Health check endpoint""" return {"message": "MobileCLIP API is running!", "status": "healthy"} @app.post("/image-embedding", response_model=EmbeddingResponse) async def image_embedding(request: ImageRequest): """Get embedding for an image given its URL""" if model is None: raise HTTPException(status_code=503, detail="Model not loaded") try: image = download_image(request.image_url) embedding = get_image_embedding(image) return EmbeddingResponse(embedding=embedding.tolist()) except HTTPException: raise except Exception as e: logger.error(f"Error in image_embedding: {str(e)}") raise HTTPException(status_code=500, detail="Internal server error") @app.post("/text-embedding", response_model=EmbeddingResponse) async def text_embedding(request: TextRequest): """Get embedding for text""" if model is None: raise HTTPException(status_code=503, detail="Model not loaded") try: embedding = get_text_embedding(request.text) return EmbeddingResponse(embedding=embedding.tolist()) except Exception as e: logger.error(f"Error in text_embedding: {str(e)}") raise HTTPException(status_code=500, detail="Internal server error") @app.post("/similarity", response_model=SimilarityResponse) async def similarity(request: SimilarityRequest): """Calculate similarity between image and text""" if model is None: raise HTTPException(status_code=503, detail="Model not loaded") try: image = download_image(request.image_url) image_embedding = get_image_embedding(image) text_embedding = get_text_embedding(request.text) similarity_score = calculate_similarity(image_embedding, text_embedding) return SimilarityResponse(similarity=similarity_score) except HTTPException: raise except Exception as e: logger.error(f"Error in similarity: {str(e)}") raise HTTPException(status_code=500, detail="Internal server error") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)