from typing import Dict, List, Any from PIL import Image import base64 import torch import os from io import BytesIO from transformers import BlipForConditionalGeneration, BlipProcessor import requests from PIL import Image from transformers import Blip2Processor, Blip2ForConditionalGeneration device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') class EndpointHandler(): def __init__(self, path=""): # load the optimized model # self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") # self.model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device) # self.model.eval() # self.model = self.model.to(device) self.processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") self.model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b") self.model.eval() self.model = self.model.to(device) def __call__(self, data: Any) -> Dict[str, Any]: """ Args: data (:obj:): includes the input data and the parameters for the inference. Return: A :obj:`dict`:. The object returned should be a dict of one list like {"captions": ["A hugging face at the office"]} containing : - "caption": A string corresponding to the generated caption. """ print("********* Helllo ***********") print(data) img_data = data.pop("input", data) prompt = data.pop("prompt", None) print("#########") # parameters = data.pop("parameters", {}) if isinstance(img_data, Image.Image): raw_image = img_data else: inputs = isinstance(img_data, str) and [img_data] or img_data # raw_image = [Image.open(BytesIO(base64.b64decode(_img))) for _img in inputs] raw_image = Image.open(BytesIO(base64.b64decode(img_data))) # processed_images = self.processor(images=raw_images, return_tensors="pt") # processed_images["pixel_values"] = processed_images["pixel_values"].to(device) # processed_images = {**processed_images, **parameters} # with torch.no_grad(): # out = self.model.generate(**processed_images) # captions = self.processor.batch_decode(out, skip_special_tokens=True) ############## # img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg' # raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB') inputs = processor(raw_image, prompt, return_tensors="pt") out = model.generate(**inputs) captions = processor.decode(out[0], skip_special_tokens=True) return {"captions": captions}