blip2-flan-t5-xxl / handler.py
merve's picture
merve HF staff
Duplicate from merve/blip2-flan-t5-xxl
49330f5
raw
history blame
1.36 kB
from typing import Dict, List, Any
from transformers import AutoProcessor, Blip2ForConditionalGeneration
import base64
from io import BytesIO
from PIL import Image
import string
import torch
class EndpointHandler:
def __init__(self, path=""):
# load model and processor from path
self.processor = AutoProcessor.from_pretrained(path)
self.model = Blip2ForConditionalGeneration.from_pretrained(path, device_map="auto", load_in_4bit=True)
def __call__(self, data):
"""
Args:
inputs:
Dict of image and text inputs.
"""
# process input
inputs = data.pop("inputs", data)
image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
inputs = self.processor(images=image, text=inputs["text"], return_tensors="pt").to("cuda", torch.float16)
generated_ids = self.model.generate(
**inputs,
temperature=1.0,
length_penalty=1.0,
repetition_penalty=1.5,
max_length=30,
min_length=1,
num_beams=5,
top_p=0.9,
)
result = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
if result and result[-1] not in string.punctuation:
result += "."
return [{"generated_text": result}]