NsfwExtension / server /caption_server.py
Ravi Pandey
Docker deploy
2fbf461
raw
history blame
3.06 kB
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from transformers import pipeline
from PIL import Image
import requests
from io import BytesIO
import uvicorn
from sentence_transformers import SentenceTransformer, util
app = FastAPI()
# CORS setup
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize the BLIP and GIT models for image captioning
blip_model_name = "Salesforce/blip-image-captioning-large"
git_model_name = "microsoft/git-large-coco"
blip_pipeline = pipeline("image-to-text", model=blip_model_name)
git_coco_pipeline = pipeline("image-to-text", model=git_model_name)
# Initialize the model for semantic similarity
model = SentenceTransformer('all-MiniLM-L6-v2')
classification_server_url = 'http://localhost:8001/classify_text'
def calculate_similarity(caption1, caption2):
embeddings = model.encode([caption1, caption2], convert_to_tensor=True)
similarity = util.pytorch_cos_sim(embeddings[0], embeddings[1])
return similarity.item()
def select_best_caption(caption1, caption2):
similarity_threshold = 0.6486401639
similarity_score = calculate_similarity(caption1, caption2)
if similarity_score > similarity_threshold:
print("Similarity score:",similarity_score)
return caption1 if len(caption1) > len(caption2) else caption2
else:
print(f"Git is selected with score :{similarity_score}")
return caption1
@app.post("/generate_caption_and_classify")
async def generate_caption_and_classify(request: Request):
try:
body = await request.json()
image_urls = body.get('image_urls', [])
results = []
for index, image_url in enumerate(image_urls):
response = requests.get(image_url)
image = Image.open(BytesIO(response.content))
# Generate captions from both models
caption_git = git_coco_pipeline(image, max_new_tokens=50)[0]['generated_text']
print("Caption generated by GIT:" , caption_git)
caption_blip = blip_pipeline(image, max_new_tokens=50)[0]['generated_text']
print("Caption generated by BLIP:" , caption_blip)
# Select the best caption
best_caption = select_best_caption(caption_git, caption_blip)
# Classify the best caption
classification_response = requests.post(
classification_server_url,
json={"texts": [best_caption]}
)
classification_result = classification_response.json()[0]
results.append({
"image_url": image_url,
"caption": best_caption,
"classification": classification_result
})
return results
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)