Spaces:
Runtime error
Runtime error
import gradio as gr | |
import argparse, torch, gc, os, random, json | |
from data import device | |
import numpy as np | |
from data import MyDataset, load_data, my_collate_fn, device | |
import re | |
def clean_str(string,use=True): | |
""" | |
Tokenization/string cleaning for all datasets except for SST. | |
Original taken from https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py | |
""" | |
if not use: return string | |
string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string) | |
string = re.sub(r"\'s", " \'s", string) | |
string = re.sub(r"\'ve", " \'ve", string) | |
string = re.sub(r"n\'t", " n\'t", string) | |
string = re.sub(r"\'re", " \'re", string) | |
string = re.sub(r"\'d", " \'d", string) | |
string = re.sub(r"\'ll", " \'ll", string) | |
string = re.sub(r",", " , ", string) | |
string = re.sub(r"!", " ! ", string) | |
string = re.sub(r"\(", " \( ", string) | |
string = re.sub(r"\)", " \) ", string) | |
string = re.sub(r"\?", " \? ", string) | |
string = re.sub(r"\s{2,}", " ", string) | |
return string.strip().lower() | |
title_list = np.load("./title_list.npy", allow_pickle=True).tolist() | |
data_path = os.path.join('..', 'data') | |
device = 'cpu' | |
vec_inuse = json.load(open('./papers_embedding_inuse.json')) | |
vocab = list(vec_inuse) | |
vocab_size = len(vocab) + 2 | |
word2index = dict() | |
index2word = list() | |
word2index['<PAD>'] = 0 | |
word2index['<OOV>'] = 1 | |
index2word.extend(['<PAD>', '<OOV>']) | |
index2word.extend(vocab) | |
word2vec = np.zeros((vocab_size, len(list(vec_inuse.values())[0])), dtype=np.float32) | |
for wd in vocab: | |
index = len(word2index) | |
word2index[wd] = index | |
word2vec[index, :] = vec_inuse[wd] | |
def data2index(data_x, word2index): | |
data_x_idx = list() | |
for instance in data_x: | |
def_word_idx = list() | |
def_words = clean_str(instance['question']).strip().split() | |
for def_word in def_words: | |
if def_word in word2index and def_word!=instance['answer']: | |
def_word_idx.append(word2index[def_word]) | |
else: | |
def_word_idx.append(word2index['<OOV>']) | |
data_x_idx.append({'answer': word2index[instance['answer']], 'question_words': def_word_idx}) | |
return data_x_idx | |
def greet(paper_str): | |
pred_list = [] | |
model = torch.load("saved.model") | |
model.eval() | |
test_dataset = MyDataset(data2index( | |
[ | |
{ | |
'answer':'p-7241', | |
'question': paper_str | |
} | |
], word2index)) | |
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=my_collate_fn) | |
for words_t, definition_words_t in test_dataloader: | |
indices = model('test', x=definition_words_t, w=words_t, mode="b") | |
predicted = indices[:, :10].detach().cpu().numpy().tolist() | |
predicted = [index2word[paper] for paper in predicted[0]] | |
del pred_list | |
gc.collect() | |
papers_output = [] | |
for paper_i in predicted: | |
paper_i = int(paper_i.split('-')[1]) | |
papers_output.append(title_list[paper_i]) | |
return papers_output | |
with gr.Blocks() as demo: | |
name = gr.Textbox(label="Question") | |
output = gr.Textbox(label="Papers") | |
greet_btn = gr.Button("Submit") | |
greet_btn.click(fn=greet, inputs=name, outputs=output) | |
demo.launch() |