EMaghakyan commited on
Commit
98660fb
1 Parent(s): bb750d5

add custom handler

Browse files
__pycache__/handler.cpython-39.pyc ADDED
Binary file (1.52 kB). View file
 
handler.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from transformers import CLIPModel, AutoProcessor, AutoTokenizer
3
+ import torch
4
+ from PIL import Image
5
+ import requests
6
+
7
+
8
+ class EndpointHandler:
9
+ def __init__(self):
10
+ self.model = CLIPModel.from_pretrained("patrickjohncyh/fashion-clip")
11
+ self.processor = AutoProcessor.from_pretrained("patrickjohncyh/fashion-clip")
12
+ self.tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
13
+
14
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
15
+ parameters = data.pop("parameters", {"mode": "image"})
16
+ inputs = data.pop("inputs", data)
17
+ with torch.no_grad():
18
+ if parameters["mode"] == "text":
19
+ inputs = self.tokenizer(inputs, padding=True, return_tensors="pt")
20
+ features = self.model.get_text_features(**inputs)
21
+
22
+ if parameters["mode"] == "image":
23
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
24
+ image = Image.open(requests.get(url, stream=True).raw)
25
+
26
+ inputs = self.processor(images=image, return_tensors="pt")
27
+ features = self.model.get_image_features(**inputs)
28
+
29
+ return features[0].tolist()
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Pillow
test_handler.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from handler import EndpointHandler
2
+
3
+ # init handler
4
+ my_handler = EndpointHandler()
5
+
6
+ # prepare sample payload
7
+ non_holiday_payload = {
8
+ "inputs": "I am quite excited how this will turn out",
9
+ "parameters": {"mode": "text"},
10
+ }
11
+
12
+ # test the handler
13
+ non_holiday_pred = my_handler(non_holiday_payload)
14
+ print(non_holiday_pred)
15
+
16
+ non_holiday_payload = {
17
+ "inputs": "https://image.momoxfashion.com/Marc-O-Polo-i402d2w-0-detail",
18
+ "parameters": {"mode": "image"},
19
+ }
20
+
21
+ non_holiday_pred = my_handler(non_holiday_payload)
22
+ print(non_holiday_pred)