from io import BytesIO import base64 import traceback import logging from PIL import Image import torch from transformers import CLIPProcessor, CLIPModel device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') logger = logging.getLogger(__name__) logger.setLevel('INFO') class EndpointHandler(): def __init__(self, path=""): self.model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(device) self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") def __call__(self, data): try: inputs = data.pop("inputs", None) text_input = None image_data = None if isinstance(inputs, Image.Image): logger.info('image sent directly') image = inputs else: text_input = inputs["text"] if "text" in inputs else None image_data = inputs['image'] if 'image' in inputs else None if image_data is not None: logger.info('image is encoded') image = Image.open(BytesIO(base64.b64decode(image_data))) if text_input: processor = self.processor(text=text_input, return_tensors="pt", padding=True).to(device) with torch.no_grad(): return {"embeddings": self.model.get_text_features(**processor).tolist()} elif image: # image = Image.open(image_data) processor = self.processor(images=image, return_tensors="pt").to(device) with torch.no_grad(): return {"embeddings": self.model.get_image_features(**processor).tolist()} else: return {'embeddings':None} except Exception as ex: logger.error('error doing request: %s', ex) logger.exception(ex) stack_info = traceback.format_exc() logger.error('stack trace:\n%s',stack_info) return {'Error':stack_info}