Spaces:
Sleeping
Sleeping
| """ | |
| Tabular Flower Classifier - Gradio App | |
| Homework 3 - GUI Module | |
| Author: Anyu Huang | |
| Model Source: its-zion-18/flowers-tabular-autolguon-predictor | |
| This app loads an AutoGluon TabularPredictor from a ZIP file | |
| and exposes a simple Gradio interface to make predictions and show class | |
| probabilities. | |
| """ | |
| # ============================================================================ | |
| # IMPORTS | |
| # ============================================================================ | |
| import os | |
| import shutil | |
| import zipfile | |
| import pathlib | |
| import pandas as pd | |
| import gradio as gr | |
| import numpy as np | |
| from autogluon.tabular import TabularPredictor | |
| # ============================================================================ | |
| # CONFIGURATION | |
| # ============================================================================ | |
| ZIP_FILENAME = "autogluon_predictor_dir.zip" | |
| EXTRACT_DIR = pathlib.Path("predictor_native") | |
| # ============================================================================ | |
| # MODEL LOADING | |
| # ============================================================================ | |
| def load_predictor(): | |
| """ | |
| Extract and load an AutoGluon TabularPredictor from a ZIP file. | |
| Workflow: | |
| 1) Check if ZIP exists in the repository root | |
| 2) Extract into EXTRACT_DIR (clean if exists) | |
| 3) Find the predictor root (folder that contains 'models') and load | |
| Returns: | |
| TabularPredictor: Loaded predictor ready for inference. | |
| Raises: | |
| FileNotFoundError: If ZIP cannot be found. | |
| """ | |
| # Check if ZIP exists in repo | |
| if not os.path.exists(ZIP_FILENAME): | |
| raise FileNotFoundError(f"ZIP file not found: {ZIP_FILENAME}") | |
| print(f"Found ZIP file: {ZIP_FILENAME}") | |
| # Clean & re-create extraction directory | |
| if EXTRACT_DIR.exists(): | |
| shutil.rmtree(EXTRACT_DIR) | |
| EXTRACT_DIR.mkdir(parents=True, exist_ok=True) | |
| # Extract the predictor directory | |
| print("Extracting predictor...") | |
| with zipfile.ZipFile(ZIP_FILENAME, 'r') as zip_ref: | |
| zip_ref.extractall(str(EXTRACT_DIR)) | |
| # Find the predictor root (heuristic: folder containing 'models') | |
| for root, dirs, files in os.walk(str(EXTRACT_DIR)): | |
| if 'models' in dirs: | |
| print(f"Loading predictor from: {root}") | |
| return TabularPredictor.load(root, require_py_version_match=False) | |
| # Fallback: try the top-level extract dir | |
| print(f"Loading predictor from: {EXTRACT_DIR}") | |
| return TabularPredictor.load(str(EXTRACT_DIR), require_py_version_match=False) | |
| # Initialize predictor once at startup | |
| print("Loading AutoGluon TabularPredictor...") | |
| PREDICTOR = load_predictor() | |
| print("Predictor loaded successfully!") | |
| # Metadata helpers (feature names & label) | |
| FEATURE_COLS = ( | |
| PREDICTOR.feature_metadata.get_features() | |
| if hasattr(PREDICTOR, 'feature_metadata') else [] | |
| ) | |
| TARGET_COL = PREDICTOR.label if hasattr(PREDICTOR, 'label') else "target" | |
| print(f"Features: {FEATURE_COLS}") | |
| print(f"Target: {TARGET_COL}") | |
| # ============================================================================ | |
| # PREDICTION FUNCTION | |
| # ============================================================================ | |
| def predict(*feature_values): | |
| """ | |
| Build a single-row DataFrame from UI inputs and get prediction + probabilities. | |
| Args: | |
| *feature_values: Sequence of values corresponding to FEATURE_COLS order. | |
| Returns: | |
| (proba_dict, message) | |
| proba_dict: dict(label -> probability), sorted desc, top-N shown by gr.Label | |
| message: Markdown summary with predicted label + confidence | |
| """ | |
| try: | |
| # Map UI inputs to a dict matching the model's feature columns | |
| input_data = {} | |
| for col, val in zip(FEATURE_COLS, feature_values[:len(FEATURE_COLS)]): | |
| try: | |
| # Try numeric first (keeps sliders/numbers numeric) | |
| input_data[col] = float(val) if val != "" else 0.0 | |
| except: | |
| # Otherwise leave as string (for categorical columns) | |
| input_data[col] = val | |
| print(f"Input data: {input_data}") | |
| # Build a DataFrame row for inference | |
| X = pd.DataFrame([input_data]) | |
| print(f"DataFrame shape: {X.shape}") | |
| print(f"DataFrame columns: {X.columns.tolist()}") | |
| # Predicted label (or regression value) | |
| pred = PREDICTOR.predict(X) | |
| pred_value = pred.iloc[0] | |
| print(f"Prediction: {pred_value}") | |
| # Class probabilities (if classifier). If regression, synthesize 100% on prediction. | |
| try: | |
| proba_df = PREDICTOR.predict_proba(X) | |
| if isinstance(proba_df, pd.Series): | |
| # Normalize to DataFrame shape if AG returns a Series | |
| proba_df = proba_df.to_frame().T | |
| proba_dict = {} | |
| for col in proba_df.columns: | |
| proba_dict[str(col)] = float(proba_df[col].iloc[0]) | |
| # Sort highest to lowest | |
| proba_dict = dict(sorted(proba_dict.items(), key=lambda x: x[1], reverse=True)) | |
| except Exception as e: | |
| print(f"Error getting probabilities: {e}") | |
| # Regression or unsupported proba: show pseudo-confidence | |
| proba_dict = {str(pred_value): 1.0} | |
| # Human-readable summary (confidence = max probability * 100) | |
| confidence = max(proba_dict.values()) * 100 if proba_dict else 100 | |
| message = f"**Prediction:** {pred_value}\n**Confidence:** {confidence:.2f}%" | |
| return proba_dict, message | |
| except Exception as e: | |
| error_msg = f"**Error:** {str(e)}\n\nPlease check the logs for details." | |
| print(f"Prediction error: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return {}, error_msg | |
| # ============================================================================ | |
| # EXAMPLES (quick-start presets for the first 4 features) | |
| # ============================================================================ | |
| EXAMPLES = [ | |
| [5.1, 3.5, 1.4, 0.2], | |
| [7.0, 3.2, 4.7, 1.4], | |
| [6.3, 3.3, 6.0, 2.5], | |
| ] | |
| if len(FEATURE_COLS) > 4: | |
| EXAMPLES = [ex + [0.0] * (len(FEATURE_COLS) - 4) for ex in EXAMPLES] | |
| # ============================================================================ | |
| # GRADIO UI | |
| # ============================================================================ | |
| with gr.Blocks(title="Tabular Flower Classifier", theme=gr.themes.Soft()) as demo: | |
| # Title & instructions | |
| gr.Markdown(""" | |
| # Tabular Flower Classifier | |
| This app uses an **AutoGluon TabularPredictor** to classify flowers based on their features. | |
| Adjust the feature values below and click **Predict** to see the classification results. | |
| """) | |
| with gr.Row(): | |
| # LEFT: Inputs | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Input Features") | |
| feature_inputs = [] | |
| # For the first 4 features, use sliders (0-10) to make the demo interactive. | |
| # Remaining features (up to 10 shown) use numeric inputs for compactness. | |
| for i, feature in enumerate(FEATURE_COLS[:10]): | |
| if i < 4: | |
| input_widget = gr.Slider(0, 10, 5.0, label=feature) | |
| else: | |
| input_widget = gr.Number(value=0.0, label=feature) | |
| feature_inputs.append(input_widget) | |
| predict_btn = gr.Button("Predict", variant="primary", size="lg") | |
| # RIGHT: Outputs | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Prediction Results") | |
| prediction_output = gr.Markdown(value="*Adjust features and click Predict*") | |
| proba_display = gr.Label(num_top_classes=5, label="Top 5 Class Probabilities") | |
| # Button click handler | |
| predict_btn.click( | |
| fn=predict, | |
| inputs=feature_inputs, | |
| outputs=[proba_display, prediction_output] | |
| ) | |
| gr.Markdown("### Example flower measurements") | |
| # Example presets | |
| gr.Examples( | |
| examples=EXAMPLES, | |
| inputs=feature_inputs, | |
| outputs=[proba_display, prediction_output], | |
| fn=predict, | |
| cache_examples=False | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| ### About | |
| - **Model**: AutoGluon TabularPredictor | |
| - **Task**: Flower classification based on measurements | |
| - **Features**: Adjust the sliders/inputs above to test different flower measurements | |
| """) | |
| # ============================================================================ | |
| # ENTRY POINT | |
| # ============================================================================ | |
| if __name__ == "__main__": | |
| demo.launch() |