AMR-KELEG commited on
Commit
206ed66
1 Parent(s): 604feee

Batch processing and styling

Browse files
Files changed (4) hide show
  1. .streamlit/config.toml +6 -0
  2. app.py +70 -17
  3. assets/ALDi_logo.svg +3 -0
  4. constants.py +1 -0
.streamlit/config.toml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ [theme]
2
+ primaryColor="#FF8000"
3
+ #backgroundColor="#FFFFFF"
4
+ #secondaryBackgroundColor="#F0F2F6"
5
+ #textColor="#262730"
6
+ #font="sans serif"
app.py CHANGED
@@ -1,13 +1,22 @@
1
  # Hint: this cheatsheet is magic! https://cheat-sheet.streamlit.app/
2
-
3
  import constants
4
- import numpy as np
5
  import pandas as pd
6
  import streamlit as st
 
7
  from transformers import BertForSequenceClassification, AutoTokenizer
8
- import random
9
  import altair as alt
10
  from altair import X, Y, Scale
 
 
 
 
 
 
 
 
 
 
11
 
12
 
13
  @st.cache_data
@@ -16,47 +25,91 @@ def convert_df(df):
16
  return df.to_csv(index=None).encode("utf-8")
17
 
18
 
19
- def compute_ALDi(inputs):
20
- return random.randint(0, 100) / 100
 
 
21
 
22
 
23
- st.title(constants.TITLE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  tab1, tab2 = st.tabs(["Input a Sentence", "Upload a File"])
26
 
27
  with tab1:
28
- sent = st.text_input("Arabic Sentence:", placeholder="Enter an Arabic sentence.")
 
 
29
 
30
  # TODO: Check if this is needed!
31
- st.button("Submit")
32
 
33
  if sent:
34
- ALDi_score = compute_ALDi(sent)
35
- st.write(ALDi_score)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  with tab2:
38
  file = st.file_uploader("Upload a file", type=["txt"])
39
  if file is not None:
40
  df = pd.read_csv(file, sep="\t", header=None)
41
  df.columns = ["Sentence"]
42
-
43
- df = pd.concat([df, df, df])
44
- df = pd.concat([df, df, df])
45
- df = pd.concat([df, df, df])
46
  df.reset_index(drop=True, inplace=True)
47
 
48
  # TODO: Run the model
49
- df["ALDi"] = df["Sentence"].apply(lambda s: compute_ALDi(s))
50
 
51
  # A horizontal rule
52
  st.markdown("""---""")
53
 
54
  chart = (
55
  alt.Chart(df.reset_index())
56
- .mark_area(color="violet", opacity=0.5)
57
  .encode(
58
  x=X(field="index", title="Sentence Index"),
59
- y=Y("ALDi", scale=Scale(domain=[0, 1]))
60
  )
61
  )
62
  st.altair_chart(chart.interactive(), use_container_width=True)
 
1
  # Hint: this cheatsheet is magic! https://cheat-sheet.streamlit.app/
 
2
  import constants
 
3
  import pandas as pd
4
  import streamlit as st
5
+ import matplotlib.pyplot as plt
6
  from transformers import BertForSequenceClassification, AutoTokenizer
7
+
8
  import altair as alt
9
  from altair import X, Y, Scale
10
+ import base64
11
+
12
+
13
+ @st.cache_data
14
+ def render_svg(svg):
15
+ """Renders the given svg string."""
16
+ b64 = base64.b64encode(svg.encode("utf-8")).decode("utf-8")
17
+ html = rf'<p align="center"> <img src="data:image/svg+xml;base64,{b64}"/> </p>'
18
+ c = st.container()
19
+ c.write(html, unsafe_allow_html=True)
20
 
21
 
22
  @st.cache_data
 
25
  return df.to_csv(index=None).encode("utf-8")
26
 
27
 
28
+ @st.cache_resource
29
+ def load_model(model_name):
30
+ model = BertForSequenceClassification.from_pretrained(model_name)
31
+ return model
32
 
33
 
34
+ tokenizer = AutoTokenizer.from_pretrained(constants.MODEL_NAME)
35
+ model = load_model(constants.MODEL_NAME)
36
+
37
+
38
+ def compute_ALDi(sentences):
39
+ # TODO: Perform inference in batches
40
+ progress_text = "Computing ALDi..."
41
+ my_bar = st.progress(0, text=progress_text)
42
+
43
+ BATCH_SIZE = 4
44
+ output_logits = []
45
+ for first_index in range(0, len(sentences), BATCH_SIZE):
46
+ inputs = tokenizer(
47
+ sentences[first_index : first_index + BATCH_SIZE],
48
+ return_tensors="pt",
49
+ padding=True,
50
+ )
51
+ outputs = model(**inputs).logits.reshape(-1).tolist()
52
+ output_logits = output_logits + [max(min(o, 1), 0) for o in outputs]
53
+ my_bar.progress(
54
+ min((first_index + BATCH_SIZE) / len(sentences), 1), text=progress_text
55
+ )
56
+ my_bar.empty()
57
+ return output_logits
58
+
59
+
60
+ render_svg(open("assets/ALDi_logo.svg").read())
61
 
62
  tab1, tab2 = st.tabs(["Input a Sentence", "Upload a File"])
63
 
64
  with tab1:
65
+ sent = st.text_input(
66
+ "Arabic Sentence:", placeholder="Enter an Arabic sentence.", on_change=None
67
+ )
68
 
69
  # TODO: Check if this is needed!
70
+ clicked = st.button("Submit")
71
 
72
  if sent:
73
+ ALDi_score = compute_ALDi([sent])[0]
74
+
75
+ ORANGE_COLOR = "#FF8000"
76
+ fig, ax = plt.subplots(figsize=(8, 1))
77
+ fig.patch.set_facecolor("none")
78
+ ax.set_facecolor("none")
79
+
80
+ ax.spines["left"].set_color(ORANGE_COLOR)
81
+ ax.spines["bottom"].set_color(ORANGE_COLOR)
82
+ ax.tick_params(axis="x", colors=ORANGE_COLOR)
83
+
84
+ ax.spines[["right", "top"]].set_visible(False)
85
+
86
+ ax.barh(y=[0], width=[ALDi_score], color=ORANGE_COLOR)
87
+ ax.set_xlim(0, 1)
88
+ ax.set_ylim(-1, 1)
89
+ ax.set_title(f"ALDi score is: {round(ALDi_score, 3)}", color=ORANGE_COLOR)
90
+ ax.get_yaxis().set_visible(False)
91
+ ax.set_xlabel("ALDi score", color=ORANGE_COLOR)
92
+ st.pyplot(fig)
93
 
94
  with tab2:
95
  file = st.file_uploader("Upload a file", type=["txt"])
96
  if file is not None:
97
  df = pd.read_csv(file, sep="\t", header=None)
98
  df.columns = ["Sentence"]
 
 
 
 
99
  df.reset_index(drop=True, inplace=True)
100
 
101
  # TODO: Run the model
102
+ df["ALDi"] = compute_ALDi(df["Sentence"].tolist())
103
 
104
  # A horizontal rule
105
  st.markdown("""---""")
106
 
107
  chart = (
108
  alt.Chart(df.reset_index())
109
+ .mark_area(color="darkorange", opacity=0.5)
110
  .encode(
111
  x=X(field="index", title="Sentence Index"),
112
+ y=Y("ALDi", scale=Scale(domain=[0, 1])),
113
  )
114
  )
115
  st.altair_chart(chart.interactive(), use_container_width=True)
assets/ALDi_logo.svg ADDED
constants.py CHANGED
@@ -1,3 +1,4 @@
1
  CHOICE_TEXT = "Input Text"
2
  CHOICE_FILE = "Upload File"
3
  TITLE = "ALDi: Arabic Level of Dialectness"
 
 
1
  CHOICE_TEXT = "Input Text"
2
  CHOICE_FILE = "Upload File"
3
  TITLE = "ALDi: Arabic Level of Dialectness"
4
+ MODEL_NAME = "AMR-KELEG/toy_regression_model"