Ishaan Gupta
updated to model path in handler
8009a47
from typing import Dict, List, Any
from PIL import Image
from transformers import AutoProcessor, AutoModelForVision2Seq
import base64
from io import BytesIO
class EndpointHandler():
def __init__(self, path=""):
# Preload all the elements you are going to need at inference.
# pseudo:
# self.model= load_model(path)
self.model = AutoModelForVision2Seq.from_pretrained(path).to("cuda")
self.processor = AutoProcessor.from_pretrained(path)
# prompt = "<grounding>An image of"
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
prompt = data.pop("prompt")
image_base64 = data.pop("image_base64")
image_data = base64.b64decode(image_base64)
image = Image.open(BytesIO(image_data))
inputs = self.processor(text=prompt, images=image, return_tensors="pt").to("cuda")
generated_ids = self.model.generate(
pixel_values=inputs["pixel_values"],
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
image_embeds=None,
image_embeds_position_mask=inputs["image_embeds_position_mask"],
use_cache=True,
max_new_tokens=128,
)
generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
# Specify `cleanup_and_extract=False` in order to see the raw model generation.
processed_text = self.processor.post_process_generation(generated_text, cleanup_and_extract=False)
# print(processed_text)
# `<grounding> An image of<phrase> a snowman</phrase><object><patch_index_0044><patch_index_0863></object> warming himself by<phrase> a fire</phrase><object><patch_index_0005><patch_index_0911></object>.`
# By default, the generated text is cleanup and the entities are extracted.
processed_text, entities = self.processor.post_process_generation(generated_text)
# print(processed_text)
# `An image of a snowman warming himself by a fire.`
return [{"processed_text": processed_text}]
# print(entities)
# `[('a snowman', (12, 21), [(0.390625, 0.046875, 0.984375, 0.828125)]), ('a fire', (41, 47), [(0.171875, 0.015625, 0.484375, 0.890625)])]`