Anvit25 commited on
Commit
45f9c09
Β·
verified Β·
1 Parent(s): c99d7cb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -84
app.py CHANGED
@@ -8,15 +8,12 @@ import matplotlib.pyplot as plt
8
  import tensorflow as tf
9
  from tensorflow.keras import models
10
 
11
- # --- 1. Configuration & Global Variables ---
12
- # Create a temporary directory for spectrograms if it doesn't exist
13
  TEMP_DIR = "temp_gradio_specs"
14
  os.makedirs(TEMP_DIR, exist_ok=True)
15
-
16
- # Define image size for the model
17
  IMG_SIZE = (224, 224)
18
 
19
- # --- 2. Load Models and Define Classes (Done once on startup) ---
20
  print("πŸš€ Loading machine learning models...")
21
  try:
22
  stage1_model = models.load_model("saved_models/stage1_model.h5")
@@ -25,41 +22,55 @@ try:
25
  print("βœ… Models loaded successfully.")
26
  except Exception as e:
27
  print(f"❌ Error loading models: {e}")
28
- # Exit if models can't be loaded
29
- exit()
30
 
31
- # Define class lists exactly as they were during training
32
  stage1_classes = ["00 - Abnormal", "01 - Normal"]
33
- abnormal_classes = sorted(os.listdir("MelSpectrograms/00 - Abnormal"))
34
- normal_classes = sorted(os.listdir("MelSpectrograms/01 - Normal"))
 
 
 
 
 
 
 
 
 
 
35
 
36
  print(f"Stage 1 Classes: {stage1_classes}")
37
  print(f"Abnormal Sub-classes: {abnormal_classes}")
38
  print(f"Normal Sub-classes: {normal_classes}")
39
 
40
 
41
- # --- 3. Helper Functions and Classes ---
42
-
43
- def save_mel_spectrogram(file_path, save_dir, sr=22050, n_mels=128, hop_length=512, n_fft=2048):
44
  """Generates and saves a Mel Spectrogram from an audio file."""
45
  try:
46
  y, sr = librosa.load(file_path, sr=sr, mono=True)
47
- S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length)
 
 
48
  S_db = librosa.power_to_db(S, ref=np.max)
49
-
50
  filename = os.path.basename(file_path).replace(".wav", ".png")
51
  save_path = os.path.join(save_dir, filename)
52
 
53
  plt.figure(figsize=(4, 4))
54
- librosa.display.specshow(S_db, sr=sr, hop_length=hop_length, x_axis='time', y_axis='mel', cmap='magma')
 
55
  plt.axis("off")
56
  plt.savefig(save_path, bbox_inches="tight", pad_inches=0)
57
  plt.close()
58
  return save_path
59
  except Exception as e:
60
- print(f"Error creating spectrogram: {e}")
61
  return None
62
 
 
63
  class HierarchicalClassifier:
64
  """A wrapper class for the two-stage prediction logic."""
65
  def __init__(self, stage1_model, abnormal_model, normal_model,
@@ -75,10 +86,18 @@ class HierarchicalClassifier:
75
  def _preprocess_image(self, image_path):
76
  img = tf.keras.utils.load_img(image_path, target_size=self.img_size)
77
  img_array = tf.keras.utils.img_to_array(img) / 255.0
78
- img_array = tf.expand_dims(img_array, 0)
79
- return img_array
80
 
81
  def predict(self, image_path):
 
 
 
 
 
 
 
 
 
82
  img_array = self._preprocess_image(image_path)
83
  stage1_pred = self.stage1_model.predict(img_array, verbose=0)
84
  stage1_idx = np.argmax(stage1_pred)
@@ -98,94 +117,59 @@ class HierarchicalClassifier:
98
  "stage1_confidence": float(np.max(stage1_pred)),
99
  "stage2_class": sub_class,
100
  "stage2_confidence": float(np.max(sub_pred)),
101
- "final_prediction": f"{main_class.split(' - ')[1]} β†’ {sub_class.split(' - ')[1]}"
102
  }
103
 
104
- # Instantiate the classifier with loaded models and classes
105
  classifier = HierarchicalClassifier(
106
  stage1_model, abnormal_model, normal_model,
107
  stage1_classes, abnormal_classes, normal_classes
108
  )
109
 
110
- # --- 4. The Main Prediction Function for Gradio ---
 
111
  def predict_washing_machine_sound(audio_filepath):
112
- """
113
- This is the core function that Gradio will call.
114
- It takes an audio file path, processes it, and returns the formatted result.
115
- """
116
  if audio_filepath is None:
117
  return "Please upload an audio file first.", None
118
 
119
  print(f"Processing file: {audio_filepath}")
120
-
121
- # The spectrogram path needs to be cleaned up after prediction
122
- spec_path = None
123
- try:
124
- # Generate a spectrogram from the input audio file
125
- spec_path = save_mel_spectrogram(audio_filepath, TEMP_DIR)
126
- if not spec_path:
127
- return "Error: Could not generate spectrogram from the audio file.", None
128
-
129
- # Get prediction from the classifier
130
- result = classifier.predict(spec_path)
131
-
132
- # Format the output for better readability
133
- output_text = (
134
- f"🎯 Final Prediction: {result['final_prediction']}\n\n"
135
- f"Confidence Scores:\n"
136
- f"--------------------\n"
137
- f"Stage 1 ({result['stage1_class']}): {result['stage1_confidence']:.4f}\n"
138
- f"Stage 2 ({result['stage2_class']}): {result['stage2_confidence']:.4f}"
139
- )
140
-
141
- # Return the formatted text and the path to the spectrogram image to display it
142
- return output_text, spec_path
143
 
144
- except Exception as e:
145
- print(f"An error occurred during prediction: {e}")
146
- return f"An error occurred: {str(e)}", None
147
-
148
- finally:
149
- # Clean up the generated spectrogram image file after it's been used
150
- # Gradio handles the temp audio file, but we must handle the temp spectrogram
151
- if spec_path and os.path.exists(spec_path):
152
- # Note: Gradio might need the file to display it, so cleaning up here
153
- # might be too early if the image component relies on the path.
154
- # For simplicity, we can let them accumulate in the temp folder or
155
- # implement more complex cleanup later. Let's comment out the immediate delete.
156
- # os.remove(spec_path)
157
- pass
158
-
159
- # --- 5. Build and Launch the Gradio Interface ---
160
- if __name__ == "__main__":
161
- # Define some example audio files from your dataset
162
- example_files = [
163
- "Washing machine/00 - Abnormal/00-2 - Dehydration mode noise/04.wav",
164
- "Washing machine/01 - Normal/01-1 - Washing mode/01.wav",
165
- "Washing machine/00 - Abnormal/00-1 - Bearing noise/02.wav"
166
- ]
167
 
 
 
 
168
  demo = gr.Interface(
169
  fn=predict_washing_machine_sound,
170
- inputs=gr.Audio(type="filepath", label="Upload Washing Machine Audio (.wav)"),
171
  outputs=[
172
  gr.Textbox(label="Prediction Result"),
173
- gr.Image(label="Generated Mel Spectrogram")
174
  ],
175
- title="Washing Machine Sound Classifier",
176
- description="Upload a WAV audio file of a washing machine to classify its operation status. The model performs a two-stage classification: first identifying 'Normal' vs 'Abnormal' sound, then determining the specific sub-type.",
177
- # examples=example_files,
178
- allow_flagging="never"
179
  )
180
 
181
- # Launch the web UI
182
  demo.launch()
183
 
184
- # Clean up the entire temp directory on exit
185
- # This is a simple way to manage temp files
186
  try:
187
- print("\nCleaning up temporary files...")
188
  shutil.rmtree(TEMP_DIR)
189
- print("βœ… Cleanup complete.")
190
  except Exception as e:
191
- print(f"Could not clean up temp files: {e}")
 
8
  import tensorflow as tf
9
  from tensorflow.keras import models
10
 
11
+ # ---------------- 1. Configuration ---------------- #
 
12
  TEMP_DIR = "temp_gradio_specs"
13
  os.makedirs(TEMP_DIR, exist_ok=True)
 
 
14
  IMG_SIZE = (224, 224)
15
 
16
+ # ---------------- 2. Load Models ------------------ #
17
  print("πŸš€ Loading machine learning models...")
18
  try:
19
  stage1_model = models.load_model("saved_models/stage1_model.h5")
 
22
  print("βœ… Models loaded successfully.")
23
  except Exception as e:
24
  print(f"❌ Error loading models: {e}")
25
+ # Do not exitβ€”allows app to show error gracefully
26
+ stage1_model = abnormal_model = normal_model = None
27
 
28
+ # Default class lists – replace with actual labels if available
29
  stage1_classes = ["00 - Abnormal", "01 - Normal"]
30
+
31
+ abnormal_classes = (
32
+ sorted(os.listdir("MelSpectrograms/00 - Abnormal"))
33
+ if os.path.exists("MelSpectrograms/00 - Abnormal")
34
+ else ["Bearing noise", "Dehydration mode noise"]
35
+ )
36
+
37
+ normal_classes = (
38
+ sorted(os.listdir("MelSpectrograms/01 - Normal"))
39
+ if os.path.exists("MelSpectrograms/01 - Normal")
40
+ else ["Wash mode", "Spin mode"]
41
+ )
42
 
43
  print(f"Stage 1 Classes: {stage1_classes}")
44
  print(f"Abnormal Sub-classes: {abnormal_classes}")
45
  print(f"Normal Sub-classes: {normal_classes}")
46
 
47
 
48
+ # ---------------- 3. Helper Functions -------------- #
49
+ def save_mel_spectrogram(file_path, save_dir, sr=22050,
50
+ n_mels=128, hop_length=512, n_fft=2048):
51
  """Generates and saves a Mel Spectrogram from an audio file."""
52
  try:
53
  y, sr = librosa.load(file_path, sr=sr, mono=True)
54
+ S = librosa.feature.melspectrogram(
55
+ y=y, sr=sr, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length
56
+ )
57
  S_db = librosa.power_to_db(S, ref=np.max)
58
+
59
  filename = os.path.basename(file_path).replace(".wav", ".png")
60
  save_path = os.path.join(save_dir, filename)
61
 
62
  plt.figure(figsize=(4, 4))
63
+ librosa.display.specshow(S_db, sr=sr, hop_length=hop_length,
64
+ x_axis="time", y_axis="mel", cmap="magma")
65
  plt.axis("off")
66
  plt.savefig(save_path, bbox_inches="tight", pad_inches=0)
67
  plt.close()
68
  return save_path
69
  except Exception as e:
70
+ print(f"❌ Error creating spectrogram: {e}")
71
  return None
72
 
73
+
74
  class HierarchicalClassifier:
75
  """A wrapper class for the two-stage prediction logic."""
76
  def __init__(self, stage1_model, abnormal_model, normal_model,
 
86
  def _preprocess_image(self, image_path):
87
  img = tf.keras.utils.load_img(image_path, target_size=self.img_size)
88
  img_array = tf.keras.utils.img_to_array(img) / 255.0
89
+ return tf.expand_dims(img_array, 0)
 
90
 
91
  def predict(self, image_path):
92
+ if not all([self.stage1_model, self.abnormal_model, self.normal_model]):
93
+ return {
94
+ "final_prediction": "❌ Models not loaded. Please upload models to /saved_models/",
95
+ "stage1_class": "N/A",
96
+ "stage1_confidence": 0,
97
+ "stage2_class": "N/A",
98
+ "stage2_confidence": 0
99
+ }
100
+
101
  img_array = self._preprocess_image(image_path)
102
  stage1_pred = self.stage1_model.predict(img_array, verbose=0)
103
  stage1_idx = np.argmax(stage1_pred)
 
117
  "stage1_confidence": float(np.max(stage1_pred)),
118
  "stage2_class": sub_class,
119
  "stage2_confidence": float(np.max(sub_pred)),
120
+ "final_prediction": f"{main_class.split(' - ')[1]} β†’ {sub_class}"
121
  }
122
 
123
+
124
  classifier = HierarchicalClassifier(
125
  stage1_model, abnormal_model, normal_model,
126
  stage1_classes, abnormal_classes, normal_classes
127
  )
128
 
129
+
130
+ # ---------------- 4. Prediction Function ----------- #
131
  def predict_washing_machine_sound(audio_filepath):
 
 
 
 
132
  if audio_filepath is None:
133
  return "Please upload an audio file first.", None
134
 
135
  print(f"Processing file: {audio_filepath}")
136
+ spec_path = save_mel_spectrogram(audio_filepath, TEMP_DIR)
137
+ if not spec_path:
138
+ return "❌ Could not generate spectrogram from the audio file.", None
139
+
140
+ result = classifier.predict(spec_path)
141
+
142
+ output_text = (
143
+ f"🎯 Final Prediction: {result['final_prediction']}\n\n"
144
+ f"Confidence Scores:\n"
145
+ f"--------------------\n"
146
+ f"Stage 1 ({result['stage1_class']}): {result['stage1_confidence']:.4f}\n"
147
+ f"Stage 2 ({result['stage2_class']}): {result['stage2_confidence']:.4f}"
148
+ )
 
 
 
 
 
 
 
 
 
 
149
 
150
+ return output_text, spec_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
+
153
+ # ---------------- 5. Gradio Interface -------------- #
154
+ if __name__ == "__main__":
155
  demo = gr.Interface(
156
  fn=predict_washing_machine_sound,
157
+ inputs=gr.Audio(type="filepath", label="Upload Washing-Machine Audio (.wav)"),
158
  outputs=[
159
  gr.Textbox(label="Prediction Result"),
160
+ gr.Image(label="Generated Mel-Spectrogram")
161
  ],
162
+ title="Washing-Machine Sound Classifier",
163
+ description="Upload a WAV file of washing-machine audio to classify its operation status.",
164
+ allow_flagging="never",
165
+ # examples=[] # ← removed local file examples
166
  )
167
 
 
168
  demo.launch()
169
 
170
+ # Cleanup temp dir after app stops
 
171
  try:
 
172
  shutil.rmtree(TEMP_DIR)
173
+ print("βœ… Cleaned up temporary files.")
174
  except Exception as e:
175
+ print(f"⚠️ Cleanup warning: {e}")