entity-linking / app.py
ajanz's picture
bug fixes in prediction function
c4f4d05
raw
history blame
1.55 kB
import gradio as gr
import datasets
import faiss
from transformers import pipeline
sample_text = """Europejscy astronomowie odkryli planetę
pozasłoneczną pochodzącą spoza naszej galaktyki, czyli
[unused0] Drogi Mlecznej [unused1]. Obserwacji dokonali
2,2-metrowym teleskopem MPG/ESO."""
textbox = gr.Textbox(
label="Type your query here.",
placeholder=sample_text, lines=10
)
def load_index(index_data: str = "clarin-knext/entity-linking-index"):
ds = datasets.load_dataset(index_data)['train']
index_data = {
idx: (e_id, e_text) for idx, (e_id, e_text) in
enumerate(zip(ds['entities'], ds['texts']))
}
faiss_index = faiss.read_index("./encoder.faissindex", faiss.IO_FLAG_MMAP)
return index_data, faiss_index
def load_model(model_name: str = "clarin-knext/entity-linking-encoder"):
pipe = pipeline("feature-extraction", model=model_name)
return pipe
model = load_model()
index = load_index()
def predict(query: str = sample_text, top_k: int=3):
index_data, faiss_index = index
# takes only the [CLS] embedding (for now)
query = model(query, return_tensors = "pt")[0][0].numpy().reshape(1, -1)
scores, indices = faiss_index.search(query, top_k)
scores, indices = scores.tolist(), indices.tolist()
results = [
(index_data[result[0]], result[1])
for output in zip(indices, scores)
for result in zip(*output)
]
return str(results)
demo = gr.Interface(fn=predict, inputs=textbox, outputs="text").launch()