rota-app / app.py
akgodwin's picture
add ABOUT.md file
370e3a5
raw history blame
No virus
3.52 kB
from functools import partial
from pathlib import Path
from pandas import DataFrame, read_csv, read_excel
import streamlit as st
from more_itertools import ichunked
from stqdm import stqdm
from onnx_model_utils import predict, predict_bulk, max_pred_bulk, RELEASE_TAG
from download import download_link
PRED_BATCH_SIZE = 4
st.set_page_config(page_title="ROTA", initial_sidebar_state="collapsed")
st.markdown(Path("ABOUT.md").read_text())
st.markdown("## ✏️ Single Coder Demo")
input_text = st.text_input(
"Input Offense",
value="FRAUDULENT USE OF A CREDIT CARD OR DEBT CARD >= $25,000",
)
predictions = predict(input_text)
st.markdown("Predictions")
labels = ["Charge Category"]
st.dataframe(
DataFrame(predictions[0])
.assign(
confidence=lambda d: d["score"].apply(lambda d: round(d * 100, 0)).astype(int)
)
.drop("score", axis="columns")
)
st.markdown("---")
st.markdown("## 📑 Bulk Coder")
st.warning(
"⚠️ *Note:* Your input data will be deduplicated"
" on the selected column to reduce computation requirements."
" You will need to re-join the results on your offense text column."
)
st.markdown("1️⃣ **Upload File**")
uploaded_file = st.file_uploader("Bulk Upload", type=["xlsx", "csv"])
file_readers = {"csv": read_csv, "xlsx": partial(read_excel, engine="openpyxl")}
if uploaded_file is not None:
for filetype, reader in file_readers.items():
if uploaded_file.name.endswith(filetype):
df = reader(uploaded_file)
file_name = uploaded_file.name
del uploaded_file
st.write("2️⃣ **Select Column of Offense Descriptions**")
string_columns = list(df.select_dtypes("object").columns)
longest_column = max(
[(df[c].str.len().mean(), c) for c in string_columns], key=lambda x: x[0]
)[1]
selected_column = st.selectbox(
"Select Column",
options=list(string_columns),
index=string_columns.index(longest_column),
)
original_length = len(df)
df_unique = df.drop_duplicates(subset=[selected_column]).copy()
del df
st.markdown(
f"Uploaded Data Sample `(Deduplicated. N Rows = {len(df_unique)}, Original N = {original_length})`"
)
st.dataframe(df_unique.head(20))
st.write(f"3️⃣ **Predict Using Column: `{selected_column}`**")
column = df_unique[selected_column].copy()
del df_unique
if st.button(f"Compute Predictions"):
input_texts = (value for _, value in column.items())
n_batches = (len(column) // PRED_BATCH_SIZE) + 1
bulk_preds = []
for batch in stqdm(
ichunked(input_texts, PRED_BATCH_SIZE),
total=n_batches,
desc="Bulk Predict Progress",
):
batch_preds = predict_bulk(batch)
bulk_preds.extend(batch_preds)
pred_df = column.to_frame()
max_preds = max_pred_bulk(bulk_preds)
pred_df["charge_category_pred"] = [p["label"] for p in max_preds]
pred_df["charge_category_pred_confidence"] = [
int(round(p["score"] * 100, 0)) for p in max_preds
]
del column
del bulk_preds
del max_preds
# # TODO: Add all scores
st.write("**Sample Output**")
st.dataframe(pred_df.head(100))
tmp_download_link = download_link(
pred_df,
f"{file_name}-ncrp-predictions.csv",
"⬇️ Download as CSV",
)
st.markdown(tmp_download_link, unsafe_allow_html=True)