DS2 commited on
Commit
825ab88
1 Parent(s): 77d9b4b

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +9 -5
inference.py CHANGED
@@ -1,8 +1,9 @@
1
  import joblib
2
  import pandas as pd
3
 
4
- # Load the model
5
  model = joblib.load("time_slot_model.joblib")
 
6
 
7
  def preprocess_time(time_str):
8
  hour = int(time_str.split(':')[0])
@@ -18,15 +19,18 @@ def categorize_time(hour):
18
  elif 12 <= hour < 14:
19
  return 'early_noon'
20
  elif 14 <= hour < 16:
21
- return 'early_noon'
22
  elif 16 <= hour < 17:
23
  return 'dusk'
24
  else:
25
  return 'other'
26
 
27
- def predict(user_input):
28
- # Preprocess and predict
29
  user_input = pd.get_dummies(pd.Series([user_input]), prefix='Pref')
30
- user_input = user_input.reindex(columns=X.columns, fill_value=0)
 
 
31
  prediction = model.predict(user_input)
32
  return prediction[0]
 
 
1
  import joblib
2
  import pandas as pd
3
 
4
+ # Load the model and columns
5
  model = joblib.load("time_slot_model.joblib")
6
+ columns = joblib.load("columns.joblib")
7
 
8
  def preprocess_time(time_str):
9
  hour = int(time_str.split(':')[0])
 
19
  elif 12 <= hour < 14:
20
  return 'early_noon'
21
  elif 14 <= hour < 16:
22
+ return 'noon'
23
  elif 16 <= hour < 17:
24
  return 'dusk'
25
  else:
26
  return 'other'
27
 
28
+ def predict(user_name, user_input):
29
+ # Process the user input
30
  user_input = pd.get_dummies(pd.Series([user_input]), prefix='Pref')
31
+ user_input = user_input.reindex(columns=columns, fill_value=0)
32
+
33
+ # Predict the preferred time slot
34
  prediction = model.predict(user_input)
35
  return prediction[0]
36
+