File size: 6,062 Bytes
47b0e32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
'''
this is for existing formulae retrieval
'''

import gradio as gr
import pickle
from pdb import set_trace
from retrieval_utils import *


### for the webapp
title = "math equation retrieval demo using the OpenStax Calculus textbook"
description = "This is a demo for math equation retrieval based on research developed at Rice University. Click on one of the examples or type an equation of your own. Then click submit to see the retrieved equation along with its surrounding context. The retrieved equation is marked in red. Currently DOES NOT support single symbol retrieval. Demo uses the OpenStax calculus textbook. "
# article = "<b>Warning and disclaimer:</b> Currently we do not gaurantee the generated questions to be good 100 percent of the time. Also the model may generate content that is biased, hostile, or toxic; be extra careful about what you provide as input. We are not responsible for any impacts that the generated content may have.<p style='text-align: center'>Developed at <a href='https://rice.edu/'>Rice University</a> and <a href='https://openstax.org/'>OpenStax</a> with <a href='https://gradio.app/'>gradio</a> and <a href='https://beta.openai.com/'>OpenAI API</a></p>"
examples = [
    ["""y=\\text{sin}\\phantom{\\rule{0.1em}{0ex}}x"""],
    ["""\\left[a,b\\right]"""],
    ["""20g"""],
    ["""x=5"""],
    ["""f\\left(x\\right)=\\frac{1}{\\sqrt{1+{x}^{2}}}"""],
    ["""P\\left(x\\right)=30x-0.3{x}^{2}-250"""],
    ["""\\epsilon =0.8;"""],
    ["""{x}_{n}=\\frac{{x}_{n-1}}{2}+\\frac{3}{2{x}_{n-1}}"""],
    ["""y=f\\left(x\\right),y=1+f\\left(x\\right),x=0,y=0"""],
    ["""\\frac{1}{2}du=d\\theta"""]
]



######################## for the retrieval model #######################
########################################################################
## configs

def retrieval_fn(inp):
    _top=1000 # top nodes as vocab
    _min_nodes=3 #3, 5
    _max_nodes=150 #150, 10
    _max_children=8 # 1 is actually 2 children because indexing starts at 0
    _max_depth=19
    _simple = False # non typed nodes; all tail nodes are converted to UNKNOWN
    simple = 'simple' if _simple else 'complex'
    _w_unknown = True # typed nodes but with unknown tail for "other". cannot be used together with "simple"
    _collection = 'OS_calculus' # 'ARQMath' # 'WikiTopiceqArxiv'
    simple = False #if data_path.split('_')[-1] == 'simple' else False
    w_unknown = True #if data_path.split('_')[-1] == 'unknown' else False # TODO
    use_softmax = True # TODO
    topN=5
    result_path = '/mnt/math-text-embedding/math_embedding/results_retrieval/jack@51.143.92.116/2020-12-15_06-09-44_eqsWikiTopiceqArxiv-RNN-top_1000-minmax_nodes_3_150-max_children_8-max_depth_19-onehot2010_w_unknown_softmax'
    data_path = '/mnt/math-text-embedding/math_embedding/openstax_retrieval_demo/model_input_data/eqsOS_calculus-top_1000-minmax_nodes_3_150-max_children_8-max_depth_19_w_unknown'
    simple = False #if data_path.split('_')[-1] == 'simple' else False
    w_unknown = True #if data_path.split('_')[-1] == 'unknown' else False # TODO
    use_softmax = True # TODO

    vocab, vocab_T, vocab_N, vocab_V, vocab_other, stoi, itos, eqs, encoder_src, encoder_pos, \
    _, _, _, _, max_n_children = load_all_data(data_path)

    with open('tree_to_eq_dict.pkl', 'rb') as f:
        tree_to_eq_dict = pickle.load(f)
    with open('eq_to_context_dict.pkl', 'rb') as f:
        eq_to_context_dict = pickle.load(f)
    
    ## load model - rnn
    print('load model ...')
    # Configure models
    hidden_size = 500
    n_layers = 2
    dropout = 0.1
    pad_idx = stoi[PAD]

    # Initialize models
    pos_embedding = PositionalEncoding(max_depth=20, max_children=10, mode='onehot') 
    # encoder = EncoderRNN(len(vocab), hidden_size, pos_embedding, n_layers, dropout=dropout, use_softmax=use_softmax)
    encoder = EncoderRNN(4009, hidden_size, pos_embedding, n_layers, dropout=dropout, use_softmax=use_softmax)
    encoder.load_state_dict(torch.load(os.path.join(result_path, 'encoder_best.pt')))
    encoder.cuda();
    encoder.eval();
    
    # retrieval
    retrieval_result_eqs, retrieval_result_context, cos_values = retrieval(inp, eqs, tree_to_eq_dict, eq_to_context_dict, encoder_src, encoder_pos, encoder, 
                                                                stoi, vocab_N, vocab_V, vocab_T, vocab_other, topN,
                                                                _top, _min_nodes, _max_nodes, _max_children, _max_depth, _simple, _w_unknown)

    return retrieval_result_context


def generate(inp):

    retrieved = retrieval_fn(inp)
    new_r = []
    for r in retrieved:
        p = r[0:250]
        n = r[-250:]
        m = r[250:-250]
        new_r.append(('..... ' + p, None))
        new_r.append((m, 'retrieved equation'))
        new_r.append((n+' ......\r\n\r\n\r\n', None))

    # output = '\n\n\n'.join(retrieval_result_contexts['$$' + inp + '$$'][0:10])

    return new_r

counter = 0
def generate_html(inp):
    
    global counter
    
    retrieved = retrieval_fn(inp)
    html_txt = ''
    for r in retrieved:
        p = r[0:250]
        n = r[-250:]
        m = r[250:-250]
        html_txt += '<p>...... ' + p.replace('<', '&lt;').replace('>', '&gt;') + '<span style="color: #ff0000">{}</span>'.format(m.replace('<', '&lt;').replace('>', '&gt;')) + n.replace('<', '&lt;').replace('>', '&gt;') + ' ... </p><br><br>'
        
    html = "<!DOCTYPE html><html><body>" + html_txt + "</body></html>"
    html = html.replace('$$', '')
    # print(html)
    
    counter += 1
    print(counter)
    
    return html

# output_text = gr.outputs.Textbox()
# output_text = gr.outputs.HighlightedText()#color_map={"retrieved equation": "red"})
output_text = gr.outputs.HTML()#color_map={"retrieved equation": "red"})


gr.Interface(fn=generate_html, 
              inputs=gr.inputs.Textbox(lines=5, label="Input Text"),
              outputs=["html"],#output_text,
              title=title, description=description,
            #   article=article, 
              examples=examples).launch(share=True)