import gradio as gr import datasets import faiss import os from transformers import pipeline auth_token = os.environ.get("CLARIN_KNEXT") sample_text = ( "Europejscy astronomowie odkryli planetę " "pozasłoneczną pochodzącą spoza naszej galaktyki, czyli " "[unused0] Drogi Mlecznej [unused1]. Obserwacji dokonali " "2,2-metrowym teleskopem MPG/ESO." ) textbox = gr.Textbox( label="Type your query here.", placeholder=sample_text, lines=10 ) def load_index(index_data: str = "clarin-knext/entity-linking-index"): ds = datasets.load_dataset(index_data, use_auth_token=auth_token)['train'] index_data = { idx: (e_id, e_text) for idx, (e_id, e_text) in enumerate(zip(ds['entities'], ds['texts'])) } faiss_index = faiss.read_index("./encoder.faissindex", faiss.IO_FLAG_MMAP) return index_data, faiss_index def load_model(model_name: str = "clarin-knext/entity-linking-encoder"): model = pipeline("feature-extraction", model=model_name, use_auth_token=auth_token) return model model = load_model() index = load_index() def predict(text: str = sample_text, top_k: int=3): text = text + "".join(['[PAD]' * 252]) index_data, faiss_index = index # takes only the [CLS] embedding (for now) query = model(text, return_tensors='pt')[0][0].numpy().reshape(1, -1) scores, indices = faiss_index.search(query, top_k) scores, indices = scores.tolist(), indices.tolist() results = "\n".join([ f"{index_data[result[0]]}: {result[1]}" for output in zip(indices, scores) for result in zip(*output) ]) return results demo = gr.Interface(fn=predict, inputs=textbox, outputs="text").launch()