Spaces:
Runtime error
Runtime error
from datetime import datetime | |
from joblib import load | |
from PIL import Image | |
import pandas as pd | |
import random | |
import streamlit as st | |
model = load("model_title_update.joblib") | |
session = st.session_state | |
# if 'df' not in session: | |
# session.df = pd.read_csv("model_log.csv").drop_duplicates() | |
if 'user_id' not in session: | |
session.user_id = random.randint(1, 1_000_000) | |
st.markdown('# Titanic Workshop - Would you be able to survive the Titanic disaster?') | |
image = Image.open('titanic.jpeg') | |
st.image(image) | |
st.markdown("## ") | |
st.markdown("### Input your passenger data: ") | |
age = st.number_input('Age', min_value=1, max_value=80) | |
sex = st.radio("Sex", ("Male", "Female")) | |
name = st.text_input("Name", "") | |
if sex == "Male": | |
title = st.radio("Title", ("Mr", "Doctor", "Master", "Reverend")) | |
else: | |
title = st.radio("Title", ("Doctor", "Miss", "Mrs")) | |
embark_port = st.radio( | |
"Port of Embarkation", | |
('Cherbourg', 'Queenstown', 'Southampton')) | |
p_class = st.radio( | |
"Passenger Class", | |
('Class 1', 'Class 2', 'Class 3')) | |
siblings = st.number_input('Number of Spouses Aboard', min_value=0, max_value=5) | |
parents = st.number_input("Number of Children Aboard", min_value=0, max_value=6) | |
passenger_data = { | |
'Age': age, | |
'is_male': 1 if sex == "Male" else 0, | |
'Pclass_class 1': 1 if p_class == "Class 1" else 0, | |
'Pclass_class 2': 1 if p_class == "Class 2" else 0, | |
'Pclass_class 3': 1 if p_class == "Class 3" else 0, | |
'Embarked_C': 1 if embark_port[0] == "C" else 0, | |
'Embarked_Q': 1 if embark_port[0] == "Q" else 0, | |
'Embarked_S': 1 if embark_port[0] == "S" else 0, | |
'SibSp': siblings, | |
'Parch': parents, | |
'Mr': 1 if title == "Mr" else 0, | |
'Miss': 1 if title == "Miss" else 0, | |
'Mrs': 1 if title == "Mrs" else 0, | |
"Master": 1 if title == "Master" else 0, | |
'Dr': 1 if title == "Doctor" else 0, | |
"Rev": 1 if title == "Reverend" else 0, | |
'Other': 0, | |
"name_len": len(name), | |
} | |
if st.button("Calculate my chances of survival"): | |
passenger_data_df = pd.DataFrame.from_dict([passenger_data]) | |
prediction = model.predict(passenger_data_df)[0] | |
prediction_proba = model.predict_proba(passenger_data_df)[:,1][0] | |
if prediction == 0: | |
st.markdown("# YOU'RE DEAD :skull:") | |
elif sex == "Male": | |
st.markdown("# YOU MADE IT, :man-swimming:") | |
else: | |
st.markdown("# YOU MADE IT, :woman-swimming:") | |
st.write(f"Probabilty of survival: {round(prediction_proba, 2)* 100}%.") | |
# passenger_data['user_id'] = session.user_id | |
# passenger_data['date'] = str(datetime.now()) | |
# passenger_data["survived"] = prediction | |
# passenger_data["survival_prob"] = prediction_proba | |
# session.df = pd.concat( | |
# [ | |
# session.df, | |
# pd.DataFrame([passenger_data]) | |
# ], | |
# ignore_index=True | |
# ) | |
# session.df.drop_duplicates().to_csv('model_log.csv', mode="a", header=False, index=False) | |
# session.df = pd.read_csv("model_log.csv") | |
# session.df = session.df.drop_duplicates() | |
# session.df.to_csv("model_log.csv", index=False) |