mattmdjaga's picture
added custom handler
d364efb
raw
history blame contribute delete
No virus
1.45 kB
from typing import Dict, List, Any
from PIL import Image
from io import BytesIO
from transformers import CLIPProcessor, CLIPModel
import base64
import torch
class EndpointHandler():
def __init__(self, path="."):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = CLIPModel.from_pretrained(path).to(self.device).eval()
self.processor = CLIPProcessor.from_pretrained(path)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
images (:obj:`PIL.Image`)
candiates (:obj:`list`)
Return:
A :obj:`list`:. The list contains items that are dicts should be liked {"label": "XXX", "score": 0.82}
"""
inputs = data.pop("inputs", data)
# decode base64 image to PIL
image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
txt = inputs['text']
# preprocess image
txt = self.processor(text=txt, return_tensors="pt",padding=True).to(self.device)
image = self.processor(images=image, return_tensors="pt",padding=True).to(self.device)
with torch.no_grad():
txt_features = self.model.get_text_features(**txt)
image_features = self.model.get_image_features(**image)
img = image_features.tolist()
txt = txt_features.tolist()
pred = {"image": img, "text": txt}
return pred