GUI1 / app.py
its-zion-18's picture
Update app.py
5deec10 verified
import os
import shutil
import zipfile
import pathlib
import pandas as pd
import gradio as gr
import huggingface_hub
import autogluon.tabular as agt
import shutil
# --- Model Loading ---
def _prepare_predictor_dir() -> str:
"""Downloads and extracts the model files from Hugging Face."""
# Add a check to clear the cache before download
if CACHE_DIR.exists():
print(f"Clearing cache directory: {CACHE_DIR}")
shutil.rmtree(CACHE_DIR)
CACHE_DIR.mkdir(parents=True, exist_ok=True)
# --- Settings and Metadata ---
MODEL_REPO_ID = "bcueva/2024-24679-tabular-autolguon-predictor"
ZIP_FILENAME = "autogluon_predictor_dir.zip"
CACHE_DIR = pathlib.Path("hf_assets")
EXTRACT_DIR = CACHE_DIR / "predictor_native"
# Features for the model
FEATURE_COLS = [
"Capacity_ml",
"Height_cm",
"Diameter_cm",
"Weight_g",
"Material",
]
# The target variable to be predicted
TARGET_COL = "Use_Type"
# Define the possible values for the 'Material' feature
MATERIAL_LABELS = [
"Ceramic",
"Glass",
"Plastic",
"Stainless Steel",
]
# Mapping the integer labels back to human-readable labels for the 'Use_Type' prediction
# These labels are based on the model's training data.
OUTCOME_LABELS = {
0: "Hot",
1: "Cold",
}
# --- Model Loading ---
def _prepare_predictor_dir() -> str:
"""Downloads and extracts the model files from Hugging Face."""
CACHE_DIR.mkdir(parents=True, exist_ok=True)
local_zip = huggingface_hub.hf_hub_download(
repo_id=MODEL_REPO_ID,
filename=ZIP_FILENAME,
repo_type="model",
local_dir=str(CACHE_DIR),
local_dir_use_symlinks=False,
)
if EXTRACT_DIR.exists():
shutil.rmtree(EXTRACT_DIR)
EXTRACT_DIR.mkdir(parents=True, exist_ok=True)
with zipfile.ZipFile(local_zip, "r") as zf:
zf.extractall(str(EXTRACT_DIR))
contents = list(EXTRACT_DIR.iterdir())
predictor_root = contents[0] if (len(contents) == 1 and contents[0].is_dir()) else EXTRACT_DIR
return str(predictor_root)
PREDICTOR_DIR = _prepare_predictor_dir()
PREDICTOR = agt.TabularPredictor.load(PREDICTOR_DIR, require_py_version_match=False)
# --- Prediction Function ---
def _human_label(c):
"""Maps the model's numerical output to a human-readable label."""
try:
ci = int(c)
if ci in OUTCOME_LABELS:
return OUTCOME_LABELS[ci]
except (ValueError, TypeError):
pass
if c in OUTCOME_LABELS:
return OUTCOME_LABELS[c]
return str(c)
def do_predict(capacity_ml, height_cm, diameter_cm, weight_g, material):
"""
Takes user input, formats it for the model, and returns a prediction.
'Cup_ID' is excluded as per the user request.
"""
# Create a DataFrame for a single prediction
# 'Cup_ID' is included as a placeholder with a dummy value (e.g., 0)
# because the model's training data included it.
row = {
"Cup_ID": 0, # Dummy value
"Capacity_ml": capacity_ml,
"Height_cm": height_cm,
"Diameter_cm": diameter_cm,
"Weight_g": weight_g,
"Material": material,
}
X = pd.DataFrame([row])
# Get the raw prediction and its label
pred_series = PREDICTOR.predict(X)
raw_pred = pred_series.iloc[0]
pred_label = _human_label(raw_pred)
# Get prediction probabilities
try:
proba = PREDICTOR.predict_proba(X)
if isinstance(proba, pd.Series):
proba = proba.to_frame().T
# Format probabilities into a dictionary for Gradio
proba_dict = {
_human_label(cls): float(val) for cls, val in proba.iloc[0].items()
}
proba_dict = dict(sorted(proba_dict.items(), key=lambda kv: kv[1], reverse=True))
except Exception as e:
print(f"Could not get probabilities: {e}")
proba_dict = None
return proba_dict
# --- Gradio UI ---
with gr.Blocks(fill_height=True) as demo:
gr.Markdown("# Cup Use Predictor ☕️")
gr.Markdown("""
Enter the physical properties of a cup to predict its intended use type (Hot or Cold).
This app uses a pre-trained **AutoGluon** model to classify the cup's purpose.
""")
with gr.Column():
material = gr.Radio(choices=MATERIAL_LABELS, value="Ceramic", label="Material")
with gr.Row():
with gr.Column():
capacity_ml = gr.Number(value=350, label="Capacity (ml)")
height_cm = gr.Number(value=10.0, label="Height (cm)")
with gr.Column():
diameter_cm = gr.Number(value=8.0, label="Diameter (cm)")
weight_g = gr.Number(value=250, label="Weight (g)")
predict_btn = gr.Button("Predict Use Type")
output_label = gr.Label(num_top_classes=2, label="Prediction")
inputs = [capacity_ml, height_cm, diameter_cm, weight_g, material]
predict_btn.click(fn=do_predict, inputs=inputs, outputs=output_label)
gr.Examples(
examples=[
[478, 8, 7.7, 315, "Ceramic"], # Example for a coffee mug (likely 'Hot')
[442, 13.8, 6.4, 155, "Glass"], # Example for a tall drinking glass (likely 'Cold')
[392, 18, 5.7, 61, "Plastic"], # Example for a small tea cup (likely 'Hot')
[302, 17.5, 5.5, 783, "Stainless Steel"], # Example for a disposable soda cup (likely 'Cold')
],
inputs=inputs,
label="Representative Examples",
examples_per_page=5,
)
if __name__ == "__main__":
demo.launch()