creativepurus's picture
Fixed UI
ecd55db
import torch
import torchaudio
import gradio as gr
import os
import numpy as np
from transformers import Wav2Vec2Processor, Wav2Vec2Model
from safetensors.torch import load_file
import torch.nn as nn
from huggingface_hub import hf_hub_download
model_path = hf_hub_download(repo_id="creativepurus/accent-wav2vec2", filename="model.safetensors")
# Load processor
processor = Wav2Vec2Processor.from_pretrained("creativepurus/accent-wav2vec2")
# Load model weights from model.safetensors
state_dict = load_file(model_path, device="cpu")
# Define the same model architecture used during training
class Wav2Vec2Classifier(nn.Module):
def __init__(self):
super(Wav2Vec2Classifier, self).__init__()
self.wav2vec2 = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-large-960h")
self.dropout = nn.Dropout(0.3)
self.classifier = nn.Linear(self.wav2vec2.config.hidden_size, 2)
def forward(self, input_values, attention_mask=None):
outputs = self.wav2vec2(input_values, attention_mask=attention_mask)
hidden_states = outputs.last_hidden_state
pooled = hidden_states.mean(dim=1)
pooled = self.dropout(pooled)
logits = self.classifier(pooled)
return logits
# Instantiate and load the model
model = Wav2Vec2Classifier()
model.load_state_dict(state_dict)
model.eval()
# Prediction function
def predict_accent(audio):
waveform, sample_rate = torchaudio.load(audio)
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
waveform = resampler(waveform)
input_values = processor(waveform.squeeze().numpy(), return_tensors="pt", sampling_rate=16000).input_values
with torch.no_grad():
logits = model(input_values)
predicted_class_id = logits.argmax().item()
label_map = {0: "Canadian English", 1: "England English"}
return label_map[predicted_class_id]
# # Gradio UI
# interface = gr.Interface(
# fn=predict_accent,
# inputs=gr.Audio(sources=["upload", "microphone"], type="filepath", label="Upload or Record Audio (WAV)"),
# outputs=gr.Textbox(label="Predicted Accent"),
# title="Accent Classification",
# description="This app classifies English accents as either Canadian or England using a fine-tuned Wav2Vec2 model.",
# allow_flagging="never"
# )
# Gradio UI with gr.Blocks
# Gradio UI with gr.Blocks and Custom Styling
custom_css = """
#predict-btn {
background-color: orange !important;
color: white !important;
font-weight: bold;
}
#author-section {
font-size: 18px;
font-weight: 500;
}
"""
with gr.Blocks(css=custom_css) as demo:
gr.Markdown("## πŸ—£οΈ Accent Classification App")
gr.Markdown("This app classifies English accents as either **Canadian** or **England** using a fine-tuned Wav2Vec2 model.")
with gr.Row():
with gr.Column():
audio_input = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Upload or Record Audio (WAV)")
predict_button = gr.Button("Predict Accent", elem_id="predict-btn")
with gr.Column():
result_output = gr.Textbox(label="Predicted Accent")
predict_button.click(fn=predict_accent, inputs=audio_input, outputs=result_output)
gr.Markdown("---")
gr.Markdown("""
<div id="author-section">
πŸ‘¨β€πŸ’» Created by <strong>Anand Purushottam</strong>
πŸ”— <a href="https://github.com/creativepurus" target="_blank">GitHub</a> |
<a href="https://linkedin.com/in/creativepurus" target="_blank">LinkedIn</a>
</div>
""",
)
demo.launch()