Spaces:
Runtime error
Runtime error
File size: 3,081 Bytes
0045e3d 19ebbaf 0045e3d 050e095 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from transformers import pipeline
from PIL import Image
import requests
from io import BytesIO
import uvicorn
from sentence_transformers import SentenceTransformer, util
app = FastAPI()
# CORS setup
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize the BLIP and GIT models for image captioning
blip_model_name = "Salesforce/blip-image-captioning-large"
git_model_name = "microsoft/git-large-coco"
blip_pipeline = pipeline("image-to-text", model=blip_model_name )
git_coco_pipeline = pipeline("image-to-text", model=git_model_name)
# Initialize the model for semantic similarity
model = SentenceTransformer('all-MiniLM-L6-v2')
classification_server_url = 'https://raviipandey-classification.hf.space/classify_text'
def calculate_similarity(caption1, caption2):
embeddings = model.encode([caption1, caption2], convert_to_tensor=True)
similarity = util.pytorch_cos_sim(embeddings[0], embeddings[1])
return similarity.item()
def select_best_caption(caption1, caption2):
similarity_threshold = 0.6486401639
similarity_score = calculate_similarity(caption1, caption2)
if similarity_score > similarity_threshold:
print("Similarity score:",similarity_score)
return caption1 if len(caption1) > len(caption2) else caption2
else:
print(f"Git is selected with score :{similarity_score}")
return caption1
@app.post("/generate_caption_and_classify")
async def generate_caption_and_classify(request: Request):
try:
body = await request.json()
image_urls = body.get('image_urls', [])
results = []
for index, image_url in enumerate(image_urls):
response = requests.get(image_url)
image = Image.open(BytesIO(response.content))
# Generate captions from both models
caption_git = git_coco_pipeline(image, max_new_tokens=50)[0]['generated_text']
print("Caption generated by GIT:" , caption_git)
caption_blip = blip_pipeline(image, max_new_tokens=50)[0]['generated_text']
print("Caption generated by BLIP:" , caption_blip)
# Select the best caption
best_caption = select_best_caption(caption_git, caption_blip)
# Classify the best caption
classification_response = requests.post(
classification_server_url,
json={"texts": [best_caption]}
)
classification_result = classification_response.json()[0]
results.append({
"image_url": image_url,
"caption": best_caption,
"classification": classification_result
})
return results
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)
|