import base64 from collections import Counter import graphviz import penman from penman.models.noop import NoOpModel from mbart_amr.data.linearization import linearized2penmanstr from transformers import LogitsProcessorList import streamlit as st from utils import get_resources, LANGUAGES, translate st.title("👩💻 Multilingual text to AMR") with st.form("input data"): text_col, lang_col = st.columns((4, 1)) text = text_col.text_input(label="Input text") src_lang = lang_col.selectbox(label="Language", options=list(LANGUAGES.keys()), index=0) submitted = st.form_submit_button("Submit") error_ct = st.empty() if submitted: text = text.strip() if not text: error_ct.error("Text cannot be empty!", icon="⚠️") else: error_ct.info("Generating abstract meaning representation (AMR)...", icon="💻") multilingual = src_lang != "English" model, tokenizer, logitsprocessor = get_resources(multilingual) gen_kwargs = { "max_length": model.config.max_length, "num_beams": model.config.num_beams, "logits_processor": LogitsProcessorList([logitsprocessor]) } linearized = translate(text, src_lang, model, tokenizer, **gen_kwargs) penman_str = linearized2penmanstr(linearized) error_ct.empty() try: graph = penman.decode(penman_str, model=NoOpModel()) except Exception as exc: st.write(f"The generated graph is not valid so it cannot be visualized correctly. Below is the closest attempt" f" to a valid graph but note that this is invalid Penman.") st.code(penman_str) with st.expander("Error trace"): st.write(exc) else: visualized = graphviz.Digraph(node_attr={"color": "#3aafa9", "style": "rounded,filled", "shape": "box", "fontcolor": "white"}) # Count which names occur multiple times, e.g. t/talk-01 t2/talk-01 nodename_c = Counter([item[2] for item in graph.triples if item[1] == ":instance"]) # Generated initial nodenames for each variable, e.g. {"t": "talk-01", "t2": "talk-01"} nodenames = {item[0]: item[2] for item in graph.triples if item[1] == ":instance"} # Modify nodenames, so that the values are unique, e.g. {"t": "talk-01 (1)", "t2": "talk-01 (2)"} # but only the value occurs more than once nodename_str_c = Counter() for varname in nodenames: nodename = nodenames[varname] if nodename_c[nodename] > 1: nodename_str_c[nodename] += 1 nodenames[varname] = f"{nodename} ({nodename_str_c[nodename]})" def get_node_name(item: str): return nodenames[item] if item in nodenames else item try: for triple in graph.triples: if triple[1] == ":instance": continue else: visualized.edge(get_node_name(triple[0]), get_node_name(triple[2]), label=triple[1]) except Exception as exc: st.write("The generated graph is not valid so it cannot be visualized correctly. Below is the closest attempt" " to a valid graph but note that this is probably invalid Penman.") st.code(penman_str) st.write("The initial linearized output of the model was:") st.code(linearized) with st.expander("Error trace"): st.write(exc) else: st.subheader("Graph visualization") st.graphviz_chart(visualized, use_container_width=True) # Download link def create_download_link(img_bytes: bytes): encoded = base64.b64encode(img_bytes).decode("utf-8") return f'Download graph' img = visualized.pipe(format="png") st.markdown(create_download_link(img), unsafe_allow_html=True) # Additional info st.subheader("Model output and Penman graph") st.write("The linearized output of the model (after some post-processing) is:") st.code(linearized) st.write("When converted into Penman, it looks like this:") st.code(penman.encode(graph)) ######################## # Information, socials # ######################## st.header("SignON 🤟") st.markdown("""
SignON aims to bridge the communication gap between deaf, hard-of-hearing and hearing people through an accessible translation service to translate between languages and modalities with particular attention to sign languages.