| from huggingface_hub import hf_hub_download |
| from torch import nn |
| from transformers import Wav2Vec2ConformerModel |
| from safetensors.torch import load_file |
| from torch_state_bridge import state_bridge |
| import torch |
| import torch.nn.functional as F |
| import torchaudio |
| import librosa |
|
|
| class Op(nn.Module): |
| def __init__(self, func,allow_self=False): |
| super().__init__() |
| self.func = func |
| self.allow_self = allow_self |
|
|
| def forward(self, x): |
| if self.allow_self: |
| return self.func(self,x) |
| return self.func(x) |
|
|
| class Wav2Vec2ConformerRNNT(Wav2Vec2ConformerModel): |
|
|
| def init_weights(self): |
| del self.encoder.pos_conv_embed |
| config = self.config |
| self.cache_length = None |
| self.enc = nn.Linear(config.hidden_size, config.joint_hidden) |
| self.pred = nn.Linear(config.pred_hidden, config.joint_hidden) |
| self.joint = nn.Linear(config.joint_hidden, config.vocab_size // len(config.languages) + 1) |
| self.embed = nn.Embedding(config.vocab_size+1, config.pred_hidden, padding_idx=config.vocab_size) |
| self.lstm = nn.LSTM(config.pred_hidden, config.pred_hidden, config.lstm_layer, batch_first=True) |
| self.act = nn.ReLU() |
| self.spec = torchaudio.transforms.Spectrogram(n_fft=512, hop_length=160, win_length=400, center=False) |
| self.mask_layer = Op(lambda self_obj,x : x.masked_fill(self_obj.cache_pad_mask.unsqueeze(1), 0),True) |
| self.register_buffer("mel_fb",torch.tensor(librosa.filters.mel(sr=config.sampling_rate,n_fft=self.spec.n_fft,n_mels=80))) |
| for idx,l in enumerate(self.feature_extractor.conv_layers): |
| if not(config.multilingual) or idx == 0: |
| l.conv = nn.Conv2d(l.conv.in_channels,l.conv.out_channels,l.conv.kernel_size[0],l.conv.stride,1) |
| l.layer_norm = nn.Identity() |
| else: |
| l.conv = nn.Sequential(nn.Conv2d(l.conv.in_channels,l.conv.out_channels,l.conv.kernel_size[0],l.conv.stride,1,groups=l.conv.out_channels),nn.Conv2d(l.conv.in_channels,l.conv.out_channels, 1)) |
| self.feature_extractor.conv_layers.append(Op(lambda x : x.transpose(1, 2))) |
| self.feature_projection.projection = nn.Linear(config.conv_dim[-1] * self.calc_length(80,repeat_num=config.num_feat_extract_layers),config.hidden_size) |
| self.feature_projection.layer_norm = Op(lambda x:x.permute(0, 2, 1, 3).flatten(2)) |
| for l in self.encoder.layers: |
| l.conv_module.glu = nn.Sequential(l.conv_module.glu,self.mask_layer) |
| l.conv_module.pointwise_conv1.bias = nn.Parameter(torch.empty(l.conv_module.pointwise_conv1.out_channels)) |
| l.conv_module.pointwise_conv2.bias = nn.Parameter(torch.empty(l.conv_module.pointwise_conv2.out_channels)) |
| l.conv_module.depthwise_conv.bias = nn.Parameter(torch.empty(l.conv_module.depthwise_conv.out_channels)) |
| self.encoder.layer_norm = nn.Identity() |
| if config.multilingual: |
| self.lang_joint_net = nn.ModuleDict({l: nn.Linear(config.joint_hidden, config.vocab_size // len(config.languages) + 1) for l in config.languages.values()}) |
| self.eps = 2**-24 |
| self.denorm = (2 ** config.num_feat_extract_layers) * self.spec.hop_length / config.sampling_rate |
| self.scaler = config.hidden_size ** (1/2) |
| return super().init_weights() |
|
|
| def _mask_hidden_states(self, hidden_states, mask_time_indices = None, attention_mask = None): |
| hidden_states = hidden_states * self.scaler |
| self.mask_layer.cache_pad_mask = (torch.arange(hidden_states.size(1), device=hidden_states.device).unsqueeze(0) >= self.cache_length.unsqueeze(1)) |
| return super()._mask_hidden_states(hidden_states, mask_time_indices, attention_mask) |
|
|
| def calc_length(self, lengths, padding=1, kernel_size=3, stride=2, repeat_num=1): |
| for _ in range(repeat_num): |
| lengths = (lengths + 2 * padding - kernel_size) // stride + 1 |
| return lengths |
|
|
| def preprocessing(self, x): |
| x, l = x |
| l = (l // self.spec.hop_length + 1).long() |
| x = torch.cat((x[:, :1], x[:, 1:] - self.config.preemph * x[:, :-1]), 1) |
| x = (self.mel_fb @ self.spec(x) + self.eps).log() |
| T = x.size(-1) |
| m = torch.arange(T, device=x.device)[None] >= l[:, None] |
| x = x.masked_fill(m[:, None], 0) |
| μ = x.sum(-1) / l[:, None] |
| denom = torch.clamp(l[:, None] - 1, min=1) |
| σ = (((x - μ[..., None])**2).sum(-1) / denom + 1e-5).sqrt() |
| x = ((x - μ[..., None]) / σ[..., None]).masked_fill(m[:, None], 0) |
| self.cache_length = self.calc_length(l, repeat_num=self.config.num_feat_extract_layers).long() |
| return F.pad(x, (0, (-T) % self.config.pad_to)).transpose(1, 2) |
|
|
| def forward(self, input_values): |
| return self.postprocessing(super().forward(self.preprocessing(input_values)).last_hidden_state) |
|
|
| def postprocessing(self, enc_out): |
| B, T, _ = enc_out.shape |
| H = self.lstm.hidden_size |
| blank = self.config.blank_id |
| pad = self.config.pad_id |
| max_len = T * self.config.max_symbols_per_step |
|
|
| tokens = torch.full((B, max_len), pad, dtype=torch.long, device=enc_out.device) |
| starts = torch.full((B, max_len), -1.0, dtype=enc_out.dtype, device=enc_out.device) |
| lengths = torch.zeros(B, dtype=torch.long, device=enc_out.device) |
| hx = torch.zeros(self.config.lstm_layer, B, H, dtype=enc_out.dtype, device=enc_out.device) |
| cx = torch.zeros_like(hx) |
| last = torch.full((B, 1), blank, dtype=torch.long, device=enc_out.device) |
|
|
| enc_proj = self.enc(enc_out) |
|
|
| for t in range(T): |
| e = enc_proj[:, t:t+1] |
| t_sec = torch.full((B, 1), t * self.denorm, dtype=enc_out.dtype, device=enc_out.device) |
|
|
| for _ in range(self.config.max_symbols_per_step): |
| hx_prev, cx_prev = hx, cx |
|
|
| p, (hx, cx) = self.lstm(self.embed(last), (hx, cx)) |
| n = self.joint(self.act(e + self.pred(p))).squeeze(1).argmax(-1) |
| emitted = n.ne(blank) |
|
|
| |
| mask = emitted.view(1, B, 1) |
| hx = torch.where(mask, hx, hx_prev) |
| cx = torch.where(mask, cx, cx_prev) |
| last = torch.where(emitted.unsqueeze(1), n.unsqueeze(1), last) |
|
|
| if emitted.any(): |
| idx = lengths[emitted].unsqueeze(1).clamp(max=max_len - 1) |
| tokens[emitted] = tokens[emitted].scatter(1, idx, n[emitted].unsqueeze(1)) |
| starts[emitted] = starts[emitted].scatter(1, idx, t_sec[emitted]) |
| lengths[emitted] += 1 |
|
|
| return tokens, starts, lengths |
|
|
| def change_language(self,language): |
| self.joint.load_state_dict(self.lang_joint_net[language].state_dict()) |
|
|
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path, *model_args, config = None, cache_dir = None, ignore_mismatched_sizes = False, force_download = False, local_files_only = False, token = None, revision = "main", use_safetensors = None, weights_only = True, **kwargs): |
| config.language = kwargs.pop("language",None) |
| config.multilingual = not(config.language) |
| if config.multilingual: |
| config.hidden_size = 1024 |
| config.num_hidden_layers = 24 |
| config.conv_depthwise_kernel_size = 9 |
| config.conv_stride = [2,2,2] |
| config.conv_kernel = [3,3,3] |
| config.conv_dim = [256,256,256] |
| config.feat_extract_norm = "group" |
| config.intermediate_size = config.hidden_size * 4 |
| config.num_feat_extract_layers = len(config.conv_dim) |
| config.lstm_layer = 2 |
| kwargs['state_dict'] = load_file(hf_hub_download(pretrained_model_name_or_path,f"{config.language or 'all'}.safetensors")) |
| return super().from_pretrained(None, *model_args, config=config, cache_dir=cache_dir, ignore_mismatched_sizes=ignore_mismatched_sizes, force_download=force_download, local_files_only=local_files_only, token=token, revision=revision, use_safetensors=use_safetensors, weights_only=weights_only, **kwargs) |
|
|
| @staticmethod |
| def _load_pretrained_model(model, state_dict, checkpoint_files, load_config): |
| changes = """ |
| preprocessor.featurizer.fb,mel_fb |
| preprocessor.featurizer.window,spec.window |
| norm_feed_forward1,ffn1_layer_norm |
| norm_feed_forward2,ffn2_layer_norm |
| feed_forward1.linear1,ffn1.intermediate_dense |
| feed_forward1.linear2,ffn1.output_dense |
| feed_forward2.linear1,ffn2.intermediate_dense |
| feed_forward2.linear2,ffn2.output_dense |
| norm_self_att,self_attn_layer_norm |
| norm_out,final_layer_norm |
| norm_conv,conv_module.layer_norm |
| .conv.,.conv_module. |
| decoder.prediction.dec_rnn.lstm,lstm |
| decoder.prediction.embed,embed |
| joint.enc,enc |
| joint.pred,pred |
| joint.joint_net.2,lang_joint_net |
| encoder.pre_encode.conv_module.0,feature_extractor.conv_layers.0.conv |
| encoder.pre_encode.out,feature_projection.projection |
| """ |
| if not model.config.multilingual: |
| changes += "encoder.pre_encode.conv_module.{n},feature_extractor.conv_layers.{(n/2)}.conv\n" |
| changes += f"lang_joint_net.{model.config.language},joint\n" |
| else: |
| changes += "encoder.pre_encode.conv_module.{n},encoder.pre_encode.conv_module.{(n-2)}\n" |
| changes += "encoder.pre_encode.conv_module.{n},feature_extractor.conv_layers.{(n//3+1)}.conv.{(n%3)}\n" |
| state_dict = state_bridge(state_dict, changes) |
| if not model.config.multilingual: |
| state_dict = {k: v for k, v in state_dict.items() if "lang_joint_net" not in k} |
| state_dict['mel_fb'] = state_dict['mel_fb'].squeeze(0) |
| state_dict.pop('ctc_decoder.decoder_layers.0.bias', None) |
| state_dict.pop('ctc_decoder.decoder_layers.0.weight', None) |
| return super()._load_pretrained_model(model, state_dict, checkpoint_files, load_config) |