Spaces:
Runtime error
Runtime error
import gradio as gr | |
import datasets | |
import faiss | |
from transformers import pipeline | |
sample_text = """Europejscy astronomowie odkryli planetę | |
pozasłoneczną pochodzącą spoza naszej galaktyki, czyli | |
[unused0] Drogi Mlecznej [unused1]. Obserwacji dokonali | |
2,2-metrowym teleskopem MPG/ESO.""" | |
textbox = gr.Textbox( | |
label="Type your query here.", | |
placeholder=sample_text, lines=10 | |
) | |
def load_index(index_data: str = "clarin-knext/entity-linking-index"): | |
ds = datasets.load_dataset(index_data)['train'] | |
index_data = { | |
idx: (e_id, e_text) for idx, (e_id, e_text) in | |
enumerate(zip(ds['entities'], ds['texts'])) | |
} | |
faiss_index = faiss.read_index("./encoder.faissindex", faiss.IO_FLAG_MMAP) | |
return index_data, faiss_index | |
def load_model(model_name: str = "clarin-knext/entity-linking-encoder"): | |
pipe = pipeline("feature-extraction", model=model_name) | |
return pipe | |
model = load_model() | |
index = load_index() | |
def predict(query: str = sample_text, top_k: int=3): | |
index_data, faiss_index = index | |
# takes only the [CLS] embedding (for now) | |
query = model(query, return_tensors = "pt")[0][0].numpy().reshape(1, -1) | |
scores, indices = faiss_index.search(query, top_k) | |
scores, indices = scores.tolist(), indices.tolist() | |
results = [ | |
(index_data[result[0]], result[1]) | |
for output in zip(indices, scores) | |
for result in zip(*output) | |
] | |
return str(results) | |
demo = gr.Interface(fn=predict, inputs=textbox, outputs="text").launch() |