from fastapi import FastAPI, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from transformers import pipeline from PIL import Image import requests from io import BytesIO import uvicorn from sentence_transformers import SentenceTransformer, util app = FastAPI() # CORS setup origins = ["*"] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Initialize the BLIP and GIT models for image captioning blip_model_name = "Salesforce/blip-image-captioning-large" git_model_name = "microsoft/git-large-coco" blip_pipeline = pipeline("image-to-text", model=blip_model_name ) git_coco_pipeline = pipeline("image-to-text", model=git_model_name) # Initialize the model for semantic similarity model = SentenceTransformer('all-MiniLM-L6-v2') classification_server_url = 'https://raviipandey-classification.hf.space/classify_text' def calculate_similarity(caption1, caption2): embeddings = model.encode([caption1, caption2], convert_to_tensor=True) similarity = util.pytorch_cos_sim(embeddings[0], embeddings[1]) return similarity.item() def select_best_caption(caption1, caption2): similarity_threshold = 0.6486401639 similarity_score = calculate_similarity(caption1, caption2) if similarity_score > similarity_threshold: print("Similarity score:",similarity_score) return caption1 if len(caption1) > len(caption2) else caption2 else: print(f"Git is selected with score :{similarity_score}") return caption1 @app.post("/generate_caption_and_classify") async def generate_caption_and_classify(request: Request): try: body = await request.json() image_urls = body.get('image_urls', []) results = [] for index, image_url in enumerate(image_urls): response = requests.get(image_url) image = Image.open(BytesIO(response.content)) # Generate captions from both models caption_git = git_coco_pipeline(image, max_new_tokens=50)[0]['generated_text'] print("Caption generated by GIT:" , caption_git) caption_blip = blip_pipeline(image, max_new_tokens=50)[0]['generated_text'] print("Caption generated by BLIP:" , caption_blip) # Select the best caption best_caption = select_best_caption(caption_git, caption_blip) # Classify the best caption classification_response = requests.post( classification_server_url, json={"texts": [best_caption]} ) classification_result = classification_response.json()[0] results.append({ "image_url": image_url, "caption": best_caption, "classification": classification_result }) return results except Exception as e: raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)