Spaces:
Configuration error
Configuration error
| import streamlit as st | |
| import pandas as pd | |
| import joblib | |
| from huggingface_hub import hf_hub_download | |
| from modelConnector import ModelConnector | |
| # =========================== | |
| # LOAD MODEL & DATASET | |
| # =========================== | |
| st.title("π Is Click Predictor") | |
| # Download and load the trained model from Hugging Face | |
| model_path = hf_hub_download(repo_id="taimax13/is_click_predictor", filename="rf_model.pkl") | |
| rf_model = joblib.load(model_path) | |
| st.success("β Model Loaded Successfully!") | |
| # =========================== | |
| # LOAD DATA FROM HUGGING FACE | |
| # =========================== | |
| st.sidebar.header("Dataset Selection") | |
| # # Download required dataset files | |
| # X_test_path = hf_hub_download(repo_id="taimax13/is_click_data", filename="X_test_1st(1).csv") | |
| # y_test_path = hf_hub_download(repo_id="taimax13/is_click_data", filename="y_test_1st.csv") | |
| # train_data_path = hf_hub_download(repo_id="taimax13/is_click_data", filename="train_dataset_full - train_dataset_full (1).csv") | |
| X_test_path = "HuggingFaceRepo/data/y_test_1st (1).csv" | |
| y_test_path = "HuggingFaceRepo/data/y_test_1st.csv" | |
| train_data_path = "HuggingFaceRepo/data/train_dataset_full - train_dataset_full.csv" | |
| # Load datasets | |
| X_test = pd.read_csv(X_test_path) | |
| y_test = pd.read_csv(y_test_path, header=None) # Ensure labels match test dataset index | |
| train_data = pd.read_csv(train_data_path) | |
| st.info(f"β Loaded datasets: **Train: {len(train_data)} rows**, **Test: {len(X_test)} rows**") | |
| # Initialize Model Connector | |
| model_connector = ModelConnector() | |
| st.title("π Is Click Predictor - Train, Retrain, and Predict") | |
| # =========================== | |
| # CHECK MODEL STATUS | |
| # =========================== | |
| if model_connector.model: | |
| st.success("β Model Loaded Successfully!") | |
| else: | |
| st.warning("β No model found. Please train one first.") | |
| # =========================== | |
| # TRAIN MODEL IF NOT FOUND | |
| # =========================== | |
| if st.button("π Train Model"): | |
| st.info("π Training model...") | |
| message = model_connector.train_model() | |
| st.success(message) | |
| # =========================== | |
| # RETRAIN MODEL | |
| # =========================== | |
| if st.button("π Retrain Model"): | |
| st.info("π Retraining model with latest data...") | |
| message = model_connector.retrain_model() | |
| st.success(message) | |
| # =========================== | |
| # SELECT A DATA SAMPLE | |
| # =========================== | |
| st.sidebar.header("Select a Test Sample for Prediction") | |
| # Merge X_test with y_test for selection (without labels affecting prediction) | |
| X_test["actual_click"] = y_test.values | |
| # Allow user to pick a row | |
| selected_index = st.sidebar.selectbox("Choose a test sample index", X_test.index) | |
| selected_row = X_test.loc[selected_index].drop("actual_click") # Exclude actual label | |
| # Display selected row | |
| st.write("### Selected Data Sample:") | |
| st.dataframe(selected_row.to_frame().T) # Display as a table | |
| # =========================== | |
| # MAKE PREDICTION & EXPORT CSV | |
| # =========================== | |
| if st.button("Predict Click"): | |
| # Convert selected row to DataFrame for model input | |
| input_data = selected_row.to_frame().T | |
| # Make prediction | |
| prediction = rf_model.predict(input_data)[0] | |
| # Add prediction to DataFrame | |
| input_data["is_click_predicted"] = prediction | |
| # Save prediction as CSV | |
| csv_filename = "prediction_result.csv" | |
| input_data.to_csv(csv_filename, index=False) | |
| # Display Prediction Result | |
| st.subheader("Prediction Result") | |
| if prediction == 1: | |
| st.success("π’ The model predicts: **User WILL CLICK on the ad!**") | |
| else: | |
| st.warning("π΄ The model predicts: **User WILL NOT CLICK on the ad.**") | |
| # Provide download button for prediction result | |
| st.download_button( | |
| label="π₯ Download Prediction Result", | |
| data=input_data.to_csv(index=False).encode("utf-8"), | |
| file_name="prediction_result.csv", | |
| mime="text/csv", | |
| ) | |
| st.markdown("---") | |
| st.info("Select a test row from the **left panel**, click **'Predict Click'**, and download the prediction result as a CSV.") | |