grantpitt commited on
Commit
40f1eaa
1 Parent(s): 7990b4e

stuff stuff

Browse files
Files changed (5) hide show
  1. .gitattributes +1 -0
  2. .gitignore +2 -0
  3. embeddings.npy +3 -0
  4. handler.py +26 -0
  5. spotify.csv +3 -0
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ *.csv filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .DS_Store
2
+ work
embeddings.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9c78d4b65edb30b193dc6d8aa8a1f9fb90d3b8cf07ced05070fbcb74f4618f7a
3
+ size 64409216
handler.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from sentence_transformers import SentenceTransformer
3
+ import pandas as pd
4
+ import numpy as np
5
+ import os
6
+
7
+
8
+ class EndpointHandler:
9
+ def __init__(self, path=""):
10
+ self.model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
11
+
12
+ self.embeddings = np.load(os.path.join(path, "embeddings.npy"))
13
+ self.spotify = pd.read_csv(os.path.join(path, "spotify.csv"))
14
+
15
+ def __call__(self, data: Dict[str, Any]) -> List[float]:
16
+ """
17
+ data args:
18
+ inputs (:obj: `str` | `PIL.Image` | `np.array`)
19
+ kwargs
20
+ Return:
21
+ A :obj:`list` | `dict`: will be serialized and returned
22
+ """
23
+ input_embedding = self.model.encode(data["inputs"])
24
+ cos_score = self.embeddings @ input_embedding
25
+ top_10 = cos_score.argsort()[-10:][::-1]
26
+ return self.spotify.iloc[top_10].to_dict(orient="records")
spotify.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:32167542d556bd066374acb61c72f656ed78f60150889e1abdfb210e0b43f2cd
3
+ size 9745265