grantpitt commited on
Commit
b86e84d
1 Parent(s): e9ccfc7
Files changed (1) hide show
  1. handler.py +7 -4
handler.py CHANGED
@@ -10,12 +10,13 @@ class EndpointHandler:
10
  Initialize the model
11
  """
12
  self.sign_ids = np.load(os.path.join(path, "sign_ids.npy"))
13
- self.sign_embeddings = np.load(os.path.join(path, "vanilla_large-patch14_image_embeddings_normalized.npy"))
 
 
14
 
15
  hf_model_path = "openai/clip-vit-large-patch14"
16
  self.model = CLIPModel.from_pretrained(hf_model_path)
17
  self.tokenizer = CLIPTokenizer.from_pretrained(hf_model_path)
18
-
19
 
20
  def __call__(self, data: Dict[str, Any]) -> List[float]:
21
  """
@@ -25,7 +26,9 @@ class EndpointHandler:
25
  Return:
26
  A :obj:`list` | `dict`: will be serialized and returned
27
  """
28
- token_inputs = self.tokenizer([data["inputs"]], padding=True, return_tensors="pt")
 
 
29
  query_embed = self.model.get_text_features(**token_inputs)
30
  np_query_embed = query_embed.detach().cpu().numpy()[0]
31
  np_query_embed /= np.linalg.norm(np_query_embed)
@@ -37,7 +40,7 @@ class EndpointHandler:
37
  cos_similarites = w * (self.sign_embeddings @ np_query_embed)
38
  count_above_threshold = np.sum(cos_similarites > threshold)
39
  sign_id_arg_rankings = np.argsort(cos_similarites)[::-1]
40
-
41
  threshold_id_arg_rankings = sign_id_arg_rankings[:count_above_threshold]
42
 
43
  result_sign_ids = self.sign_ids[threshold_id_arg_rankings]
 
10
  Initialize the model
11
  """
12
  self.sign_ids = np.load(os.path.join(path, "sign_ids.npy"))
13
+ self.sign_embeddings = np.load(
14
+ os.path.join(path, "vanilla_large-patch14_image_embeddings_normalized.npy")
15
+ )
16
 
17
  hf_model_path = "openai/clip-vit-large-patch14"
18
  self.model = CLIPModel.from_pretrained(hf_model_path)
19
  self.tokenizer = CLIPTokenizer.from_pretrained(hf_model_path)
 
20
 
21
  def __call__(self, data: Dict[str, Any]) -> List[float]:
22
  """
 
26
  Return:
27
  A :obj:`list` | `dict`: will be serialized and returned
28
  """
29
+ token_inputs = self.tokenizer(
30
+ [data["inputs"]], padding=True, return_tensors="pt"
31
+ )
32
  query_embed = self.model.get_text_features(**token_inputs)
33
  np_query_embed = query_embed.detach().cpu().numpy()[0]
34
  np_query_embed /= np.linalg.norm(np_query_embed)
 
40
  cos_similarites = w * (self.sign_embeddings @ np_query_embed)
41
  count_above_threshold = np.sum(cos_similarites > threshold)
42
  sign_id_arg_rankings = np.argsort(cos_similarites)[::-1]
43
+
44
  threshold_id_arg_rankings = sign_id_arg_rankings[:count_above_threshold]
45
 
46
  result_sign_ids = self.sign_ids[threshold_id_arg_rankings]