File size: 8,316 Bytes
dd6041d 01c6f39 d014bc4 efadd28 d014bc4 2f634c0 d014bc4 01c6f39 d014bc4 2f634c0 efadd28 2fb3f40 43422b3 efadd28 d014bc4 dd6041d d014bc4 abd8a31 d014bc4 87e53ba d014bc4 902a6c9 d014bc4 01c6f39 f937011 01c6f39 d014bc4 01c6f39 7f005c9 01c6f39 7f005c9 01c6f39 d014bc4 18c1dd1 d014bc4 2f634c0 d014bc4 efadd28 82418b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 |
import platform
from typing import Tuple, List, Dict, Optional
import streamlit as st
import torch
from trecover.config import var
from trecover.utils.beam_search import beam_search, dashboard_loop
from trecover.utils.inference import data_to_columns, create_noisy_columns
from trecover.utils.transform import columns_to_tensor, tensor_to_target
from trecover.utils.visualization import visualize_columns, visualize_target
MAX_CHARS = 256
PLAIN_EXAMPLES = {
'Select example': None,
'Example 1': 'As people around the country went into the streets to cheer the conviction, some businesses in '
'Portland boarded up their windows once again.',
'Example 2': 'That night, a small group of activists wearing black approached a group of journalists, threatening'
' to smash the cameras of those who remained on scene.',
'Example 3': 'English as we know it today came to be exported to other parts of the world through British '
'colonisation, and is now the dominant language in Britain'
}
NOISED_EXAMPLES = {
'Select example': None,
'Example 1': 'a ds fpziq ofe ngkhbo p pghl ue waq frlqjo o u dnxrm dgr yrtsco kho deuasm dhysc ao u nwzhy tle r '
'yzpe xwabc gce nger klqto wiq nfprso t no tpgq tcfh ae twas tw ur re e t gyutsm t xgo rc ubhq e wle '
'r ty h nwpeaq xdsc o dnhelm v thir ikcq tkuo i o twn ps frio mo oe b kuiqtb jsq zi tnye ge dgrqs s '
'cioe ys whic wne wp thlo dnprsc xvpyrt hurlm kveaj nbfp dome pbeaj dusmo a r dzrqsm xace du nxkuai '
'gpulcm tpi h pie uim r wbhrj ui n dwgp dkeio nkwhqs zs'
}
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = torch.hub.load('alex-snd/TRecover', model='trecover', device=device, version='latest')
def main() -> None:
st.set_page_config(
page_title='TRecover',
page_icon='🩹',
layout='wide',
initial_sidebar_state='expanded')
if 'history' not in st.session_state:
st.session_state.history = list()
if 'data' not in st.session_state:
st.session_state.data = ''
if 'regenerate' not in st.session_state:
st.session_state.regenerate = False
if 'columns' not in st.session_state:
st.session_state.columns = None
if 'is_unix' not in st.session_state:
st.session_state.is_unix = platform.system() != 'Windows'
sidebar()
def set_regenerate() -> None:
st.session_state.regenerate = True
def unset_regenerate() -> None:
st.session_state.regenerate = False
def sidebar() -> None:
st.sidebar.markdown(body=
"""
<h1 align="center">
<font size="20">🤷</font>
<a href="https://alex-snd.github.io/TRecover">About the Project</a>
</h1>
<br><br>
""",
unsafe_allow_html=True)
option = st.sidebar.radio('Sections', ('Inference', 'Inference history'))
if option == 'Inference':
is_plain, min_noise, max_noise, bw = inference_sidebar()
inference_page(is_plain, min_noise, max_noise, bw)
else:
history_sidebar()
history_page()
def inference_sidebar() -> Tuple[bool, int, int, int]:
st.sidebar.text('\n')
data_type = st.sidebar.radio('Input type', ('Plain text', 'Noisy columns'), key='data_type',
index=0 if 'Plain text' == st.session_state.get('data_type', 'Plain text') else 1)
is_plain = data_type == 'Plain text'
st.sidebar.text('\n')
if is_plain:
min_noise, max_noise = st.sidebar.slider('\nNoise range', 0, 5, key='noise_range',
value=st.session_state.get('noise_range', (0, 5)),
on_change=set_regenerate)
else:
min_noise, max_noise = 0, 0
bw = st.sidebar.slider('Beam search width', 1, 25, key='beam_width',
value=st.session_state.get('beam_width', 5))
if max_noise > var.MAX_NOISE:
st.sidebar.warning('Max noise value is too large. This will entail poor performance')
return is_plain, min_noise, max_noise + 1, bw
def history_sidebar() -> None:
pass
def save_to_history(is_plain: bool,
min_noise: int,
max_noise: int,
bw: int,
columns: List[str],
chains: List[Tuple[str, float]]
) -> None:
text = st.session_state.data if is_plain else None
st.session_state.history.append((is_plain, text, min_noise, max_noise, bw, columns, chains))
@st.cache(ttl=3600, show_spinner=False, suppress_st_warning=True)
def predict(columns: List[str], bw: int) -> List[Tuple[str, float]]:
src = columns_to_tensor(columns, device)
chains = beam_search(src, model, bw, device, beam_loop=dashboard_loop)
chains = [(visualize_target(tensor_to_target(chain)), prob) for (chain, prob) in chains]
return chains
def get_noisy_columns(data: str, min_noise: int, max_noise: int) -> List[str]:
columns = create_noisy_columns(data, min_noise, max_noise)
return [''.join(set(c)) for c in columns] # kinda shuffle columns
def get_input_data(examples: Dict[str, Optional[str]], max_chars: int) -> str:
input_field, examples_filed = st.columns([1, 0.27])
option = examples_filed.selectbox(label='', options=examples.keys())
return input_field.text_input(label='', value=examples[option] or st.session_state.data, max_chars=max_chars)
def inference_page(is_plain: bool, min_noise: int, max_noise: int, bw: int) -> None:
st.subheader('Insert plain text' if is_plain else 'Insert noisy columns separated by spaces')
if is_plain:
data = get_input_data(PLAIN_EXAMPLES, max_chars=MAX_CHARS)
else:
data = get_input_data(NOISED_EXAMPLES, max_chars=MAX_CHARS * 4)
if not data:
st.stop()
if is_plain:
if st.session_state.regenerate or not st.session_state.columns or data != st.session_state.data:
columns = get_noisy_columns(data, min_noise, max_noise)
st.session_state.columns = columns
unset_regenerate()
else:
columns = st.session_state.columns
else:
columns = data_to_columns(data, separator=' ')
st.session_state.data = data
st.subheader('\nColumns')
st.text(visualize_columns(columns, delimiter=''))
st.subheader('\n')
placeholder = st.empty()
recover_field, regen_filed = placeholder.columns([.11, 1])
if is_plain:
regen_filed.button('Regenerate', on_click=set_regenerate)
if columns and recover_field.button('Recover'):
if st.session_state.is_unix:
with placeholder.container():
progress_bar_placeholder = st.empty()
st.button('Stop')
with progress_bar_placeholder:
chains = predict(columns, bw)
else:
with placeholder:
chains = predict(columns, bw)
with placeholder.container():
st.subheader('\nPrediction')
st.text('\n\n'.join(chain for chain, _ in chains))
if st.button('Clear'):
st.session_state.task_id = None
save_to_history(is_plain, min_noise, max_noise, bw, columns, chains)
def history_page() -> None:
st.header('Inference History')
if len(st.session_state.history) == 0:
st.info('No records saved')
return
for record_id, (is_plain, text, min_noise, max_noise, bw, columns, chains) in enumerate(st.session_state.history,
start=1):
st.info(f'Record {record_id}')
if is_plain:
st.text(f'Plain data: {text}')
st.text(f'Noise range: [{min_noise}, {max_noise}]')
st.text(f'Beam search width: {bw}')
st.text('Columns:')
st.text(visualize_columns(columns, delimiter=''))
st.text('Prediction:')
st.text('\n\n'.join(chain for chain, _ in chains))
if __name__ == '__main__':
main() |