File size: 8,316 Bytes
dd6041d
01c6f39
d014bc4
efadd28
d014bc4
2f634c0
 
 
 
 
d014bc4
01c6f39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d014bc4
2f634c0
efadd28
 
 
 
2fb3f40
43422b3
efadd28
 
 
d014bc4
 
 
 
 
 
 
 
 
 
 
 
dd6041d
 
 
d014bc4
 
 
 
 
 
 
 
 
 
 
 
abd8a31
 
 
 
 
 
 
 
 
 
 
d014bc4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87e53ba
d014bc4
 
 
 
 
902a6c9
 
d014bc4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01c6f39
f937011
01c6f39
 
 
 
 
 
d014bc4
01c6f39
7f005c9
 
01c6f39
7f005c9
01c6f39
d014bc4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18c1dd1
d014bc4
 
 
 
2f634c0
d014bc4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efadd28
 
 
82418b9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
import platform
from typing import Tuple, List, Dict, Optional

import streamlit as st
import torch
from trecover.config import var
from trecover.utils.beam_search import beam_search, dashboard_loop
from trecover.utils.inference import data_to_columns, create_noisy_columns
from trecover.utils.transform import columns_to_tensor, tensor_to_target
from trecover.utils.visualization import visualize_columns, visualize_target

MAX_CHARS = 256

PLAIN_EXAMPLES = {
    'Select example': None,
    'Example 1': 'As people around the country went into the streets to cheer the conviction, some businesses in '
                 'Portland boarded up their windows once again.',
    'Example 2': 'That night, a small group of activists wearing black approached a group of journalists, threatening'
                 ' to smash the cameras of those who remained on scene.',
    'Example 3': 'English as we know it today came to be exported to other parts of the world through British '
                 'colonisation, and is now the dominant language in Britain'
}

NOISED_EXAMPLES = {
    'Select example': None,
    'Example 1': 'a ds fpziq ofe ngkhbo p pghl ue waq frlqjo o u dnxrm dgr yrtsco kho deuasm dhysc ao u nwzhy tle r '
                 'yzpe xwabc gce nger klqto wiq nfprso t no tpgq tcfh ae twas tw ur re e t gyutsm t xgo rc ubhq e wle '
                 'r ty h nwpeaq xdsc o dnhelm v thir ikcq tkuo i o twn ps frio mo oe b kuiqtb jsq zi tnye ge dgrqs s '
                 'cioe ys whic wne wp thlo dnprsc xvpyrt hurlm kveaj nbfp dome pbeaj dusmo a r dzrqsm xace du nxkuai '
                 'gpulcm tpi h pie uim r wbhrj ui n dwgp dkeio nkwhqs zs'
}

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = torch.hub.load('alex-snd/TRecover', model='trecover', device=device, version='latest')


def main() -> None:
    st.set_page_config(
        page_title='TRecover',
        page_icon='🩹',
        layout='wide',
        initial_sidebar_state='expanded')

    if 'history' not in st.session_state:
        st.session_state.history = list()

    if 'data' not in st.session_state:
        st.session_state.data = ''

    if 'regenerate' not in st.session_state:
        st.session_state.regenerate = False

    if 'columns' not in st.session_state:
        st.session_state.columns = None

    if 'is_unix' not in st.session_state:
        st.session_state.is_unix = platform.system() != 'Windows'

    sidebar()


def set_regenerate() -> None:
    st.session_state.regenerate = True


def unset_regenerate() -> None:
    st.session_state.regenerate = False


def sidebar() -> None:
    st.sidebar.markdown(body=
                        """
                        <h1 align="center"> 
                            <font size="20">🤷</font>
                            <a href="https://alex-snd.github.io/TRecover">About the Project</a>
                        </h1>
                        <br><br>
                        """,
                        unsafe_allow_html=True)

    option = st.sidebar.radio('Sections', ('Inference', 'Inference history'))

    if option == 'Inference':
        is_plain, min_noise, max_noise, bw = inference_sidebar()
        inference_page(is_plain, min_noise, max_noise, bw)
    else:
        history_sidebar()
        history_page()


def inference_sidebar() -> Tuple[bool, int, int, int]:
    st.sidebar.text('\n')

    data_type = st.sidebar.radio('Input type', ('Plain text', 'Noisy columns'), key='data_type',
                                 index=0 if 'Plain text' == st.session_state.get('data_type', 'Plain text') else 1)
    is_plain = data_type == 'Plain text'

    st.sidebar.text('\n')

    if is_plain:
        min_noise, max_noise = st.sidebar.slider('\nNoise range', 0, 5, key='noise_range',
                                                 value=st.session_state.get('noise_range', (0, 5)),
                                                 on_change=set_regenerate)
    else:
        min_noise, max_noise = 0, 0

    bw = st.sidebar.slider('Beam search width', 1, 25, key='beam_width',
                           value=st.session_state.get('beam_width', 5))

    if max_noise > var.MAX_NOISE:
        st.sidebar.warning('Max noise value is too large. This will entail poor performance')

    return is_plain, min_noise, max_noise + 1, bw


def history_sidebar() -> None:
    pass


def save_to_history(is_plain: bool,
                    min_noise: int,
                    max_noise: int,
                    bw: int,
                    columns: List[str],
                    chains: List[Tuple[str, float]]
                    ) -> None:
    text = st.session_state.data if is_plain else None

    st.session_state.history.append((is_plain, text, min_noise, max_noise, bw, columns, chains))


@st.cache(ttl=3600, show_spinner=False, suppress_st_warning=True)
def predict(columns: List[str], bw: int) -> List[Tuple[str, float]]:
    src = columns_to_tensor(columns, device)

    chains = beam_search(src, model, bw, device, beam_loop=dashboard_loop)
    chains = [(visualize_target(tensor_to_target(chain)), prob) for (chain, prob) in chains]

    return chains


def get_noisy_columns(data: str, min_noise: int, max_noise: int) -> List[str]:
    columns = create_noisy_columns(data, min_noise, max_noise)

    return [''.join(set(c)) for c in columns]  # kinda shuffle columns


def get_input_data(examples: Dict[str, Optional[str]], max_chars: int) -> str:
    input_field, examples_filed = st.columns([1, 0.27])

    option = examples_filed.selectbox(label='', options=examples.keys())

    return input_field.text_input(label='', value=examples[option] or st.session_state.data, max_chars=max_chars)


def inference_page(is_plain: bool, min_noise: int, max_noise: int, bw: int) -> None:
    st.subheader('Insert plain text' if is_plain else 'Insert noisy columns separated by spaces')

    if is_plain:
        data = get_input_data(PLAIN_EXAMPLES, max_chars=MAX_CHARS)
    else:
        data = get_input_data(NOISED_EXAMPLES, max_chars=MAX_CHARS * 4)

    if not data:
        st.stop()

    if is_plain:
        if st.session_state.regenerate or not st.session_state.columns or data != st.session_state.data:
            columns = get_noisy_columns(data, min_noise, max_noise)
            st.session_state.columns = columns
            unset_regenerate()
        else:
            columns = st.session_state.columns
    else:
        columns = data_to_columns(data, separator=' ')

    st.session_state.data = data

    st.subheader('\nColumns')
    st.text(visualize_columns(columns, delimiter=''))
    st.subheader('\n')

    placeholder = st.empty()
    recover_field, regen_filed = placeholder.columns([.11, 1])

    if is_plain:
        regen_filed.button('Regenerate', on_click=set_regenerate)

    if columns and recover_field.button('Recover'):
        if st.session_state.is_unix:
            with placeholder.container():
                progress_bar_placeholder = st.empty()
                st.button('Stop')

                with progress_bar_placeholder:
                    chains = predict(columns, bw)
        else:
            with placeholder:
                chains = predict(columns, bw)

        with placeholder.container():
            st.subheader('\nPrediction')
            st.text('\n\n'.join(chain for chain, _ in chains))

            if st.button('Clear'):
                st.session_state.task_id = None

        save_to_history(is_plain, min_noise, max_noise, bw, columns, chains)


def history_page() -> None:
    st.header('Inference History')

    if len(st.session_state.history) == 0:
        st.info('No records saved')
        return

    for record_id, (is_plain, text, min_noise, max_noise, bw, columns, chains) in enumerate(st.session_state.history,
                                                                                            start=1):
        st.info(f'Record {record_id}')

        if is_plain:
            st.text(f'Plain data: {text}')

        st.text(f'Noise range: [{min_noise}, {max_noise}]')
        st.text(f'Beam search width: {bw}')

        st.text('Columns:')
        st.text(visualize_columns(columns, delimiter=''))

        st.text('Prediction:')
        st.text('\n\n'.join(chain for chain, _ in chains))


if __name__ == '__main__':
    main()