404Brain-Not-Found-yeah commited on
Commit
4082be1
·
verified ·
1 Parent(s): 68afaa5

Upload train_model.py

Browse files
Files changed (1) hide show
  1. train_model.py +161 -0
train_model.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import librosa
4
+ import pandas as pd
5
+ from sklearn.model_selection import train_test_split, cross_val_score
6
+ from sklearn.ensemble import RandomForestClassifier
7
+ from sklearn.preprocessing import StandardScaler
8
+ import joblib
9
+ import warnings
10
+ import soundfile as sf
11
+ import logging
12
+ import traceback
13
+ import sys
14
+
15
+ # 设置更详细的日志记录
16
+ logging.basicConfig(
17
+ level=logging.DEBUG,
18
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
19
+ handlers=[
20
+ logging.StreamHandler(sys.stdout),
21
+ logging.FileHandler('training.log')
22
+ ]
23
+ )
24
+ logger = logging.getLogger(__name__)
25
+
26
+ warnings.filterwarnings('ignore')
27
+
28
+ def extract_features(file_path):
29
+ """Extract audio features from a file."""
30
+ try:
31
+ logger.info(f"Starting feature extraction for: {file_path}")
32
+
33
+ # Verify file exists
34
+ if not os.path.exists(file_path):
35
+ logger.error(f"File does not exist: {file_path}")
36
+ return None
37
+
38
+ # Verify file format
39
+ try:
40
+ with sf.SoundFile(file_path) as sf_file:
41
+ logger.info(f"Audio file info: {sf_file.samplerate}Hz, {sf_file.channels} channels")
42
+ except Exception as e:
43
+ logger.error(f"Error reading audio file with soundfile: {str(e)}\n{traceback.format_exc()}")
44
+ return None
45
+
46
+ # Load audio file with error handling
47
+ try:
48
+ logger.info("Loading audio file...")
49
+ y, sr = librosa.load(file_path, duration=30, sr=None)
50
+ if len(y) == 0:
51
+ logger.error("Audio file is empty")
52
+ return None
53
+ logger.info(f"Successfully loaded audio: {len(y)} samples, {sr}Hz sample rate")
54
+ except Exception as e:
55
+ logger.error(f"Error loading audio: {str(e)}\n{traceback.format_exc()}")
56
+ return None
57
+
58
+ # Ensure minimum duration
59
+ duration = len(y) / sr
60
+ logger.info(f"Audio duration: {duration:.2f} seconds")
61
+ if duration < 1.0:
62
+ logger.error("Audio file is too short (less than 1 second)")
63
+ return None
64
+
65
+ features_dict = {}
66
+
67
+ try:
68
+ # 1. MFCC (13 features x 2 = 26)
69
+ logger.info("Extracting MFCC features...")
70
+ mfccs = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=13)
71
+ features_dict['mfccs_mean'] = np.mean(mfccs, axis=1)
72
+ features_dict['mfccs_var'] = np.var(mfccs, axis=1)
73
+ logger.info(f"MFCC features shape: {mfccs.shape}")
74
+ except Exception as e:
75
+ logger.error(f"Error extracting MFCC: {str(e)}\n{traceback.format_exc()}")
76
+ return None
77
+
78
+ try:
79
+ # 2. Chroma Features
80
+ logger.info("Extracting chroma features...")
81
+ chroma = librosa.feature.chroma_stft(y=y, sr=sr)
82
+ features_dict['chroma'] = np.mean(chroma, axis=1)
83
+ logger.info(f"Chroma features shape: {chroma.shape}")
84
+ except Exception as e:
85
+ logger.error(f"Error extracting chroma features: {str(e)}\n{traceback.format_exc()}")
86
+ return None
87
+
88
+ # Combine all features
89
+ try:
90
+ logger.info("Combining features...")
91
+ features = np.concatenate([
92
+ features_dict['mfccs_mean'],
93
+ features_dict['mfccs_var'],
94
+ features_dict['chroma']
95
+ ])
96
+ logger.info(f"Final feature vector shape: {features.shape}")
97
+ return features
98
+ except Exception as e:
99
+ logger.error(f"Error combining features: {str(e)}\n{traceback.format_exc()}")
100
+ return None
101
+
102
+ except Exception as e:
103
+ logger.error(f"Unexpected error in feature extraction: {str(e)}\n{traceback.format_exc()}")
104
+ return None
105
+
106
+ def prepare_dataset():
107
+ """Prepare dataset from healing and non-healing music folders."""
108
+ # 直接使用合成数据集
109
+ print("Using synthetic dataset for initial deployment...")
110
+ np.random.seed(42)
111
+ n_samples = 100 # 增加样本数量
112
+ n_features = 38 # 26 MFCC features + 12 Chroma features
113
+
114
+ # 创建更有结构的合成特征
115
+ synthetic_features = np.random.normal(0, 1, (n_samples, n_features))
116
+ # 创建平衡的标签
117
+ synthetic_labels = np.concatenate([np.ones(n_samples//2), np.zeros(n_samples//2)])
118
+
119
+ return synthetic_features, synthetic_labels
120
+
121
+ def train_and_evaluate_model():
122
+ """Train and evaluate the model."""
123
+ # Prepare dataset
124
+ print("Extracting features from audio files...")
125
+ X, y = prepare_dataset()
126
+
127
+ # Scale features
128
+ scaler = StandardScaler()
129
+ X_scaled = scaler.fit_transform(X)
130
+
131
+ # Split dataset
132
+ X_train, X_test, y_train, y_test = train_test_split(
133
+ X_scaled, y, test_size=0.2, random_state=42
134
+ )
135
+
136
+ # Train model
137
+ print("Training model...")
138
+ model = RandomForestClassifier(n_estimators=100, random_state=42)
139
+ model.fit(X_train, y_train)
140
+
141
+ # Evaluate model
142
+ print("Evaluating model...")
143
+ cv_scores = cross_val_score(model, X_scaled, y, cv=5)
144
+ print(f"Cross-validation scores: {cv_scores}")
145
+ print(f"Average CV score: {cv_scores.mean():.3f} (+/- {cv_scores.std() * 2:.3f})")
146
+
147
+ # Save model and scaler
148
+ print("Saving model and scaler...")
149
+ model_dir = os.path.join(os.path.dirname(__file__), "models")
150
+ os.makedirs(model_dir, exist_ok=True)
151
+
152
+ model_path = os.path.join(model_dir, "model.joblib")
153
+ scaler_path = os.path.join(model_dir, "scaler.joblib")
154
+
155
+ joblib.dump(model, model_path)
156
+ joblib.dump(scaler, scaler_path)
157
+
158
+ return model, scaler
159
+
160
+ if __name__ == "__main__":
161
+ train_and_evaluate_model()