Automatic Speech Recognition
Transformers
Safetensors
DiCoW
speech
whisper
multilingual
speaker-diarization
meeting-transcription
BUT-FIT
custom_code
Instructions to use BUT-FIT/DiCoW_v3_2 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use BUT-FIT/DiCoW_v3_2 with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("automatic-speech-recognition", model="BUT-FIT/DiCoW_v3_2", trust_remote_code=True)# Load model directly from transformers import AutoModelForSpeechSeq2Seq model = AutoModelForSpeechSeq2Seq.from_pretrained("BUT-FIT/DiCoW_v3_2", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import torch | |
| from torch import nn | |
| from transformers.modeling_outputs import CausalLMOutput, BaseModelOutput | |
| # Compatibility fallback for Transformers versions | |
| # - Transformers <= 4.38: WHISPER_ATTENTION_CLASSES is available | |
| # - Transformers >= 4.39 (including 4.55.0): WHISPER_ATTENTION_CLASSES removed, use WhisperAttention dispatcher | |
| from transformers.models.whisper.modeling_whisper import WhisperEncoder, WhisperEncoderLayer | |
| try: | |
| from transformers.models.whisper.modeling_whisper import WHISPER_ATTENTION_CLASSES | |
| except ImportError: | |
| from transformers.models.whisper.modeling_whisper import WhisperAttention | |
| WHISPER_ATTENTION_CLASSES = { | |
| "eager": WhisperAttention, | |
| "sdpa": WhisperAttention, | |
| "flash_attention_2": WhisperAttention, | |
| } | |
| from .FDDT import FDDT | |
| from .config import DiCoWConfig | |
| from .SCBs import SpeakerCommunicationBlock | |
| class DiCoWEncoder(WhisperEncoder): | |
| config_class = DiCoWConfig | |
| def __init__(self, config: DiCoWConfig): | |
| super().__init__(config) | |
| self.ctc_weight = config.ctc_weight | |
| if config.additional_layer and self.ctc_weight > 0.0: | |
| self.additional_layer = WhisperEncoderLayer(config) | |
| if config.additional_self_attention_layer and self.ctc_weight > 0.0: | |
| self.additional_self_attention_layer = WHISPER_ATTENTION_CLASSES[config._attn_implementation]( | |
| embed_dim=config.d_model, | |
| num_heads=config.encoder_attention_heads, | |
| dropout=config.attention_dropout, | |
| config=config, | |
| ) | |
| if config.sub_sample and self.ctc_weight > 0.0: | |
| self.subsample_conv1 = nn.Conv1d( | |
| in_channels=config.d_model, | |
| out_channels=config.d_model, | |
| kernel_size=3, | |
| stride=2, | |
| padding=1, | |
| bias=False, | |
| ) | |
| self.subsample_conv2 = nn.Conv1d( | |
| in_channels=config.d_model, | |
| out_channels=config.d_model, | |
| kernel_size=3, | |
| stride=2, | |
| padding=1, | |
| bias=False, | |
| ) | |
| if self.ctc_weight > 0.0: | |
| self.lm_head = nn.Linear(config.d_model, config.vocab_size + 1, bias=False) | |
| self.final_dropout = nn.Dropout(config.final_dropout) | |
| if config.use_fddt: | |
| num_fddts = self.config.apply_fddt_to_n_layers if self.config.apply_fddt_to_n_layers != -1 else len( | |
| self.layers) | |
| self.initial_fddt = FDDT(config.d_model, | |
| non_target_rate=config.non_target_fddt_value, | |
| is_diagonal=config.fddt_is_diagonal, | |
| bias_only=config.fddt_bias_only, | |
| use_silence=config.fddt_use_silence, | |
| use_target=config.fddt_use_target, | |
| use_overlap=config.fddt_use_overlap, | |
| use_non_target=config.fddt_use_non_target, | |
| use_interaction=False, | |
| scb_module=None | |
| # in initial layers we dont want communication | |
| ) | |
| num_scbs = (self.config.scb_layers if self.config.scb_layers != -1 else len( | |
| self.layers)) if self.config.is_mt else 0 | |
| self.scbs_identity_layers = config.encoder_layers - num_scbs | |
| self.fddts = nn.ModuleList([ | |
| FDDT(config.d_model, | |
| non_target_rate=1.0, | |
| is_diagonal=config.fddt_is_diagonal, | |
| bias_only=config.fddt_bias_only, | |
| use_silence=config.fddt_use_silence, | |
| use_target=config.fddt_use_target, | |
| use_overlap=config.fddt_use_overlap, | |
| use_non_target=config.fddt_use_non_target, | |
| use_interaction=i >= self.scbs_identity_layers, | |
| scb_module=SpeakerCommunicationBlock(config, | |
| scb_method=config.scb_method) if i >= self.scbs_identity_layers else None, | |
| ) | |
| for i in range(num_fddts) | |
| ]) | |
| self.first_task_token = self.config.vocab_size - 30 * 50 - 1 - 6 # 30 seconds of 50 Hz timestamps -1 to get to 0.0 and -6 number of tasks | |
| self.post_init() | |
| def _load_pretrained_model( | |
| cls, | |
| model, | |
| state_dict, | |
| loaded_keys, | |
| resolved_archive_file, | |
| pretrained_model_name_or_path, | |
| **kwargs | |
| ): | |
| for key in list(state_dict.keys()): | |
| if key.startswith("encoder."): | |
| state_dict[key[8:]] = state_dict.pop(key) | |
| loaded_keys.remove(key) | |
| loaded_keys.append(key[8:]) | |
| output = super()._load_pretrained_model( | |
| model, | |
| state_dict, | |
| loaded_keys, | |
| resolved_archive_file, | |
| pretrained_model_name_or_path, | |
| **kwargs | |
| ) | |
| return output | |
| def get_loss(self, logits, labels): | |
| if labels.max() >= self.config.vocab_size: | |
| raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}") | |
| if self.config.remove_timestamps_from_ctc: | |
| labels = torch.nn.utils.rnn.pad_sequence([label[label < self.first_task_token] for label in labels], | |
| padding_value=-100).T | |
| input_lengths = torch.full((logits.shape[0],), fill_value=logits.shape[1], | |
| device=logits.device) | |
| # assuming that padded tokens are filled with -100 | |
| # when not being attended to | |
| labels_mask = labels >= 0 | |
| target_lengths = labels_mask.sum(-1) | |
| # flattened_targets = labels_enc.masked_select(labels_mask) | |
| # ctc_loss doesn't support fp16 | |
| log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1) | |
| with torch.backends.cudnn.flags(enabled=True): | |
| ctc_loss = nn.functional.ctc_loss( | |
| log_probs, | |
| labels, | |
| input_lengths, | |
| target_lengths, | |
| blank=logits.shape[-1] - 1, | |
| reduction=self.config.ctc_loss_reduction, | |
| zero_infinity=True, | |
| ) | |
| return ctc_loss | |
| def forward( | |
| self, | |
| input_features, | |
| attention_mask=None, | |
| head_mask=None, | |
| output_attentions=None, | |
| output_hidden_states=None, | |
| return_dict=None, | |
| stno_mask=None, | |
| per_group_sizes=None | |
| ): | |
| # For MT-ASR the input has shape (B X S) x F x T | |
| # we can use torch.view(B, S, F, -1) to obtain | |
| # new tensor with speaker dim | |
| expected_seq_length = self.config.max_source_positions * self.conv1.stride[0] * self.conv2.stride[0] | |
| if input_features.shape[-1] != expected_seq_length: | |
| if input_features.shape[-1] > expected_seq_length: | |
| return CausalLMOutput( | |
| logits=None, | |
| hidden_states=None, | |
| attentions=None, | |
| ) | |
| else: | |
| raise ValueError( | |
| f"Whisper expects the mel input features to be of length {expected_seq_length}, but found {input_features.shape[-1]}. Make sure to pad the input mel features to {expected_seq_length}." | |
| ) | |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
| output_hidden_states = ( | |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
| ) | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| inputs_embeds = nn.functional.gelu(self.conv1(input_features)) | |
| inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) | |
| inputs_embeds = inputs_embeds.permute(0, 2, 1) | |
| embed_pos = self.embed_positions.weight | |
| if self.config.use_fddt: | |
| inputs_embeds = self.initial_fddt(inputs_embeds, stno_mask) | |
| hidden_states = inputs_embeds + embed_pos | |
| hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) | |
| encoder_states = () if output_hidden_states else None | |
| all_attentions = () if output_attentions else None | |
| # check if head_mask has a correct number of layers specified if desired | |
| if head_mask is not None: | |
| assert head_mask.size()[0] == ( | |
| len(self.layers) | |
| ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." | |
| for idx, encoder_layer in enumerate(self.layers): | |
| if output_hidden_states: | |
| encoder_states = encoder_states + (hidden_states,) | |
| # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) | |
| to_drop = False | |
| if self.training: | |
| dropout_probability = torch.rand([]) | |
| if dropout_probability < self.layerdrop: # skip the layer | |
| to_drop = True | |
| if self.config.use_fddt and idx < len(self.fddts): | |
| hidden_states = self.fddts[idx](hidden_states, stno_mask) | |
| if to_drop: | |
| layer_outputs = (None, None) | |
| else: | |
| if self.gradient_checkpointing and self.training: | |
| layer_outputs = self._gradient_checkpointing_func( | |
| encoder_layer.__call__, | |
| hidden_states, | |
| None, | |
| (head_mask[idx] if head_mask is not None else None), | |
| output_attentions, | |
| ) | |
| else: | |
| layer_outputs = encoder_layer( | |
| hidden_states, | |
| None, | |
| layer_head_mask=(head_mask[idx] if head_mask is not None else None), | |
| output_attentions=output_attentions, | |
| ) | |
| hidden_states = layer_outputs[0] | |
| if output_attentions: | |
| all_attentions = all_attentions + (layer_outputs[1],) | |
| hidden_states = self.layer_norm(hidden_states) | |
| if output_hidden_states: | |
| encoder_states = encoder_states + (hidden_states,) | |
| if not return_dict: | |
| outputs = tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) | |
| else: | |
| outputs = BaseModelOutput( | |
| last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions | |
| ) | |
| if hasattr(self, "additional_layer"): | |
| inter_output, = self.additional_layer( | |
| outputs.last_hidden_state, | |
| attention_mask=None, | |
| output_attentions=output_attentions, | |
| layer_head_mask=None, | |
| ) | |
| elif hasattr(self, "additional_self_attention_layer"): | |
| inter_output, _ = self.additional_self_attention_layer( | |
| outputs.last_hidden_state, | |
| attention_mask=None, | |
| output_attentions=output_attentions, | |
| layer_head_mask=None, | |
| ) | |
| else: | |
| inter_output = outputs.last_hidden_state | |
| inter_output = self.final_dropout(inter_output) | |
| if hasattr(self, "subsample_conv2"): | |
| inter_output = self.subsample_conv2(self.subsample_conv1(inter_output.transpose(1, 2))).transpose(1, 2) | |
| if self.ctc_weight > 0.0: | |
| logits = self.lm_head(inter_output) | |
| else: | |
| logits = None | |
| return CausalLMOutput( | |
| logits=logits, | |
| hidden_states=outputs.hidden_states, | |
| attentions=outputs.attentions, | |
| ) | |