File size: 6,277 Bytes
408a733 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
import os
import io
import csv
import gradio as gr
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_io as tfio
import matplotlib.pyplot as plt
from tensorflow import keras
from huggingface_hub import from_pretrained_keras
# Configuration
class_names = [
"Irish",
"Midlands",
"Northern",
"Scottish",
"Southern",
"Welsh",
"Not a speech",
]
# Download Yamnet model from TF Hub
yamnet_model = hub.load("https://tfhub.dev/google/yamnet/1")
# Download dense model from HF Hub
model = from_pretrained_keras(
pretrained_model_name_or_path="fbadine/uk_ireland_accent_classification"
)
# Function that reads a wav audio file and resamples it to 16000 Hz
# This function is copied from the tutorial:
# https://www.tensorflow.org/tutorials/audio/transfer_learning_audio
def load_16k_audio_wav(filename):
# Read file content
file_content = tf.io.read_file(filename)
# Decode audio wave
audio_wav, sample_rate = tf.audio.decode_wav(file_content, desired_channels=1)
audio_wav = tf.squeeze(audio_wav, axis=-1)
sample_rate = tf.cast(sample_rate, dtype=tf.int64)
# Resample to 16k
audio_wav = tfio.audio.resample(audio_wav, rate_in=sample_rate, rate_out=16000)
return audio_wav
# Function thatt takes the audio file produced by gr.Audio(source="microphone") and
# returns a tensor applying the following transformations:
# - Resample to 16000 Hz
# - Normalize
# - Reshape to [1, -1]
def mic_to_tensor(recorded_audio_file):
sample_rate, audio = recorded_audio_file
audio_wav = tf.constant(audio, dtype=tf.float32)
if tf.rank(audio_wav) > 1:
audio_wav = tf.reduce_mean(audio_wav, axis=1)
audio_wav = tfio.audio.resample(audio_wav, rate_in=sample_rate, rate_out=16000)
audio_wav = tf.divide(audio_wav, tf.reduce_max(tf.abs(audio_wav)))
return audio_wav
# Function that takes a tensor and applies the following:
# - Pass it through Yamnet model to get the embeddings which are the input of the dense model
# - Pass the embeddings through the dense model to get the predictions
def tensor_to_predictions(audio_tensor):
# Get audio embeddings & scores.
scores, embeddings, mel_spectrogram = yamnet_model(audio_tensor)
# Predict the output of the accent recognition model with embeddings as input
predictions = model.predict(embeddings)
return predictions, mel_spectrogram
# Function tha is called when the user clicks "Predict" button. It does the following:
# - Calls tensor_to_predictions() to get the predictions
# - Generates the top scoring labels
# - Generates the top scoring plot
def predict_accent(recorded_audio_file, uploaded_audio_file):
# Transform input to tensor
if recorded_audio_file:
audio_tensor = mic_to_tensor(recorded_audio_file)
else:
audio_tensor = load_16k_audio_wav(uploaded_audio_file)
# Model Inference
predictions, mel_spectrogram = tensor_to_predictions(audio_tensor)
# Get the infered class
infered_class = class_names[predictions.mean(axis=0).argmax()]
# Generate Output 1 - Accents
top_scoring_labels_output = {
class_names[i]: float(predictions.mean(axis=0)[i])
for i in range(len(class_names))
}
# Generate Output 2
top_scoring_plot_output = generate_top_scoring_plot(predictions)
return [top_scoring_labels_output, top_scoring_plot_output]
# Clears all inputs and outputs when the user clicks "Clear" button
def clear_inputs_and_outputs():
return [None, None, None, None]
# Function that generates the top scoring plot
# This function is copied from the tutorial and adjusted to our needs
# https://keras.io/examples/audio/uk_ireland_accent_recognition/tinyurl.com/4a8xn7at
def generate_top_scoring_plot(predictions):
# Plot and label the model output scores for the top-scoring classes.
mean_predictions = np.mean(predictions, axis=0)
top_class_indices = np.argsort(mean_predictions)[::-1]
fig = plt.figure(figsize=(10, 2))
plt.imshow(
predictions[:, top_class_indices].T,
aspect="auto",
interpolation="nearest",
cmap="gray_r",
)
# patch_padding = (PATCH_WINDOW_SECONDS / 2) / PATCH_HOP_SECONDS
# values from the model documentation
patch_padding = (0.025 / 2) / 0.01
plt.xlim([-patch_padding - 0.5, predictions.shape[0] + patch_padding - 0.5])
# Label the top_N classes.
yticks = range(0, len(class_names), 1)
plt.yticks(yticks, [class_names[top_class_indices[x]] for x in yticks])
_ = plt.ylim(-0.5 + np.array([len(class_names), 0]))
return fig
# Main function
if __name__ == "__main__":
demo = gr.Blocks()
with demo:
gr.Markdown(
"""
<center><h1>English speaker accent recognition using Transfer Learning</h1></center> \
This space is a demo of an English (precisely UK & Ireland) accent classification model using Keras.<br> \
In this space, you can record your voice or upload a wav file and the model will predict the accent of the audio<br><br>
"""
)
with gr.Row():
## Input
with gr.Column():
mic_input = gr.Audio(source="microphone", label="Record your own voice")
upl_input = gr.Audio(
source="upload", type="filepath", label="Upload a wav file"
)
with gr.Row():
clr_btn = gr.Button(value="Clear", variant="secondary")
prd_btn = gr.Button(value="Predict")
with gr.Column():
lbl_output = gr.Label(label="Top Predictions")
with gr.Group():
gr.Markdown("<center>Prediction per time slot</center>")
plt_output = gr.Plot(
label="Prediction per time slot", show_label=False
)
clr_btn.click(
fn=clear_inputs_and_outputs,
inputs=[],
outputs=[mic_input, upl_input, lbl_output, plt_output],
)
prd_btn.click(
fn=predict_accent,
inputs=[mic_input, upl_input],
outputs=[lbl_output, plt_output],
)
demo.launch(debug=True)
|