Spaces:
Runtime error
Runtime error
import gradio as gr | |
import pandas as pd | |
import pickle | |
import os | |
# Helper function to load pickle files | |
def load_pickle_file(file_name): | |
file_path = os.path.join(os.path.dirname(__file__), file_name) | |
try: | |
if os.path.exists(file_path): | |
with open(file_path, 'rb') as file: | |
return pickle.load(file) | |
else: | |
return f"File {file_name} not found." | |
except Exception as e: | |
return f"An error occurred while loading {file_name}: {e}" | |
# Load the pre-trained model and label encoder | |
model = load_pickle_file('best_model.pkl') | |
label_encoder = load_pickle_file('label_encoder.pkl') | |
# Ensure model and label encoder are loaded correctly | |
if isinstance(model, str) or isinstance(label_encoder, str): | |
raise Exception(f"Error loading model or label encoder: {model} | {label_encoder}") | |
# Define the prediction function | |
def predict_coffee_type(time_of_day, coffee_strength, sweetness_level, milk_type, coffee_temperature, flavored_coffee, caffeine_tolerance, coffee_bean, coffee_size, dietary_preferences): | |
# Input Data | |
input_data = pd.DataFrame({ | |
'Token_0': [time_of_day], | |
'Token_1': [coffee_strength], | |
'Token_2': [sweetness_level], | |
'Token_3': [milk_type], | |
'Token_4': [coffee_temperature], | |
'Token_5': [flavored_coffee], | |
'Token_6': [caffeine_tolerance], | |
'Token_7': [coffee_bean], | |
'Token_8': [coffee_size], | |
'Token_9': [dietary_preferences] | |
}) | |
# One-hot encode the input data (ensure it matches the training data) | |
input_encoded = pd.get_dummies(input_data) | |
required_columns = model.feature_names_in_ # Ensure that the input has the correct columns | |
for col in required_columns: | |
if col not in input_encoded.columns: | |
input_encoded[col] = 0 # Add missing columns as 0 | |
input_encoded = input_encoded[required_columns] # Ensure the order of columns matches the training data | |
# Make prediction | |
prediction = model.predict(input_encoded)[0] | |
# Decode the label | |
coffee_type = label_encoder.inverse_transform([prediction])[0] | |
return f"Recommended Coffee: {coffee_type}" | |
# Set up Gradio interface | |
interface = gr.Interface( | |
fn=predict_coffee_type, | |
inputs=[ | |
gr.inputs.Dropdown(choices=['morning', 'afternoon', 'evening'], label="Time of Day"), | |
gr.inputs.Dropdown(choices=['mild', 'regular', 'strong'], label="Coffee Strength"), | |
gr.inputs.Dropdown(choices=['unsweetened', 'lightly sweetened', 'sweet'], label="Sweetness Level"), | |
gr.inputs.Dropdown(choices=['none', 'regular', 'skim', 'almond'], label="Milk Type"), | |
gr.inputs.Dropdown(choices=['hot', 'iced', 'cold brew'], label="Coffee Temperature"), | |
gr.inputs.Dropdown(choices=['yes', 'no'], label="Flavored Coffee"), | |
gr.inputs.Dropdown(choices=['low', 'medium', 'high'], label="Caffeine Tolerance"), | |
gr.inputs.Dropdown(choices=['Arabica', 'Robusta', 'blend'], label="Coffee Bean"), | |
gr.inputs.Dropdown(choices=['small', 'medium', 'large'], label="Coffee Size"), | |
gr.inputs.Dropdown(choices=['none', 'vegan', 'lactose-intolerant'], label="Dietary Preferences") | |
], | |
outputs="text", | |
title="Coffee Type Prediction" | |
) | |
# Launch the Gradio app | |
if __name__ == "__main__": | |
interface.launch() | |