Luis commited on
Commit
3259b0d
1 Parent(s): 1905ba9

add yamnet

Browse files
Files changed (2) hide show
  1. app.py +71 -4
  2. miaow_16k.wav +0 -0
app.py CHANGED
@@ -1,7 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="audio", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # https://tfhub.dev/google/lite-model/yamnet/classification/tflite/1
3
+
4
+ import tensorflow as tf
5
+ import tensorflow_hub as hub
6
+ import numpy as np
7
+ import csv
8
+
9
+ import matplotlib.pyplot as plt
10
+ from IPython.display import Audio
11
+ from scipy.io import wavfile
12
+
13
+ import os
14
+
15
  import gradio as gr
16
 
 
 
17
 
18
+ # Load the model.
19
+ model = hub.load('https://tfhub.dev/google/yamnet/1')
20
+
21
+
22
+ # Find the name of the class with the top score when mean-aggregated across frames.
23
+ def class_names_from_csv(class_map_csv_text):
24
+ """Returns list of class names corresponding to score vector."""
25
+ class_names = []
26
+ with tf.io.gfile.GFile(class_map_csv_text) as csvfile:
27
+ reader = csv.DictReader(csvfile)
28
+ for row in reader:
29
+ class_names.append(row['display_name'])
30
+
31
+ return class_names
32
+
33
+
34
+ class_map_path = model.class_map_path().numpy()
35
+ class_names = class_names_from_csv(class_map_path)
36
+
37
+
38
+ def ensure_sample_rate(original_sample_rate, waveform,
39
+ desired_sample_rate=16000):
40
+ """Resample waveform if required."""
41
+ if original_sample_rate != desired_sample_rate:
42
+ desired_length = int(round(float(len(waveform)) /
43
+ original_sample_rate * desired_sample_rate))
44
+ waveform = scipy.signal.resample(waveform, desired_length)
45
+ return desired_sample_rate, waveform
46
+
47
+
48
+ os.system("wget https://storage.googleapis.com/audioset/miaow_16k.wav")
49
+
50
+
51
+ def inference(audio):
52
+ # wav_file_name = 'speech_whistling2.wav'
53
+ wav_file_name = audio
54
+ sample_rate, wav_data = wavfile.read(wav_file_name, 'rb')
55
+ sample_rate, wav_data = ensure_sample_rate(sample_rate, wav_data)
56
+
57
+ waveform = wav_data / tf.int16.max
58
+
59
+ # Run the model, check the output.
60
+ scores, embeddings, spectrogram = model(waveform)
61
+
62
+ scores_np = scores.numpy()
63
+ spectrogram_np = spectrogram.numpy()
64
+ infered_class = class_names[scores_np.mean(axis=0).argmax()]
65
+
66
+ return f'The main sound is: {infered_class}'
67
+
68
+
69
+ examples = [['miaow_16k.wav']]
70
+ title = "yamnet"
71
+ description = "An audio event classifier trained on the AudioSet dataset to predict audio events from the AudioSet ontology."
72
+ gr.Interface(inference, gr.inputs.Audio(type="filepath"), "text", examples=examples, title=title,
73
+ description=description).launch(enable_queue=True)
74
+
miaow_16k.wav ADDED
Binary file (216 kB). View file