bobs-reduced-clip / handler.py
rbanfield's picture
Update handler.py
9c170f0
from io import BytesIO
import base64
from PIL import Image
import torch
from transformers import CLIPProcessor, CLIPModel
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class EndpointHandler():
def __init__(self, path=""):
self.model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
def __call__(self, data):
text_input = None
if isinstance(data, dict):
inputs = data.pop("inputs", None)
text_input = inputs.get('text',None)
image_data = BytesIO(base64.b64decode(inputs['image'])) if 'image' in inputs else None
else:
# assuming its an image sent via binary
image_data = BytesIO(data)
if text_input:
processor = self.processor(text=text_input, return_tensors="pt", padding=True).to(device)
with torch.no_grad():
return {"embeddings": self.model.get_text_features(**processor).to("cpu").tolist()}
elif image_data:
image = Image.open(image_data)
processor = self.processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
return {"embeddings": self.model.get_image_features(**processor).to("cpu").tolist()}
else:
return {"embeddings": None}