rosetta / app.py
yhavinga's picture
Rename app to rosetta. Make two-column. Add some texts.
cdb537e
raw
history blame
6.97 kB
import time
import psutil
import streamlit as st
import torch
from langdetect import detect
from default_texts import default_texts
from generator import GeneratorFactory
device = torch.cuda.device_count() - 1
TRANSLATION_EN_TO_NL = "translation_en_to_nl"
TRANSLATION_NL_TO_EN = "translation_nl_to_en"
GENERATOR_LIST = [
{
"model_name": "yhavinga/t5-small-24L-ccmatrix-multi",
"desc": "T5 small nl24 ccmatrix en->nl",
"task": TRANSLATION_EN_TO_NL,
"split_sentences": True,
},
{
"model_name": "yhavinga/t5-small-24L-ccmatrix-multi",
"desc": "T5 small nl24 ccmatrix nl-en",
"task": TRANSLATION_NL_TO_EN,
"split_sentences": True,
},
{
"model_name": "Helsinki-NLP/opus-mt-en-nl",
"desc": "Opus MT en->nl",
"task": TRANSLATION_EN_TO_NL,
"split_sentences": True,
},
{
"model_name": "Helsinki-NLP/opus-mt-nl-en",
"desc": "Opus MT nl->en",
"task": TRANSLATION_NL_TO_EN,
"split_sentences": True,
},
# {
# "model_name": "yhavinga/longt5-local-eff-large-nl8-voc8k-ddwn-512beta-512l-nedd-256ccmatrix-nl-en",
# "desc": "longT5 large nl8 256cc/512beta/512l nl->en",
# "task": TRANSLATION_NL_TO_EN,
# "split_sentences": False,
# },
{
"model_name": "yhavinga/longt5-local-eff-large-nl8-voc8k-ddwn-512beta-512-nedd-nl-en",
"desc": "longT5 large nl8 512beta/512l nl->en",
"task": TRANSLATION_NL_TO_EN,
"split_sentences": False,
},
{
"model_name": "yhavinga/longt5-local-eff-large-nl8-voc8k-ddwn-512beta-512l-nedd-256ccmatrix-en-nl",
"desc": "longT5 large nl8 256cc/512beta/512l en->nl",
"task": TRANSLATION_EN_TO_NL,
"split_sentences": False,
},
# {
# "model_name": "yhavinga/longt5-local-eff-base-nl36-voc8k-256l-472beta-256l-472beta-en-nl",
# "desc": "longT5 large nl8 256l/472beta/256l/472beta en->nl",
# "task": TRANSLATION_EN_TO_NL,
# "split_sentences": False,
# },
# {
# "model_name": "yhavinga/byt5-small-ccmatrix-en-nl",
# "desc": "ByT5 small ccmatrix en->nl",
# "task": TRANSLATION_EN_TO_NL,
# "split_sentences": True,
# },
# {
# "model_name": "yhavinga/t5-eff-large-8l-nedd-en-nl",
# "desc": "T5 eff large nl8 en->nl",
# "task": TRANSLATION_EN_TO_NL,
# "split_sentences": True,
# },
# {
# "model_name": "yhavinga/t5-base-36L-ccmatrix-multi",
# "desc": "T5 base nl36 ccmatrix en->nl",
# "task": TRANSLATION_EN_TO_NL,
# "split_sentences": True,
# },
# {
# "model_name": "yhavinga/longt5-local-eff-large-nl8-voc8k-ddwn-512beta-512-nedd-en-nl",
# "desc": "longT5 large nl8 512beta/512l en->nl",
# "task": TRANSLATION_EN_TO_NL,
# "split_sentences": False,
# },
# {
# "model_name": "yhavinga/t5-base-36L-nedd-x-en-nl-300",
# "desc": "T5 base 36L nedd en->nl 300",
# "task": TRANSLATION_EN_TO_NL,
# "split_sentences": True,
# },
# {
# "model_name": "yhavinga/long-t5-local-small-ccmatrix-en-nl",
# "desc": "longT5 small ccmatrix en->nl",
# "task": TRANSLATION_EN_TO_NL,
# "split_sentences": True,
# },
]
def main():
st.set_page_config( # Alternate names: setup_page, page, layout
page_title="Rosetta en/nl", # String or None. Strings get appended with "โ€ข Streamlit".
layout="wide", # Can be "centered" or "wide". In the future also "dashboard", etc.
initial_sidebar_state="auto", # Can be "auto", "expanded", "collapsed"
page_icon="๐Ÿ“‘", # String, anything supported by st.image, or None.
)
if "generators" not in st.session_state:
st.session_state["generators"] = GeneratorFactory(GENERATOR_LIST)
generators = st.session_state["generators"]
with open("style.css") as f:
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
st.sidebar.image("rosetta.png", width=200)
st.sidebar.markdown(
"""# Rosetta
Vertaal van en naar Engels"""
)
default_text = st.sidebar.radio(
"Change default text",
tuple(default_texts.keys()),
index=0,
)
if default_text or "prompt_box" not in st.session_state:
st.session_state["prompt_box"] = default_texts[default_text]["text"]
# create a left and right column
left, right = st.columns(2)
text_area = left.text_area("Enter text", st.session_state.prompt_box, height=500)
st.session_state["text"] = text_area
# Sidebar parameters
st.sidebar.title("Parameters:")
num_beams = st.sidebar.number_input("Num beams", min_value=1, max_value=10, value=1)
num_beam_groups = st.sidebar.number_input(
"Num beam groups", min_value=1, max_value=10, value=1
)
length_penalty = st.sidebar.number_input(
"Length penalty", min_value=0.0, max_value=2.0, value=1.2, step=0.1
)
st.sidebar.markdown(
"""For an explanation of the parameters, head over to the [Huggingface blog post about text generation](https://huggingface.co/blog/how-to-generate)
and the [Huggingface text generation interface doc](https://huggingface.co/transformers/main_classes/model.html?highlight=generate#transformers.generation_utils.GenerationMixin.generate).
"""
)
params = {
"num_beams": num_beams,
"num_beam_groups": num_beam_groups,
"length_penalty": length_penalty,
"early_stopping": True,
}
if left.button("Run"):
memory = psutil.virtual_memory()
language = detect(st.session_state.text)
if language == "en":
task = TRANSLATION_EN_TO_NL
elif language == "nl":
task = TRANSLATION_NL_TO_EN
else:
left.error(f"Language {language} not supported")
return
# Num beam groups should be a divisor of num beams
if num_beams % num_beam_groups != 0:
left.error("Num beams should be a multiple of num beam groups")
return
for generator in generators.filter(task=task):
right.markdown(f"๐Ÿงฎ **Model `{generator}`**")
time_start = time.time()
result, params_used = generator.generate(
text=st.session_state.text, **params
)
time_end = time.time()
time_diff = time_end - time_start
right.write(result.replace("\n", " \n"))
text_line = ", ".join([f"{k}={v}" for k, v in params_used.items()])
right.markdown(f" ๐Ÿ•™ *generated in {time_diff:.2f}s, `{text_line}`*")
st.write(
f"""
---
*Memory: {memory.total / 10**9:.2f}GB, used: {memory.percent}%, available: {memory.available / 10**9:.2f}GB*
"""
)
if __name__ == "__main__":
main()