rosetta / app.py
yhavinga's picture
Add diversity penalty, pin tokenizers on older version
0d30451
raw
history blame contribute delete
No virus
6.82 kB
import time
import psutil
import streamlit as st
import torch
from langdetect import detect
from transformers import TextIteratorStreamer
from default_texts import default_texts
from generator import GeneratorFactory
device = torch.cuda.device_count() - 1
TRANSLATION_EN_TO_NL = "translation_en_to_nl"
TRANSLATION_NL_TO_EN = "translation_nl_to_en"
GENERATOR_LIST = [
{
"model_name": "yhavinga/ul2-base-en-nl",
"desc": "UL2 base en->nl",
"task": TRANSLATION_EN_TO_NL,
"split_sentences": False,
},
# {
# "model_name": "yhavinga/ul2-large-en-nl",
# "desc": "UL2 large en->nl",
# "task": TRANSLATION_EN_TO_NL,
# "split_sentences": False,
# },
{
"model_name": "Helsinki-NLP/opus-mt-en-nl",
"desc": "Opus MT en->nl",
"task": TRANSLATION_EN_TO_NL,
"split_sentences": True,
},
{
"model_name": "Helsinki-NLP/opus-mt-nl-en",
"desc": "Opus MT nl->en",
"task": TRANSLATION_NL_TO_EN,
"split_sentences": True,
},
# {
# "model_name": "yhavinga/t5-small-24L-ccmatrix-multi",
# "desc": "T5 small nl24 ccmatrix nl-en",
# "task": TRANSLATION_NL_TO_EN,
# "split_sentences": True,
# },
{
"model_name": "yhavinga/longt5-local-eff-large-nl8-voc8k-ddwn-neddx2-nl-en",
"desc": "Long t5 large-nl8 nl-en",
"task": TRANSLATION_NL_TO_EN,
"split_sentences": False,
},
# {
# "model_name": "yhavinga/byt5-small-ccmatrix-en-nl",
# "desc": "ByT5 small ccmatrix en->nl",
# "task": TRANSLATION_EN_TO_NL,
# "split_sentences": True,
# },
# {
# "model_name": "yhavinga/t5-base-36L-ccmatrix-multi",
# "desc": "T5 base nl36 ccmatrix en->nl",
# "task": TRANSLATION_EN_TO_NL,
# "split_sentences": True,
# },
# {
]
class StreamlitTextIteratorStreamer(TextIteratorStreamer):
def __init__(
self, output_placeholder, tokenizer, skip_prompt=False, **decode_kwargs
):
super().__init__(tokenizer, skip_prompt, **decode_kwargs)
self.output_placeholder = output_placeholder
self.output_text = ""
def on_finalized_text(self, text: str, stream_end: bool = False):
self.output_text += text
self.output_placeholder.markdown(self.output_text, unsafe_allow_html=True)
super().on_finalized_text(text, stream_end)
def main():
st.set_page_config( # Alternate names: setup_page, page, layout
page_title="Rosetta en/nl", # String or None. Strings get appended with "โ€ข Streamlit".
layout="wide", # Can be "centered" or "wide". In the future also "dashboard", etc.
initial_sidebar_state="auto", # Can be "auto", "expanded", "collapsed"
page_icon="๐Ÿ“‘", # String, anything supported by st.image, or None.
)
if "generators" not in st.session_state:
st.session_state["generators"] = GeneratorFactory(GENERATOR_LIST)
generators = st.session_state["generators"]
with open("style.css") as f:
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
st.sidebar.image("rosetta.png", width=200)
st.sidebar.markdown(
"""# Rosetta
Vertaal van en naar Engels"""
)
default_text = st.sidebar.radio(
"Change default text",
tuple(default_texts.keys()),
index=0,
)
if default_text or "prompt_box" not in st.session_state:
st.session_state["prompt_box"] = default_texts[default_text]["text"]
# create a left and right column
left, right = st.columns(2)
text_area = left.text_area("Enter text", st.session_state.prompt_box, height=500)
st.session_state["text"] = text_area
# Sidebar parameters
st.sidebar.title("Parameters:")
num_beams = st.sidebar.number_input("Num beams", min_value=1, max_value=10, value=1)
num_beam_groups = st.sidebar.number_input(
"Num beam groups", min_value=1, max_value=10, value=1
)
length_penalty = st.sidebar.number_input(
"Length penalty", min_value=0.0, max_value=2.0, value=1.0, step=0.1
)
diversity_penalty = st.sidebar.number_input(
"Diversity penalty", min_value=0.0, max_value=2.0, value=0.1, step=0.1
)
st.sidebar.markdown(
"""For an explanation of the parameters, head over to the [Huggingface blog post about text generation](https://huggingface.co/blog/how-to-generate)
and the [Huggingface text generation interface doc](https://huggingface.co/transformers/main_classes/model.html?highlight=generate#transformers.generation_utils.GenerationMixin.generate).
"""
)
params = {
"num_beams": num_beams,
"num_beam_groups": num_beam_groups,
"diversity_penalty": diversity_penalty if num_beam_groups > 1 else 0.0,
"length_penalty": length_penalty if num_beams > 1 else 1.0,
"early_stopping": True,
}
if left.button("Run"):
memory = psutil.virtual_memory()
language = detect(st.session_state.text)
if language == "en":
task = TRANSLATION_EN_TO_NL
elif language == "nl":
task = TRANSLATION_NL_TO_EN
else:
left.error(f"Language {language} not supported")
return
# Num beam groups should be a divisor of num beams
if num_beams % num_beam_groups != 0:
left.error("Num beams should be a multiple of num beam groups")
return
streaming_enabled = num_beams == 1
if not streaming_enabled:
left.markdown("*`num_beams > 1` so streaming is disabled*")
for generator in generators.filter(task=task):
model_container = right.container()
model_container.markdown(f"๐Ÿงฎ **Model `{generator}`**")
output_placeholder = model_container.empty()
streamer = (
StreamlitTextIteratorStreamer(output_placeholder, generator.tokenizer)
if streaming_enabled
else None
)
time_start = time.time()
result, params_used = generator.generate(
text=st.session_state.text, streamer=streamer, **params
)
time_end = time.time()
time_diff = time_end - time_start
if not streaming_enabled:
right.write(result.replace("\n", " \n"))
text_line = ", ".join([f"{k}={v}" for k, v in params_used.items()])
right.markdown(f" ๐Ÿ•™ *generated in {time_diff:.2f}s, `{text_line}`*")
st.write(
f"""
---
*Memory: {memory.total / 10**9:.2f}GB, used: {memory.percent}%, available: {memory.available / 10**9:.2f}GB*
"""
)
if __name__ == "__main__":
main()