from typing import Dict, List, Any from PIL import Image from transformers import AutoProcessor, AutoModelForVision2Seq import base64 from io import BytesIO class EndpointHandler(): def __init__(self, path=""): # Preload all the elements you are going to need at inference. # pseudo: # self.model= load_model(path) self.model = AutoModelForVision2Seq.from_pretrained(path).to("cuda") self.processor = AutoProcessor.from_pretrained(path) # prompt = "An image of" def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: prompt = data.pop("prompt") image_base64 = data.pop("image_base64") image_data = base64.b64decode(image_base64) image = Image.open(BytesIO(image_data)) inputs = self.processor(text=prompt, images=image, return_tensors="pt").to("cuda") generated_ids = self.model.generate( pixel_values=inputs["pixel_values"], input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], image_embeds=None, image_embeds_position_mask=inputs["image_embeds_position_mask"], use_cache=True, max_new_tokens=128, ) generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0] # Specify `cleanup_and_extract=False` in order to see the raw model generation. processed_text = self.processor.post_process_generation(generated_text, cleanup_and_extract=False) # print(processed_text) # ` An image of a snowman warming himself by a fire.` # By default, the generated text is cleanup and the entities are extracted. processed_text, entities = self.processor.post_process_generation(generated_text) # print(processed_text) # `An image of a snowman warming himself by a fire.` return [{"processed_text": processed_text}] # print(entities) # `[('a snowman', (12, 21), [(0.390625, 0.046875, 0.984375, 0.828125)]), ('a fire', (41, 47), [(0.171875, 0.015625, 0.484375, 0.890625)])]`