tst / app.py
CAGmllab's picture
Update app.py
34fd186
raw
history blame
2.93 kB
from styleformer import Styleformer
import streamlit as st
import numpy as np
import json
class Demo:
def __init__(self):
st.set_page_config(
page_title="Styleformer Demo",
initial_sidebar_state="expanded"
)
self.style_map = {
#key : (name , style_num)
'ctf': ('Casual to Formal', 0),
'ftc': ('Formal to Casual', 1),
'atp': ('Active to Passive', 2),
'pta': ('Passive to Active', 3)
}
self.inference_map = {
0: 'Regular model on CPU',
1: 'Regular model on GPU',
2: 'Quantized model on CPU'
}
with open("streamlit_examples.json") as f:
self.examples = json.load(f)
@st.cache(show_spinner=False, suppress_st_warning=True, allow_output_mutation=True)
def load_sf(self, style=0):
sf = Styleformer(style = style)
return sf
def main(self):
st.title("Styleformer")
st.write('A Neural Language Style Transfer framework to transfer natural language text smoothly between fine-grained language styles like formal/casual, active/passive, and many more')
style_key = st.sidebar.selectbox(
label='Choose Style',
options=list(self.style_map.keys()),
format_func=lambda x:self.style_map[x][0]
)
exp = st.sidebar.expander('Knobs', expanded=True)
with exp:
inference_on = exp.selectbox(
label='Inference on',
options=list(self.inference_map.keys()),
format_func=lambda x:self.inference_map[x]
)
quality_filter = exp.slider(
label='Quality filter',
min_value=0.5,
max_value=0.99,
value=0.95
)
max_candidates = exp.number_input(
label='Max candidates',
min_value=1,
max_value=20,
value=5
)
with st.spinner('Loading model..'):
sf = self.load_sf(self.style_map[style_key][1])
input_text = st.selectbox(
label="Choose an example",
options=self.examples[style_key]
)
input_text = st.text_input(
label="Input text",
value=input_text
)
if input_text.strip():
result = sf.transfer(input_text, inference_on=inference_on, quality_filter=quality_filter, max_candidates=max_candidates)
st.markdown(f'#### Output:')
st.write('')
if result:
st.success(result)
else:
st.info('No good quality transfers available !')
else:
st.warning("Please select/enter text to proceed")
if __name__ == "__main__":
obj = Demo()
obj.main()