File size: 3,259 Bytes
586b853
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a87018d
586b853
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a87018d
586b853
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import gradio as gr
import argparse, torch, gc, os, random, json
from data import device
import numpy as np
from data import MyDataset, load_data, my_collate_fn, device
import re
def clean_str(string,use=True):
    """
    Tokenization/string cleaning for all datasets except for SST.
    Original taken from https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py
    """
    if not use: return string

    string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string)
    string = re.sub(r"\'s", " \'s", string)
    string = re.sub(r"\'ve", " \'ve", string)
    string = re.sub(r"n\'t", " n\'t", string)
    string = re.sub(r"\'re", " \'re", string)
    string = re.sub(r"\'d", " \'d", string)
    string = re.sub(r"\'ll", " \'ll", string)
    string = re.sub(r",", " , ", string)
    string = re.sub(r"!", " ! ", string)
    string = re.sub(r"\(", " \( ", string)
    string = re.sub(r"\)", " \) ", string)
    string = re.sub(r"\?", " \? ", string)
    string = re.sub(r"\s{2,}", " ", string)
    return string.strip().lower()

title_list = np.load("./title_list.npy", allow_pickle=True).tolist()
data_path = os.path.join('..', 'data')
device = 'cpu'
vec_inuse = json.load(open('./papers_embedding_inuse.json'))
vocab = list(vec_inuse)
vocab_size = len(vocab) + 2
word2index = dict()
index2word = list()
word2index['<PAD>'] = 0
word2index['<OOV>'] = 1
index2word.extend(['<PAD>', '<OOV>'])
index2word.extend(vocab)
word2vec = np.zeros((vocab_size, len(list(vec_inuse.values())[0])), dtype=np.float32)
for wd in vocab:
    index = len(word2index)
    word2index[wd] = index
    word2vec[index, :] = vec_inuse[wd]

def data2index(data_x, word2index):
    data_x_idx = list()
    for instance in data_x:
        def_word_idx = list()
        def_words = clean_str(instance['question']).strip().split()

        for def_word in def_words:
            if def_word in word2index and def_word!=instance['answer']:
                def_word_idx.append(word2index[def_word])
            else:
                def_word_idx.append(word2index['<OOV>'])
        data_x_idx.append({'answer': word2index[instance['answer']], 'question_words': def_word_idx})

    return data_x_idx


def greet(paper_str):
    pred_list = []

    model = torch.load("saved.model")
    model.eval()


    test_dataset = MyDataset(data2index(
        [
            {
            'answer':'p-7241',
            'question': paper_str
            }
        ], word2index))
    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=my_collate_fn)

    for words_t, definition_words_t in test_dataloader:
        indices = model('test', x=definition_words_t, w=words_t, mode="b")
        predicted = indices[:, :10].detach().cpu().numpy().tolist()
    predicted = [index2word[paper] for paper in predicted[0]]

    del pred_list
    gc.collect()

    papers_output = []
    for paper_i in predicted:
        paper_i = int(paper_i.split('-')[1])
        papers_output.append(title_list[paper_i])
    return papers_output
    
    
with gr.Blocks() as demo:
    name = gr.Textbox(label="Question")
    output = gr.Textbox(label="Papers")
    greet_btn = gr.Button("Submit")
    greet_btn.click(fn=greet, inputs=name, outputs=output)

demo.launch()