File size: 6,967 Bytes
46ffa30 cdb537e bc21832 cdb537e bc21832 cdb537e 46ffa30 bc21832 46ffa30 8cd0b56 bc21832 8cd0b56 46ffa30 bc21832 46ffa30 bc21832 46ffa30 bc21832 3f553b1 46ffa30 cdb537e 8cd0b56 cdb537e bc21832 8cd0b56 bc21832 8cd0b56 bc21832 8cd0b56 bc21832 8cd0b56 bc21832 8cd0b56 bc21832 8cd0b56 46ffa30 cdb537e 46ffa30 cdb537e 46ffa30 cdb537e 46ffa30 cdb537e 46ffa30 bc21832 46ffa30 cdb537e bc21832 cdb537e a19a543 3f553b1 a19a543 46ffa30 3f553b1 46ffa30 3f553b1 8cd0b56 3f553b1 46ffa30 cdb537e 3f553b1 bc21832 cdb537e bc21832 cdb537e bc21832 cdb537e 3f553b1 cdb537e 3f553b1 cdb537e 3f553b1 46ffa30 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
import time
import psutil
import streamlit as st
import torch
from langdetect import detect
from default_texts import default_texts
from generator import GeneratorFactory
device = torch.cuda.device_count() - 1
TRANSLATION_EN_TO_NL = "translation_en_to_nl"
TRANSLATION_NL_TO_EN = "translation_nl_to_en"
GENERATOR_LIST = [
{
"model_name": "yhavinga/t5-small-24L-ccmatrix-multi",
"desc": "T5 small nl24 ccmatrix en->nl",
"task": TRANSLATION_EN_TO_NL,
"split_sentences": True,
},
{
"model_name": "yhavinga/t5-small-24L-ccmatrix-multi",
"desc": "T5 small nl24 ccmatrix nl-en",
"task": TRANSLATION_NL_TO_EN,
"split_sentences": True,
},
{
"model_name": "Helsinki-NLP/opus-mt-en-nl",
"desc": "Opus MT en->nl",
"task": TRANSLATION_EN_TO_NL,
"split_sentences": True,
},
{
"model_name": "Helsinki-NLP/opus-mt-nl-en",
"desc": "Opus MT nl->en",
"task": TRANSLATION_NL_TO_EN,
"split_sentences": True,
},
# {
# "model_name": "yhavinga/longt5-local-eff-large-nl8-voc8k-ddwn-512beta-512l-nedd-256ccmatrix-nl-en",
# "desc": "longT5 large nl8 256cc/512beta/512l nl->en",
# "task": TRANSLATION_NL_TO_EN,
# "split_sentences": False,
# },
{
"model_name": "yhavinga/longt5-local-eff-large-nl8-voc8k-ddwn-512beta-512-nedd-nl-en",
"desc": "longT5 large nl8 512beta/512l nl->en",
"task": TRANSLATION_NL_TO_EN,
"split_sentences": False,
},
{
"model_name": "yhavinga/longt5-local-eff-large-nl8-voc8k-ddwn-512beta-512l-nedd-256ccmatrix-en-nl",
"desc": "longT5 large nl8 256cc/512beta/512l en->nl",
"task": TRANSLATION_EN_TO_NL,
"split_sentences": False,
},
# {
# "model_name": "yhavinga/longt5-local-eff-base-nl36-voc8k-256l-472beta-256l-472beta-en-nl",
# "desc": "longT5 large nl8 256l/472beta/256l/472beta en->nl",
# "task": TRANSLATION_EN_TO_NL,
# "split_sentences": False,
# },
# {
# "model_name": "yhavinga/byt5-small-ccmatrix-en-nl",
# "desc": "ByT5 small ccmatrix en->nl",
# "task": TRANSLATION_EN_TO_NL,
# "split_sentences": True,
# },
# {
# "model_name": "yhavinga/t5-eff-large-8l-nedd-en-nl",
# "desc": "T5 eff large nl8 en->nl",
# "task": TRANSLATION_EN_TO_NL,
# "split_sentences": True,
# },
# {
# "model_name": "yhavinga/t5-base-36L-ccmatrix-multi",
# "desc": "T5 base nl36 ccmatrix en->nl",
# "task": TRANSLATION_EN_TO_NL,
# "split_sentences": True,
# },
# {
# "model_name": "yhavinga/longt5-local-eff-large-nl8-voc8k-ddwn-512beta-512-nedd-en-nl",
# "desc": "longT5 large nl8 512beta/512l en->nl",
# "task": TRANSLATION_EN_TO_NL,
# "split_sentences": False,
# },
# {
# "model_name": "yhavinga/t5-base-36L-nedd-x-en-nl-300",
# "desc": "T5 base 36L nedd en->nl 300",
# "task": TRANSLATION_EN_TO_NL,
# "split_sentences": True,
# },
# {
# "model_name": "yhavinga/long-t5-local-small-ccmatrix-en-nl",
# "desc": "longT5 small ccmatrix en->nl",
# "task": TRANSLATION_EN_TO_NL,
# "split_sentences": True,
# },
]
def main():
st.set_page_config( # Alternate names: setup_page, page, layout
page_title="Rosetta en/nl", # String or None. Strings get appended with "โข Streamlit".
layout="wide", # Can be "centered" or "wide". In the future also "dashboard", etc.
initial_sidebar_state="auto", # Can be "auto", "expanded", "collapsed"
page_icon="๐", # String, anything supported by st.image, or None.
)
if "generators" not in st.session_state:
st.session_state["generators"] = GeneratorFactory(GENERATOR_LIST)
generators = st.session_state["generators"]
with open("style.css") as f:
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
st.sidebar.image("rosetta.png", width=200)
st.sidebar.markdown(
"""# Rosetta
Vertaal van en naar Engels"""
)
default_text = st.sidebar.radio(
"Change default text",
tuple(default_texts.keys()),
index=0,
)
if default_text or "prompt_box" not in st.session_state:
st.session_state["prompt_box"] = default_texts[default_text]["text"]
# create a left and right column
left, right = st.columns(2)
text_area = left.text_area("Enter text", st.session_state.prompt_box, height=500)
st.session_state["text"] = text_area
# Sidebar parameters
st.sidebar.title("Parameters:")
num_beams = st.sidebar.number_input("Num beams", min_value=1, max_value=10, value=1)
num_beam_groups = st.sidebar.number_input(
"Num beam groups", min_value=1, max_value=10, value=1
)
length_penalty = st.sidebar.number_input(
"Length penalty", min_value=0.0, max_value=2.0, value=1.2, step=0.1
)
st.sidebar.markdown(
"""For an explanation of the parameters, head over to the [Huggingface blog post about text generation](https://huggingface.co/blog/how-to-generate)
and the [Huggingface text generation interface doc](https://huggingface.co/transformers/main_classes/model.html?highlight=generate#transformers.generation_utils.GenerationMixin.generate).
"""
)
params = {
"num_beams": num_beams,
"num_beam_groups": num_beam_groups,
"length_penalty": length_penalty,
"early_stopping": True,
}
if left.button("Run"):
memory = psutil.virtual_memory()
language = detect(st.session_state.text)
if language == "en":
task = TRANSLATION_EN_TO_NL
elif language == "nl":
task = TRANSLATION_NL_TO_EN
else:
left.error(f"Language {language} not supported")
return
# Num beam groups should be a divisor of num beams
if num_beams % num_beam_groups != 0:
left.error("Num beams should be a multiple of num beam groups")
return
for generator in generators.filter(task=task):
right.markdown(f"๐งฎ **Model `{generator}`**")
time_start = time.time()
result, params_used = generator.generate(
text=st.session_state.text, **params
)
time_end = time.time()
time_diff = time_end - time_start
right.write(result.replace("\n", " \n"))
text_line = ", ".join([f"{k}={v}" for k, v in params_used.items()])
right.markdown(f" ๐ *generated in {time_diff:.2f}s, `{text_line}`*")
st.write(
f"""
---
*Memory: {memory.total / 10**9:.2f}GB, used: {memory.percent}%, available: {memory.available / 10**9:.2f}GB*
"""
)
if __name__ == "__main__":
main()
|