new-test-model / handler.py
jeff-RQ's picture
Update handler.py
2a7e830
raw
history blame
1.08 kB
from typing import Any, Dict
from transformers import Blip2Processor, Blip2ForConditionalGeneration
import io
from PIL import Image
import base64
import torch
class EndpointHandler:
def __init__(self, path=""):
# load model and processor from path
self.processor = Blip2Processor.from_pretrained(path)
self.model = Blip2ForConditionalGeneration.from_pretrained(path, torch_dtype=torch.float16)
self.device = "cuda"
self.model.to(self.device)
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
# process input
data = data.pop("inputs", data)
text = data.pop("text", data)
image_string = base64.b64decode(data["image"])
image = Image.open(io.BytesIO(image_string))
inputs = self.processor(images=image, text=text, return_tensors="pt").to(self.device, torch.float16)
generated_ids = self.model.generate(**inputs)
generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
return [{"answer": generated_text}]