Spaces:
Running
Running
# Hint: this cheatsheet is magic! https://cheat-sheet.streamlit.app/ | |
import constants | |
import pandas as pd | |
import streamlit as st | |
import matplotlib.pyplot as plt | |
from transformers import BertForSequenceClassification, AutoTokenizer | |
import altair as alt | |
from altair import X, Y, Scale | |
import base64 | |
def render_svg(svg): | |
"""Renders the given svg string.""" | |
b64 = base64.b64encode(svg.encode("utf-8")).decode("utf-8") | |
html = rf'<p align="center"> <img src="data:image/svg+xml;base64,{b64}"/> </p>' | |
c = st.container() | |
c.write(html, unsafe_allow_html=True) | |
def convert_df(df): | |
# IMPORTANT: Cache the conversion to prevent computation on every rerun | |
return df.to_csv(index=None).encode("utf-8") | |
def load_model(model_name): | |
model = BertForSequenceClassification.from_pretrained(model_name) | |
return model | |
tokenizer = AutoTokenizer.from_pretrained(constants.MODEL_NAME) | |
model = load_model(constants.MODEL_NAME) | |
def compute_ALDi(sentences): | |
# TODO: Perform inference in batches | |
progress_text = "Computing ALDi..." | |
my_bar = st.progress(0, text=progress_text) | |
BATCH_SIZE = 4 | |
output_logits = [] | |
for first_index in range(0, len(sentences), BATCH_SIZE): | |
inputs = tokenizer( | |
sentences[first_index : first_index + BATCH_SIZE], | |
return_tensors="pt", | |
padding=True, | |
) | |
outputs = model(**inputs).logits.reshape(-1).tolist() | |
output_logits = output_logits + [max(min(o, 1), 0) for o in outputs] | |
my_bar.progress( | |
min((first_index + BATCH_SIZE) / len(sentences), 1), text=progress_text | |
) | |
my_bar.empty() | |
return output_logits | |
render_svg(open("assets/ALDi_logo.svg").read()) | |
tab1, tab2 = st.tabs(["Input a Sentence", "Upload a File"]) | |
with tab1: | |
sent = st.text_input( | |
"Arabic Sentence:", placeholder="Enter an Arabic sentence.", on_change=None | |
) | |
# TODO: Check if this is needed! | |
clicked = st.button("Submit") | |
if sent: | |
ALDi_score = compute_ALDi([sent])[0] | |
ORANGE_COLOR = "#FF8000" | |
fig, ax = plt.subplots(figsize=(8, 1)) | |
fig.patch.set_facecolor("none") | |
ax.set_facecolor("none") | |
ax.spines["left"].set_color(ORANGE_COLOR) | |
ax.spines["bottom"].set_color(ORANGE_COLOR) | |
ax.tick_params(axis="x", colors=ORANGE_COLOR) | |
ax.spines[["right", "top"]].set_visible(False) | |
ax.barh(y=[0], width=[ALDi_score], color=ORANGE_COLOR) | |
ax.set_xlim(0, 1) | |
ax.set_ylim(-1, 1) | |
ax.set_title(f"ALDi score is: {round(ALDi_score, 3)}", color=ORANGE_COLOR) | |
ax.get_yaxis().set_visible(False) | |
ax.set_xlabel("ALDi score", color=ORANGE_COLOR) | |
st.pyplot(fig) | |
with tab2: | |
file = st.file_uploader("Upload a file", type=["txt"]) | |
if file is not None: | |
df = pd.read_csv(file, sep="\t", header=None) | |
df.columns = ["Sentence"] | |
df.reset_index(drop=True, inplace=True) | |
# TODO: Run the model | |
df["ALDi"] = compute_ALDi(df["Sentence"].tolist()) | |
# A horizontal rule | |
st.markdown("""---""") | |
chart = ( | |
alt.Chart(df.reset_index()) | |
.mark_area(color="darkorange", opacity=0.5) | |
.encode( | |
x=X(field="index", title="Sentence Index"), | |
y=Y("ALDi", scale=Scale(domain=[0, 1])), | |
) | |
) | |
st.altair_chart(chart.interactive(), use_container_width=True) | |
col1, col2 = st.columns([4, 1]) | |
with col1: | |
# Display the output | |
st.table( | |
df, | |
) | |
with col2: | |
# Add a download button | |
csv = convert_df(df) | |
st.download_button( | |
label=":file_folder: Download predictions as CSV", | |
data=csv, | |
file_name="ALDi_scores.csv", | |
mime="text/csv", | |
) | |