moonlightlane's picture
Add application file
47b0e32
raw
history blame
6.06 kB
'''
this is for existing formulae retrieval
'''
import gradio as gr
import pickle
from pdb import set_trace
from retrieval_utils import *
### for the webapp
title = "math equation retrieval demo using the OpenStax Calculus textbook"
description = "This is a demo for math equation retrieval based on research developed at Rice University. Click on one of the examples or type an equation of your own. Then click submit to see the retrieved equation along with its surrounding context. The retrieved equation is marked in red. Currently DOES NOT support single symbol retrieval. Demo uses the OpenStax calculus textbook. "
# article = "<b>Warning and disclaimer:</b> Currently we do not gaurantee the generated questions to be good 100 percent of the time. Also the model may generate content that is biased, hostile, or toxic; be extra careful about what you provide as input. We are not responsible for any impacts that the generated content may have.<p style='text-align: center'>Developed at <a href='https://rice.edu/'>Rice University</a> and <a href='https://openstax.org/'>OpenStax</a> with <a href='https://gradio.app/'>gradio</a> and <a href='https://beta.openai.com/'>OpenAI API</a></p>"
examples = [
["""y=\\text{sin}\\phantom{\\rule{0.1em}{0ex}}x"""],
["""\\left[a,b\\right]"""],
["""20g"""],
["""x=5"""],
["""f\\left(x\\right)=\\frac{1}{\\sqrt{1+{x}^{2}}}"""],
["""P\\left(x\\right)=30x-0.3{x}^{2}-250"""],
["""\\epsilon =0.8;"""],
["""{x}_{n}=\\frac{{x}_{n-1}}{2}+\\frac{3}{2{x}_{n-1}}"""],
["""y=f\\left(x\\right),y=1+f\\left(x\\right),x=0,y=0"""],
["""\\frac{1}{2}du=d\\theta"""]
]
######################## for the retrieval model #######################
########################################################################
## configs
def retrieval_fn(inp):
_top=1000 # top nodes as vocab
_min_nodes=3 #3, 5
_max_nodes=150 #150, 10
_max_children=8 # 1 is actually 2 children because indexing starts at 0
_max_depth=19
_simple = False # non typed nodes; all tail nodes are converted to UNKNOWN
simple = 'simple' if _simple else 'complex'
_w_unknown = True # typed nodes but with unknown tail for "other". cannot be used together with "simple"
_collection = 'OS_calculus' # 'ARQMath' # 'WikiTopiceqArxiv'
simple = False #if data_path.split('_')[-1] == 'simple' else False
w_unknown = True #if data_path.split('_')[-1] == 'unknown' else False # TODO
use_softmax = True # TODO
topN=5
result_path = '/mnt/math-text-embedding/math_embedding/results_retrieval/jack@51.143.92.116/2020-12-15_06-09-44_eqsWikiTopiceqArxiv-RNN-top_1000-minmax_nodes_3_150-max_children_8-max_depth_19-onehot2010_w_unknown_softmax'
data_path = '/mnt/math-text-embedding/math_embedding/openstax_retrieval_demo/model_input_data/eqsOS_calculus-top_1000-minmax_nodes_3_150-max_children_8-max_depth_19_w_unknown'
simple = False #if data_path.split('_')[-1] == 'simple' else False
w_unknown = True #if data_path.split('_')[-1] == 'unknown' else False # TODO
use_softmax = True # TODO
vocab, vocab_T, vocab_N, vocab_V, vocab_other, stoi, itos, eqs, encoder_src, encoder_pos, \
_, _, _, _, max_n_children = load_all_data(data_path)
with open('tree_to_eq_dict.pkl', 'rb') as f:
tree_to_eq_dict = pickle.load(f)
with open('eq_to_context_dict.pkl', 'rb') as f:
eq_to_context_dict = pickle.load(f)
## load model - rnn
print('load model ...')
# Configure models
hidden_size = 500
n_layers = 2
dropout = 0.1
pad_idx = stoi[PAD]
# Initialize models
pos_embedding = PositionalEncoding(max_depth=20, max_children=10, mode='onehot')
# encoder = EncoderRNN(len(vocab), hidden_size, pos_embedding, n_layers, dropout=dropout, use_softmax=use_softmax)
encoder = EncoderRNN(4009, hidden_size, pos_embedding, n_layers, dropout=dropout, use_softmax=use_softmax)
encoder.load_state_dict(torch.load(os.path.join(result_path, 'encoder_best.pt')))
encoder.cuda();
encoder.eval();
# retrieval
retrieval_result_eqs, retrieval_result_context, cos_values = retrieval(inp, eqs, tree_to_eq_dict, eq_to_context_dict, encoder_src, encoder_pos, encoder,
stoi, vocab_N, vocab_V, vocab_T, vocab_other, topN,
_top, _min_nodes, _max_nodes, _max_children, _max_depth, _simple, _w_unknown)
return retrieval_result_context
def generate(inp):
retrieved = retrieval_fn(inp)
new_r = []
for r in retrieved:
p = r[0:250]
n = r[-250:]
m = r[250:-250]
new_r.append(('..... ' + p, None))
new_r.append((m, 'retrieved equation'))
new_r.append((n+' ......\r\n\r\n\r\n', None))
# output = '\n\n\n'.join(retrieval_result_contexts['$$' + inp + '$$'][0:10])
return new_r
counter = 0
def generate_html(inp):
global counter
retrieved = retrieval_fn(inp)
html_txt = ''
for r in retrieved:
p = r[0:250]
n = r[-250:]
m = r[250:-250]
html_txt += '<p>...... ' + p.replace('<', '&lt;').replace('>', '&gt;') + '<span style="color: #ff0000">{}</span>'.format(m.replace('<', '&lt;').replace('>', '&gt;')) + n.replace('<', '&lt;').replace('>', '&gt;') + ' ... </p><br><br>'
html = "<!DOCTYPE html><html><body>" + html_txt + "</body></html>"
html = html.replace('$$', '')
# print(html)
counter += 1
print(counter)
return html
# output_text = gr.outputs.Textbox()
# output_text = gr.outputs.HighlightedText()#color_map={"retrieved equation": "red"})
output_text = gr.outputs.HTML()#color_map={"retrieved equation": "red"})
gr.Interface(fn=generate_html,
inputs=gr.inputs.Textbox(lines=5, label="Input Text"),
outputs=["html"],#output_text,
title=title, description=description,
# article=article,
examples=examples).launch(share=True)