Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, HTTPException, Request | |
from fastapi.middleware.cors import CORSMiddleware | |
from transformers import pipeline | |
from PIL import Image | |
import requests | |
from io import BytesIO | |
import uvicorn | |
from sentence_transformers import SentenceTransformer, util | |
app = FastAPI() | |
# CORS setup | |
origins = ["*"] | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=origins, | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Initialize the BLIP and GIT models for image captioning | |
blip_model_name = "Salesforce/blip-image-captioning-large" | |
git_model_name = "microsoft/git-large-coco" | |
blip_pipeline = pipeline("image-to-text", model=blip_model_name ) | |
git_coco_pipeline = pipeline("image-to-text", model=git_model_name) | |
# Initialize the model for semantic similarity | |
model = SentenceTransformer('all-MiniLM-L6-v2') | |
classification_server_url = 'http://localhost:8001/classify_text' | |
def calculate_similarity(caption1, caption2): | |
embeddings = model.encode([caption1, caption2], convert_to_tensor=True) | |
similarity = util.pytorch_cos_sim(embeddings[0], embeddings[1]) | |
return similarity.item() | |
def select_best_caption(caption1, caption2): | |
similarity_threshold = 0.6486401639 | |
similarity_score = calculate_similarity(caption1, caption2) | |
if similarity_score > similarity_threshold: | |
print("Similarity score:",similarity_score) | |
return caption1 if len(caption1) > len(caption2) else caption2 | |
else: | |
print(f"Git is selected with score :{similarity_score}") | |
return caption1 | |
async def generate_caption_and_classify(request: Request): | |
try: | |
body = await request.json() | |
image_urls = body.get('image_urls', []) | |
results = [] | |
for index, image_url in enumerate(image_urls): | |
response = requests.get(image_url) | |
image = Image.open(BytesIO(response.content)) | |
# Generate captions from both models | |
caption_git = git_coco_pipeline(image, max_new_tokens=50)[0]['generated_text'] | |
print("Caption generated by GIT:" , caption_git) | |
caption_blip = blip_pipeline(image, max_new_tokens=50)[0]['generated_text'] | |
print("Caption generated by BLIP:" , caption_blip) | |
# Select the best caption | |
best_caption = select_best_caption(caption_git, caption_blip) | |
# Classify the best caption | |
classification_response = requests.post( | |
classification_server_url, | |
json={"texts": [best_caption]} | |
) | |
classification_result = classification_response.json()[0] | |
results.append({ | |
"image_url": image_url, | |
"caption": best_caption, | |
"classification": classification_result | |
}) | |
return results | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |