blip2-opt-6.7b / handler.py
merve's picture
merve HF staff
Update handler.py
aa93800
raw
history blame contribute delete
No virus
1.37 kB
from typing import Dict, List, Any
from transformers import AutoProcessor, Blip2ForConditionalGeneration
import base64
import torch
class EndpointHandler:
def __init__(self, path=""):
# load model and processor from path
self.processor = AutoProcessor.from_pretrained(path)
self.model = Blip2ForConditionalGeneration.from_pretrained(path, device_map="auto", load_in_8bit=True).to("cuda")
def __call__(self, inputs: Dict[str, Any]) -> Dict[str, str]:
"""
Args:
inputs:
Dict of image and text inputs.
"""
# process input
inputs = data.pop("inputs", data)
image = base64.b64decode(inputs["image"])
inputs = processor(images=image, text=inputs["text"], return_tensors="pt").to("cuda", torch.float16)
generated_ids = model.generate(
**inputs,
do_sample=decoding_method == "Nucleus sampling",
temperature=1.0,
length_penalty=1.0,
repetition_penalty=1.5,
max_length=30,
min_length=1,
num_beams=5,
top_p=0.9,
)
result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
if output and output[-1] not in string.punctuation:
output += "."
return [{"generated_text": output}]