File size: 3,630 Bytes
b804c93 17f9f88 07f0183 17f9f88 07f0183 8792ec4 07f0183 8792ec4 17f9f88 07f0183 17f9f88 07f0183 17f9f88 07f0183 17f9f88 07f0183 17f9f88 07f0183 17f9f88 b804c93 17f9f88 b804c93 c5851c8 26958d6 c5851c8 8cc10a4 c5851c8 8cc10a4 26958d6 8cc10a4 26958d6 8cc10a4 c5851c8 ecd55db c5851c8 b804c93 c5851c8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import torch
import torchaudio
import gradio as gr
import os
import numpy as np
from transformers import Wav2Vec2Processor, Wav2Vec2Model
from safetensors.torch import load_file
import torch.nn as nn
from huggingface_hub import hf_hub_download
model_path = hf_hub_download(repo_id="creativepurus/accent-wav2vec2", filename="model.safetensors")
# Load processor
processor = Wav2Vec2Processor.from_pretrained("creativepurus/accent-wav2vec2")
# Load model weights from model.safetensors
state_dict = load_file(model_path, device="cpu")
# Define the same model architecture used during training
class Wav2Vec2Classifier(nn.Module):
def __init__(self):
super(Wav2Vec2Classifier, self).__init__()
self.wav2vec2 = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-large-960h")
self.dropout = nn.Dropout(0.3)
self.classifier = nn.Linear(self.wav2vec2.config.hidden_size, 2)
def forward(self, input_values, attention_mask=None):
outputs = self.wav2vec2(input_values, attention_mask=attention_mask)
hidden_states = outputs.last_hidden_state
pooled = hidden_states.mean(dim=1)
pooled = self.dropout(pooled)
logits = self.classifier(pooled)
return logits
# Instantiate and load the model
model = Wav2Vec2Classifier()
model.load_state_dict(state_dict)
model.eval()
# Prediction function
def predict_accent(audio):
waveform, sample_rate = torchaudio.load(audio)
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
waveform = resampler(waveform)
input_values = processor(waveform.squeeze().numpy(), return_tensors="pt", sampling_rate=16000).input_values
with torch.no_grad():
logits = model(input_values)
predicted_class_id = logits.argmax().item()
label_map = {0: "Canadian English", 1: "England English"}
return label_map[predicted_class_id]
# # Gradio UI
# interface = gr.Interface(
# fn=predict_accent,
# inputs=gr.Audio(sources=["upload", "microphone"], type="filepath", label="Upload or Record Audio (WAV)"),
# outputs=gr.Textbox(label="Predicted Accent"),
# title="Accent Classification",
# description="This app classifies English accents as either Canadian or England using a fine-tuned Wav2Vec2 model.",
# allow_flagging="never"
# )
# Gradio UI with gr.Blocks
# Gradio UI with gr.Blocks and Custom Styling
custom_css = """
#predict-btn {
background-color: orange !important;
color: white !important;
font-weight: bold;
}
#author-section {
font-size: 18px;
font-weight: 500;
}
"""
with gr.Blocks(css=custom_css) as demo:
gr.Markdown("## π£οΈ Accent Classification App")
gr.Markdown("This app classifies English accents as either **Canadian** or **England** using a fine-tuned Wav2Vec2 model.")
with gr.Row():
with gr.Column():
audio_input = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Upload or Record Audio (WAV)")
predict_button = gr.Button("Predict Accent", elem_id="predict-btn")
with gr.Column():
result_output = gr.Textbox(label="Predicted Accent")
predict_button.click(fn=predict_accent, inputs=audio_input, outputs=result_output)
gr.Markdown("---")
gr.Markdown("""
<div id="author-section">
π¨βπ» Created by <strong>Anand Purushottam</strong>
π <a href="https://github.com/creativepurus" target="_blank">GitHub</a> |
<a href="https://linkedin.com/in/creativepurus" target="_blank">LinkedIn</a>
</div>
""",
)
demo.launch() |