Bram Vanroy commited on
Commit
05b9456
β€’
1 Parent(s): f8b0e70

add check for empty input and show info/error

Browse files
Files changed (1) hide show
  1. app.py +75 -63
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from collections import Counter
2
 
3
  import graphviz
@@ -18,83 +19,94 @@ with st.form("input data"):
18
  src_lang = lang_col.selectbox(label="Language", options=list(LANGUAGES.keys()), index=0)
19
  submitted = st.form_submit_button("Submit")
20
 
 
21
  if submitted:
22
- multilingual = src_lang != "English"
23
- model, tokenizer, logitsprocessor = get_resources(multilingual)
24
- gen_kwargs = {
25
- "max_length": model.config.max_length,
26
- "num_beams": model.config.num_beams,
27
- "logits_processor": LogitsProcessorList([logitsprocessor])
28
- }
29
-
30
- linearized = translate(text, src_lang, model, tokenizer, **gen_kwargs)
31
- penman_str = linearized2penmanstr(linearized)
32
-
33
- try:
34
- graph = penman.decode(penman_str, model=NoOpModel())
35
- except Exception as exc:
36
- st.write(f"The generated graph is not valid so it cannot be visualized correctly. Below is the closest attempt"
37
- f" to a valid graph but note that this is invalid Penman.")
38
- st.code(penman_str)
39
-
40
- with st.expander("Error trace"):
41
- st.write(exc)
42
  else:
43
- visualized = graphviz.Digraph(node_attr={"color": "#3aafa9", "style": "rounded,filled", "shape": "box",
44
- "fontcolor": "white"})
45
-
46
- # Count which names occur multiple times, e.g. t/talk-01 t2/talk-01
47
- nodename_c = Counter([item[2] for item in graph.triples if item[1] == ":instance"])
48
- # Generated initial nodenames for each variable, e.g. {"t": "talk-01", "t2": "talk-01"}
49
- nodenames = {item[0]: item[2] for item in graph.triples if item[1] == ":instance"}
50
-
51
- # Modify nodenames, so that the values are unique, e.g. {"t": "talk-01 (1)", "t2": "talk-01 (2)"}
52
- # but only the value occurs more than once
53
- nodename_str_c = Counter()
54
- for varname in nodenames:
55
- nodename = nodenames[varname]
56
- if nodename_c[nodename] > 1:
57
- nodename_str_c[nodename] += 1
58
- nodenames[varname] = f"{nodename} ({nodename_str_c[nodename]})"
59
-
60
- def get_node_name(item: str):
61
- return nodenames[item] if item in nodenames else item
62
 
63
  try:
64
- for triple in graph.triples:
65
- if triple[1] == ":instance":
66
- continue
67
- else:
68
- visualized.edge(get_node_name(triple[0]), get_node_name(triple[2]), label=triple[1])
69
  except Exception as exc:
70
- st.write("The generated graph is not valid so it cannot be visualized correctly. Below is the closest attempt"
71
- " to a valid graph but note that this is probably invalid Penman.")
72
  st.code(penman_str)
73
- st.write("The initial linearized output of the model was:")
74
- st.code(linearized)
75
 
76
  with st.expander("Error trace"):
77
  st.write(exc)
78
  else:
79
- st.subheader("Graph visualization")
80
- st.graphviz_chart(visualized, use_container_width=True)
81
-
82
- # Download
83
- img = visualized.pipe(format="png")
84
- st.download_button("Download graph", img, mime="image/png")
85
-
86
- # Additional info
87
- st.subheader("Model output and Penman graph")
88
- st.write("The linearized output of the model (after some post-processing) is:")
89
- st.code(linearized)
90
- st.write("When converted into Penman, it looks like this:")
91
- st.code(penman.encode(graph))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
 
94
  ########################
95
  # Information, socials #
96
  ########################
97
- st.markdown("## Project: SignON 🀟")
98
 
99
  st.markdown("""
100
  <div style="display: flex">
@@ -108,7 +120,7 @@ st.markdown("""
108
  """, unsafe_allow_html=True)
109
 
110
 
111
- st.markdown("## Contact βœ’οΈ")
112
 
113
  st.markdown("Would you like additional functionality in the demo? Or just want to get in touch?"
114
  " Give me a shout on [Twitter](https://twitter.com/BramVanroy)"
 
1
+ import base64
2
  from collections import Counter
3
 
4
  import graphviz
 
19
  src_lang = lang_col.selectbox(label="Language", options=list(LANGUAGES.keys()), index=0)
20
  submitted = st.form_submit_button("Submit")
21
 
22
+ error_ct = st.empty()
23
  if submitted:
24
+ text = text.strip()
25
+ if not text:
26
+ error_ct.error("Text cannot be empty!", icon="⚠️")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  else:
28
+ error_ct.info("Generating abstract meaning representation (AMR)...", icon="πŸ’»")
29
+ multilingual = src_lang != "English"
30
+ model, tokenizer, logitsprocessor = get_resources(multilingual)
31
+ gen_kwargs = {
32
+ "max_length": model.config.max_length,
33
+ "num_beams": model.config.num_beams,
34
+ "logits_processor": LogitsProcessorList([logitsprocessor])
35
+ }
36
+
37
+ linearized = translate(text, src_lang, model, tokenizer, **gen_kwargs)
38
+ penman_str = linearized2penmanstr(linearized)
39
+ error_ct.empty()
 
 
 
 
 
 
 
40
 
41
  try:
42
+ graph = penman.decode(penman_str, model=NoOpModel())
 
 
 
 
43
  except Exception as exc:
44
+ st.write(f"The generated graph is not valid so it cannot be visualized correctly. Below is the closest attempt"
45
+ f" to a valid graph but note that this is invalid Penman.")
46
  st.code(penman_str)
 
 
47
 
48
  with st.expander("Error trace"):
49
  st.write(exc)
50
  else:
51
+ visualized = graphviz.Digraph(node_attr={"color": "#3aafa9", "style": "rounded,filled", "shape": "box",
52
+ "fontcolor": "white"})
53
+
54
+ # Count which names occur multiple times, e.g. t/talk-01 t2/talk-01
55
+ nodename_c = Counter([item[2] for item in graph.triples if item[1] == ":instance"])
56
+ # Generated initial nodenames for each variable, e.g. {"t": "talk-01", "t2": "talk-01"}
57
+ nodenames = {item[0]: item[2] for item in graph.triples if item[1] == ":instance"}
58
+
59
+ # Modify nodenames, so that the values are unique, e.g. {"t": "talk-01 (1)", "t2": "talk-01 (2)"}
60
+ # but only the value occurs more than once
61
+ nodename_str_c = Counter()
62
+ for varname in nodenames:
63
+ nodename = nodenames[varname]
64
+ if nodename_c[nodename] > 1:
65
+ nodename_str_c[nodename] += 1
66
+ nodenames[varname] = f"{nodename} ({nodename_str_c[nodename]})"
67
+
68
+ def get_node_name(item: str):
69
+ return nodenames[item] if item in nodenames else item
70
+
71
+ try:
72
+ for triple in graph.triples:
73
+ if triple[1] == ":instance":
74
+ continue
75
+ else:
76
+ visualized.edge(get_node_name(triple[0]), get_node_name(triple[2]), label=triple[1])
77
+ except Exception as exc:
78
+ st.write("The generated graph is not valid so it cannot be visualized correctly. Below is the closest attempt"
79
+ " to a valid graph but note that this is probably invalid Penman.")
80
+ st.code(penman_str)
81
+ st.write("The initial linearized output of the model was:")
82
+ st.code(linearized)
83
+
84
+ with st.expander("Error trace"):
85
+ st.write(exc)
86
+ else:
87
+ st.subheader("Graph visualization")
88
+ st.graphviz_chart(visualized, use_container_width=True)
89
+
90
+ # Download link
91
+ def create_download_link(img_bytes: bytes):
92
+ encoded = base64.b64encode(img_bytes).decode("utf-8")
93
+ return f'<a href="data:image/png;charset=utf-8;base64,{encoded}" download="amr-graph.png">Download graph</a>'
94
+
95
+ img = visualized.pipe(format="png")
96
+ st.markdown(create_download_link(img), unsafe_allow_html=True)
97
+
98
+ # Additional info
99
+ st.subheader("Model output and Penman graph")
100
+ st.write("The linearized output of the model (after some post-processing) is:")
101
+ st.code(linearized)
102
+ st.write("When converted into Penman, it looks like this:")
103
+ st.code(penman.encode(graph))
104
 
105
 
106
  ########################
107
  # Information, socials #
108
  ########################
109
+ st.header("Project: SignON 🀟")
110
 
111
  st.markdown("""
112
  <div style="display: flex">
 
120
  """, unsafe_allow_html=True)
121
 
122
 
123
+ st.header("Contact βœ’οΈ")
124
 
125
  st.markdown("Would you like additional functionality in the demo? Or just want to get in touch?"
126
  " Give me a shout on [Twitter](https://twitter.com/BramVanroy)"