Kaworu17 commited on
Commit
e63bfc0
·
verified ·
1 Parent(s): 4a03abd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -14
app.py CHANGED
@@ -4,11 +4,10 @@ import numpy as np
4
  import matplotlib.pyplot as plt
5
  import gradio as gr
6
  import soundfile as sf
7
- from scipy.signal import resample # Correct resampling method
8
 
9
  # Load YAMNet model from TensorFlow Hub
10
- yamnet_model_handle = "https://tfhub.dev/google/yamnet/1"
11
- yamnet_model = hub.load(yamnet_model_handle)
12
 
13
  # Load class labels
14
  def load_class_map():
@@ -17,30 +16,30 @@ def load_class_map():
17
  'https://raw.githubusercontent.com/tensorflow/models/master/research/audioset/yamnet/yamnet_class_map.csv'
18
  )
19
  with open(class_map_path, 'r') as f:
20
- class_names = [line.strip().split(',')[2] for line in f.readlines()[1:]]
21
- return class_names
22
 
23
  class_names = load_class_map()
24
 
25
  # Classification function
26
  def classify_audio(file_path):
27
  try:
28
- # Load audio file (WAV, MP3, etc.)
29
  audio_data, sample_rate = sf.read(file_path)
30
 
31
- # Convert stereo to mono if needed
32
  if len(audio_data.shape) > 1:
33
  audio_data = np.mean(audio_data, axis=1)
34
 
35
- # Normalize audio
36
  audio_data = audio_data / np.max(np.abs(audio_data))
37
 
38
- # Resample to 16kHz if necessary
39
  target_rate = 16000
40
  if sample_rate != target_rate:
41
  duration = audio_data.shape[0] / sample_rate
42
  new_length = int(duration * target_rate)
43
  audio_data = resample(audio_data, new_length)
 
44
 
45
  # Convert to tensor
46
  waveform = tf.convert_to_tensor(audio_data, dtype=tf.float32)
@@ -53,20 +52,20 @@ def classify_audio(file_path):
53
  top_prediction = class_names[top_5[0]]
54
  top_scores = {class_names[i]: float(mean_scores[i]) for i in top_5}
55
 
56
- # Create waveform plot
57
  fig, ax = plt.subplots()
58
  ax.plot(audio_data)
59
  ax.set_title("Waveform")
60
- ax.set_xlabel("Time")
61
  ax.set_ylabel("Amplitude")
62
  plt.tight_layout()
63
 
64
  return top_prediction, top_scores, fig
65
 
66
  except Exception as e:
67
- return f"Error processing audio: {e}", {}, None
68
 
69
- # Gradio interface
70
  interface = gr.Interface(
71
  fn=classify_audio,
72
  inputs=gr.Audio(type="filepath", label="Upload .wav or .mp3 audio file"),
@@ -80,4 +79,4 @@ interface = gr.Interface(
80
  )
81
 
82
  if __name__ == "__main__":
83
- interface.launch()
 
4
  import matplotlib.pyplot as plt
5
  import gradio as gr
6
  import soundfile as sf
7
+ from scipy.signal import resample
8
 
9
  # Load YAMNet model from TensorFlow Hub
10
+ yamnet_model = hub.load("https://tfhub.dev/google/yamnet/1")
 
11
 
12
  # Load class labels
13
  def load_class_map():
 
16
  'https://raw.githubusercontent.com/tensorflow/models/master/research/audioset/yamnet/yamnet_class_map.csv'
17
  )
18
  with open(class_map_path, 'r') as f:
19
+ return [line.strip().split(',')[2] for line in f.readlines()[1:]]
 
20
 
21
  class_names = load_class_map()
22
 
23
  # Classification function
24
  def classify_audio(file_path):
25
  try:
26
+ # Load audio
27
  audio_data, sample_rate = sf.read(file_path)
28
 
29
+ # Convert stereo to mono
30
  if len(audio_data.shape) > 1:
31
  audio_data = np.mean(audio_data, axis=1)
32
 
33
+ # Normalize
34
  audio_data = audio_data / np.max(np.abs(audio_data))
35
 
36
+ # Resample to 16kHz if needed
37
  target_rate = 16000
38
  if sample_rate != target_rate:
39
  duration = audio_data.shape[0] / sample_rate
40
  new_length = int(duration * target_rate)
41
  audio_data = resample(audio_data, new_length)
42
+ sample_rate = target_rate
43
 
44
  # Convert to tensor
45
  waveform = tf.convert_to_tensor(audio_data, dtype=tf.float32)
 
52
  top_prediction = class_names[top_5[0]]
53
  top_scores = {class_names[i]: float(mean_scores[i]) for i in top_5}
54
 
55
+ # Waveform plot
56
  fig, ax = plt.subplots()
57
  ax.plot(audio_data)
58
  ax.set_title("Waveform")
59
+ ax.set_xlabel("Time (samples)")
60
  ax.set_ylabel("Amplitude")
61
  plt.tight_layout()
62
 
63
  return top_prediction, top_scores, fig
64
 
65
  except Exception as e:
66
+ return f"Error processing audio: {str(e)}", {}, None
67
 
68
+ # Gradio interface (HF-compatible)
69
  interface = gr.Interface(
70
  fn=classify_audio,
71
  inputs=gr.Audio(type="filepath", label="Upload .wav or .mp3 audio file"),
 
79
  )
80
 
81
  if __name__ == "__main__":
82
+ interface.launch()