from fastapi import FastAPI, File, UploadFile, Form | |
from pydantic import BaseModel | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from PIL import Image | |
from io import BytesIO | |
import os | |
import base64 | |
app = FastAPI() | |
class ImageQuestion(BaseModel): | |
image: str | |
question: str | |
model_id = os.environ.get("MODEL_ID") | |
revision = os.environ.get("REVISION") | |
if not model_id or not revision: | |
raise ValueError("Please set MODEL_ID and REVISION environment variables.") | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, trust_remote_code=True, revision=revision | |
) | |
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision) | |
def generate_answer(image_bytes, question): | |
img = base64.b64decode(image_bytes) | |
image = Image.open(BytesIO(img)) | |
enc_image = model.encode_image(image) | |
return model.answer_question(enc_image, question, tokenizer) | |
async def get_answer(data: ImageQuestion): | |
answer = generate_answer(data.image, data.question) | |
return {"answer": answer} | |