Captioning / src /api /app.py
Mokhtar
Deploying backend code
e4721a6
import io
import os
import torch
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
from transformers import GPT2Tokenizer
# Adjust imports based on your folder structure
from config.config import config
from src.models.model import get_model
from src.preprocessing.transforms import get_transforms
app = FastAPI(title="Object Captioning LLM API")
# --- CORS CONFIGURATION (Crucial for Vercel) ---
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # For production, replace "*" with your Vercel URL
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load Model
print(f"Loading model: {config.MODEL_TYPE} on {config.DEVICE}...")
device = config.DEVICE
model = get_model(config).to(device)
# Legacy support for ResNetGPT2
if config.MODEL_TYPE == "resnet_gpt2":
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
# Use a relative path or ensure this file is uploaded to the Docker container
model_path = os.path.join(config.MODEL_SAVE_DIR, "best_model_llm.pth")
if os.path.exists(model_path):
model.load_state_dict(torch.load(model_path, map_location=device))
print("Loaded trained custom model.")
else:
print("Warning: No trained model found for ResNetGPT2. Using random weights.")
else:
tokenizer = None
model.eval()
transform = get_transforms(train=False)
@app.get("/")
def home():
return {"message": "Image Captioning API is running. Send POST requests to /predict"}
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
try:
# Read Image
image_data = await file.read()
image = Image.open(io.BytesIO(image_data)).convert("RGB")
# Generate Caption
if config.MODEL_TYPE == "resnet_gpt2":
img_tensor = transform(image).to(device)
# Ensure generate_caption handles the tensor/tokenizer correctly
caption = model.generate_caption(img_tensor, tokenizer)
else:
# SOTA models (BLIP/ViT) take the PIL image directly
caption = model.generate_caption(image)
return {
"caption": caption
}
except Exception as e:
print(f"Error during prediction: {e}")
raise HTTPException(status_code=500, detail=str(e))