mohamedamgad2002's picture
minor changes
16b62ad
import streamlit as st
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pickle import load, dump
import random
import shap
import os
with open('preprocessing.pkl', 'rb') as file:
preprocessor = load(file)
with open('model.pkl', 'rb') as file:
model = load(file)
df = pd.read_csv('Flight_Price.csv')
df = df[['departure_time', 'stops', 'arrival_time', 'class','airline', 'flight', 'source_city', 'destination_city', 'duration', 'days_left']]
cols_names = df.columns.tolist()
df_pre = preprocessor.transform(df)
df_pre = pd.DataFrame(df_pre, columns=cols_names)
def input_features():
st.sidebar.title('Input Features')
airline = st.sidebar.selectbox('Airline', df['airline'].unique().tolist())
source_city = st.sidebar.selectbox('Source City', df['source_city'].unique().tolist())
destination_city = st.sidebar.selectbox('Destination City', df['destination_city'].unique().tolist())
departure_time = st.sidebar.selectbox('Departure Time', df['departure_time'].unique().tolist())
arrival_time = st.sidebar.selectbox('Arrival Time', df['departure_time'].unique().tolist())
stops = st.sidebar.selectbox('Stops', df['stops'].unique().tolist())
class_reservation = st.sidebar.selectbox('Class', df['class'].unique().tolist())
days_left = st.sidebar.slider('Days Left', df['days_left'].min(), df['days_left'].max(), 10)
duration = st.sidebar.slider('Duration', df['duration'].min(), df['duration'].max(), df['duration'].mean())
flight = random.choice(df[df['airline'] == airline]['flight'].unique().tolist())
data = {
'departure_time': departure_time,
'stops': stops,
'arrival_time': arrival_time,
'class': class_reservation,
'airline': airline,
'flight': flight,
'source_city': source_city,
'destination_city': destination_city,
'duration': duration,
'days_left': days_left
}
features = pd.DataFrame(data, index=[0])
return features
def predict(features, preprocessor, model):
features = preprocessor.transform(features)
prediction = model.predict(features)
return np.round(np.exp(prediction), 3)
@st.cache_data
def summary_plot():
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(df_pre)
st.write('Feature Importance')
fig = plt.figure()
plt.title('Feature Importance based on SHAP values')
shap.summary_plot(shap_values, df_pre)
st.pyplot(fig, bbox_inches='tight')
curr_dir = os.getcwd()
image_path = os.path.join(curr_dir, 'assets/flight.jpg')
features = input_features()
st.write('# Flight Price Prediction')
st.image(image_path)
st.write('---')
st.write('## Specified Input Parameters')
st.write(features)
st.write('---')
if st.button('Predict'):
price = predict(features, preprocessor, model)
st.markdown(f"""
<div style="border: 2px solid #4CAF50; padding: 20px; text-align: center; background-color: #f9f9f9; border-radius: 10px;">
<strong style="font-size: 30px; color: #333;">Price Predicted for this trip is: {price[0]:,.2f}</strong>
</div>
""", unsafe_allow_html=True)
# Model Explanation
summary_plot()
st.write('---')