Spaces:
Running
Running
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
import torch | |
import open_clip | |
from mobileclip.modules.common.mobileone import reparameterize_model | |
from PIL import Image | |
import requests | |
from io import BytesIO | |
import logging | |
try: | |
import numpy as np | |
print("✅ NumPy imported successfully:", np.__version__) | |
except ImportError as e: | |
print("❌ NumPy failed to import:", str(e)) | |
import os | |
# Set cache directories | |
os.environ['HF_HOME'] = '/app/.cache' | |
os.environ['TORCH_HOME'] = '/app/.cache/torch' | |
os.environ['TRANSFORMERS_CACHE'] = '/app/.cache/transformers' | |
# Create cache directories if they don't exist | |
os.makedirs('/app/.cache', exist_ok=True) | |
os.makedirs('/app/.cache/torch', exist_ok=True) | |
os.makedirs('/app/.cache/transformers', exist_ok=True) | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
app = FastAPI( | |
title="MobileCLIP API", | |
description="API for MobileCLIP image and text embeddings", | |
version="1.0.0" | |
) | |
# Global variables for model | |
model = None | |
preprocess = None | |
tokenizer = None | |
class TextRequest(BaseModel): | |
text: str | |
class ImageRequest(BaseModel): | |
image_url: str | |
class SimilarityRequest(BaseModel): | |
image_url: str | |
text: str | |
class EmbeddingResponse(BaseModel): | |
embedding: list | |
class SimilarityResponse(BaseModel): | |
similarity: float | |
def load_model(): | |
"""Load and initialize the MobileCLIP model""" | |
global model, preprocess, tokenizer | |
try: | |
logger.info("📥 Downloading MobileCLIP-S2 model...") | |
# Explicitly set cache directory | |
model, _, preprocess = open_clip.create_model_and_transforms( | |
'MobileCLIP-S2', | |
pretrained='datacompdr', | |
cache_dir='/app/.cache' | |
) | |
logger.info("🔧 Loading tokenizer...") | |
tokenizer = open_clip.get_tokenizer('MobileCLIP-S2') | |
# Reparameterize for inference | |
logger.info("⚡ Reparameterizing model for inference...") | |
model.eval() | |
model = reparameterize_model(model) | |
logger.info("✅ Model loaded and optimized successfully!") | |
except Exception as e: | |
logger.error(f"❌ Failed to load model: {str(e)}") | |
raise e | |
def download_image(url: str) -> Image.Image: | |
"""Download image from URL""" | |
try: | |
response = requests.get(url, timeout=10) | |
response.raise_for_status() | |
image = Image.open(BytesIO(response.content)) | |
return image.convert('RGB') | |
except Exception as e: | |
raise HTTPException(status_code=400, detail=f"Failed to download image: {str(e)}") | |
def get_image_embedding(image: Image.Image): | |
"""Get embedding for an image""" | |
try: | |
image_tensor = preprocess(image).unsqueeze(0) | |
with torch.no_grad(): | |
image_features = model.encode_image(image_tensor) | |
# Normalize the embedding | |
image_features = image_features / image_features.norm(dim=-1, keepdim=True) | |
return image_features.squeeze().cpu() | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Failed to process image: {str(e)}") | |
def get_text_embedding(text: str): | |
"""Get embedding for text""" | |
try: | |
text_tokens = tokenizer([text]) | |
with torch.no_grad(): | |
text_features = model.encode_text(text_tokens) | |
# Normalize the embedding | |
text_features = text_features / text_features.norm(dim=-1, keepdim=True) | |
return text_features.squeeze().cpu() | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Failed to process text: {str(e)}") | |
def calculate_similarity(embedding1: np.ndarray, embedding2: np.ndarray) -> float: | |
"""Calculate cosine similarity between two embeddings""" | |
return float(np.dot(embedding1, embedding2)) | |
async def startup_event(): | |
"""Load model on startup""" | |
logger.info("🚀 Starting MobileCLIP API...") | |
logger.info("📦 Loading model - this may take 2-5 minutes...") | |
load_model() | |
logger.info("✅ Model loaded successfully! API is ready.") | |
async def root(): | |
"""Health check endpoint""" | |
return {"message": "MobileCLIP API is running!", "status": "healthy"} | |
async def image_embedding(request: ImageRequest): | |
"""Get embedding for an image given its URL""" | |
if model is None: | |
raise HTTPException(status_code=503, detail="Model not loaded") | |
try: | |
image = download_image(request.image_url) | |
embedding = get_image_embedding(image) | |
return EmbeddingResponse(embedding=embedding.tolist()) | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.error(f"Error in image_embedding: {str(e)}") | |
raise HTTPException(status_code=500, detail="Internal server error") | |
async def text_embedding(request: TextRequest): | |
"""Get embedding for text""" | |
if model is None: | |
raise HTTPException(status_code=503, detail="Model not loaded") | |
try: | |
embedding = get_text_embedding(request.text) | |
return EmbeddingResponse(embedding=embedding.tolist()) | |
except Exception as e: | |
logger.error(f"Error in text_embedding: {str(e)}") | |
raise HTTPException(status_code=500, detail="Internal server error") | |
async def similarity(request: SimilarityRequest): | |
"""Calculate similarity between image and text""" | |
if model is None: | |
raise HTTPException(status_code=503, detail="Model not loaded") | |
try: | |
image = download_image(request.image_url) | |
image_embedding = get_image_embedding(image) | |
text_embedding = get_text_embedding(request.text) | |
similarity_score = calculate_similarity(image_embedding, text_embedding) | |
return SimilarityResponse(similarity=similarity_score) | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.error(f"Error in similarity: {str(e)}") | |
raise HTTPException(status_code=500, detail="Internal server error") | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) |