finhdev commited on
Commit
9188b68
·
verified ·
1 Parent(s): 1fdf84e

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +21 -0
handler.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io, base64, torch
2
+ from PIL import Image
3
+ from transformers import CLIPProcessor, CLIPModel
4
+
5
+ class EndpointHandler:
6
+ def __init__(self, path=""):
7
+ device = "cuda" if torch.cuda.is_available() else "cpu"
8
+ self.model = CLIPModel.from_pretrained(path).to(device)
9
+ self.processor = CLIPProcessor.from_pretrained(path)
10
+ self.device = device
11
+
12
+ def __call__(self, data):
13
+ # Expect JSON {"image": "<base64 PNG/JPEG>", "candidate_labels": ["cat","dog"]}
14
+ img_b64 = data["image"]
15
+ labels = data.get("candidate_labels", [])
16
+ image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
17
+
18
+ inputs = self.processor(text=labels, images=image,
19
+ return_tensors="pt", padding=True).to(self.device)
20
+ probs = self.model(**inputs).logits_per_image.softmax(dim=-1)[0].tolist()
21
+ return [{"label": l, "score": float(p)} for l, p in zip(labels, probs)]