|
from styleformer import Styleformer |
|
import streamlit as st |
|
import numpy as np |
|
import json |
|
|
|
class Demo: |
|
def __init__(self): |
|
st.set_page_config( |
|
page_title="Styleformer Demo", |
|
initial_sidebar_state="expanded" |
|
) |
|
self.style_map = { |
|
|
|
'ctf': ('Casual to Formal', 0), |
|
'ftc': ('Formal to Casual', 1), |
|
'atp': ('Active to Passive', 2), |
|
'pta': ('Passive to Active', 3) |
|
} |
|
self.inference_map = { |
|
0: 'Regular model on CPU', |
|
1: 'Regular model on GPU', |
|
2: 'Quantized model on CPU' |
|
} |
|
with open("streamlit_examples.json") as f: |
|
self.examples = json.load(f) |
|
|
|
@st.cache(show_spinner=False, suppress_st_warning=True, allow_output_mutation=True) |
|
def load_sf(self, style=0): |
|
sf = Styleformer(style = style) |
|
return sf |
|
|
|
def main(self): |
|
st.title("Styleformer") |
|
st.write('A Neural Language Style Transfer framework to transfer natural language text smoothly between fine-grained language styles like formal/casual, active/passive, and many more') |
|
|
|
style_key = st.sidebar.selectbox( |
|
label='Choose Style', |
|
options=list(self.style_map.keys()), |
|
format_func=lambda x:self.style_map[x][0] |
|
) |
|
exp = st.sidebar.expander('Knobs', expanded=True) |
|
with exp: |
|
inference_on = exp.selectbox( |
|
label='Inference on', |
|
options=list(self.inference_map.keys()), |
|
format_func=lambda x:self.inference_map[x] |
|
) |
|
quality_filter = exp.slider( |
|
label='Quality filter', |
|
min_value=0.5, |
|
max_value=0.99, |
|
value=0.95 |
|
) |
|
max_candidates = exp.number_input( |
|
label='Max candidates', |
|
min_value=1, |
|
max_value=20, |
|
value=5 |
|
) |
|
with st.spinner('Loading model..'): |
|
sf = self.load_sf(self.style_map[style_key][1]) |
|
input_text = st.selectbox( |
|
label="Choose an example", |
|
options=self.examples[style_key] |
|
) |
|
input_text = st.text_input( |
|
label="Input text", |
|
value=input_text |
|
) |
|
|
|
if input_text.strip(): |
|
result = sf.transfer(input_text, inference_on=inference_on, quality_filter=quality_filter, max_candidates=max_candidates) |
|
st.markdown(f'#### Output:') |
|
st.write('') |
|
if result: |
|
st.success(result) |
|
else: |
|
st.info('No good quality transfers available !') |
|
else: |
|
st.warning("Please select/enter text to proceed") |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
obj = Demo() |
|
obj.main() |
|
|
|
|
|
|