Gradcof / app.py
judebebo32's picture
Update app.py
c5c0350 verified
raw
history blame
3.16 kB
import gradio as gr
import pandas as pd
import pickle
import os
# Load pickle files
def load_pickle_file(file_name):
file_path = os.path.join(os.path.dirname(__file__), file_name)
try:
if os.path.exists(file_path):
with open(file_path, 'rb') as file:
return pickle.load(file)
else:
return f"File {file_name} not found."
except Exception as e:
return f"An error occurred while loading {file_name}: {e}"
# Load model and label encoder
model = load_pickle_file('best_model.pkl')
label_encoder = load_pickle_file('label_encoder.pkl')
if isinstance(model, str) or isinstance(label_encoder, str):
raise Exception(f"Error loading model or label encoder: {model} | {label_encoder}")
# Prediction function
def predict_coffee_type(time_of_day, coffee_strength, sweetness_level, milk_type, coffee_temperature, flavored_coffee, caffeine_tolerance, coffee_bean, coffee_size, dietary_preferences):
# Prepare input data
input_data = pd.DataFrame({
'Token_0': [time_of_day],
'Token_1': [coffee_strength],
'Token_2': [sweetness_level],
'Token_3': [milk_type],
'Token_4': [coffee_temperature],
'Token_5': [flavored_coffee],
'Token_6': [caffeine_tolerance],
'Token_7': [coffee_bean],
'Token_8': [coffee_size],
'Token_9': [dietary_preferences]
})
# One-hot encode the input data
input_encoded = pd.get_dummies(input_data)
required_columns = model.feature_names_in_ # Ensure columns match training data
for col in required_columns:
if col not in input_encoded.columns:
input_encoded[col] = 0 # Add missing columns
input_encoded = input_encoded[required_columns] # Reorder columns to match model training
# Predict the coffee type
prediction = model.predict(input_encoded)[0]
coffee_type = label_encoder.inverse_transform([prediction])[0]
return f"Recommended Coffee: {coffee_type}"
# Set up Gradio interface
interface = gr.Interface(
fn=predict_coffee_type,
inputs=[
gr.inputs.Dropdown(choices=['morning', 'afternoon', 'evening'], label="Time of Day"),
gr.inputs.Dropdown(choices=['mild', 'regular', 'strong'], label="Coffee Strength"),
gr.inputs.Dropdown(choices=['unsweetened', 'lightly sweetened', 'sweet'], label="Sweetness Level"),
gr.inputs.Dropdown(choices=['none', 'regular', 'skim', 'almond'], label="Milk Type"),
gr.inputs.Dropdown(choices=['hot', 'iced', 'cold brew'], label="Coffee Temperature"),
gr.inputs.Dropdown(choices=['yes', 'no'], label="Flavored Coffee"),
gr.inputs.Dropdown(choices=['low', 'medium', 'high'], label="Caffeine Tolerance"),
gr.inputs.Dropdown(choices=['Arabica', 'Robusta', 'blend'], label="Coffee Bean"),
gr.inputs.Dropdown(choices=['small', 'medium', 'large'], label="Coffee Size"),
gr.inputs.Dropdown(choices=['none', 'vegan', 'lactose-intolerant'], label="Dietary Preferences")
],
outputs="text",
title="Coffee Type Prediction"
)
# Launch the Gradio app
if __name__ == "__main__":
interface.launch()