|
import torch |
|
import torchaudio |
|
import gradio as gr |
|
import os |
|
import numpy as np |
|
|
|
from transformers import Wav2Vec2Processor, Wav2Vec2Model |
|
from safetensors.torch import load_file |
|
import torch.nn as nn |
|
|
|
from huggingface_hub import hf_hub_download |
|
|
|
model_path = hf_hub_download(repo_id="creativepurus/accent-wav2vec2", filename="model.safetensors") |
|
|
|
|
|
processor = Wav2Vec2Processor.from_pretrained("creativepurus/accent-wav2vec2") |
|
|
|
|
|
state_dict = load_file(model_path, device="cpu") |
|
|
|
|
|
class Wav2Vec2Classifier(nn.Module): |
|
def __init__(self): |
|
super(Wav2Vec2Classifier, self).__init__() |
|
self.wav2vec2 = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-large-960h") |
|
self.dropout = nn.Dropout(0.3) |
|
self.classifier = nn.Linear(self.wav2vec2.config.hidden_size, 2) |
|
|
|
def forward(self, input_values, attention_mask=None): |
|
outputs = self.wav2vec2(input_values, attention_mask=attention_mask) |
|
hidden_states = outputs.last_hidden_state |
|
pooled = hidden_states.mean(dim=1) |
|
pooled = self.dropout(pooled) |
|
logits = self.classifier(pooled) |
|
return logits |
|
|
|
|
|
model = Wav2Vec2Classifier() |
|
model.load_state_dict(state_dict) |
|
model.eval() |
|
|
|
|
|
def predict_accent(audio): |
|
waveform, sample_rate = torchaudio.load(audio) |
|
if sample_rate != 16000: |
|
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000) |
|
waveform = resampler(waveform) |
|
input_values = processor(waveform.squeeze().numpy(), return_tensors="pt", sampling_rate=16000).input_values |
|
with torch.no_grad(): |
|
logits = model(input_values) |
|
predicted_class_id = logits.argmax().item() |
|
label_map = {0: "Canadian English", 1: "England English"} |
|
return label_map[predicted_class_id] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
custom_css = """ |
|
#predict-btn { |
|
background-color: orange !important; |
|
color: white !important; |
|
font-weight: bold; |
|
} |
|
#author-section { |
|
font-size: 18px; |
|
font-weight: 500; |
|
} |
|
""" |
|
|
|
with gr.Blocks(css=custom_css) as demo: |
|
gr.Markdown("## π£οΈ Accent Classification App") |
|
gr.Markdown("This app classifies English accents as either **Canadian** or **England** using a fine-tuned Wav2Vec2 model.") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
audio_input = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Upload or Record Audio (WAV)") |
|
predict_button = gr.Button("Predict Accent", elem_id="predict-btn") |
|
with gr.Column(): |
|
result_output = gr.Textbox(label="Predicted Accent") |
|
|
|
predict_button.click(fn=predict_accent, inputs=audio_input, outputs=result_output) |
|
|
|
gr.Markdown("---") |
|
gr.Markdown(""" |
|
<div id="author-section"> |
|
π¨βπ» Created by <strong>Anand Purushottam</strong> |
|
π <a href="https://github.com/creativepurus" target="_blank">GitHub</a> | |
|
<a href="https://linkedin.com/in/creativepurus" target="_blank">LinkedIn</a> |
|
</div> |
|
""", |
|
) |
|
|
|
demo.launch() |