from typing import Any, Dict, List from transformers import Idefics2Processor, Idefics2ForConditionalGeneration import torch import logging from PIL import Image import requests class EndpointHandler: def __init__(self, path=""): # Preload all the elements you are going to need at inference. self.logger = logging.getLogger() self.logger.addHandler(logging.StreamHandler()) self.device = "cuda" if torch.cuda.is_available() else "cpu" self.processor = Idefics2Processor.from_pretrained(path) self.model = Idefics2ForConditionalGeneration.from_pretrained(path) self.model.to(self.device) self.logger.info("Initialisation finished!") def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ data args: inputs (:obj: `str` | `PIL.Image` | `np.array`) kwargs Return: A :obj:`list` | `dict`: will be serialized and returned """ """image = data.pop("inputs", data) self.logger.info("image") # process image inputs = self.processor(images=image, return_tensors="pt").to(self.device) self.logger.info("inputs") self.logger.info(f"{inputs.input_ids}") generated_ids = self.model.generate(**inputs) self.logger.info("generated") # run prediction generated_text = self.processor.batch_decode( generated_ids, skip_special_tokens=True ) self.logger.info("decoded")""" url_2 = "http://images.cocodataset.org/val2017/000000219578.jpg" image_1 = data.pop("inputs", data) image_2 = Image.open(requests.get(url_2, stream=True).raw) images = [image_1, image_2] messages = [ { "role": "user", "content": [ { "type": "text", "text": "What’s the difference between these two images?", }, {"type": "image"}, {"type": "image"}, ], } ] self.model.to(self.device) # at inference time, one needs to pass `add_generation_prompt=True` in order to make sure the model completes the prompt text = self.processor.apply_chat_template(messages, add_generation_prompt=True) self.logger.info(text) # 'User: What’s the difference between these two images?\nAssistant:' inputs = self.processor(images=images, text=text, return_tensors="pt").to( self.device ) self.logger.info("inputs") generated_text = self.model.generate(**inputs, max_new_tokens=500) self.logger.info("generated") generated_text = self.processor.batch_decode( generated_text, skip_special_tokens=True )[0] self.logger.info(f"Generated text: {generated_text}") # decode output return generated_text