from typing import Dict, List, Any | |
from transformers import AutoProcessor, Blip2ForConditionalGeneration | |
import base64 | |
from io import BytesIO | |
from PIL import Image | |
import string | |
import torch | |
class EndpointHandler: | |
def __init__(self, path=""): | |
# load model and processor from path | |
self.processor = AutoProcessor.from_pretrained(path) | |
self.model = Blip2ForConditionalGeneration.from_pretrained(path, device_map="auto", load_in_4bit=True) | |
def __call__(self, data): | |
""" | |
Args: | |
inputs: | |
Dict of image and text inputs. | |
""" | |
# process input | |
inputs = data.pop("inputs", data) | |
image = Image.open(BytesIO(base64.b64decode(inputs['image']))) | |
inputs = self.processor(images=image, text=inputs["text"], return_tensors="pt").to("cuda", torch.float16) | |
generated_ids = self.model.generate( | |
**inputs, | |
temperature=1.0, | |
length_penalty=1.0, | |
repetition_penalty=1.5, | |
max_length=30, | |
min_length=1, | |
num_beams=5, | |
top_p=0.9, | |
) | |
result = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() | |
if result and result[-1] not in string.punctuation: | |
result += "." | |
return [{"generated_text": result}] |