Aman K commited on
Commit
a1c1315
1 Parent(s): 5ae3a9f

prepared alignment model to be loaded using flaxautomodel

Browse files
Files changed (3) hide show
  1. config.json +4 -1
  2. flax_model.msgpack +3 -0
  3. flax_modeling_alignment.py +181 -0
config.json CHANGED
@@ -1,5 +1,5 @@
1
  {
2
- "_name_or_path": "amankhandelia/mms_alignment_model",
3
  "activation_dropout": 0.1,
4
  "adapter_attn_dim": null,
5
  "adapter_kernel_size": 3,
@@ -9,6 +9,9 @@
9
  "architectures": [
10
  "Wav2Vec2ForAudioFrameClassification"
11
  ],
 
 
 
12
  "attention_dropout": 0.0,
13
  "bos_token_id": 1,
14
  "classifier_proj_size": 256,
 
1
  {
2
+ "_name_or_path": "amankhandelia/flax_mms_alignment_model",
3
  "activation_dropout": 0.1,
4
  "adapter_attn_dim": null,
5
  "adapter_kernel_size": 3,
 
9
  "architectures": [
10
  "Wav2Vec2ForAudioFrameClassification"
11
  ],
12
+ "auto_map": {
13
+ "FlaxAutoModel": "flax_modeling_alignment.FlaxWav2Vec2ForAudioFrameClassification"
14
+ },
15
  "attention_dropout": 0.0,
16
  "bos_token_id": 1,
17
  "classifier_proj_size": 256,
flax_model.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1a9a569d76919565b0879dec40c3f545a00a03cd839820a248058dc021e862a6
3
+ size 1261893241
flax_modeling_alignment.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ import flax.linen as nn
6
+
7
+ from transformers.modeling_flax_outputs import FlaxCausalLMOutput
8
+ from transformers.models.wav2vec2.configuration_wav2vec2 import Wav2Vec2Config
9
+ from transformers.models.wav2vec2.modeling_flax_wav2vec2 import (
10
+ FlaxWav2Vec2FeatureEncoder,
11
+ FlaxWav2Vec2FeatureProjection,
12
+ FlaxWav2Vec2StableLayerNormEncoder,
13
+ FlaxWav2Vec2Adapter,
14
+ FlaxWav2Vec2PreTrainedModel,
15
+ FlaxWav2Vec2BaseModelOutput,
16
+ )
17
+
18
+
19
+ class FlaxWav2Vec2Module(nn.Module):
20
+ config: Wav2Vec2Config
21
+ dtype: jnp.dtype = jnp.float32
22
+
23
+ def setup(self):
24
+ self.feature_extractor = FlaxWav2Vec2FeatureEncoder(self.config, dtype=self.dtype)
25
+ self.feature_projection = FlaxWav2Vec2FeatureProjection(self.config, dtype=self.dtype)
26
+ if self.config.mask_time_prob > 0.0 or self.config.mask_feature_prob > 0.0:
27
+ self.masked_spec_embed = self.param(
28
+ "masked_spec_embed", jax.nn.initializers.uniform(), (self.config.hidden_size,)
29
+ )
30
+
31
+ if self.config.do_stable_layer_norm:
32
+ self.encoder = FlaxWav2Vec2StableLayerNormEncoder(self.config, dtype=self.dtype)
33
+ else:
34
+ raise NotImplementedError("``config.do_stable_layer_norm is False`` is currently not supported.")
35
+
36
+ self.adapter = FlaxWav2Vec2Adapter(self.config, dtype=self.dtype) if self.config.add_adapter else None
37
+
38
+ def __call__(
39
+ self,
40
+ input_values,
41
+ attention_mask=None,
42
+ mask_time_indices=None,
43
+ deterministic=True,
44
+ output_attentions=None,
45
+ output_hidden_states=None,
46
+ freeze_feature_encoder=False,
47
+ return_dict=None,
48
+ ):
49
+ extract_features = self.feature_extractor(input_values, freeze_feature_encoder=freeze_feature_encoder)
50
+
51
+ # make sure that no loss is computed on padded inputs
52
+ if attention_mask is not None:
53
+ # compute reduced attention_mask corresponding to feature vectors
54
+ attention_mask = self._get_feature_vector_attention_mask(
55
+ extract_features.shape[1], attention_mask, add_adapter=False
56
+ )
57
+
58
+ hidden_states, extract_features = self.feature_projection(extract_features, deterministic=deterministic)
59
+ if mask_time_indices is not None: # apply SpecAugment along time axis with given indices
60
+ hidden_states = jnp.where(
61
+ jnp.broadcast_to(mask_time_indices[:, :, None], hidden_states.shape),
62
+ jnp.broadcast_to(self.masked_spec_embed[None, None, :], hidden_states.shape),
63
+ hidden_states,
64
+ )
65
+
66
+ encoder_outputs = self.encoder(
67
+ hidden_states,
68
+ attention_mask=attention_mask,
69
+ deterministic=deterministic,
70
+ output_attentions=output_attentions,
71
+ output_hidden_states=output_hidden_states,
72
+ return_dict=return_dict,
73
+ )
74
+
75
+ hidden_states = encoder_outputs[0]
76
+
77
+ if self.adapter is not None:
78
+ hidden_states = self.adapter(hidden_states)
79
+
80
+ if not return_dict:
81
+ return (hidden_states, extract_features) + encoder_outputs[1:]
82
+
83
+ return FlaxWav2Vec2BaseModelOutput(
84
+ last_hidden_state=hidden_states,
85
+ extract_features=extract_features,
86
+ hidden_states=encoder_outputs.hidden_states,
87
+ attentions=encoder_outputs.attentions,
88
+ )
89
+
90
+ def _get_feat_extract_output_lengths(
91
+ self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
92
+ ):
93
+ """
94
+ Computes the output length of the convolutional layers
95
+ """
96
+
97
+ add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
98
+
99
+ def _conv_out_length(input_length, kernel_size, stride):
100
+ # 1D convolutional layer output length formula taken
101
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
102
+ return (input_length - kernel_size) // stride + 1
103
+
104
+ for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
105
+ input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
106
+
107
+ if add_adapter:
108
+ for _ in range(self.config.num_adapter_layers):
109
+ input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
110
+
111
+ return input_lengths
112
+
113
+ def _get_feature_vector_attention_mask(
114
+ self, feature_vector_length: int, attention_mask: jnp.ndarray, add_adapter=None
115
+ ):
116
+ # Effectively attention_mask.sum(-1), but not inplace to be able to run
117
+ # on inference mode.
118
+ non_padded_lengths = attention_mask.cumsum(axis=-1)[:, -1]
119
+
120
+ output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
121
+
122
+ batch_size = attention_mask.shape[0]
123
+
124
+ attention_mask = jnp.zeros((batch_size, feature_vector_length), dtype=attention_mask.dtype)
125
+ # these two operations makes sure that all values
126
+ # before the output lengths indices are attended to
127
+ attention_mask = attention_mask.at[jnp.arange(attention_mask.shape[0]), output_lengths - 1].set(1)
128
+ attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool")
129
+ return attention_mask
130
+
131
+
132
+ class FlaxWav2Vec2Model(FlaxWav2Vec2PreTrainedModel):
133
+ module_class = FlaxWav2Vec2Module
134
+
135
+
136
+ class FlaxWav2Vec2ForAudioFrameClassificationModule(nn.Module):
137
+ config: Wav2Vec2Config
138
+ dtype: jnp.dtype = jnp.float32
139
+
140
+ def setup(self):
141
+ self.wav2vec2 = FlaxWav2Vec2Module(self.config, dtype=self.dtype)
142
+ self.classifier = nn.Dense(
143
+ self.config.num_labels,
144
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
145
+ dtype=self.dtype,
146
+ )
147
+
148
+ def __call__(
149
+ self,
150
+ input_values,
151
+ attention_mask=None,
152
+ mask_time_indices=None,
153
+ deterministic=True,
154
+ output_attentions=None,
155
+ output_hidden_states=None,
156
+ freeze_feature_encoder=False,
157
+ return_dict=None,
158
+ ):
159
+ outputs = self.wav2vec2(
160
+ input_values,
161
+ attention_mask=attention_mask,
162
+ mask_time_indices=mask_time_indices,
163
+ deterministic=deterministic,
164
+ output_attentions=output_attentions,
165
+ output_hidden_states=output_hidden_states,
166
+ freeze_feature_encoder=freeze_feature_encoder,
167
+ return_dict=return_dict,
168
+ )
169
+
170
+ hidden_states = outputs[0]
171
+
172
+ logits = self.classifier(hidden_states)
173
+
174
+ if not return_dict:
175
+ return (logits,) + outputs[2:]
176
+
177
+ return FlaxCausalLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
178
+
179
+
180
+ class FlaxWav2Vec2ForAudioFrameClassification(FlaxWav2Vec2PreTrainedModel):
181
+ module_class = FlaxWav2Vec2ForAudioFrameClassificationModule