SanderGi commited on
Commit
20b52a3
·
1 Parent(s): 7796889

hubert phoneme + quick test model

Browse files
app/app.py CHANGED
@@ -4,7 +4,7 @@
4
  import gradio as gr
5
  import pandas as pd
6
 
7
- from tasks import start_eval_task, get_status
8
  from hf import get_or_create_leaderboard
9
 
10
  from codes import CODES
@@ -205,6 +205,16 @@ with gr.Blocks(
205
  outputs=result,
206
  )
207
 
 
 
 
 
 
 
 
 
 
 
208
  with gr.TabItem("📊 Submission Status"):
209
  query = gr.Textbox(
210
  label="Model ID or Task ID",
 
4
  import gradio as gr
5
  import pandas as pd
6
 
7
+ from tasks import start_eval_task, get_status, run_sample_inference
8
  from hf import get_or_create_leaderboard
9
 
10
  from codes import CODES
 
205
  outputs=result,
206
  )
207
 
208
+ gr.Markdown("---\n### Test Model")
209
+ test_audio = gr.Audio(interactive=True, format="wav")
210
+ test_btn = gr.Button("Run")
211
+ test_result = gr.Textbox(label="Test Result")
212
+ test_btn.click(
213
+ fn=run_sample_inference,
214
+ inputs=[test_audio, model_id, model_type, output_code],
215
+ outputs=test_result,
216
+ )
217
+
218
  with gr.TabItem("📊 Submission Status"):
219
  query = gr.Textbox(
220
  label="Model ID or Task ID",
app/inference.py CHANGED
@@ -3,8 +3,9 @@
3
  import torch
4
  from transformers import AutoProcessor, AutoModelForCTC
5
  from espnet2.bin.s2t_inference import Speech2Text
 
6
 
7
- MODEL_TYPES = ["Transformers CTC", "POWSM"]
8
 
9
  DEVICE = (
10
  "cuda"
@@ -78,6 +79,23 @@ def transcribe_transformers_ctc(audio, model) -> str:
78
  return processor.decode(predicted_ids[0])
79
 
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  # ===========================================================================
82
 
83
 
@@ -86,6 +104,8 @@ def load_model(model_id, type, device=DEVICE):
86
  return load_powsm(model_id, device=device)
87
  elif type == "Transformers CTC":
88
  return load_transformers_ctc(model_id, device=device)
 
 
89
  else:
90
  raise ValueError("Unsupported model type: " + str(type))
91
 
@@ -95,5 +115,7 @@ def transcribe(audio, type, model) -> str:
95
  return transcribe_powsm(audio, model)
96
  elif type == "Transformers CTC":
97
  return transcribe_transformers_ctc(audio, model)
 
 
98
  else:
99
  raise ValueError("Unsupported model type: " + str(type))
 
3
  import torch
4
  from transformers import AutoProcessor, AutoModelForCTC
5
  from espnet2.bin.s2t_inference import Speech2Text
6
+ from inference_huberphoneme import HuBERTPhoneme, Tokenizer
7
 
8
+ MODEL_TYPES = ["Transformers CTC", "POWSM", "HuBERTPhoneme"]
9
 
10
  DEVICE = (
11
  "cuda"
 
79
  return processor.decode(predicted_ids[0])
80
 
81
 
82
+ # ===========================================================================
83
+ # ============================== HuBERTPhoneme ==============================
84
+ def load_hubert_phoneme(model_id, device=DEVICE):
85
+ model = HuBERTPhoneme.from_pretrained(model_id).to(device).eval()
86
+ tokenizer = Tokenizer(with_blank=model.ctc_training)
87
+ return model, tokenizer, device
88
+
89
+
90
+ def transcribe_hubert_phoneme(audio, model) -> str:
91
+ model, tokenizer, device = model
92
+ with torch.inference_mode():
93
+ output, _ = model.inference(torch.from_numpy(audio).to(device).unsqueeze(0))
94
+ predictions = output.argmax(dim=-1).squeeze().cpu()
95
+ arpabet = tokenizer.decode(predictions.unique_consecutive())
96
+ return arpabet
97
+
98
+
99
  # ===========================================================================
100
 
101
 
 
104
  return load_powsm(model_id, device=device)
105
  elif type == "Transformers CTC":
106
  return load_transformers_ctc(model_id, device=device)
107
+ elif type == "HuBERTPhoneme":
108
+ return load_hubert_phoneme(model_id, device=device)
109
  else:
110
  raise ValueError("Unsupported model type: " + str(type))
111
 
 
115
  return transcribe_powsm(audio, model)
116
  elif type == "Transformers CTC":
117
  return transcribe_transformers_ctc(audio, model)
118
+ elif type == "HuBERTPhoneme":
119
+ return transcribe_hubert_phoneme(audio, model)
120
  else:
121
  raise ValueError("Unsupported model type: " + str(type))
app/inference_huberphoneme.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/bootphon/spokenlm-phoneme
2
+
3
+ import torch
4
+ import torchaudio
5
+ from huggingface_hub import PyTorchModelHubMixin
6
+ from torch import Tensor, nn
7
+ from torchaudio.models.wav2vec2 import components
8
+ from torchaudio.pipelines import HUBERT_BASE
9
+ from typing import Iterable
10
+
11
+
12
+ class Tokenizer:
13
+ # fmt:off
14
+ PHONEMES = {
15
+ "SIL": 0, "AA": 1, "AE": 2, "AH": 3, "AO": 4, "AW": 5, "AY": 6, "B": 7,
16
+ "CH": 8, "D": 9, "DH": 10, "EH": 11, "ER": 12, "EY": 13, "F": 14, "G": 15,
17
+ "HH": 16, "IH": 17, "IY": 18, "JH": 19, "K": 20, "L": 21, "M": 22, "N": 23,
18
+ "NG": 24, "OW": 25, "OY": 26, "P": 27, "R": 28, "S": 29, "SH": 30, "T": 31,
19
+ "TH": 32, "UH": 33, "UW": 34, "V": 35, "W": 36, "Y": 37, "Z": 38, "ZH": 39,
20
+ }
21
+ # fmt:on
22
+
23
+ def __init__(self, with_blank: bool = False) -> None:
24
+ self.token_to_id = self.PHONEMES | {"<pad>": self.pad_id}
25
+ self.id_to_token = {v: k for k, v in self.token_to_id.items()}
26
+ self.with_blank = with_blank
27
+
28
+ @property
29
+ def vocab_size(self) -> int:
30
+ if self.with_blank:
31
+ return len(self.PHONEMES) + 1
32
+ return len(self.PHONEMES)
33
+
34
+ @property
35
+ def silence_id(self) -> int:
36
+ return self.PHONEMES["SIL"]
37
+
38
+ @property
39
+ def pad_id(self) -> int:
40
+ return len(self.PHONEMES)
41
+
42
+ def encode(self, phones: "list[str] | str") -> torch.LongTensor:
43
+ if isinstance(phones, str):
44
+ phones = phones.split(" ")
45
+ return torch.LongTensor([self.token_to_id[phone] for phone in phones])
46
+
47
+ def decode(self, tokens: Iterable[int]) -> str:
48
+ return " ".join(
49
+ self.id_to_token[int(token)]
50
+ for token in tokens
51
+ if token < self.pad_id and int(token) != self.silence_id
52
+ )
53
+
54
+
55
+ FINETUNING_HUBERT_CONFIG = {
56
+ "encoder_projection_dropout": 0,
57
+ "encoder_attention_dropout": 0,
58
+ "encoder_ff_interm_dropout": 0.1,
59
+ "encoder_dropout": 0,
60
+ "encoder_layer_drop": 0.1, # In torchaudio: 0.05
61
+ "mask_prob": 0.75, # In torchaudio: 0.65
62
+ "mask_channel_prob": 0.5,
63
+ "mask_channel_length": 10, # In torchaudio and fairseq: 64. This is the value for pretraining.
64
+ "num_classes": 500, # Number of classes during HuBERT pretraining.
65
+ }
66
+
67
+
68
+ class HuBERTPhoneme(nn.Module, PyTorchModelHubMixin):
69
+ def __init__(self, freeze_encoder: bool = True, ctc_training: bool = False) -> None:
70
+ """Initialize the model.
71
+
72
+ Parameters
73
+ ----------
74
+ freeze_encoder : bool, optional
75
+ Whether to freeze the Transformer encoder of HuBERT, by default True.
76
+ The convolutional layers are always frozen.
77
+ """
78
+ super().__init__()
79
+ self.model = torchaudio.models.hubert_pretrain_base(**FINETUNING_HUBERT_CONFIG)
80
+ self.model.wav2vec2.load_state_dict(HUBERT_BASE.get_model().state_dict())
81
+ self.aux = nn.Linear(
82
+ HUBERT_BASE._params["encoder_embed_dim"],
83
+ Tokenizer(with_blank=ctc_training).vocab_size,
84
+ )
85
+ self.freeze_encoder = freeze_encoder
86
+ self.ctc_training = ctc_training
87
+
88
+ def forward(
89
+ self, waveforms: Tensor, lengths: "Tensor | None" = None
90
+ ) -> "tuple[Tensor, Tensor | None]":
91
+ """Extract logits during training, with masking."""
92
+ if self.freeze_encoder:
93
+ with torch.no_grad():
94
+ x, out_len = self.model.wav2vec2.feature_extractor(waveforms, lengths)
95
+ padding_mask = components._get_padding_mask(x, out_len)
96
+ x, attention_mask = self.model.wav2vec2.encoder._preprocess(x, out_len) # type: ignore
97
+ x, _ = self.model.mask_generator(x, padding_mask)
98
+ x = self.model.wav2vec2.encoder.transformer(x, attention_mask=attention_mask) # type: ignore
99
+ else:
100
+ with torch.no_grad():
101
+ x, out_len = self.model.wav2vec2.feature_extractor(waveforms, lengths)
102
+ padding_mask = components._get_padding_mask(x, out_len)
103
+ x, attention_mask = self.model.wav2vec2.encoder._preprocess(x, out_len) # type: ignore
104
+ x, _ = self.model.mask_generator(x, padding_mask)
105
+ x = self.model.wav2vec2.encoder.transformer(x, attention_mask=attention_mask) # type: ignore
106
+ logits = self.aux(x)
107
+ return logits, out_len
108
+
109
+ def inference(
110
+ self, waveforms: Tensor, lengths: "Tensor | None" = None
111
+ ) -> "tuple[Tensor, Tensor | None]":
112
+ """Extract logits during inference. No masking is applied."""
113
+ x, out_len = self.model.wav2vec2(waveforms, lengths)
114
+ logits = self.aux(x)
115
+ return logits, out_len
116
+
117
+ @torch.jit.export
118
+ def extract_features(
119
+ self, waveforms: Tensor, lengths: "Tensor | None" = None
120
+ ) -> "tuple[list[Tensor], Tensor | None]":
121
+ """Extract features from intermediate layers. No masking is applied."""
122
+ x, out_len = self.model.wav2vec2.extract_features(waveforms, lengths)
123
+ x.append(self.aux(x[-1]))
124
+ return x, out_len
125
+
126
+ def train(self, mode: bool = True) -> "HuBERTPhoneme":
127
+ """Override the train method to set the encoder in eval mode if it is frozen."""
128
+ if self.freeze_encoder:
129
+ self.model.wav2vec2.eval()
130
+ else:
131
+ self.model.wav2vec2.train(mode)
132
+ self.aux.train(mode)
133
+ return self
app/tasks.py CHANGED
@@ -5,6 +5,8 @@ import multiprocessing
5
  from typing import TypedDict
6
  from datetime import datetime
7
 
 
 
8
 
9
  from metrics import per, fer
10
  from datasets import load_from_disk
@@ -127,3 +129,25 @@ def _eval_task(task: Task, leaderboard_lock):
127
  except Exception as e:
128
  task["status"] = "failed"
129
  task["error"] = str(e)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  from typing import TypedDict
6
  from datetime import datetime
7
 
8
+ import librosa
9
+ import numpy as np
10
 
11
  from metrics import per, fer
12
  from datasets import load_from_disk
 
129
  except Exception as e:
130
  task["status"] = "failed"
131
  task["error"] = str(e)
132
+
133
+
134
+ def run_sample_inference(audio, model_id: str, model_type: str, phone_code: str):
135
+ clear_cache()
136
+
137
+ # Load model
138
+ model = load_model(model_id, model_type)
139
+
140
+ # Format audio as monochannel 16 kHz float32
141
+ sample_rate, wav_array = audio
142
+ wav_array = wav_array.astype(np.float32)
143
+ if wav_array.ndim == 2 and wav_array.shape[1] == 2:
144
+ wav_array = np.mean(wav_array, axis=1)
145
+ wav_array = librosa.resample(y=wav_array, orig_sr=sample_rate, target_sr=16_000)
146
+
147
+ # Transcribe
148
+ transcript = transcribe(wav_array, model_type, model)
149
+ if phone_code != "ipa":
150
+ transcript = convert(transcript, phone_code, "ipa")
151
+
152
+ clear_cache()
153
+ return transcript
requirements.txt CHANGED
@@ -6,9 +6,9 @@ datasets==4.0.0
6
  pandas==2.3.3
7
  numpy==2.0.2
8
  panphon==0.21.2
9
- torch==2.8.0
10
- torchaudio==2.8.0
11
- torchcodec==0.6.0
12
  transformers==4.56.0
13
  phonemizer==3.3.0
14
  espnet==202509
@@ -17,3 +17,4 @@ espnet-model-zoo==0.1.7
17
  # UI
18
  gradio==5.12.0
19
  protobuf==6.32.0
 
 
6
  pandas==2.3.3
7
  numpy==2.0.2
8
  panphon==0.21.2
9
+ torch==2.9.1
10
+ torchaudio==2.9.1
11
+ torchcodec==0.8.0
12
  transformers==4.56.0
13
  phonemizer==3.3.0
14
  espnet==202509
 
17
  # UI
18
  gradio==5.12.0
19
  protobuf==6.32.0
20
+ pydantic==2.10.6
requirements_lock.txt CHANGED
@@ -90,8 +90,8 @@ propcache==0.3.2
90
  protobuf==6.32.0
91
  pyarrow==21.0.0
92
  pycparser==2.23
93
- pydantic==2.11.7
94
- pydantic_core==2.33.2
95
  pydub==0.25.1
96
  Pygments==2.19.2
97
  pyparsing==3.2.3
@@ -127,10 +127,10 @@ sympy==1.14.0
127
  threadpoolctl==3.6.0
128
  tokenizers==0.22.0
129
  tomlkit==0.13.3
130
- torch==2.8.0
131
  torch-complex==0.4.4
132
- torchaudio==2.8.0
133
- torchcodec==0.6.0
134
  torchmetrics==1.8.2
135
  tqdm==4.67.1
136
  transformers==4.56.0
 
90
  protobuf==6.32.0
91
  pyarrow==21.0.0
92
  pycparser==2.23
93
+ pydantic==2.10.6
94
+ pydantic_core==2.27.2
95
  pydub==0.25.1
96
  Pygments==2.19.2
97
  pyparsing==3.2.3
 
127
  threadpoolctl==3.6.0
128
  tokenizers==0.22.0
129
  tomlkit==0.13.3
130
+ torch==2.9.1
131
  torch-complex==0.4.4
132
+ torchaudio==2.9.1
133
+ torchcodec==0.8.0
134
  torchmetrics==1.8.2
135
  tqdm==4.67.1
136
  transformers==4.56.0