Spaces:
Sleeping
Sleeping
| import io | |
| import os | |
| import torch | |
| from fastapi import FastAPI, UploadFile, File, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from PIL import Image | |
| from transformers import GPT2Tokenizer | |
| # Adjust imports based on your folder structure | |
| from config.config import config | |
| from src.models.model import get_model | |
| from src.preprocessing.transforms import get_transforms | |
| app = FastAPI(title="Object Captioning LLM API") | |
| # --- CORS CONFIGURATION (Crucial for Vercel) --- | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # For production, replace "*" with your Vercel URL | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Load Model | |
| print(f"Loading model: {config.MODEL_TYPE} on {config.DEVICE}...") | |
| device = config.DEVICE | |
| model = get_model(config).to(device) | |
| # Legacy support for ResNetGPT2 | |
| if config.MODEL_TYPE == "resnet_gpt2": | |
| tokenizer = GPT2Tokenizer.from_pretrained('gpt2') | |
| # Use a relative path or ensure this file is uploaded to the Docker container | |
| model_path = os.path.join(config.MODEL_SAVE_DIR, "best_model_llm.pth") | |
| if os.path.exists(model_path): | |
| model.load_state_dict(torch.load(model_path, map_location=device)) | |
| print("Loaded trained custom model.") | |
| else: | |
| print("Warning: No trained model found for ResNetGPT2. Using random weights.") | |
| else: | |
| tokenizer = None | |
| model.eval() | |
| transform = get_transforms(train=False) | |
| def home(): | |
| return {"message": "Image Captioning API is running. Send POST requests to /predict"} | |
| async def predict(file: UploadFile = File(...)): | |
| try: | |
| # Read Image | |
| image_data = await file.read() | |
| image = Image.open(io.BytesIO(image_data)).convert("RGB") | |
| # Generate Caption | |
| if config.MODEL_TYPE == "resnet_gpt2": | |
| img_tensor = transform(image).to(device) | |
| # Ensure generate_caption handles the tensor/tokenizer correctly | |
| caption = model.generate_caption(img_tensor, tokenizer) | |
| else: | |
| # SOTA models (BLIP/ViT) take the PIL image directly | |
| caption = model.generate_caption(image) | |
| return { | |
| "caption": caption | |
| } | |
| except Exception as e: | |
| print(f"Error during prediction: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) |