File size: 3,630 Bytes
b804c93
 
17f9f88
 
 
07f0183
17f9f88
 
 
07f0183
8792ec4
07f0183
8792ec4
17f9f88
 
 
07f0183
17f9f88
 
07f0183
17f9f88
 
 
 
 
 
 
07f0183
17f9f88
 
 
 
 
 
07f0183
 
17f9f88
 
 
07f0183
 
17f9f88
 
 
b804c93
 
 
17f9f88
 
 
 
 
 
b804c93
c5851c8
 
 
 
 
 
 
 
 
26958d6
c5851c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8cc10a4
 
 
 
 
 
c5851c8
8cc10a4
 
26958d6
8cc10a4
26958d6
8cc10a4
 
c5851c8
 
 
 
 
ecd55db
c5851c8
b804c93
c5851c8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch
import torchaudio
import gradio as gr
import os
import numpy as np

from transformers import Wav2Vec2Processor, Wav2Vec2Model
from safetensors.torch import load_file
import torch.nn as nn

from huggingface_hub import hf_hub_download

model_path = hf_hub_download(repo_id="creativepurus/accent-wav2vec2", filename="model.safetensors")

# Load processor
processor = Wav2Vec2Processor.from_pretrained("creativepurus/accent-wav2vec2")

# Load model weights from model.safetensors
state_dict = load_file(model_path, device="cpu")

# Define the same model architecture used during training
class Wav2Vec2Classifier(nn.Module):
    def __init__(self):
        super(Wav2Vec2Classifier, self).__init__()
        self.wav2vec2 = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-large-960h")
        self.dropout = nn.Dropout(0.3)
        self.classifier = nn.Linear(self.wav2vec2.config.hidden_size, 2)

    def forward(self, input_values, attention_mask=None):
        outputs = self.wav2vec2(input_values, attention_mask=attention_mask)
        hidden_states = outputs.last_hidden_state
        pooled = hidden_states.mean(dim=1)
        pooled = self.dropout(pooled)
        logits = self.classifier(pooled)
        return logits

# Instantiate and load the model
model = Wav2Vec2Classifier()
model.load_state_dict(state_dict)
model.eval()

# Prediction function
def predict_accent(audio):
    waveform, sample_rate = torchaudio.load(audio)
    if sample_rate != 16000:
        resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
        waveform = resampler(waveform)
    input_values = processor(waveform.squeeze().numpy(), return_tensors="pt", sampling_rate=16000).input_values
    with torch.no_grad():
        logits = model(input_values)
        predicted_class_id = logits.argmax().item()
    label_map = {0: "Canadian English", 1: "England English"}
    return label_map[predicted_class_id]

# # Gradio UI
# interface = gr.Interface(
#     fn=predict_accent,
#     inputs=gr.Audio(sources=["upload", "microphone"], type="filepath", label="Upload or Record Audio (WAV)"),
#     outputs=gr.Textbox(label="Predicted Accent"),
#     title="Accent Classification",
#     description="This app classifies English accents as either Canadian or England using a fine-tuned Wav2Vec2 model.",
#     allow_flagging="never"
# )


# Gradio UI with gr.Blocks

# Gradio UI with gr.Blocks and Custom Styling

custom_css = """
#predict-btn {
    background-color: orange !important;
    color: white !important;
    font-weight: bold;
}
#author-section {
    font-size: 18px;
    font-weight: 500;
}
"""

with gr.Blocks(css=custom_css) as demo:
    gr.Markdown("## πŸ—£οΈ Accent Classification App")
    gr.Markdown("This app classifies English accents as either **Canadian** or **England** using a fine-tuned Wav2Vec2 model.")
    
    with gr.Row():
        with gr.Column():
            audio_input = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Upload or Record Audio (WAV)")
            predict_button = gr.Button("Predict Accent", elem_id="predict-btn")
        with gr.Column():
            result_output = gr.Textbox(label="Predicted Accent")

    predict_button.click(fn=predict_accent, inputs=audio_input, outputs=result_output)

    gr.Markdown("---")
    gr.Markdown("""
    <div id="author-section">
    πŸ‘¨β€πŸ’» Created by <strong>Anand Purushottam</strong>  
    πŸ”— <a href="https://github.com/creativepurus" target="_blank">GitHub</a> | 
    <a href="https://linkedin.com/in/creativepurus" target="_blank">LinkedIn</a>
    </div>
    """,
    )

demo.launch()