images / main.py
badalsahani's picture
Update main.py
db84333 verified
from fastapi import FastAPI, File, UploadFile, Form
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
from PIL import Image
from io import BytesIO
import os
import base64
app = FastAPI()
class ImageQuestion(BaseModel):
image: str
question: str
model_id = os.environ.get("MODEL_ID")
revision = os.environ.get("REVISION")
if not model_id or not revision:
raise ValueError("Please set MODEL_ID and REVISION environment variables.")
model = AutoModelForCausalLM.from_pretrained(
model_id, trust_remote_code=True, revision=revision
)
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision)
def generate_answer(image_bytes, question):
img = base64.b64decode(image_bytes)
image = Image.open(BytesIO(img))
enc_image = model.encode_image(image)
return model.answer_question(enc_image, question, tokenizer)
@app.post("/get_answer/")
async def get_answer(data: ImageQuestion):
answer = generate_answer(data.image, data.question)
return {"answer": answer}