gabrielearmento's picture
Refactor handler.py to improve VQA model performance
83e9d74
raw
history blame contribute delete
No virus
1.04 kB
from typing import Any, Dict, List
from transformers import AutoModel, AutoTokenizer, pipeline
class EndpointHandler:
def __init__(self, path=""):
# Load the pipeline for the model
model = AutoModel.from_pretrained(
"openbmb/MiniCPM-Llama3-V-2_5-int4",
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(
"openbmb/MiniCPM-Llama3-V-2_5-int4", trust_remote_code=True
)
self.pipeline = pipeline(model=model, tokenizer=tokenizer)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
# Get the image and question from the request
image = data.get("image")
question = data.get("question")
# Perform the VQA
return self.pipeline(image, question)
# if __name__ == "__main__":
# handler = EndpointHandler()
# data = {
# "image": "https://pwm.im-cdn.it/image/1524723057/xxl.jpg",
# "question": "Describe the image:",
# }
# print(handler(data))