import base64 import subprocess from tempfile import NamedTemporaryFile from fastapi import FastAPI, HTTPException from PIL import Image from pydantic import BaseModel from transformers import AutoModelForCausalLM, AutoTokenizer app = FastAPI() # define request body class RequestData(BaseModel): prompt: str image: str def load_model(): model_id = "vikhyatk/moondream2" revision = "2024-08-26" model = AutoModelForCausalLM.from_pretrained( model_id, trust_remote_code=True, revision=revision ) tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision) return model, tokenizer MODEL, TOKENIZER = load_model() print("INFO: Model loaded successfully!") @app.get("/") def greet_json(): return {"Hello": "World!"} @app.post("/query") def query(data: RequestData): prompt = data.prompt image = data.image try: # decode base64 to image image = base64.b64decode(image) with NamedTemporaryFile(delete=True, suffix=".png") as temp_image: temp_image.write(image) temp_image.flush() image = Image.open(temp_image.name) enc_image = MODEL.encode_image(image) response = MODEL.answer_question(enc_image, str(prompt), TOKENIZER) return {"response": str(response)} except Exception as e: raise HTTPException(status_code=500, detail=str(e))