finhdev commited on
Commit
aa10251
·
verified ·
1 Parent(s): 825b375

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +101 -13
handler.py CHANGED
@@ -1,13 +1,16 @@
1
  # handler.py (repo root)
 
2
  import io, base64, torch
3
  from PIL import Image
4
  import open_clip
 
 
5
 
6
  class EndpointHandler:
7
  """
8
  Zero‑shot classifier for MobileCLIP‑B (OpenCLIP).
9
 
10
- Expected client JSON *to the endpoint*:
11
  {
12
  "inputs": {
13
  "image": "<base64 PNG/JPEG>",
@@ -16,43 +19,128 @@ class EndpointHandler:
16
  }
17
  """
18
 
 
 
 
19
  def __init__(self, path: str = ""):
20
  weights = f"{path}/mobileclip_b.pt"
 
 
21
  self.model, _, self.preprocess = open_clip.create_model_and_transforms(
22
  "MobileCLIP-B", pretrained=weights
23
  )
24
- self.model.eval()
25
 
 
 
 
 
26
  self.tokenizer = open_clip.get_tokenizer("MobileCLIP-B")
 
 
27
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
28
  self.model.to(self.device)
29
 
 
 
 
 
 
 
 
30
  def __call__(self, data):
31
- # ── unwrap Hugging Face's `inputs` envelope ───────────
32
  payload = data.get("inputs", data)
33
 
34
- img_b64 = payload["image"]
35
- labels = payload.get("candidate_labels", [])
36
  if not labels:
37
  return {"error": "candidate_labels list is empty"}
38
 
39
- # Decode & preprocess image
40
  image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
41
  img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
42
 
43
- # Tokenise labels
44
- text_tokens = self.tokenizer(labels).to(self.device)
 
 
 
 
 
 
 
 
 
45
 
46
- # Forward pass
47
  with torch.no_grad(), torch.cuda.amp.autocast():
48
  img_feat = self.model.encode_image(img_tensor)
49
- txt_feat = self.model.encode_text(text_tokens)
50
  img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
51
- txt_feat = txt_feat / txt_feat.norm(dim=-1, keepdim=True)
52
- probs = (100 * img_feat @ txt_feat.T).softmax(dim=-1)[0].tolist()
53
 
54
- # Sorted output
 
 
 
55
  return [
56
  {"label": l, "score": float(p)}
57
  for l, p in sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)
58
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # handler.py (repo root)
2
+
3
  import io, base64, torch
4
  from PIL import Image
5
  import open_clip
6
+ from open_clip import fuse_conv_bn_sequential
7
+
8
 
9
  class EndpointHandler:
10
  """
11
  Zero‑shot classifier for MobileCLIP‑B (OpenCLIP).
12
 
13
+ Client JSON format:
14
  {
15
  "inputs": {
16
  "image": "<base64 PNG/JPEG>",
 
19
  }
20
  """
21
 
22
+ # ----------------------------------------------------- #
23
+ # INITIALISATION (once) #
24
+ # ----------------------------------------------------- #
25
  def __init__(self, path: str = ""):
26
  weights = f"{path}/mobileclip_b.pt"
27
+
28
+ # Load model + transforms
29
  self.model, _, self.preprocess = open_clip.create_model_and_transforms(
30
  "MobileCLIP-B", pretrained=weights
31
  )
 
32
 
33
+ # Fuse Conv+BN for faster inference
34
+ self.model = fuse_conv_bn_sequential(self.model).eval()
35
+
36
+ # Tokeniser
37
  self.tokenizer = open_clip.get_tokenizer("MobileCLIP-B")
38
+
39
+ # Device
40
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
41
  self.model.to(self.device)
42
 
43
+ # -------- text‑embedding cache --------
44
+ # key: prompt string • value: torch.Tensor [512] on correct device
45
+ self.label_cache: dict[str, torch.Tensor] = {}
46
+
47
+ # ----------------------------------------------------- #
48
+ # INFERENCE (per request) #
49
+ # ----------------------------------------------------- #
50
  def __call__(self, data):
51
+ # 1. Unwrap the HF "inputs" envelope
52
  payload = data.get("inputs", data)
53
 
54
+ img_b64 = payload["image"]
55
+ labels = payload.get("candidate_labels", [])
56
  if not labels:
57
  return {"error": "candidate_labels list is empty"}
58
 
59
+ # 2. Decode & preprocess image
60
  image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
61
  img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
62
 
63
+ # 3. Text embeddings with cache
64
+ missing = [l for l in labels if l not in self.label_cache]
65
+ if missing:
66
+ tokens = self.tokenizer(missing).to(self.device)
67
+ with torch.no_grad():
68
+ emb = self.model.encode_text(tokens)
69
+ emb = emb / emb.norm(dim=-1, keepdim=True)
70
+ for lbl, vec in zip(missing, emb):
71
+ self.label_cache[lbl] = vec # store on device
72
+
73
+ txt_feat = torch.stack([self.label_cache[l] for l in labels])
74
 
75
+ # 4. Forward pass for image
76
  with torch.no_grad(), torch.cuda.amp.autocast():
77
  img_feat = self.model.encode_image(img_tensor)
 
78
  img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
 
 
79
 
80
+ # 5. Similarity & softmax
81
+ probs = (100 * img_feat @ txt_feat.T).softmax(dim=-1)[0].tolist()
82
+
83
+ # 6. Return sorted list
84
  return [
85
  {"label": l, "score": float(p)}
86
  for l, p in sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)
87
  ]
88
+
89
+ # # handler.py (repo root)
90
+ # import io, base64, torch
91
+ # from PIL import Image
92
+ # import open_clip
93
+
94
+ # class EndpointHandler:
95
+ # """
96
+ # Zero‑shot classifier for MobileCLIP‑B (OpenCLIP).
97
+
98
+ # Expected client JSON *to the endpoint*:
99
+ # {
100
+ # "inputs": {
101
+ # "image": "<base64 PNG/JPEG>",
102
+ # "candidate_labels": ["cat", "dog", ...]
103
+ # }
104
+ # }
105
+ # """
106
+
107
+ # def __init__(self, path: str = ""):
108
+ # weights = f"{path}/mobileclip_b.pt"
109
+ # self.model, _, self.preprocess = open_clip.create_model_and_transforms(
110
+ # "MobileCLIP-B", pretrained=weights
111
+ # )
112
+ # self.model.eval()
113
+
114
+ # self.tokenizer = open_clip.get_tokenizer("MobileCLIP-B")
115
+ # self.device = "cuda" if torch.cuda.is_available() else "cpu"
116
+ # self.model.to(self.device)
117
+
118
+ # def __call__(self, data):
119
+ # # ── unwrap Hugging Face's `inputs` envelope ───────────
120
+ # payload = data.get("inputs", data)
121
+
122
+ # img_b64 = payload["image"]
123
+ # labels = payload.get("candidate_labels", [])
124
+ # if not labels:
125
+ # return {"error": "candidate_labels list is empty"}
126
+
127
+ # # Decode & preprocess image
128
+ # image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
129
+ # img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
130
+
131
+ # # Tokenise labels
132
+ # text_tokens = self.tokenizer(labels).to(self.device)
133
+
134
+ # # Forward pass
135
+ # with torch.no_grad(), torch.cuda.amp.autocast():
136
+ # img_feat = self.model.encode_image(img_tensor)
137
+ # txt_feat = self.model.encode_text(text_tokens)
138
+ # img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
139
+ # txt_feat = txt_feat / txt_feat.norm(dim=-1, keepdim=True)
140
+ # probs = (100 * img_feat @ txt_feat.T).softmax(dim=-1)[0].tolist()
141
+
142
+ # # Sorted output
143
+ # return [
144
+ # {"label": l, "score": float(p)}
145
+ # for l, p in sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)
146
+ # ]