ajanz commited on
Commit
c4f4d05
1 Parent(s): 8182466

bug fixes in prediction function

Browse files
Files changed (1) hide show
  1. app.py +16 -11
app.py CHANGED
@@ -4,12 +4,10 @@ import faiss
4
 
5
  from transformers import pipeline
6
 
7
- import requests
8
-
9
 
10
  sample_text = """Europejscy astronomowie odkryli planetę
11
  pozasłoneczną pochodzącą spoza naszej galaktyki, czyli
12
- [START_ENT] Drogi Mlecznej [END_ENT]. Obserwacji dokonali
13
  2,2-metrowym teleskopem MPG/ESO."""
14
 
15
 
@@ -34,18 +32,25 @@ def load_model(model_name: str = "clarin-knext/entity-linking-encoder"):
34
  return pipe
35
 
36
 
37
- def predict(model, index, query: str = sample_text, top_k: int=3):
 
 
 
 
38
  index_data, faiss_index = index
39
- query = model(query)
 
40
 
41
  scores, indices = faiss_index.search(query, top_k)
42
- results = [index_data[idx] for row in indices for idx in row]
43
 
44
- return "\n".join(str(results))
45
-
46
-
47
- model = load_model()
48
- index = load_index()
 
 
49
 
50
 
51
  demo = gr.Interface(fn=predict, inputs=textbox, outputs="text").launch()
 
4
 
5
  from transformers import pipeline
6
 
 
 
7
 
8
  sample_text = """Europejscy astronomowie odkryli planetę
9
  pozasłoneczną pochodzącą spoza naszej galaktyki, czyli
10
+ [unused0] Drogi Mlecznej [unused1]. Obserwacji dokonali
11
  2,2-metrowym teleskopem MPG/ESO."""
12
 
13
 
 
32
  return pipe
33
 
34
 
35
+ model = load_model()
36
+ index = load_index()
37
+
38
+
39
+ def predict(query: str = sample_text, top_k: int=3):
40
  index_data, faiss_index = index
41
+ # takes only the [CLS] embedding (for now)
42
+ query = model(query, return_tensors = "pt")[0][0].numpy().reshape(1, -1)
43
 
44
  scores, indices = faiss_index.search(query, top_k)
45
+ scores, indices = scores.tolist(), indices.tolist()
46
 
47
+ results = [
48
+ (index_data[result[0]], result[1])
49
+ for output in zip(indices, scores)
50
+ for result in zip(*output)
51
+ ]
52
+
53
+ return str(results)
54
 
55
 
56
  demo = gr.Interface(fn=predict, inputs=textbox, outputs="text").launch()