ajanz commited on
Commit
8182466
1 Parent(s): 6a3aaf8

model as a pipeline, bug fixes in index loader function

Browse files
Files changed (2) hide show
  1. app.py +5 -4
  2. space.py +0 -50
app.py CHANGED
@@ -4,6 +4,8 @@ import faiss
4
 
5
  from transformers import pipeline
6
 
 
 
7
 
8
  sample_text = """Europejscy astronomowie odkryli planetę
9
  pozasłoneczną pochodzącą spoza naszej galaktyki, czyli
@@ -23,14 +25,13 @@ def load_index(index_data: str = "clarin-knext/entity-linking-index"):
23
  idx: (e_id, e_text) for idx, (e_id, e_text) in
24
  enumerate(zip(ds['entities'], ds['texts']))
25
  }
26
- faiss_index = faiss.load_index("./encoder.faissindex")
27
  return index_data, faiss_index
28
-
29
 
30
 
31
  def load_model(model_name: str = "clarin-knext/entity-linking-encoder"):
32
- model = pipeline(task=model_name)
33
- return model
34
 
35
 
36
  def predict(model, index, query: str = sample_text, top_k: int=3):
 
4
 
5
  from transformers import pipeline
6
 
7
+ import requests
8
+
9
 
10
  sample_text = """Europejscy astronomowie odkryli planetę
11
  pozasłoneczną pochodzącą spoza naszej galaktyki, czyli
 
25
  idx: (e_id, e_text) for idx, (e_id, e_text) in
26
  enumerate(zip(ds['entities'], ds['texts']))
27
  }
28
+ faiss_index = faiss.read_index("./encoder.faissindex", faiss.IO_FLAG_MMAP)
29
  return index_data, faiss_index
 
30
 
31
 
32
  def load_model(model_name: str = "clarin-knext/entity-linking-encoder"):
33
+ pipe = pipeline("feature-extraction", model=model_name)
34
+ return pipe
35
 
36
 
37
  def predict(model, index, query: str = sample_text, top_k: int=3):
space.py DELETED
@@ -1,50 +0,0 @@
1
- import gradio as gr
2
- import datasets
3
- import faiss
4
-
5
- from transformers import pipeline
6
-
7
-
8
- sample_text = """Europejscy astronomowie odkryli planetę
9
- pozasłoneczną pochodzącą spoza naszej galaktyki, czyli
10
- [START_ENT] Drogi Mlecznej [END_ENT]. Obserwacji dokonali
11
- 2,2-metrowym teleskopem MPG/ESO."""
12
-
13
-
14
- textbox = gr.Textbox(
15
- label="Type your query here.",
16
- placeholder=sample_text, lines=10
17
- )
18
-
19
-
20
- def load_index(index_data: str = "clarin-knext/entity-linking-index"):
21
- ds = datasets.load_dataset(index_data)['train']
22
- index_data = {
23
- idx: (e_id, e_text) for idx, (e_id, e_text) in
24
- enumerate(zip(ds['entities'], ds['texts']))
25
- }
26
- faiss_index = faiss.load_index("./encoder.faissindex")
27
- return index_data, faiss_index
28
-
29
-
30
-
31
- def load_model(model_name: str = "clarin-knext/entity-linking-encoder"):
32
- model = pipeline(task=model_name)
33
- return model
34
-
35
-
36
- def predict(model, index, query: str = sample_text, top_k: int=3):
37
- index_data, faiss_index = index
38
- query = model(query)
39
-
40
- scores, indices = faiss_index.search(query, top_k)
41
- results = [index_data[idx] for row in indices for idx in row]
42
-
43
- return "\n".join(str(results))
44
-
45
-
46
- model = load_model()
47
- index = load_index()
48
-
49
-
50
- demo = gr.Interface(fn=predict, inputs=textbox, outputs="text").launch()