VQA-Server / app.py
Matanew1's picture
update
201bd55
raw
history blame
1.52 kB
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.responses import JSONResponse
from transformers import BlipProcessor, TFBlipForQuestionAnswering
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize the processor and model manually
processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large")
model = TFBlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-capfilt-large")
@app.post('/answer_question')
async def answer_question(image: UploadFile = File(...), question: str = Form(...)):
image_path = 'temp_image.jpg'
with open(image_path, 'wb') as f:
f.write(await image.read())
# Open the image using PIL
pil_image = Image.open(image_path)
# Process the image and question
inputs = processor(images=pil_image, text=question, return_tensors="tf")
pixel_values = inputs["pixel_values"]
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values)
answer = processor.decode(outputs[0], skip_special_tokens=True)
return JSONResponse(content={'answer': answer})
if __name__ == '__main__':
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8080)