rasmodev commited on
Commit
4bd969c
1 Parent(s): 72ee688

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -79
app.py CHANGED
@@ -4,7 +4,6 @@ import numpy as np
4
  import pickle
5
  import catboost
6
  from sklearn.impute import SimpleImputer
7
- import requests
8
 
9
  # Load the saved model and unique values:
10
  with open("model_and_key_components.pkl", "rb") as f:
@@ -47,94 +46,90 @@ st.sidebar.markdown("**Year of Migration**: Enter the year of migration for the
47
  st.sidebar.markdown("**Country of Birth**: Choose the individual's birth country (e.g., United-States, Other).")
48
  st.sidebar.markdown("**Importance of Record**: Enter the weight of the instance (numeric value, e.g., 0.9).")
49
 
50
- # Create input fields for user input
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  col1, col2, col3 = st.columns(3)
52
 
53
  with col1:
54
- age = st.number_input("Age", min_value=0)
55
- gender = st.selectbox("Gender", ["Male", "Female"])
56
- education = st.selectbox("Education", unique_values['education'])
57
- worker_class = st.selectbox("Class of Worker", unique_values['worker_class'])
58
- marital_status = st.selectbox("Marital Status", unique_values['marital_status'])
59
- race = st.selectbox("Race", unique_values['race'])
60
- is_hispanic = st.selectbox("Hispanic Origin", unique_values['is_hispanic'])
61
- employment_commitment = st.selectbox("Full/Part-Time Employment", unique_values['employment_commitment'])
62
- wage_per_hour = st.number_input("Wage Per Hour", min_value=0)
 
63
 
64
  with col2:
65
- working_week_per_year = st.number_input("Weeks Worked Per Year", min_value=0)
66
- industry_code = st.selectbox("Category Code of Industry", unique_values['industry_code'])
67
- industry_code_main = st.selectbox("Major Industry Code", unique_values['industry_code_main'])
68
- occupation_code = st.selectbox("Category Code of Occupation", unique_values['occupation_code'])
69
- occupation_code_main = st.selectbox("Major Occupation Code", unique_values['occupation_code_main'])
70
- total_employed = st.number_input("Number of Persons Worked for Employer", min_value=0)
71
- household_stat = st.selectbox("Detailed Household and Family Status", unique_values['household_stat'])
72
- household_summary = st.selectbox("Detailed Household Summary", unique_values['household_summary'])
73
- vet_benefit = st.selectbox("Veteran Benefits", unique_values['vet_benefit'])
74
 
75
  with col3:
76
- tax_status = st.selectbox("Tax Filer Status", unique_values['tax_status'])
77
- gains = st.number_input("Gains", min_value=0)
78
- losses = st.number_input("Losses", min_value=0)
79
- stocks_status = st.number_input("Dividends from Stocks", min_value=0)
80
- citizenship = st.selectbox("Citizenship", unique_values['citizenship'])
81
- mig_year = st.selectbox("Migration Year", unique_values['mig_year'])
82
- country_of_birth_own = st.selectbox("Country of Birth", unique_values['country_of_birth_own'])
83
- importance_of_record = st.number_input("Importance of Record", min_value=0.0)
84
 
85
- # Button to trigger prediction
86
  if st.button("Predict"):
87
- # Create a dictionary of user input
88
- user_input = {
89
- "age": int(age),
90
- "gender": gender,
91
- "education": education,
92
- "worker_class": worker_class,
93
- "marital_status": marital_status,
94
- "race": race,
95
- "is_hispanic": is_hispanic,
96
- "employment_commitment": employment_commitment,
97
- "wage_per_hour": int(wage_per_hour),
98
- "working_week_per_year": int(working_week_per_year),
99
- "industry_code": int(industry_code),
100
- "industry_code_main": industry_code_main,
101
- "occupation_code": int(occupation_code),
102
- "occupation_code_main": occupation_code_main,
103
- "total_employed": int(total_employed),
104
- "household_stat": household_stat,
105
- "household_summary": household_summary,
106
- "vet_benefit": int(vet_benefit),
107
- "tax_status": tax_status,
108
- "gains": int(gains),
109
- "losses": int(losses),
110
- "stocks_status": int(stocks_status),
111
- "citizenship": citizenship,
112
- "mig_year": int(mig_year),
113
- "country_of_birth_own": country_of_birth_own,
114
- "importance_of_record": float(importance_of_record)
115
- }
116
 
117
- # Send a POST request to the FastAPI server
118
- response = requests.post("https://rasmodev-income-prediction-fastapi.hf.space/predict/", json=user_input)
119
-
120
 
121
- # Check if the request was successful
122
- if response.status_code == 200:
123
- prediction_data = response.json()
124
-
125
- # Display prediction result to the user
126
- st.subheader("Prediction Result")
127
-
128
- # Determine income prediction and format message
129
- if prediction_data['income_prediction'] == "Income over $50K":
130
- st.success("This individual is predicted to have an income of over $50K.")
131
- else:
132
- st.error("This individual is predicted to have an income of under $50K")
133
-
134
- # Display prediction probability
135
- st.subheader("Prediction Probability")
136
- probability = prediction_data['prediction_probability']
137
- st.write(f"The probability of the individual having an income over $50K is: {probability:.2f}")
138
  else:
139
- st.error("Error: Unable to get prediction")
 
 
 
 
140
 
 
4
  import pickle
5
  import catboost
6
  from sklearn.impute import SimpleImputer
 
7
 
8
  # Load the saved model and unique values:
9
  with open("model_and_key_components.pkl", "rb") as f:
 
46
  st.sidebar.markdown("**Country of Birth**: Choose the individual's birth country (e.g., United-States, Other).")
47
  st.sidebar.markdown("**Importance of Record**: Enter the weight of the instance (numeric value, e.g., 0.9).")
48
 
49
+ # Create the input fields in the order of your DataFrame
50
+ input_data = {
51
+ 'age': 0, # Default values, you can change these as needed
52
+ 'gender': unique_values['gender'][0],
53
+ 'education': unique_values['education'][0],
54
+ 'worker_class': unique_values['worker_class'][0],
55
+ 'marital_status': unique_values['marital_status'][0],
56
+ 'race': unique_values['race'][0],
57
+ 'is_hispanic': unique_values['is_hispanic'][0],
58
+ 'employment_commitment': unique_values['employment_commitment'][0],
59
+ 'employment_stat': unique_values['employment_stat'][0],
60
+ 'wage_per_hour': 0, # Default value
61
+ 'working_week_per_year': 0, # Default value
62
+ 'industry_code': 0, # Default value
63
+ 'industry_code_main': unique_values['industry_code_main'][0],
64
+ 'occupation_code': 0, # Default value
65
+ 'occupation_code_main': unique_values['occupation_code_main'][0],
66
+ 'total_employed': 0, # Default value
67
+ 'household_stat': unique_values['household_stat'][0],
68
+ 'household_summary': unique_values['household_summary'][0],
69
+ 'vet_benefit': 0, # Default value
70
+ 'tax_status': unique_values['tax_status'][0],
71
+ 'gains': 0, # Default value
72
+ 'losses': 0, # Default value
73
+ 'stocks_status': 0, # Default value
74
+ 'citizenship': unique_values['citizenship'][0],
75
+ 'mig_year': 0,
76
+ 'country_of_birth_own': 'United-States',
77
+ 'importance_of_record': 0.0 # Default value
78
+ }
79
+
80
+ # Create the input fields
81
  col1, col2, col3 = st.columns(3)
82
 
83
  with col1:
84
+ input_data['age'] = st.number_input("Age", min_value=0, key='age')
85
+ input_data['gender'] = st.selectbox("Gender", unique_values['gender'], key='gender')
86
+ input_data['education'] = st.selectbox("Education", unique_values['education'], key='education')
87
+ input_data['worker_class'] = st.selectbox("Class of Worker", unique_values['worker_class'], key='worker_class')
88
+ input_data['marital_status'] = st.selectbox("Marital Status", unique_values['marital_status'], key='marital_status')
89
+ input_data['race'] = st.selectbox("Race", unique_values['race'], key='race')
90
+ input_data['is_hispanic'] = st.selectbox("Hispanic Origin", unique_values['is_hispanic'], key='is_hispanic')
91
+ input_data['employment_commitment'] = st.selectbox("Full/Part-Time Employment", unique_values['employment_commitment'], key='employment_commitment')
92
+ input_data['employment_stat'] = st.selectbox("Has Own Business Or Is Self Employed", unique_values['employment_stat'], key='employment_stat')
93
+ input_data['wage_per_hour'] = st.number_input("Wage Per Hour", min_value=0, key='wage_per_hour')
94
 
95
  with col2:
96
+ input_data['working_week_per_year'] = st.number_input("Weeks Worked Per Year", min_value=0, key='working_week_per_year')
97
+ input_data['industry_code'] = st.selectbox("Category Code of Industry", unique_values['industry_code'], key='industry_code')
98
+ input_data['industry_code_main'] = st.selectbox("Major Industry Code", unique_values['industry_code_main'], key='industry_code_main')
99
+ input_data['occupation_code'] = st.selectbox("Category Code of Occupation", unique_values['occupation_code'], key='occupation_code')
100
+ input_data['occupation_code_main'] = st.selectbox("Major Occupation Code", unique_values['occupation_code_main'], key='occupation_code_main')
101
+ input_data['total_employed'] = st.number_input("Number of Persons Worked for Employer", min_value=0, key='total_employed')
102
+ input_data['household_stat'] = st.selectbox("Detailed Household and Family Status", unique_values['household_stat'], key='household_stat')
103
+ input_data['household_summary'] = st.selectbox("Detailed Household Summary", unique_values['household_summary'], key='household_summary')
104
+ input_data['vet_benefit'] = st.selectbox("Veteran Benefits", unique_values['vet_benefit'], key='vet_benefit')
105
 
106
  with col3:
107
+ input_data['tax_status'] = st.selectbox("Tax Filer Status", unique_values['tax_status'], key='tax_status')
108
+ input_data['gains'] = st.number_input("Gains", min_value=0, key='gains')
109
+ input_data['losses'] = st.number_input("Losses", min_value=0, key='losses')
110
+ input_data['stocks_status'] = st.number_input("Dividends from Stocks", min_value=0, key='stocks_status')
111
+ input_data['citizenship'] = st.selectbox("Citizenship", unique_values['citizenship'], key='citizenship')
112
+ input_data['mig_year'] = st.selectbox("Migration Year", unique_values['mig_year'], key='migration_year')
113
+ input_data['country_of_birth_own'] = st.selectbox("Country of Birth", unique_values['country_of_birth_own'], key='country_of_birth_own')
114
+ input_data['importance_of_record'] = st.number_input("Importance of Record", min_value=0, key='importance_of_record')
115
 
116
+ # Button to make predictions
117
  if st.button("Predict"):
118
+ # Transform the input data to a DataFrame for prediction
119
+ input_df = pd.DataFrame([input_data])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
+ # Make predictions
122
+ prediction = dt_model.predict(input_df)
123
+ prediction_proba = dt_model.predict_proba(input_df)
124
 
125
+ # Display prediction result
126
+ st.subheader("Prediction")
127
+ if prediction[0] == 1:
128
+ st.success("This individual is predicted to have an income of over $50K.")
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  else:
130
+ st.error("This individual is predicted to have an income of under $50K")
131
+
132
+ # Show prediction probability
133
+ st.subheader("Prediction Probability")
134
+ st.write(f"The probability of the individual having an income over $50K is: {prediction_proba[0][1]:.2f}")
135