rbanfield commited on
Commit
b926327
·
1 Parent(s): ab71b31

Upload 2 files

Browse files

Upload a handler.py and requirements.txt to run on inference endpoints

Files changed (2) hide show
  1. handler.py +30 -0
  2. requirements.txt +3 -0
handler.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+ import base64
3
+
4
+ from PIL import Image
5
+ import torch
6
+ from transformers import CLIPProcessor, CLIPTextModel, CLIPVisionModelWithProjection
7
+
8
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
9
+
10
+ class EndpointHandler():
11
+ def __init__(self, path=""):
12
+ self.text_model = CLIPTextModel.from_pretrained("rbanfield/clip-vit-large-patch14")
13
+ self.image_model = CLIPVisionModelWithProjection.from_pretrained("rbanfield/clip-vit-large-patch14")
14
+ self.processor = CLIPProcessor.from_pretrained("rbanfield/clip-vit-large-patch14")
15
+
16
+ def __call__(self, data):
17
+ text_input = data.pop("text", None)
18
+ image_input = data.pop("image", None)
19
+
20
+ if text_input:
21
+ processor = self.processor(text=text_input, return_tensors="pt", padding=True)
22
+ with torch.no_grad():
23
+ return self.text_model(**processor).pooler_output.tolist()
24
+ elif image_input:
25
+ image = Image.open(BytesIO(base64.b64decode(image_input)))
26
+ processor = self.processor(images=image, return_tensors="pt")
27
+ with torch.no_grad():
28
+ return self.image_model(**processor).image_embeds.tolist()
29
+ else:
30
+ return None
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ Pillow
2
+ transformers
3
+ torch