MobileCLIP / app.py
101Frost's picture
Update app.py
0c30b5c verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
import open_clip
from mobileclip.modules.common.mobileone import reparameterize_model
from PIL import Image
import requests
from io import BytesIO
import logging
try:
import numpy as np
print("✅ NumPy imported successfully:", np.__version__)
except ImportError as e:
print("❌ NumPy failed to import:", str(e))
import os
# Set cache directories
os.environ['HF_HOME'] = '/app/.cache'
os.environ['TORCH_HOME'] = '/app/.cache/torch'
os.environ['TRANSFORMERS_CACHE'] = '/app/.cache/transformers'
# Create cache directories if they don't exist
os.makedirs('/app/.cache', exist_ok=True)
os.makedirs('/app/.cache/torch', exist_ok=True)
os.makedirs('/app/.cache/transformers', exist_ok=True)
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(
title="MobileCLIP API",
description="API for MobileCLIP image and text embeddings",
version="1.0.0"
)
# Global variables for model
model = None
preprocess = None
tokenizer = None
class TextRequest(BaseModel):
text: str
class ImageRequest(BaseModel):
image_url: str
class SimilarityRequest(BaseModel):
image_url: str
text: str
class EmbeddingResponse(BaseModel):
embedding: list
class SimilarityResponse(BaseModel):
similarity: float
def load_model():
"""Load and initialize the MobileCLIP model"""
global model, preprocess, tokenizer
try:
logger.info("📥 Downloading MobileCLIP-S2 model...")
# Explicitly set cache directory
model, _, preprocess = open_clip.create_model_and_transforms(
'MobileCLIP-S2',
pretrained='datacompdr',
cache_dir='/app/.cache'
)
logger.info("🔧 Loading tokenizer...")
tokenizer = open_clip.get_tokenizer('MobileCLIP-S2')
# Reparameterize for inference
logger.info("⚡ Reparameterizing model for inference...")
model.eval()
model = reparameterize_model(model)
logger.info("✅ Model loaded and optimized successfully!")
except Exception as e:
logger.error(f"❌ Failed to load model: {str(e)}")
raise e
def download_image(url: str) -> Image.Image:
"""Download image from URL"""
try:
response = requests.get(url, timeout=10)
response.raise_for_status()
image = Image.open(BytesIO(response.content))
return image.convert('RGB')
except Exception as e:
raise HTTPException(status_code=400, detail=f"Failed to download image: {str(e)}")
def get_image_embedding(image: Image.Image):
"""Get embedding for an image"""
try:
image_tensor = preprocess(image).unsqueeze(0)
with torch.no_grad():
image_features = model.encode_image(image_tensor)
# Normalize the embedding
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
return image_features.squeeze().cpu()
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to process image: {str(e)}")
def get_text_embedding(text: str):
"""Get embedding for text"""
try:
text_tokens = tokenizer([text])
with torch.no_grad():
text_features = model.encode_text(text_tokens)
# Normalize the embedding
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
return text_features.squeeze().cpu()
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to process text: {str(e)}")
def calculate_similarity(embedding1: np.ndarray, embedding2: np.ndarray) -> float:
"""Calculate cosine similarity between two embeddings"""
return float(np.dot(embedding1, embedding2))
@app.on_event("startup")
async def startup_event():
"""Load model on startup"""
logger.info("🚀 Starting MobileCLIP API...")
logger.info("📦 Loading model - this may take 2-5 minutes...")
load_model()
logger.info("✅ Model loaded successfully! API is ready.")
@app.get("/")
async def root():
"""Health check endpoint"""
return {"message": "MobileCLIP API is running!", "status": "healthy"}
@app.post("/image-embedding", response_model=EmbeddingResponse)
async def image_embedding(request: ImageRequest):
"""Get embedding for an image given its URL"""
if model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
image = download_image(request.image_url)
embedding = get_image_embedding(image)
return EmbeddingResponse(embedding=embedding.tolist())
except HTTPException:
raise
except Exception as e:
logger.error(f"Error in image_embedding: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
@app.post("/text-embedding", response_model=EmbeddingResponse)
async def text_embedding(request: TextRequest):
"""Get embedding for text"""
if model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
embedding = get_text_embedding(request.text)
return EmbeddingResponse(embedding=embedding.tolist())
except Exception as e:
logger.error(f"Error in text_embedding: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
@app.post("/similarity", response_model=SimilarityResponse)
async def similarity(request: SimilarityRequest):
"""Calculate similarity between image and text"""
if model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
image = download_image(request.image_url)
image_embedding = get_image_embedding(image)
text_embedding = get_text_embedding(request.text)
similarity_score = calculate_similarity(image_embedding, text_embedding)
return SimilarityResponse(similarity=similarity_score)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error in similarity: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)