from typing import Dict, List, Any from transformers import AutoProcessor, Blip2ForConditionalGeneration import base64 import torch class EndpointHandler: def __init__(self, path=""): # load model and processor from path self.processor = AutoProcessor.from_pretrained(path) self.model = Blip2ForConditionalGeneration.from_pretrained(path, device_map="auto", load_in_8bit=True).to("cuda") def __call__(self, inputs: Dict[str, Any]) -> Dict[str, str]: """ Args: inputs: Dict of image and text inputs. """ # process input inputs = data.pop("inputs", data) image = base64.b64decode(inputs["image"]) inputs = processor(images=image, text=inputs["text"], return_tensors="pt").to("cuda", torch.float16) generated_ids = model.generate( **inputs, do_sample=decoding_method == "Nucleus sampling", temperature=1.0, length_penalty=1.0, repetition_penalty=1.5, max_length=30, min_length=1, num_beams=5, top_p=0.9, ) result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() if output and output[-1] not in string.punctuation: output += "." return [{"generated_text": output}]