ajanz commited on
Commit
9e7f361
1 Parent(s): b2f66ed
Files changed (2) hide show
  1. app.py +60 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import datasets
3
+ import faiss
4
+ import os
5
+
6
+ from transformers import pipeline
7
+
8
+
9
+ auth_token = os.environ.get("CLARIN_KNEXT")
10
+
11
+
12
+ sample_text = (
13
+ "Wydarzenia te miały miejsce na początku mojej przyjaźni z Holmesem, "
14
+ "kiedy, jeszcze jako [unused0] kawalerowie [unused1], mieszkaliśmy razem przy Baker Street."
15
+ )
16
+
17
+
18
+ textbox = gr.Textbox(
19
+ label="Type your query here.",
20
+ value=sample_text, lines=10
21
+ )
22
+
23
+
24
+ def load_index(index_data: str = "clarin-knext/wsd-linking-index"):
25
+ ds = datasets.load_dataset(index_data, use_auth_token=auth_token)['train']
26
+ index_data = {
27
+ idx: (e_id, e_text) for idx, (e_id, e_text) in
28
+ enumerate(zip(ds['entities'], ds['texts']))
29
+ }
30
+ faiss_index = faiss.read_index("./encoder.faissindex", faiss.IO_FLAG_MMAP)
31
+ return index_data, faiss_index
32
+
33
+
34
+ def load_model(model_name: str = "clarin-knext/wsd-encoder"):
35
+ model = pipeline("feature-extraction", model=model_name, use_auth_token=auth_token)
36
+ return model
37
+
38
+
39
+ model = load_model()
40
+ index = load_index()
41
+
42
+
43
+ def predict(text: str = sample_text, top_k: int=3):
44
+ index_data, faiss_index = index
45
+ # takes only the [CLS] embedding (for now)
46
+ query = model(text, return_tensors='pt')[0][0].numpy().reshape(1, -1)
47
+
48
+ scores, indices = faiss_index.search(query, top_k)
49
+ scores, indices = scores.tolist(), indices.tolist()
50
+
51
+ results = "\n".join([
52
+ f"{index_data[result[0]]}: {result[1]}"
53
+ for output in zip(indices, scores)
54
+ for result in zip(*output)
55
+ ])
56
+
57
+ return results
58
+
59
+
60
+ demo = gr.Interface(fn=predict, inputs=textbox, outputs="text").launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio
2
+ datasets
3
+ transformers==4.24.0
4
+ faiss-cpu
5
+ torch