|
|
|
|
|
|
|
import tensorflow as tf |
|
import tensorflow_hub as hub |
|
import numpy as np |
|
import csv |
|
|
|
|
|
|
|
from scipy.io import wavfile |
|
import scipy |
|
|
|
|
|
|
|
import os |
|
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
|
from pydub import AudioSegment |
|
from matplotlib import pyplot as plt |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = hub.load('https://tfhub.dev/google/yamnet/1') |
|
|
|
debug = True |
|
|
|
|
|
|
|
def class_names_from_csv(class_map_csv_text): |
|
"""Returns list of class names corresponding to score vector.""" |
|
class_names = [] |
|
with tf.io.gfile.GFile(class_map_csv_text) as csvfile: |
|
reader = csv.DictReader(csvfile) |
|
for row in reader: |
|
class_names.append(row['display_name']) |
|
|
|
return class_names |
|
|
|
|
|
class_map_path = model.class_map_path().numpy() |
|
class_names = class_names_from_csv(class_map_path) |
|
|
|
|
|
def ensure_sample_rate(original_sample_rate, waveform, |
|
desired_sample_rate=16000): |
|
"""Resample waveform if required.""" |
|
if original_sample_rate != desired_sample_rate: |
|
desired_length = int(round(float(len(waveform)) / |
|
original_sample_rate * desired_sample_rate)) |
|
waveform = scipy.signal.resample(waveform, desired_length) |
|
return desired_sample_rate, waveform |
|
|
|
|
|
os.system("wget https://storage.googleapis.com/audioset/miaow_16k.wav") |
|
|
|
|
|
def inference(audio): |
|
|
|
wav_file_name = audio |
|
if debug: print(f'read, wav_file_name: {wav_file_name}') |
|
|
|
if wav_file_name.endswith('.mp3'): |
|
|
|
new_wav = convMp3ToWav(wav_file_name) |
|
os.remove(wav_file_name) |
|
wav_file_name = new_wav |
|
if debug: print(f'covMp3ToWav, wav_file_name: {wav_file_name}') |
|
|
|
sample_rate, wav_data = wavfile.read(wav_file_name, 'rb') |
|
|
|
if debug: print(f'read, wav_data: {wav_data}') |
|
if debug: print(f'read, sample_rate: {sample_rate}, wav_data: {wav_data.shape}') |
|
sample_rate, wav_data = ensure_sample_rate(sample_rate, wav_data) |
|
if debug: print(f'ensure_sample_rate, sample_rate: {sample_rate}, wav_data: {wav_data.shape}') |
|
if debug: print(f'ensure_single_channel, wav_data.ndim: {wav_data.ndim}') |
|
if wav_data.ndim >= 2: wav_data = wav_data[:, 0] |
|
if debug: print(f'ensure_single_channel, wav_data: {wav_data.shape}') |
|
if debug: print(f'ensured, wav_data: {wav_data}') |
|
|
|
waveform = wav_data / tf.int16.max |
|
|
|
|
|
scores, embeddings, spectrogram = model(waveform) |
|
|
|
scores_np = scores.numpy() |
|
spectrogram_np = spectrogram.numpy() |
|
|
|
scores_np_sorted = np.sort(scores_np.mean(axis=0)) |
|
scores_np_arg_sorted = np.argsort(scores_np.mean(axis=0)) |
|
|
|
class_index_array = [scores_np_arg_sorted[-1], scores_np_arg_sorted[-2], scores_np_arg_sorted[-3], scores_np_arg_sorted[-4], scores_np_arg_sorted[-5]] |
|
infered_class = class_names[class_index_array[0]] |
|
second_class = class_names[class_index_array[1]] |
|
|
|
float_formatter = "{:.4f}".format |
|
np.set_printoptions(formatter={'float_kind':float_formatter}) |
|
class_names_str = str(f'[{class_names[class_index_array[0]]}], [{class_names[class_index_array[1]]}], [{class_names[class_index_array[2]]}], [{class_names[class_index_array[3]]}], [{class_names[class_index_array[4]]}]') |
|
|
|
|
|
scores_str = str('[{:.4f}'.format(scores_np_sorted[-1]) + '], [{:.4f}'.format(scores_np_sorted[-2]) + '], [{:.4f}'.format(scores_np_sorted[-3]) + '], [{:.4f}'.format(scores_np_sorted[-4]) + '], [{:.4f}'.format(scores_np_sorted[-5])) + ']' |
|
|
|
|
|
return f'The main sound is: [{infered_class}], \n\nthe second sound is: [{second_class}]. \n\n classes: {class_names_str}, \n\n scores: {scores_str}' |
|
|
|
|
|
def convMp3ToWav(wav_file_name): |
|
src = wav_file_name |
|
dst = wav_file_name + ".wav" |
|
|
|
sound = AudioSegment.from_file(src) |
|
sound.export(dst, format="wav") |
|
return dst |
|
|
|
|
|
examples = [['miaow_16k.wav']] |
|
title = "yamnet" |
|
description = "An audio event classifier trained on the AudioSet dataset to predict audio events from the AudioSet ontology." |
|
gr.Interface(inference, gr.inputs.Audio(type="filepath"), "text", examples=examples, title=title, |
|
description=description).launch(enable_queue=True) |
|
|
|
|