ajanz commited on
Commit
6a3aaf8
1 Parent(s): 11cc362

model as a pipeline, bug fixes in index loader function

Browse files
Files changed (1) hide show
  1. app.py +50 -0
app.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import datasets
3
+ import faiss
4
+
5
+ from transformers import pipeline
6
+
7
+
8
+ sample_text = """Europejscy astronomowie odkryli planetę
9
+ pozasłoneczną pochodzącą spoza naszej galaktyki, czyli
10
+ [START_ENT] Drogi Mlecznej [END_ENT]. Obserwacji dokonali
11
+ 2,2-metrowym teleskopem MPG/ESO."""
12
+
13
+
14
+ textbox = gr.Textbox(
15
+ label="Type your query here.",
16
+ placeholder=sample_text, lines=10
17
+ )
18
+
19
+
20
+ def load_index(index_data: str = "clarin-knext/entity-linking-index"):
21
+ ds = datasets.load_dataset(index_data)['train']
22
+ index_data = {
23
+ idx: (e_id, e_text) for idx, (e_id, e_text) in
24
+ enumerate(zip(ds['entities'], ds['texts']))
25
+ }
26
+ faiss_index = faiss.load_index("./encoder.faissindex")
27
+ return index_data, faiss_index
28
+
29
+
30
+
31
+ def load_model(model_name: str = "clarin-knext/entity-linking-encoder"):
32
+ model = pipeline(task=model_name)
33
+ return model
34
+
35
+
36
+ def predict(model, index, query: str = sample_text, top_k: int=3):
37
+ index_data, faiss_index = index
38
+ query = model(query)
39
+
40
+ scores, indices = faiss_index.search(query, top_k)
41
+ results = [index_data[idx] for row in indices for idx in row]
42
+
43
+ return "\n".join(str(results))
44
+
45
+
46
+ model = load_model()
47
+ index = load_index()
48
+
49
+
50
+ demo = gr.Interface(fn=predict, inputs=textbox, outputs="text").launch()