idefics2-8b-ocr / handler.py
huz-relay's picture
Add user input
1b88ea1
raw
history blame contribute delete
No virus
3.04 kB
from typing import Any, Dict, List
from transformers import Idefics2Processor, Idefics2ForConditionalGeneration
import torch
import logging
from PIL import Image
import requests
class EndpointHandler:
def __init__(self, path=""):
# Preload all the elements you are going to need at inference.
self.logger = logging.getLogger()
self.logger.addHandler(logging.StreamHandler())
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.processor = Idefics2Processor.from_pretrained(path)
self.model = Idefics2ForConditionalGeneration.from_pretrained(path)
self.model.to(self.device)
self.logger.info("Initialisation finished!")
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
"""image = data.pop("inputs", data)
self.logger.info("image")
# process image
inputs = self.processor(images=image, return_tensors="pt").to(self.device)
self.logger.info("inputs")
self.logger.info(f"{inputs.input_ids}")
generated_ids = self.model.generate(**inputs)
self.logger.info("generated")
# run prediction
generated_text = self.processor.batch_decode(
generated_ids, skip_special_tokens=True
)
self.logger.info("decoded")"""
url_2 = "http://images.cocodataset.org/val2017/000000219578.jpg"
image_1 = data.pop("inputs", data)
image_2 = Image.open(requests.get(url_2, stream=True).raw)
images = [image_1, image_2]
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": "What’s the difference between these two images?",
},
{"type": "image"},
{"type": "image"},
],
}
]
self.model.to(self.device)
# at inference time, one needs to pass `add_generation_prompt=True` in order to make sure the model completes the prompt
text = self.processor.apply_chat_template(messages, add_generation_prompt=True)
self.logger.info(text)
# 'User: What’s the difference between these two images?<image><image><end_of_utterance>\nAssistant:'
inputs = self.processor(images=images, text=text, return_tensors="pt").to(
self.device
)
self.logger.info("inputs")
generated_text = self.model.generate(**inputs, max_new_tokens=500)
self.logger.info("generated")
generated_text = self.processor.batch_decode(
generated_text, skip_special_tokens=True
)[0]
self.logger.info(f"Generated text: {generated_text}")
# decode output
return generated_text