finhdev commited on
Commit
5ca6bfd
·
verified ·
1 Parent(s): 382a099

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +17 -19
handler.py CHANGED
@@ -2,33 +2,31 @@
2
 
3
  import io, base64, torch, open_clip
4
  from PIL import Image
 
5
 
6
  class EndpointHandler:
7
  """
8
- MobileCLIP‑B (pretrained='datacompdr') zeroshot classifier.
9
- Expects JSON:
10
- {
11
- "inputs": {
12
- "image": "<base64 PNG/JPEG>",
13
- "candidate_labels": ["a photo of a cat", ...]
14
- }
15
  }
 
16
  """
17
 
18
  def __init__(self, path=""):
19
- # Model + transforms
20
  self.model, _, self.preprocess = open_clip.create_model_and_transforms(
21
  "mobileclip_b", pretrained="datacompdr"
22
  )
23
- self.model.eval()
24
 
25
- # Tokeniser & device
26
  self.tokenizer = open_clip.get_tokenizer("mobileclip_b")
27
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
28
  self.model.to(self.device)
29
 
30
- # Cache {prompt1×512 tensor}
31
- self.cache: dict[str, torch.Tensor] = {}
32
 
33
  def __call__(self, data):
34
  payload = data.get("inputs", data)
@@ -37,22 +35,22 @@ class EndpointHandler:
37
  if not labels:
38
  return {"error": "candidate_labels list is empty"}
39
 
40
- # Image tensor
41
  img = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
42
  img_t = self.preprocess(img).unsqueeze(0).to(self.device)
43
 
44
- # Text embeddings with cache
45
- new_labels = [l for l in labels if l not in self.cache]
46
- if new_labels:
47
- tok = self.tokenizer(new_labels).to(self.device)
48
  with torch.no_grad():
49
  emb = self.model.encode_text(tok)
50
  emb = emb / emb.norm(dim=-1, keepdim=True)
51
- for l, e in zip(new_labels, emb):
52
  self.cache[l] = e
53
  txt_t = torch.stack([self.cache[l] for l in labels])
54
 
55
- # Forward
56
  with torch.no_grad(), torch.cuda.amp.autocast():
57
  img_f = self.model.encode_image(img_t)
58
  img_f = img_f / img_f.norm(dim=-1, keepdim=True)
 
2
 
3
  import io, base64, torch, open_clip
4
  from PIL import Image
5
+ from mobileclip.modules.common.mobileone import reparameterize_model # optional
6
 
7
  class EndpointHandler:
8
  """
9
+ MobileCLIP‑B ('datacompdr') · textembedding cache.
10
+ Expects: {
11
+ "inputs": {
12
+ "image": "<base64>",
13
+ "candidate_labels": ["a photo of a cat", ...]
 
 
14
  }
15
+ }
16
  """
17
 
18
  def __init__(self, path=""):
19
+ # -- Load MobileCLIP‑B checkpoint identical to local run -------------
20
  self.model, _, self.preprocess = open_clip.create_model_and_transforms(
21
  "mobileclip_b", pretrained="datacompdr"
22
  )
23
+ self.model = reparameterize_model(self.model).eval() # matches local pipeline
24
 
 
25
  self.tokenizer = open_clip.get_tokenizer("mobileclip_b")
26
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
27
  self.model.to(self.device)
28
 
29
+ self.cache: dict[str, torch.Tensor] = {} # label embedding
 
30
 
31
  def __call__(self, data):
32
  payload = data.get("inputs", data)
 
35
  if not labels:
36
  return {"error": "candidate_labels list is empty"}
37
 
38
+ # -------- image preprocessing --------------------------------------
39
  img = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
40
  img_t = self.preprocess(img).unsqueeze(0).to(self.device)
41
 
42
+ # -------- text embeddings with cache -------------------------------
43
+ new = [l for l in labels if l not in self.cache]
44
+ if new:
45
+ tok = self.tokenizer(new).to(self.device)
46
  with torch.no_grad():
47
  emb = self.model.encode_text(tok)
48
  emb = emb / emb.norm(dim=-1, keepdim=True)
49
+ for l, e in zip(new, emb):
50
  self.cache[l] = e
51
  txt_t = torch.stack([self.cache[l] for l in labels])
52
 
53
+ # -------- forward & softmax ----------------------------------------
54
  with torch.no_grad(), torch.cuda.amp.autocast():
55
  img_f = self.model.encode_image(img_t)
56
  img_f = img_f / img_f.norm(dim=-1, keepdim=True)