Feature Extraction
Transformers
Safetensors
English
rd_distiller
speech
automatic-speech-recognition
custom_code
Instructions to use Hguimaraes/rd_hubert with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Hguimaraes/rd_hubert with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="Hguimaraes/rd_hubert", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Hguimaraes/rd_hubert", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """ | |
| Builder for Distiller | |
| Author: Heng-Jui Chang (https://github.com/vectominist) | |
| """ | |
| import torch | |
| from torch import nn | |
| from .configuration_distiller import DistillerConfig | |
| from .distiller_w2v2_modules import ( | |
| ConvFeatureExtractionModel, | |
| GradMultiply, | |
| ) | |
| from .distiller_modules import ( | |
| TransformerEncoder, | |
| SplitLinear, | |
| ) | |
| class DistillerModel(nn.Module): | |
| """ | |
| Distiller Model | |
| """ | |
| def __init__(self, config: DistillerConfig): | |
| super().__init__() | |
| self.config = config | |
| self.conv_layers = eval(config.extractor_conv_feature_layers) | |
| feat_emb_dim = self.conv_layers[-1][0] | |
| self.feature_extractor = ConvFeatureExtractionModel( | |
| self.conv_layers, | |
| dropout=config.extractor_dropout, | |
| mode=config.extractor_mode, | |
| conv_bias=False, | |
| ) | |
| self.feature_grad_mult = config.feature_grad_mult | |
| self.n_tasks = config.n_tasks | |
| self.task_emb_type = config.task_emb_type | |
| final_emb_size = config.encoder_embed_dim | |
| if self.task_emb_type == "add": | |
| self.task_embedding = nn.Embedding(config.n_tasks, config.encoder_embed_dim) | |
| nn.init.normal_(self.task_embedding.weight, 0.0, 0.1) | |
| elif self.task_emb_type == "concat": | |
| assert config.task_emb_size > 0 | |
| feat_emb_dim += config.task_emb_size | |
| self.task_embedding = nn.Embedding(config.n_tasks, config.task_emb_size) | |
| elif self.task_emb_type == "concat-last": | |
| assert config.task_emb_size > 0 | |
| self.task_embedding = nn.Embedding(config.n_tasks, config.task_emb_size) | |
| final_emb_size += config.task_emb_size | |
| elif self.task_emb_type == "expand-last": | |
| self.pred_layer_id = config.pred_layer_id | |
| assert self.n_tasks == len(self.pred_layer_id) | |
| print( | |
| f"[DistillerModel] - Expands the output dimension by {self.n_tasks} times" | |
| ) | |
| print(f"[DistillerModel] - Pred layers: {self.pred_layer_id}") | |
| elif self.task_emb_type == "self-hidden": | |
| self.pred_layer_id = config.pred_layer_id | |
| assert self.n_tasks == len(self.pred_layer_id) | |
| assert self.n_tasks == config.encoder_layers + 1 | |
| print("[DistillerModel] - Predicting with self-hidden layers") | |
| print(f"[DistillerModel] - Pred layers: {self.pred_layer_id}") | |
| elif self.task_emb_type == "none": | |
| print( | |
| f"[DistillerModel] - Disabled task embedding (predicts only layer {self.n_tasks})" | |
| ) | |
| else: | |
| raise NotImplementedError(f"Unknown task emb type {self.task_emb_type}") | |
| self.post_extract_proj = ( | |
| nn.Linear(feat_emb_dim, config.encoder_embed_dim) | |
| if feat_emb_dim != config.encoder_embed_dim | |
| else None | |
| ) | |
| if config.encoder_layers > 0: | |
| self.encoder = TransformerEncoder(config) | |
| else: | |
| self.encoder = nn.GELU() | |
| final_dim = config.final_dim * ( | |
| 1 if self.task_emb_type != "expand-last" else self.n_tasks | |
| ) | |
| inter_dim = config.out_layer_inter_dim | |
| inter_dim = inter_dim if inter_dim > 0 else final_emb_size | |
| print(f"[DistillerModel] - Out layer type: {config.out_layer_type}") | |
| if config.out_layer_type == "expand-last": | |
| assert self.task_emb_type == "expand-last" | |
| print(f"[DistillerModel] - Inter dim = {inter_dim}") | |
| self.output_layer = nn.Sequential( | |
| nn.Linear(final_emb_size, inter_dim * self.n_tasks), | |
| nn.GELU(), | |
| SplitLinear(inter_dim, self.n_tasks, config.final_dim), | |
| ) | |
| elif config.out_layer_type in {"none", "self-hidden"}: | |
| self.output_layer = None | |
| else: | |
| raise NotImplementedError(f"Unknown out layer type {config.out_layer_type}") | |
| def forward_feature(self, wave, pad_mask): | |
| """Forward feature extractor""" | |
| if self.feature_grad_mult > 0: | |
| feat = self.feature_extractor(wave) | |
| if self.feature_grad_mult != 1.0: | |
| feat = GradMultiply.apply(feat, self.feature_grad_mult) | |
| else: | |
| with torch.no_grad(): | |
| feat = self.feature_extractor(wave) | |
| feat = feat.transpose(1, 2) # B x T x D | |
| pad_mask = self.cal_pad_mask(pad_mask, feat.shape[1]) | |
| return feat, pad_mask | |
| def forward(self, wave, pad_mask, task_id=None, get_hidden=False, no_pred=False): | |
| """ | |
| Forward function | |
| Input: | |
| wave (FloatTensor): B x T_wave | |
| pad_mask (BoolTensor): B x T_wave | |
| task_id (LongTensor): N >= 1 | |
| """ | |
| feat, pad_mask = self.forward_feature(wave, pad_mask) | |
| if self.task_emb_type not in ["none", "expand-last", "self-hidden"]: | |
| if task_id is None: | |
| task_id = self.generate_task_id(feat.device) | |
| elif isinstance(task_id, list): | |
| task_id = torch.LongTensor(task_id).to(feat.device) | |
| task_embs = self.task_embedding(task_id) | |
| # N x D | |
| n_sz = len(task_id) | |
| else: | |
| n_sz = 1 | |
| b_sz, t_sz, _ = feat.shape | |
| if self.task_emb_type == "add": | |
| # Add embs to feature | |
| if self.post_extract_proj is not None: | |
| feat_final = self.post_extract_proj(feat) | |
| else: | |
| feat_final = feat | |
| feat_final = feat_final.unsqueeze(1) + task_embs.unsqueeze(0).unsqueeze(2) | |
| elif self.task_emb_type == "concat": | |
| # Concatenates embs to feature | |
| feat_final = torch.cat( | |
| [ | |
| feat.unsqueeze(1).expand(-1, n_sz, -1, -1), | |
| task_embs.unsqueeze(0).unsqueeze(2).expand(b_sz, -1, t_sz, -1), | |
| ], | |
| dim=-1, | |
| ) | |
| if self.post_extract_proj is not None: | |
| feat_final = self.post_extract_proj(feat_final) | |
| else: | |
| if self.post_extract_proj is not None: | |
| feat_final = self.post_extract_proj(feat) | |
| else: | |
| feat_final = feat | |
| feat_final = feat_final.unsqueeze(1) | |
| # feat_final: B x N x T x D or B x 1 x T x D | |
| pad_mask = pad_mask.unsqueeze(1).expand(-1, n_sz, -1).reshape(b_sz * n_sz, t_sz) | |
| # BN x T | |
| feat_final = feat_final.reshape(b_sz * n_sz, t_sz, -1) | |
| # BN x T x D | |
| layer_hiddens = [] | |
| if self.config.encoder_layers > 0: | |
| get_hidden_tmp = ( | |
| True if (self.task_emb_type == "self-hidden") else get_hidden | |
| ) | |
| hidden, layer_hiddens = self.encoder( | |
| feat_final, ~pad_mask.bool(), get_hidden=get_hidden_tmp | |
| ) | |
| else: | |
| hidden = self.encoder(feat_final) | |
| if not no_pred: | |
| if self.task_emb_type == "self-hidden": | |
| pred = torch.stack([feat_final] + layer_hiddens, dim=1) | |
| else: | |
| pred = self.output_layer(hidden).reshape(b_sz, n_sz, t_sz, -1) | |
| # B x N x T x D | |
| else: | |
| pred = None | |
| if (not no_pred) and self.task_emb_type == "expand-last": | |
| assert n_sz == 1, n_sz | |
| pred = ( | |
| pred.squeeze(1) | |
| .reshape(b_sz, t_sz, self.n_tasks, -1) | |
| .permute(0, 2, 1, 3) | |
| ) | |
| # B x N x T x D | |
| if get_hidden: | |
| return feat, feat_final, pred, pad_mask, layer_hiddens | |
| else: | |
| return feat, feat_final, pred, pad_mask | |
| def cal_pad_mask(self, pad_mask, max_len): | |
| """Calculates pad mask after conv.""" | |
| pad_len = (pad_mask > 0).sum(1).long() | |
| for _, k_size, s_size in self.conv_layers: | |
| pad_len = (pad_len - k_size) // s_size + 1 | |
| new_pad_mask = torch.ones( | |
| (pad_mask.shape[0], max_len), dtype=pad_mask.dtype, device=pad_mask.device | |
| ) | |
| for idx in range(pad_len.shape[0]): | |
| new_pad_mask[idx, pad_len[idx] :] = 0 | |
| return new_pad_mask | |
| def generate_task_id(self, device): | |
| return torch.arange(self.n_tasks, device=device, dtype=torch.long) | |