404Brain-Not-Found-yeah commited on
Commit
832d529
1 Parent(s): 2a07da2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -17
app.py CHANGED
@@ -4,12 +4,13 @@ import numpy as np
4
  from predict import extract_features
5
  import os
6
  import tempfile
7
- from huggingface_hub import hf_hub_download
8
  import logging
 
9
 
10
  # Set up logging
11
  logging.basicConfig(
12
- level=logging.INFO,
13
  format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
14
  )
15
  logger = logging.getLogger(__name__)
@@ -25,20 +26,84 @@ st.set_page_config(
25
  def load_model():
26
  """Load model from Hugging Face Hub"""
27
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  logger.info("Downloading model from Hugging Face Hub...")
29
- model_path = hf_hub_download(
30
- repo_id="404Brain-Not-Found-yeah/healing-music-classifier",
31
- filename="models/model.joblib"
32
- )
33
- scaler_path = hf_hub_download(
34
- repo_id="404Brain-Not-Found-yeah/healing-music-classifier",
35
- filename="models/scaler.joblib"
36
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- logger.info("Loading model and scaler...")
39
- return joblib.load(model_path), joblib.load(scaler_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  except Exception as e:
41
- logger.error(f"Error loading model: {str(e)}")
 
42
  return None, None
43
 
44
  def main():
@@ -70,7 +135,7 @@ def main():
70
  # Load model
71
  model, scaler = load_model()
72
  if model is None or scaler is None:
73
- st.error("Model loading failed. Please try again later.")
74
  return
75
 
76
  progress_bar.progress(50)
@@ -84,9 +149,14 @@ def main():
84
  progress_bar.progress(70)
85
 
86
  # Predict
87
- scaled_features = scaler.transform([features])
88
- healing_probability = model.predict_proba(scaled_features)[0][1]
89
- progress_bar.progress(90)
 
 
 
 
 
90
 
91
  # Display results
92
  st.subheader("Analysis Results")
 
4
  from predict import extract_features
5
  import os
6
  import tempfile
7
+ from huggingface_hub import hf_hub_download, list_repo_files
8
  import logging
9
+ import traceback
10
 
11
  # Set up logging
12
  logging.basicConfig(
13
+ level=logging.DEBUG,
14
  format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
15
  )
16
  logger = logging.getLogger(__name__)
 
26
  def load_model():
27
  """Load model from Hugging Face Hub"""
28
  try:
29
+ # 首先列出仓库中的所有文件
30
+ logger.info("Listing repository files...")
31
+ try:
32
+ files = list_repo_files("404Brain-Not-Found-yeah/healing-music-classifier")
33
+ logger.info(f"Repository files: {files}")
34
+ st.write("Available files in repository:", files) # 显示在界面上
35
+ except Exception as e:
36
+ logger.error(f"Error listing repository files: {str(e)}\n{traceback.format_exc()}")
37
+ st.error(f"Error listing repository files: {str(e)}")
38
+ return None, None
39
+
40
+ # 创建临时目录
41
+ os.makedirs("temp_models", exist_ok=True)
42
+ logger.info("Created temp_models directory")
43
+
44
  logger.info("Downloading model from Hugging Face Hub...")
45
+ # 下载模型文件
46
+ try:
47
+ model_path = hf_hub_download(
48
+ repo_id="404Brain-Not-Found-yeah/healing-music-classifier",
49
+ filename="models/model.joblib",
50
+ local_dir="temp_models"
51
+ )
52
+ logger.info(f"Model downloaded to: {model_path}")
53
+ st.write(f"Model downloaded to: {model_path}") # 显示在界面上
54
+ except Exception as e:
55
+ logger.error(f"Error downloading model: {str(e)}\n{traceback.format_exc()}")
56
+ st.error(f"Error downloading model: {str(e)}")
57
+ return None, None
58
+
59
+ # 下载scaler文件
60
+ try:
61
+ scaler_path = hf_hub_download(
62
+ repo_id="404Brain-Not-Found-yeah/healing-music-classifier",
63
+ filename="models/scaler.joblib",
64
+ local_dir="temp_models"
65
+ )
66
+ logger.info(f"Scaler downloaded to: {scaler_path}")
67
+ st.write(f"Scaler downloaded to: {scaler_path}") # 显示在界面上
68
+ except Exception as e:
69
+ logger.error(f"Error downloading scaler: {str(e)}\n{traceback.format_exc()}")
70
+ st.error(f"Error downloading scaler: {str(e)}")
71
+ return None, None
72
 
73
+ # 加载模型文件
74
+ try:
75
+ logger.info("Loading model and scaler...")
76
+ # 检查文件是否存在
77
+ if not os.path.exists(model_path):
78
+ logger.error(f"Model file not found at: {model_path}")
79
+ st.error(f"Model file not found at: {model_path}")
80
+ return None, None
81
+ if not os.path.exists(scaler_path):
82
+ logger.error(f"Scaler file not found at: {scaler_path}")
83
+ st.error(f"Scaler file not found at: {scaler_path}")
84
+ return None, None
85
+
86
+ # 检查文件大小
87
+ model_size = os.path.getsize(model_path)
88
+ scaler_size = os.path.getsize(scaler_path)
89
+ logger.info(f"Model file size: {model_size} bytes")
90
+ logger.info(f"Scaler file size: {scaler_size} bytes")
91
+ st.write(f"Model file size: {model_size} bytes")
92
+ st.write(f"Scaler file size: {scaler_size} bytes")
93
+
94
+ model = joblib.load(model_path)
95
+ scaler = joblib.load(scaler_path)
96
+ logger.info("Model and scaler loaded successfully")
97
+ st.success("Model and scaler loaded successfully!") # 显示成功消息
98
+ return model, scaler
99
+ except Exception as e:
100
+ logger.error(f"Error loading model/scaler files: {str(e)}\n{traceback.format_exc()}")
101
+ st.error(f"Error loading model/scaler files: {str(e)}")
102
+ return None, None
103
+
104
  except Exception as e:
105
+ logger.error(f"Unexpected error in load_model: {str(e)}\n{traceback.format_exc()}")
106
+ st.error(f"Unexpected error in load_model: {str(e)}")
107
  return None, None
108
 
109
  def main():
 
135
  # Load model
136
  model, scaler = load_model()
137
  if model is None or scaler is None:
138
+ st.error("Model loading failed. Please check the logs for details.")
139
  return
140
 
141
  progress_bar.progress(50)
 
149
  progress_bar.progress(70)
150
 
151
  # Predict
152
+ try:
153
+ scaled_features = scaler.transform([features])
154
+ healing_probability = model.predict_proba(scaled_features)[0][1]
155
+ progress_bar.progress(90)
156
+ except Exception as e:
157
+ logger.error(f"Error during prediction: {str(e)}\n{traceback.format_exc()}")
158
+ st.error(f"Error during prediction: {str(e)}")
159
+ return
160
 
161
  # Display results
162
  st.subheader("Analysis Results")