from typing import Dict, List, Any from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer from PIL import Image import torch import base64 from base64 import b64encode import requests import json import io # Take in base64 string and return cv image def stringToRGB(base64_string): imgdata = base64.b64decode(str(base64_string)) img = Image.open(io.BytesIO(imgdata)).convert('RGB') # opencv_img= cv2.cvtColor(np.array(img), cv2.COLOR_BGR2RGB) return img def predict_caption(image_str, max_token = 32): num_beams = 4 gen_kwargs = {"max_length": max_token, "num_beams": num_beams} images = [] image = stringToRGB(image_str) images.append(image) pixel_values = feature_extractor(images=images, return_tensors="pt").pixel_values pixel_values = pixel_values.to(device) output_ids = model.generate(pixel_values, **gen_kwargs) preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True) preds = [pred.strip() for pred in preds] return preds[0] class EndpointHandler(): def __init__(self, path=""): # Preload all the elements you are going to need at inference. # pseudo: # self.model= load_model(path model = VisionEncoderDecoderModel.from_pretrained(path) feature_extractor = ViTFeatureExtractor.from_pretrained(path) tokenizer = AutoTokenizer.from_pretrained(path) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ data args: inputs (:obj: `str` | `PIL.Image` | `np.array`) kwargs Return: A :obj:`list` | `dict`: will be serialized and returned """ max_token = data.pop("max_token", 32) img_str = data.pop("data", None) caption = predict_caption(img_str, max_token=max_token) return {"caption": f"{caption}"} # pseudo # self.model(input)