wsd-linking / app.py
ajanz's picture
example text has been changed
404cfe2
raw
history blame contribute delete
No virus
1.71 kB
import gradio as gr
import datasets
import faiss
import os
from transformers import pipeline
auth_token = os.environ.get("CLARIN_KNEXT")
sample_text = (
"NASA poinformowała o nowych wynikach obserwacji przy pomocy "
"[unused0] teleskopu [unused1] JWST. Naukowcom udało się odnaleźć "
"parę wodną w jednym z systemów gwiazdowych w odległości od gwiazdy "
"podobnej jak dystans Ziemi od Słońca."
)
textbox = gr.Textbox(
label="Type your query here.",
value=sample_text, lines=10
)
def load_index(index_data: str = "clarin-knext/wsd-linking-index"):
ds = datasets.load_dataset(index_data, use_auth_token=auth_token)['train']
index_data = {
idx: (e_id, e_text) for idx, (e_id, e_text) in
enumerate(zip(ds['entities'], ds['texts']))
}
faiss_index = faiss.read_index("./encoder.faissindex", faiss.IO_FLAG_MMAP)
return index_data, faiss_index
def load_model(model_name: str = "clarin-knext/wsd-encoder"):
model = pipeline("feature-extraction", model=model_name, use_auth_token=auth_token)
return model
model = load_model()
index = load_index()
def predict(text: str = sample_text, top_k: int=3):
index_data, faiss_index = index
# takes only the [CLS] embedding (for now)
query = model(text, return_tensors='pt')[0][0].numpy().reshape(1, -1)
scores, indices = faiss_index.search(query, top_k)
scores, indices = scores.tolist(), indices.tolist()
results = "\n".join([
f"{index_data[result[0]]}: {result[1]}"
for output in zip(indices, scores)
for result in zip(*output)
])
return results
demo = gr.Interface(fn=predict, inputs=textbox, outputs="text").launch()