# Hint: this cheatsheet is magic! https://cheat-sheet.streamlit.app/ import constants import pandas as pd import streamlit as st import matplotlib.pyplot as plt from transformers import BertForSequenceClassification, AutoTokenizer import altair as alt from altair import X, Y, Scale import base64 import re def preprocess_text(arabic_text): """Apply preprocessing to the given Arabic text. Args: arabic_text: The Arabic text to be preprocessed. Returns: The preprocessed Arabic text. """ no_urls = re.sub( r"(https|http)?:\/\/(\w|\.|\/|\?|\=|\&|\%)*\b", "", arabic_text, flags=re.MULTILINE, ) no_english = re.sub(r"[a-zA-Z]", "", no_urls) return no_english @st.cache_data def render_svg(svg): """Renders the given svg string.""" b64 = base64.b64encode(svg.encode("utf-8")).decode("utf-8") html = rf'
' c = st.container() c.write(html, unsafe_allow_html=True) @st.cache_data def convert_df(df): # IMPORTANT: Cache the conversion to prevent computation on every rerun return df.to_csv(index=None).encode("utf-8") @st.cache_resource def load_model(model_name): model = BertForSequenceClassification.from_pretrained(model_name) return model tokenizer = AutoTokenizer.from_pretrained(constants.MODEL_NAME) model = load_model(constants.MODEL_NAME) def compute_ALDi(sentences): """Computes the ALDi score for the given sentences. Args: sentences: A list of Arabic sentences. Returns: A list of ALDi scores for the given sentences. """ progress_text = "Computing ALDi..." my_bar = st.progress(0, text=progress_text) BATCH_SIZE = 4 output_logits = [] preprocessed_sentences = [preprocess_text(s) for s in sentences] for first_index in range(0, len(preprocessed_sentences), BATCH_SIZE): inputs = tokenizer( preprocessed_sentences[first_index : first_index + BATCH_SIZE], return_tensors="pt", padding=True, ) outputs = model(**inputs).logits.reshape(-1).tolist() output_logits = output_logits + [max(min(o, 1), 0) for o in outputs] my_bar.progress( min((first_index + BATCH_SIZE) / len(preprocessed_sentences), 1), text=progress_text, ) my_bar.empty() return output_logits render_svg(open("assets/ALDi_logo.svg").read()) tab1, tab2 = st.tabs(["Input a Sentence", "Upload a File"]) with tab1: sent = st.text_input( "Arabic Sentence:", placeholder="Enter an Arabic sentence.", on_change=None ) # TODO: Check if this is needed! clicked = st.button("Submit") if sent: ALDi_score = compute_ALDi([sent])[0] ORANGE_COLOR = "#FF8000" fig, ax = plt.subplots(figsize=(8, 1)) fig.patch.set_facecolor("none") ax.set_facecolor("none") ax.spines["left"].set_color(ORANGE_COLOR) ax.spines["bottom"].set_color(ORANGE_COLOR) ax.tick_params(axis="x", colors=ORANGE_COLOR) ax.spines[["right", "top"]].set_visible(False) ax.barh(y=[0], width=[ALDi_score], color=ORANGE_COLOR) ax.set_xlim(0, 1) ax.set_ylim(-1, 1) ax.set_title(f"ALDi score is: {round(ALDi_score, 3)}", color=ORANGE_COLOR) ax.get_yaxis().set_visible(False) ax.set_xlabel("ALDi score", color=ORANGE_COLOR) st.pyplot(fig) print(sent) with open("logs.txt", "a") as f: f.write(sent + "\n") with tab2: file = st.file_uploader("Upload a file", type=["txt"]) if file is not None: df = pd.read_csv(file, sep="\t", header=None) df.columns = ["Sentence"] df.reset_index(drop=True, inplace=True) # TODO: Run the model df["ALDi"] = compute_ALDi(df["Sentence"].tolist()) # A horizontal rule st.markdown("""---""") chart = ( alt.Chart(df.reset_index()) .mark_area(color="darkorange", opacity=0.5) .encode( x=X(field="index", title="Sentence Index"), y=Y("ALDi", scale=Scale(domain=[0, 1])), ) ) st.altair_chart(chart.interactive(), use_container_width=True) col1, col2 = st.columns([4, 1]) with col1: # Display the output st.table( df, ) with col2: # Add a download button csv = convert_df(df) st.download_button( label=":file_folder: Download predictions as CSV", data=csv, file_name="ALDi_scores.csv", mime="text/csv", )