grantpitt commited on
Commit
e9ccfc7
1 Parent(s): bd8f682

initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ __pycache__/
2
+ .ipynb_checkpoints/
3
+ local_test.ipynb
4
+ hosted_test.ipynb
handler.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from transformers import CLIPTokenizer, CLIPModel
3
+ import numpy as np
4
+ import os
5
+
6
+
7
+ class EndpointHandler:
8
+ def __init__(self, path=""):
9
+ """
10
+ Initialize the model
11
+ """
12
+ self.sign_ids = np.load(os.path.join(path, "sign_ids.npy"))
13
+ self.sign_embeddings = np.load(os.path.join(path, "vanilla_large-patch14_image_embeddings_normalized.npy"))
14
+
15
+ hf_model_path = "openai/clip-vit-large-patch14"
16
+ self.model = CLIPModel.from_pretrained(hf_model_path)
17
+ self.tokenizer = CLIPTokenizer.from_pretrained(hf_model_path)
18
+
19
+
20
+ def __call__(self, data: Dict[str, Any]) -> List[float]:
21
+ """
22
+ data args:
23
+ inputs (:obj: `str` | `PIL.Image` | `np.array`)
24
+ kwargs
25
+ Return:
26
+ A :obj:`list` | `dict`: will be serialized and returned
27
+ """
28
+ token_inputs = self.tokenizer([data["inputs"]], padding=True, return_tensors="pt")
29
+ query_embed = self.model.get_text_features(**token_inputs)
30
+ np_query_embed = query_embed.detach().cpu().numpy()[0]
31
+ np_query_embed /= np.linalg.norm(np_query_embed)
32
+
33
+ # Compute the cosine similarity; note the embeddings are normalized.
34
+ # This weight is arbitrary, but makes the results easier to think about
35
+ w = 2.5
36
+ threshold = 0.475
37
+ cos_similarites = w * (self.sign_embeddings @ np_query_embed)
38
+ count_above_threshold = np.sum(cos_similarites > threshold)
39
+ sign_id_arg_rankings = np.argsort(cos_similarites)[::-1]
40
+
41
+ threshold_id_arg_rankings = sign_id_arg_rankings[:count_above_threshold]
42
+
43
+ result_sign_ids = self.sign_ids[threshold_id_arg_rankings]
44
+ result_sign_scores = cos_similarites[threshold_id_arg_rankings]
45
+ return [result_sign_ids.tolist(), result_sign_scores.tolist()]
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ numpy==1.23.1
2
+ transformers==4.21.1
sign_ids.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e282e229c3af38c7c0ee6ce5cf15317d5c6f83b7c44a18fe04f0239a0bbd8bde
3
+ size 465400
vanilla_large-patch14_image_embeddings_normalized.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4f70fc1bcba9555a00344cb21132276955645a7b78c54de1a1efcb17f776f033
3
+ size 357329024