Spaces:
Running
Running
hubert phoneme + quick test model
Browse files- app/app.py +11 -1
- app/inference.py +23 -1
- app/inference_huberphoneme.py +133 -0
- app/tasks.py +24 -0
- requirements.txt +4 -3
- requirements_lock.txt +5 -5
app/app.py
CHANGED
|
@@ -4,7 +4,7 @@
|
|
| 4 |
import gradio as gr
|
| 5 |
import pandas as pd
|
| 6 |
|
| 7 |
-
from tasks import start_eval_task, get_status
|
| 8 |
from hf import get_or_create_leaderboard
|
| 9 |
|
| 10 |
from codes import CODES
|
|
@@ -205,6 +205,16 @@ with gr.Blocks(
|
|
| 205 |
outputs=result,
|
| 206 |
)
|
| 207 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
with gr.TabItem("📊 Submission Status"):
|
| 209 |
query = gr.Textbox(
|
| 210 |
label="Model ID or Task ID",
|
|
|
|
| 4 |
import gradio as gr
|
| 5 |
import pandas as pd
|
| 6 |
|
| 7 |
+
from tasks import start_eval_task, get_status, run_sample_inference
|
| 8 |
from hf import get_or_create_leaderboard
|
| 9 |
|
| 10 |
from codes import CODES
|
|
|
|
| 205 |
outputs=result,
|
| 206 |
)
|
| 207 |
|
| 208 |
+
gr.Markdown("---\n### Test Model")
|
| 209 |
+
test_audio = gr.Audio(interactive=True, format="wav")
|
| 210 |
+
test_btn = gr.Button("Run")
|
| 211 |
+
test_result = gr.Textbox(label="Test Result")
|
| 212 |
+
test_btn.click(
|
| 213 |
+
fn=run_sample_inference,
|
| 214 |
+
inputs=[test_audio, model_id, model_type, output_code],
|
| 215 |
+
outputs=test_result,
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
with gr.TabItem("📊 Submission Status"):
|
| 219 |
query = gr.Textbox(
|
| 220 |
label="Model ID or Task ID",
|
app/inference.py
CHANGED
|
@@ -3,8 +3,9 @@
|
|
| 3 |
import torch
|
| 4 |
from transformers import AutoProcessor, AutoModelForCTC
|
| 5 |
from espnet2.bin.s2t_inference import Speech2Text
|
|
|
|
| 6 |
|
| 7 |
-
MODEL_TYPES = ["Transformers CTC", "POWSM"]
|
| 8 |
|
| 9 |
DEVICE = (
|
| 10 |
"cuda"
|
|
@@ -78,6 +79,23 @@ def transcribe_transformers_ctc(audio, model) -> str:
|
|
| 78 |
return processor.decode(predicted_ids[0])
|
| 79 |
|
| 80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
# ===========================================================================
|
| 82 |
|
| 83 |
|
|
@@ -86,6 +104,8 @@ def load_model(model_id, type, device=DEVICE):
|
|
| 86 |
return load_powsm(model_id, device=device)
|
| 87 |
elif type == "Transformers CTC":
|
| 88 |
return load_transformers_ctc(model_id, device=device)
|
|
|
|
|
|
|
| 89 |
else:
|
| 90 |
raise ValueError("Unsupported model type: " + str(type))
|
| 91 |
|
|
@@ -95,5 +115,7 @@ def transcribe(audio, type, model) -> str:
|
|
| 95 |
return transcribe_powsm(audio, model)
|
| 96 |
elif type == "Transformers CTC":
|
| 97 |
return transcribe_transformers_ctc(audio, model)
|
|
|
|
|
|
|
| 98 |
else:
|
| 99 |
raise ValueError("Unsupported model type: " + str(type))
|
|
|
|
| 3 |
import torch
|
| 4 |
from transformers import AutoProcessor, AutoModelForCTC
|
| 5 |
from espnet2.bin.s2t_inference import Speech2Text
|
| 6 |
+
from inference_huberphoneme import HuBERTPhoneme, Tokenizer
|
| 7 |
|
| 8 |
+
MODEL_TYPES = ["Transformers CTC", "POWSM", "HuBERTPhoneme"]
|
| 9 |
|
| 10 |
DEVICE = (
|
| 11 |
"cuda"
|
|
|
|
| 79 |
return processor.decode(predicted_ids[0])
|
| 80 |
|
| 81 |
|
| 82 |
+
# ===========================================================================
|
| 83 |
+
# ============================== HuBERTPhoneme ==============================
|
| 84 |
+
def load_hubert_phoneme(model_id, device=DEVICE):
|
| 85 |
+
model = HuBERTPhoneme.from_pretrained(model_id).to(device).eval()
|
| 86 |
+
tokenizer = Tokenizer(with_blank=model.ctc_training)
|
| 87 |
+
return model, tokenizer, device
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def transcribe_hubert_phoneme(audio, model) -> str:
|
| 91 |
+
model, tokenizer, device = model
|
| 92 |
+
with torch.inference_mode():
|
| 93 |
+
output, _ = model.inference(torch.from_numpy(audio).to(device).unsqueeze(0))
|
| 94 |
+
predictions = output.argmax(dim=-1).squeeze().cpu()
|
| 95 |
+
arpabet = tokenizer.decode(predictions.unique_consecutive())
|
| 96 |
+
return arpabet
|
| 97 |
+
|
| 98 |
+
|
| 99 |
# ===========================================================================
|
| 100 |
|
| 101 |
|
|
|
|
| 104 |
return load_powsm(model_id, device=device)
|
| 105 |
elif type == "Transformers CTC":
|
| 106 |
return load_transformers_ctc(model_id, device=device)
|
| 107 |
+
elif type == "HuBERTPhoneme":
|
| 108 |
+
return load_hubert_phoneme(model_id, device=device)
|
| 109 |
else:
|
| 110 |
raise ValueError("Unsupported model type: " + str(type))
|
| 111 |
|
|
|
|
| 115 |
return transcribe_powsm(audio, model)
|
| 116 |
elif type == "Transformers CTC":
|
| 117 |
return transcribe_transformers_ctc(audio, model)
|
| 118 |
+
elif type == "HuBERTPhoneme":
|
| 119 |
+
return transcribe_hubert_phoneme(audio, model)
|
| 120 |
else:
|
| 121 |
raise ValueError("Unsupported model type: " + str(type))
|
app/inference_huberphoneme.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/bootphon/spokenlm-phoneme
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torchaudio
|
| 5 |
+
from huggingface_hub import PyTorchModelHubMixin
|
| 6 |
+
from torch import Tensor, nn
|
| 7 |
+
from torchaudio.models.wav2vec2 import components
|
| 8 |
+
from torchaudio.pipelines import HUBERT_BASE
|
| 9 |
+
from typing import Iterable
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Tokenizer:
|
| 13 |
+
# fmt:off
|
| 14 |
+
PHONEMES = {
|
| 15 |
+
"SIL": 0, "AA": 1, "AE": 2, "AH": 3, "AO": 4, "AW": 5, "AY": 6, "B": 7,
|
| 16 |
+
"CH": 8, "D": 9, "DH": 10, "EH": 11, "ER": 12, "EY": 13, "F": 14, "G": 15,
|
| 17 |
+
"HH": 16, "IH": 17, "IY": 18, "JH": 19, "K": 20, "L": 21, "M": 22, "N": 23,
|
| 18 |
+
"NG": 24, "OW": 25, "OY": 26, "P": 27, "R": 28, "S": 29, "SH": 30, "T": 31,
|
| 19 |
+
"TH": 32, "UH": 33, "UW": 34, "V": 35, "W": 36, "Y": 37, "Z": 38, "ZH": 39,
|
| 20 |
+
}
|
| 21 |
+
# fmt:on
|
| 22 |
+
|
| 23 |
+
def __init__(self, with_blank: bool = False) -> None:
|
| 24 |
+
self.token_to_id = self.PHONEMES | {"<pad>": self.pad_id}
|
| 25 |
+
self.id_to_token = {v: k for k, v in self.token_to_id.items()}
|
| 26 |
+
self.with_blank = with_blank
|
| 27 |
+
|
| 28 |
+
@property
|
| 29 |
+
def vocab_size(self) -> int:
|
| 30 |
+
if self.with_blank:
|
| 31 |
+
return len(self.PHONEMES) + 1
|
| 32 |
+
return len(self.PHONEMES)
|
| 33 |
+
|
| 34 |
+
@property
|
| 35 |
+
def silence_id(self) -> int:
|
| 36 |
+
return self.PHONEMES["SIL"]
|
| 37 |
+
|
| 38 |
+
@property
|
| 39 |
+
def pad_id(self) -> int:
|
| 40 |
+
return len(self.PHONEMES)
|
| 41 |
+
|
| 42 |
+
def encode(self, phones: "list[str] | str") -> torch.LongTensor:
|
| 43 |
+
if isinstance(phones, str):
|
| 44 |
+
phones = phones.split(" ")
|
| 45 |
+
return torch.LongTensor([self.token_to_id[phone] for phone in phones])
|
| 46 |
+
|
| 47 |
+
def decode(self, tokens: Iterable[int]) -> str:
|
| 48 |
+
return " ".join(
|
| 49 |
+
self.id_to_token[int(token)]
|
| 50 |
+
for token in tokens
|
| 51 |
+
if token < self.pad_id and int(token) != self.silence_id
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
FINETUNING_HUBERT_CONFIG = {
|
| 56 |
+
"encoder_projection_dropout": 0,
|
| 57 |
+
"encoder_attention_dropout": 0,
|
| 58 |
+
"encoder_ff_interm_dropout": 0.1,
|
| 59 |
+
"encoder_dropout": 0,
|
| 60 |
+
"encoder_layer_drop": 0.1, # In torchaudio: 0.05
|
| 61 |
+
"mask_prob": 0.75, # In torchaudio: 0.65
|
| 62 |
+
"mask_channel_prob": 0.5,
|
| 63 |
+
"mask_channel_length": 10, # In torchaudio and fairseq: 64. This is the value for pretraining.
|
| 64 |
+
"num_classes": 500, # Number of classes during HuBERT pretraining.
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class HuBERTPhoneme(nn.Module, PyTorchModelHubMixin):
|
| 69 |
+
def __init__(self, freeze_encoder: bool = True, ctc_training: bool = False) -> None:
|
| 70 |
+
"""Initialize the model.
|
| 71 |
+
|
| 72 |
+
Parameters
|
| 73 |
+
----------
|
| 74 |
+
freeze_encoder : bool, optional
|
| 75 |
+
Whether to freeze the Transformer encoder of HuBERT, by default True.
|
| 76 |
+
The convolutional layers are always frozen.
|
| 77 |
+
"""
|
| 78 |
+
super().__init__()
|
| 79 |
+
self.model = torchaudio.models.hubert_pretrain_base(**FINETUNING_HUBERT_CONFIG)
|
| 80 |
+
self.model.wav2vec2.load_state_dict(HUBERT_BASE.get_model().state_dict())
|
| 81 |
+
self.aux = nn.Linear(
|
| 82 |
+
HUBERT_BASE._params["encoder_embed_dim"],
|
| 83 |
+
Tokenizer(with_blank=ctc_training).vocab_size,
|
| 84 |
+
)
|
| 85 |
+
self.freeze_encoder = freeze_encoder
|
| 86 |
+
self.ctc_training = ctc_training
|
| 87 |
+
|
| 88 |
+
def forward(
|
| 89 |
+
self, waveforms: Tensor, lengths: "Tensor | None" = None
|
| 90 |
+
) -> "tuple[Tensor, Tensor | None]":
|
| 91 |
+
"""Extract logits during training, with masking."""
|
| 92 |
+
if self.freeze_encoder:
|
| 93 |
+
with torch.no_grad():
|
| 94 |
+
x, out_len = self.model.wav2vec2.feature_extractor(waveforms, lengths)
|
| 95 |
+
padding_mask = components._get_padding_mask(x, out_len)
|
| 96 |
+
x, attention_mask = self.model.wav2vec2.encoder._preprocess(x, out_len) # type: ignore
|
| 97 |
+
x, _ = self.model.mask_generator(x, padding_mask)
|
| 98 |
+
x = self.model.wav2vec2.encoder.transformer(x, attention_mask=attention_mask) # type: ignore
|
| 99 |
+
else:
|
| 100 |
+
with torch.no_grad():
|
| 101 |
+
x, out_len = self.model.wav2vec2.feature_extractor(waveforms, lengths)
|
| 102 |
+
padding_mask = components._get_padding_mask(x, out_len)
|
| 103 |
+
x, attention_mask = self.model.wav2vec2.encoder._preprocess(x, out_len) # type: ignore
|
| 104 |
+
x, _ = self.model.mask_generator(x, padding_mask)
|
| 105 |
+
x = self.model.wav2vec2.encoder.transformer(x, attention_mask=attention_mask) # type: ignore
|
| 106 |
+
logits = self.aux(x)
|
| 107 |
+
return logits, out_len
|
| 108 |
+
|
| 109 |
+
def inference(
|
| 110 |
+
self, waveforms: Tensor, lengths: "Tensor | None" = None
|
| 111 |
+
) -> "tuple[Tensor, Tensor | None]":
|
| 112 |
+
"""Extract logits during inference. No masking is applied."""
|
| 113 |
+
x, out_len = self.model.wav2vec2(waveforms, lengths)
|
| 114 |
+
logits = self.aux(x)
|
| 115 |
+
return logits, out_len
|
| 116 |
+
|
| 117 |
+
@torch.jit.export
|
| 118 |
+
def extract_features(
|
| 119 |
+
self, waveforms: Tensor, lengths: "Tensor | None" = None
|
| 120 |
+
) -> "tuple[list[Tensor], Tensor | None]":
|
| 121 |
+
"""Extract features from intermediate layers. No masking is applied."""
|
| 122 |
+
x, out_len = self.model.wav2vec2.extract_features(waveforms, lengths)
|
| 123 |
+
x.append(self.aux(x[-1]))
|
| 124 |
+
return x, out_len
|
| 125 |
+
|
| 126 |
+
def train(self, mode: bool = True) -> "HuBERTPhoneme":
|
| 127 |
+
"""Override the train method to set the encoder in eval mode if it is frozen."""
|
| 128 |
+
if self.freeze_encoder:
|
| 129 |
+
self.model.wav2vec2.eval()
|
| 130 |
+
else:
|
| 131 |
+
self.model.wav2vec2.train(mode)
|
| 132 |
+
self.aux.train(mode)
|
| 133 |
+
return self
|
app/tasks.py
CHANGED
|
@@ -5,6 +5,8 @@ import multiprocessing
|
|
| 5 |
from typing import TypedDict
|
| 6 |
from datetime import datetime
|
| 7 |
|
|
|
|
|
|
|
| 8 |
|
| 9 |
from metrics import per, fer
|
| 10 |
from datasets import load_from_disk
|
|
@@ -127,3 +129,25 @@ def _eval_task(task: Task, leaderboard_lock):
|
|
| 127 |
except Exception as e:
|
| 128 |
task["status"] = "failed"
|
| 129 |
task["error"] = str(e)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
from typing import TypedDict
|
| 6 |
from datetime import datetime
|
| 7 |
|
| 8 |
+
import librosa
|
| 9 |
+
import numpy as np
|
| 10 |
|
| 11 |
from metrics import per, fer
|
| 12 |
from datasets import load_from_disk
|
|
|
|
| 129 |
except Exception as e:
|
| 130 |
task["status"] = "failed"
|
| 131 |
task["error"] = str(e)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def run_sample_inference(audio, model_id: str, model_type: str, phone_code: str):
|
| 135 |
+
clear_cache()
|
| 136 |
+
|
| 137 |
+
# Load model
|
| 138 |
+
model = load_model(model_id, model_type)
|
| 139 |
+
|
| 140 |
+
# Format audio as monochannel 16 kHz float32
|
| 141 |
+
sample_rate, wav_array = audio
|
| 142 |
+
wav_array = wav_array.astype(np.float32)
|
| 143 |
+
if wav_array.ndim == 2 and wav_array.shape[1] == 2:
|
| 144 |
+
wav_array = np.mean(wav_array, axis=1)
|
| 145 |
+
wav_array = librosa.resample(y=wav_array, orig_sr=sample_rate, target_sr=16_000)
|
| 146 |
+
|
| 147 |
+
# Transcribe
|
| 148 |
+
transcript = transcribe(wav_array, model_type, model)
|
| 149 |
+
if phone_code != "ipa":
|
| 150 |
+
transcript = convert(transcript, phone_code, "ipa")
|
| 151 |
+
|
| 152 |
+
clear_cache()
|
| 153 |
+
return transcript
|
requirements.txt
CHANGED
|
@@ -6,9 +6,9 @@ datasets==4.0.0
|
|
| 6 |
pandas==2.3.3
|
| 7 |
numpy==2.0.2
|
| 8 |
panphon==0.21.2
|
| 9 |
-
torch==2.
|
| 10 |
-
torchaudio==2.
|
| 11 |
-
torchcodec==0.
|
| 12 |
transformers==4.56.0
|
| 13 |
phonemizer==3.3.0
|
| 14 |
espnet==202509
|
|
@@ -17,3 +17,4 @@ espnet-model-zoo==0.1.7
|
|
| 17 |
# UI
|
| 18 |
gradio==5.12.0
|
| 19 |
protobuf==6.32.0
|
|
|
|
|
|
| 6 |
pandas==2.3.3
|
| 7 |
numpy==2.0.2
|
| 8 |
panphon==0.21.2
|
| 9 |
+
torch==2.9.1
|
| 10 |
+
torchaudio==2.9.1
|
| 11 |
+
torchcodec==0.8.0
|
| 12 |
transformers==4.56.0
|
| 13 |
phonemizer==3.3.0
|
| 14 |
espnet==202509
|
|
|
|
| 17 |
# UI
|
| 18 |
gradio==5.12.0
|
| 19 |
protobuf==6.32.0
|
| 20 |
+
pydantic==2.10.6
|
requirements_lock.txt
CHANGED
|
@@ -90,8 +90,8 @@ propcache==0.3.2
|
|
| 90 |
protobuf==6.32.0
|
| 91 |
pyarrow==21.0.0
|
| 92 |
pycparser==2.23
|
| 93 |
-
pydantic==2.
|
| 94 |
-
pydantic_core==2.
|
| 95 |
pydub==0.25.1
|
| 96 |
Pygments==2.19.2
|
| 97 |
pyparsing==3.2.3
|
|
@@ -127,10 +127,10 @@ sympy==1.14.0
|
|
| 127 |
threadpoolctl==3.6.0
|
| 128 |
tokenizers==0.22.0
|
| 129 |
tomlkit==0.13.3
|
| 130 |
-
torch==2.
|
| 131 |
torch-complex==0.4.4
|
| 132 |
-
torchaudio==2.
|
| 133 |
-
torchcodec==0.
|
| 134 |
torchmetrics==1.8.2
|
| 135 |
tqdm==4.67.1
|
| 136 |
transformers==4.56.0
|
|
|
|
| 90 |
protobuf==6.32.0
|
| 91 |
pyarrow==21.0.0
|
| 92 |
pycparser==2.23
|
| 93 |
+
pydantic==2.10.6
|
| 94 |
+
pydantic_core==2.27.2
|
| 95 |
pydub==0.25.1
|
| 96 |
Pygments==2.19.2
|
| 97 |
pyparsing==3.2.3
|
|
|
|
| 127 |
threadpoolctl==3.6.0
|
| 128 |
tokenizers==0.22.0
|
| 129 |
tomlkit==0.13.3
|
| 130 |
+
torch==2.9.1
|
| 131 |
torch-complex==0.4.4
|
| 132 |
+
torchaudio==2.9.1
|
| 133 |
+
torchcodec==0.8.0
|
| 134 |
torchmetrics==1.8.2
|
| 135 |
tqdm==4.67.1
|
| 136 |
transformers==4.56.0
|