Spaces:
Runtime error
Runtime error
''' | |
this is for existing formulae retrieval | |
''' | |
import gradio as gr | |
import pickle | |
from pdb import set_trace | |
from retrieval_utils import * | |
### for the webapp | |
title = "math equation retrieval demo using the OpenStax Calculus textbook" | |
description = "This is a demo for math equation retrieval based on research developed at Rice University. Click on one of the examples or type an equation of your own. Then click submit to see the retrieved equation along with its surrounding context. The retrieved equation is marked in red. Currently DOES NOT support single symbol retrieval. Demo uses the OpenStax calculus textbook. " | |
# article = "<b>Warning and disclaimer:</b> Currently we do not gaurantee the generated questions to be good 100 percent of the time. Also the model may generate content that is biased, hostile, or toxic; be extra careful about what you provide as input. We are not responsible for any impacts that the generated content may have.<p style='text-align: center'>Developed at <a href='https://rice.edu/'>Rice University</a> and <a href='https://openstax.org/'>OpenStax</a> with <a href='https://gradio.app/'>gradio</a> and <a href='https://beta.openai.com/'>OpenAI API</a></p>" | |
examples = [ | |
["""y=\\text{sin}\\phantom{\\rule{0.1em}{0ex}}x"""], | |
["""\\left[a,b\\right]"""], | |
["""20g"""], | |
["""x=5"""], | |
["""f\\left(x\\right)=\\frac{1}{\\sqrt{1+{x}^{2}}}"""], | |
["""P\\left(x\\right)=30x-0.3{x}^{2}-250"""], | |
["""\\epsilon =0.8;"""], | |
["""{x}_{n}=\\frac{{x}_{n-1}}{2}+\\frac{3}{2{x}_{n-1}}"""], | |
["""y=f\\left(x\\right),y=1+f\\left(x\\right),x=0,y=0"""], | |
["""\\frac{1}{2}du=d\\theta"""] | |
] | |
######################## for the retrieval model ####################### | |
######################################################################## | |
## configs | |
def retrieval_fn(inp): | |
_top=1000 # top nodes as vocab | |
_min_nodes=3 #3, 5 | |
_max_nodes=150 #150, 10 | |
_max_children=8 # 1 is actually 2 children because indexing starts at 0 | |
_max_depth=19 | |
_simple = False # non typed nodes; all tail nodes are converted to UNKNOWN | |
simple = 'simple' if _simple else 'complex' | |
_w_unknown = True # typed nodes but with unknown tail for "other". cannot be used together with "simple" | |
_collection = 'OS_calculus' # 'ARQMath' # 'WikiTopiceqArxiv' | |
simple = False #if data_path.split('_')[-1] == 'simple' else False | |
w_unknown = True #if data_path.split('_')[-1] == 'unknown' else False # TODO | |
use_softmax = True # TODO | |
topN=5 | |
result_path = '/mnt/math-text-embedding/math_embedding/results_retrieval/jack@51.143.92.116/2020-12-15_06-09-44_eqsWikiTopiceqArxiv-RNN-top_1000-minmax_nodes_3_150-max_children_8-max_depth_19-onehot2010_w_unknown_softmax' | |
data_path = '/mnt/math-text-embedding/math_embedding/openstax_retrieval_demo/model_input_data/eqsOS_calculus-top_1000-minmax_nodes_3_150-max_children_8-max_depth_19_w_unknown' | |
simple = False #if data_path.split('_')[-1] == 'simple' else False | |
w_unknown = True #if data_path.split('_')[-1] == 'unknown' else False # TODO | |
use_softmax = True # TODO | |
vocab, vocab_T, vocab_N, vocab_V, vocab_other, stoi, itos, eqs, encoder_src, encoder_pos, \ | |
_, _, _, _, max_n_children = load_all_data(data_path) | |
with open('tree_to_eq_dict.pkl', 'rb') as f: | |
tree_to_eq_dict = pickle.load(f) | |
with open('eq_to_context_dict.pkl', 'rb') as f: | |
eq_to_context_dict = pickle.load(f) | |
## load model - rnn | |
print('load model ...') | |
# Configure models | |
hidden_size = 500 | |
n_layers = 2 | |
dropout = 0.1 | |
pad_idx = stoi[PAD] | |
# Initialize models | |
pos_embedding = PositionalEncoding(max_depth=20, max_children=10, mode='onehot') | |
# encoder = EncoderRNN(len(vocab), hidden_size, pos_embedding, n_layers, dropout=dropout, use_softmax=use_softmax) | |
encoder = EncoderRNN(4009, hidden_size, pos_embedding, n_layers, dropout=dropout, use_softmax=use_softmax) | |
encoder.load_state_dict(torch.load(os.path.join(result_path, 'encoder_best.pt'))) | |
encoder.cuda(); | |
encoder.eval(); | |
# retrieval | |
retrieval_result_eqs, retrieval_result_context, cos_values = retrieval(inp, eqs, tree_to_eq_dict, eq_to_context_dict, encoder_src, encoder_pos, encoder, | |
stoi, vocab_N, vocab_V, vocab_T, vocab_other, topN, | |
_top, _min_nodes, _max_nodes, _max_children, _max_depth, _simple, _w_unknown) | |
return retrieval_result_context | |
def generate(inp): | |
retrieved = retrieval_fn(inp) | |
new_r = [] | |
for r in retrieved: | |
p = r[0:250] | |
n = r[-250:] | |
m = r[250:-250] | |
new_r.append(('..... ' + p, None)) | |
new_r.append((m, 'retrieved equation')) | |
new_r.append((n+' ......\r\n\r\n\r\n', None)) | |
# output = '\n\n\n'.join(retrieval_result_contexts['$$' + inp + '$$'][0:10]) | |
return new_r | |
counter = 0 | |
def generate_html(inp): | |
global counter | |
retrieved = retrieval_fn(inp) | |
html_txt = '' | |
for r in retrieved: | |
p = r[0:250] | |
n = r[-250:] | |
m = r[250:-250] | |
html_txt += '<p>...... ' + p.replace('<', '<').replace('>', '>') + '<span style="color: #ff0000">{}</span>'.format(m.replace('<', '<').replace('>', '>')) + n.replace('<', '<').replace('>', '>') + ' ... </p><br><br>' | |
html = "<!DOCTYPE html><html><body>" + html_txt + "</body></html>" | |
html = html.replace('$$', '') | |
# print(html) | |
counter += 1 | |
print(counter) | |
return html | |
# output_text = gr.outputs.Textbox() | |
# output_text = gr.outputs.HighlightedText()#color_map={"retrieved equation": "red"}) | |
output_text = gr.outputs.HTML()#color_map={"retrieved equation": "red"}) | |
gr.Interface(fn=generate_html, | |
inputs=gr.inputs.Textbox(lines=5, label="Input Text"), | |
outputs=["html"],#output_text, | |
title=title, description=description, | |
# article=article, | |
examples=examples).launch(share=True) |