Chittrarasu commited on
Commit
3369b7c
·
1 Parent(s): 81c6189
Files changed (1) hide show
  1. service/prediction_service.py +14 -1
service/prediction_service.py CHANGED
@@ -1,6 +1,7 @@
1
  from sentence_transformers import SentenceTransformer
2
  import os
3
  from huggingface_hub import hf_hub_download
 
4
 
5
  # Get the Hugging Face token from environment variable
6
  hf_token = os.getenv('HF_TOKEN')
@@ -8,6 +9,7 @@ hf_token = os.getenv('HF_TOKEN')
8
  # Hugging Face Model ID and local model directory
9
  hf_model_id = 'Alibaba-NLP/gte-base-en-v1.5'
10
  model_dir = '/tmp/sentence_transformer' # Use /tmp for write permissions
 
11
 
12
  # Create model directory if not exists
13
  os.makedirs(model_dir, exist_ok=True)
@@ -21,9 +23,21 @@ else:
21
  print(f"Loading model from local directory: {model_dir}")
22
  model = SentenceTransformer(model_dir, trust_remote_code=True) # Added trust_remote_code=True
23
 
 
 
 
 
 
 
 
 
24
  # Define predict_label function
25
  def predict_label(text):
26
  try:
 
 
 
 
27
  # Ensure input is a list for the model
28
  if not isinstance(text, list):
29
  text = [text]
@@ -49,4 +63,3 @@ def predict_label(text):
49
  # Log the exception for debugging
50
  print(f"Error in predict_label: {e}")
51
  return "Error", 0.0
52
-
 
1
  from sentence_transformers import SentenceTransformer
2
  import os
3
  from huggingface_hub import hf_hub_download
4
+ from joblib import load # <-- Import this to load the model
5
 
6
  # Get the Hugging Face token from environment variable
7
  hf_token = os.getenv('HF_TOKEN')
 
9
  # Hugging Face Model ID and local model directory
10
  hf_model_id = 'Alibaba-NLP/gte-base-en-v1.5'
11
  model_dir = '/tmp/sentence_transformer' # Use /tmp for write permissions
12
+ clf_model_path = '/tmp/logistic_regression_model.pkl'
13
 
14
  # Create model directory if not exists
15
  os.makedirs(model_dir, exist_ok=True)
 
23
  print(f"Loading model from local directory: {model_dir}")
24
  model = SentenceTransformer(model_dir, trust_remote_code=True) # Added trust_remote_code=True
25
 
26
+ # ✅ Load the logistic regression model and define clf globally
27
+ clf = None # Initialize as None
28
+ if os.path.exists(clf_model_path):
29
+ clf = load(clf_model_path) # Load the logistic regression model
30
+ print("Logistic Regression model loaded successfully.")
31
+ else:
32
+ print("Logistic Regression model not found. Ensure it is saved in /tmp.")
33
+
34
  # Define predict_label function
35
  def predict_label(text):
36
  try:
37
+ # Check if clf is loaded
38
+ if clf is None:
39
+ raise ValueError("Logistic Regression model is not loaded.")
40
+
41
  # Ensure input is a list for the model
42
  if not isinstance(text, list):
43
  text = [text]
 
63
  # Log the exception for debugging
64
  print(f"Error in predict_label: {e}")
65
  return "Error", 0.0