Spaces:
Runtime error
Runtime error
File size: 3,259 Bytes
586b853 a87018d 586b853 a87018d 586b853 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import gradio as gr
import argparse, torch, gc, os, random, json
from data import device
import numpy as np
from data import MyDataset, load_data, my_collate_fn, device
import re
def clean_str(string,use=True):
"""
Tokenization/string cleaning for all datasets except for SST.
Original taken from https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py
"""
if not use: return string
string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string)
string = re.sub(r"\'s", " \'s", string)
string = re.sub(r"\'ve", " \'ve", string)
string = re.sub(r"n\'t", " n\'t", string)
string = re.sub(r"\'re", " \'re", string)
string = re.sub(r"\'d", " \'d", string)
string = re.sub(r"\'ll", " \'ll", string)
string = re.sub(r",", " , ", string)
string = re.sub(r"!", " ! ", string)
string = re.sub(r"\(", " \( ", string)
string = re.sub(r"\)", " \) ", string)
string = re.sub(r"\?", " \? ", string)
string = re.sub(r"\s{2,}", " ", string)
return string.strip().lower()
title_list = np.load("./title_list.npy", allow_pickle=True).tolist()
data_path = os.path.join('..', 'data')
device = 'cpu'
vec_inuse = json.load(open('./papers_embedding_inuse.json'))
vocab = list(vec_inuse)
vocab_size = len(vocab) + 2
word2index = dict()
index2word = list()
word2index['<PAD>'] = 0
word2index['<OOV>'] = 1
index2word.extend(['<PAD>', '<OOV>'])
index2word.extend(vocab)
word2vec = np.zeros((vocab_size, len(list(vec_inuse.values())[0])), dtype=np.float32)
for wd in vocab:
index = len(word2index)
word2index[wd] = index
word2vec[index, :] = vec_inuse[wd]
def data2index(data_x, word2index):
data_x_idx = list()
for instance in data_x:
def_word_idx = list()
def_words = clean_str(instance['question']).strip().split()
for def_word in def_words:
if def_word in word2index and def_word!=instance['answer']:
def_word_idx.append(word2index[def_word])
else:
def_word_idx.append(word2index['<OOV>'])
data_x_idx.append({'answer': word2index[instance['answer']], 'question_words': def_word_idx})
return data_x_idx
def greet(paper_str):
pred_list = []
model = torch.load("saved.model")
model.eval()
test_dataset = MyDataset(data2index(
[
{
'answer':'p-7241',
'question': paper_str
}
], word2index))
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=my_collate_fn)
for words_t, definition_words_t in test_dataloader:
indices = model('test', x=definition_words_t, w=words_t, mode="b")
predicted = indices[:, :10].detach().cpu().numpy().tolist()
predicted = [index2word[paper] for paper in predicted[0]]
del pred_list
gc.collect()
papers_output = []
for paper_i in predicted:
paper_i = int(paper_i.split('-')[1])
papers_output.append(title_list[paper_i])
return papers_output
with gr.Blocks() as demo:
name = gr.Textbox(label="Question")
output = gr.Textbox(label="Papers")
greet_btn = gr.Button("Submit")
greet_btn.click(fn=greet, inputs=name, outputs=output)
demo.launch() |