File size: 2,211 Bytes
70ab770 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
# To run this app, use: streamlit run test.py
import streamlit as st
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
# Application title and description
st.title("Machine Learning Model Visualization")
st.write("This application demonstrates random forest classification on the iris dataset")
# Data acquisition and preparation
@st.cache_data
def load_data():
from sklearn.datasets import load_iris
iris = load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)
df['target'] = iris.target
return df, iris.target_names
data, target_names = load_data()
# Interactive data exploration
st.subheader("Dataset Exploration")
if st.checkbox("Display dataset"):
st.dataframe(data)
# Feature selection interface
st.subheader("Feature Selection")
features = st.multiselect(
"Select features for model training",
options=data.columns[:-1],
default=data.columns[0]
)
if len(features) > 0:
# Model parameters adjustment
st.subheader("Model Parameters")
n_estimators = st.slider("Number of trees", 1, 100, 10)
max_depth = st.slider("Maximum tree depth", 1, 20, 5)
# Model training
if st.button("Train Model"):
X = data[features]
y = data['target']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
model = RandomForestClassifier(n_estimators=n_estimators, max_depth=max_depth, random_state=42)
model.fit(X_train, y_train)
# Performance evaluation
y_pred = model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
st.success(f"Model accuracy: {accuracy:.4f}")
# Visualization of feature importance
if len(features) > 1:
st.subheader("Feature Importance")
fig, ax = plt.subplots()
ax.bar(features, model.feature_importances_)
plt.xticks(rotation=45)
st.pyplot(fig)
else:
st.warning("Please select at least one feature for model training") |