|
import platform |
|
from typing import Tuple, List, Dict, Optional |
|
|
|
import streamlit as st |
|
import torch |
|
from trecover.config import var |
|
from trecover.utils.beam_search import beam_search, dashboard_loop |
|
from trecover.utils.inference import data_to_columns, create_noisy_columns |
|
from trecover.utils.transform import columns_to_tensor, tensor_to_target |
|
from trecover.utils.visualization import visualize_columns, visualize_target |
|
|
|
MAX_CHARS = 256 |
|
|
|
PLAIN_EXAMPLES = { |
|
'Select example': None, |
|
'Example 1': 'As people around the country went into the streets to cheer the conviction, some businesses in ' |
|
'Portland boarded up their windows once again.', |
|
'Example 2': 'That night, a small group of activists wearing black approached a group of journalists, threatening' |
|
' to smash the cameras of those who remained on scene.', |
|
'Example 3': 'English as we know it today came to be exported to other parts of the world through British ' |
|
'colonisation, and is now the dominant language in Britain' |
|
} |
|
|
|
NOISED_EXAMPLES = { |
|
'Select example': None, |
|
'Example 1': 'a ds fpziq ofe ngkhbo p pghl ue waq frlqjo o u dnxrm dgr yrtsco kho deuasm dhysc ao u nwzhy tle r ' |
|
'yzpe xwabc gce nger klqto wiq nfprso t no tpgq tcfh ae twas tw ur re e t gyutsm t xgo rc ubhq e wle ' |
|
'r ty h nwpeaq xdsc o dnhelm v thir ikcq tkuo i o twn ps frio mo oe b kuiqtb jsq zi tnye ge dgrqs s ' |
|
'cioe ys whic wne wp thlo dnprsc xvpyrt hurlm kveaj nbfp dome pbeaj dusmo a r dzrqsm xace du nxkuai ' |
|
'gpulcm tpi h pie uim r wbhrj ui n dwgp dkeio nkwhqs zs' |
|
} |
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
model = torch.hub.load('alex-snd/TRecover', model='trecover', device=device, version='latest') |
|
|
|
|
|
def main() -> None: |
|
st.set_page_config( |
|
page_title='TRecover', |
|
page_icon='🩹', |
|
layout='wide', |
|
initial_sidebar_state='expanded') |
|
|
|
if 'history' not in st.session_state: |
|
st.session_state.history = list() |
|
|
|
if 'data' not in st.session_state: |
|
st.session_state.data = '' |
|
|
|
if 'regenerate' not in st.session_state: |
|
st.session_state.regenerate = False |
|
|
|
if 'columns' not in st.session_state: |
|
st.session_state.columns = None |
|
|
|
if 'is_unix' not in st.session_state: |
|
st.session_state.is_unix = platform.system() != 'Windows' |
|
|
|
sidebar() |
|
|
|
|
|
def set_regenerate() -> None: |
|
st.session_state.regenerate = True |
|
|
|
|
|
def unset_regenerate() -> None: |
|
st.session_state.regenerate = False |
|
|
|
|
|
def sidebar() -> None: |
|
st.sidebar.markdown(body= |
|
""" |
|
<h1 align="center"> |
|
<font size="20">🤷</font> |
|
<a href="https://alex-snd.github.io/TRecover">About the Project</a> |
|
</h1> |
|
<br><br> |
|
""", |
|
unsafe_allow_html=True) |
|
|
|
option = st.sidebar.radio('Sections', ('Inference', 'Inference history')) |
|
|
|
if option == 'Inference': |
|
is_plain, min_noise, max_noise, bw = inference_sidebar() |
|
inference_page(is_plain, min_noise, max_noise, bw) |
|
else: |
|
history_sidebar() |
|
history_page() |
|
|
|
|
|
def inference_sidebar() -> Tuple[bool, int, int, int]: |
|
st.sidebar.text('\n') |
|
|
|
data_type = st.sidebar.radio('Input type', ('Plain text', 'Noisy columns'), key='data_type', |
|
index=0 if 'Plain text' == st.session_state.get('data_type', 'Plain text') else 1) |
|
is_plain = data_type == 'Plain text' |
|
|
|
st.sidebar.text('\n') |
|
|
|
if is_plain: |
|
min_noise, max_noise = st.sidebar.slider('\nNoise range', 0, 5, key='noise_range', |
|
value=st.session_state.get('noise_range', (0, 5)), |
|
on_change=set_regenerate) |
|
else: |
|
min_noise, max_noise = 0, 0 |
|
|
|
bw = st.sidebar.slider('Beam search width', 1, 25, key='beam_width', |
|
value=st.session_state.get('beam_width', 5)) |
|
|
|
if max_noise > var.MAX_NOISE: |
|
st.sidebar.warning('Max noise value is too large. This will entail poor performance') |
|
|
|
return is_plain, min_noise, max_noise + 1, bw |
|
|
|
|
|
def history_sidebar() -> None: |
|
pass |
|
|
|
|
|
def save_to_history(is_plain: bool, |
|
min_noise: int, |
|
max_noise: int, |
|
bw: int, |
|
columns: List[str], |
|
chains: List[Tuple[str, float]] |
|
) -> None: |
|
text = st.session_state.data if is_plain else None |
|
|
|
st.session_state.history.append((is_plain, text, min_noise, max_noise, bw, columns, chains)) |
|
|
|
|
|
@st.cache(ttl=3600, show_spinner=False, suppress_st_warning=True) |
|
def predict(columns: List[str], bw: int) -> List[Tuple[str, float]]: |
|
src = columns_to_tensor(columns, device) |
|
|
|
chains = beam_search(src, model, bw, device, beam_loop=dashboard_loop) |
|
chains = [(visualize_target(tensor_to_target(chain)), prob) for (chain, prob) in chains] |
|
|
|
return chains |
|
|
|
|
|
def get_noisy_columns(data: str, min_noise: int, max_noise: int) -> List[str]: |
|
columns = create_noisy_columns(data, min_noise, max_noise) |
|
|
|
return [''.join(set(c)) for c in columns] |
|
|
|
|
|
def get_input_data(examples: Dict[str, Optional[str]], max_chars: int) -> str: |
|
input_field, examples_filed = st.columns([1, 0.27]) |
|
|
|
option = examples_filed.selectbox(label='', options=examples.keys()) |
|
|
|
return input_field.text_input(label='', value=examples[option] or st.session_state.data, max_chars=max_chars) |
|
|
|
|
|
def inference_page(is_plain: bool, min_noise: int, max_noise: int, bw: int) -> None: |
|
st.subheader('Insert plain text' if is_plain else 'Insert noisy columns separated by spaces') |
|
|
|
if is_plain: |
|
data = get_input_data(PLAIN_EXAMPLES, max_chars=MAX_CHARS) |
|
else: |
|
data = get_input_data(NOISED_EXAMPLES, max_chars=MAX_CHARS * 4) |
|
|
|
if not data: |
|
st.stop() |
|
|
|
if is_plain: |
|
if st.session_state.regenerate or not st.session_state.columns or data != st.session_state.data: |
|
columns = get_noisy_columns(data, min_noise, max_noise) |
|
st.session_state.columns = columns |
|
unset_regenerate() |
|
else: |
|
columns = st.session_state.columns |
|
else: |
|
columns = data_to_columns(data, separator=' ') |
|
|
|
st.session_state.data = data |
|
|
|
st.subheader('\nColumns') |
|
st.text(visualize_columns(columns, delimiter='')) |
|
st.subheader('\n') |
|
|
|
placeholder = st.empty() |
|
recover_field, regen_filed = placeholder.columns([.11, 1]) |
|
|
|
if is_plain: |
|
regen_filed.button('Regenerate', on_click=set_regenerate) |
|
|
|
if columns and recover_field.button('Recover'): |
|
if st.session_state.is_unix: |
|
with placeholder.container(): |
|
progress_bar_placeholder = st.empty() |
|
st.button('Stop') |
|
|
|
with progress_bar_placeholder: |
|
chains = predict(columns, bw) |
|
else: |
|
with placeholder: |
|
chains = predict(columns, bw) |
|
|
|
with placeholder.container(): |
|
st.subheader('\nPrediction') |
|
st.text('\n\n'.join(chain for chain, _ in chains)) |
|
|
|
if st.button('Clear'): |
|
st.session_state.task_id = None |
|
|
|
save_to_history(is_plain, min_noise, max_noise, bw, columns, chains) |
|
|
|
|
|
def history_page() -> None: |
|
st.header('Inference History') |
|
|
|
if len(st.session_state.history) == 0: |
|
st.info('No records saved') |
|
return |
|
|
|
for record_id, (is_plain, text, min_noise, max_noise, bw, columns, chains) in enumerate(st.session_state.history, |
|
start=1): |
|
st.info(f'Record {record_id}') |
|
|
|
if is_plain: |
|
st.text(f'Plain data: {text}') |
|
|
|
st.text(f'Noise range: [{min_noise}, {max_noise}]') |
|
st.text(f'Beam search width: {bw}') |
|
|
|
st.text('Columns:') |
|
st.text(visualize_columns(columns, delimiter='')) |
|
|
|
st.text('Prediction:') |
|
st.text('\n\n'.join(chain for chain, _ in chains)) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |