File size: 1,682 Bytes
6a3aaf8
 
 
6831777
6a3aaf8
88d581a
6a3aaf8
 
6831777
 
 
fb53c32
 
 
 
 
 
6a3aaf8
 
 
 
10ac5a3
6a3aaf8
 
 
 
5d35937
6a3aaf8
 
 
 
8182466
6a3aaf8
 
 
 
88d581a
e30b005
6a3aaf8
 
e30b005
c4f4d05
 
 
66d0fee
6a3aaf8
c4f4d05
efae79d
6a3aaf8
 
c4f4d05
6a3aaf8
66d0fee
 
c4f4d05
 
66d0fee
c4f4d05
66d0fee
6a3aaf8
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import gradio as gr
import datasets
import faiss
import os

from transformers import pipeline


auth_token = os.environ.get("CLARIN_KNEXT")


sample_text = (
    "Europejscy astronomowie odkryli planetę "
    "pozasłoneczną pochodzącą spoza naszej galaktyki, czyli "
    "[unused0] Drogi Mlecznej [unused1]. Obserwacji dokonali "
    "2,2-metrowym teleskopem MPG/ESO."
)


textbox = gr.Textbox(
    label="Type your query here.",
    value=sample_text, lines=10
)


def load_index(index_data: str = "clarin-knext/entity-linking-index"):
    ds = datasets.load_dataset(index_data, use_auth_token=auth_token)['train']
    index_data = {
        idx: (e_id, e_text) for idx, (e_id, e_text) in
        enumerate(zip(ds['entities'], ds['texts']))
    }
    faiss_index = faiss.read_index("./encoder.faissindex", faiss.IO_FLAG_MMAP)
    return index_data, faiss_index


def load_model(model_name: str = "clarin-knext/entity-linking-encoder"):
    model = pipeline("feature-extraction", model=model_name, use_auth_token=auth_token)
    return model


model = load_model()
index = load_index()


def predict(text: str = sample_text, top_k: int=3):
    index_data, faiss_index = index
    # takes only the [CLS] embedding (for now)
    query = model(text, return_tensors='pt')[0][0].numpy().reshape(1, -1)  
    
    scores, indices = faiss_index.search(query, top_k)
    scores, indices = scores.tolist(), indices.tolist()
    
    results = "\n".join([
        f"{index_data[result[0]]}: {result[1]}"
        for output in zip(indices, scores)
        for result in zip(*output)
    ])
    
    return results


demo = gr.Interface(fn=predict, inputs=textbox, outputs="text").launch()