patrickvonplaten commited on
Commit
004e84f
1 Parent(s): 7e72381

upload hubert for seq classification

Browse files
__pycache__/hubert_for_sequence_classification.cpython-38.pyc ADDED
Binary file (3.37 kB). View file
config.json ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_dropout": 0.0,
3
+ "apply_spec_augment": true,
4
+ "architectures": [
5
+ "HubertModel"
6
+ ],
7
+ "attention_dropout": 0.1,
8
+ "bos_token_id": 1,
9
+ "conv_bias": true,
10
+ "conv_dim": [
11
+ 512,
12
+ 512,
13
+ 512,
14
+ 512,
15
+ 512,
16
+ 512,
17
+ 512
18
+ ],
19
+ "conv_kernel": [
20
+ 10,
21
+ 3,
22
+ 3,
23
+ 3,
24
+ 3,
25
+ 2,
26
+ 2
27
+ ],
28
+ "conv_stride": [
29
+ 5,
30
+ 2,
31
+ 2,
32
+ 2,
33
+ 2,
34
+ 2,
35
+ 2
36
+ ],
37
+ "ctc_loss_reduction": "sum",
38
+ "ctc_zero_infinity": false,
39
+ "do_stable_layer_norm": true,
40
+ "eos_token_id": 2,
41
+ "feat_extract_activation": "gelu",
42
+ "feat_extract_dropout": 0.0,
43
+ "feat_extract_norm": "layer",
44
+ "feat_proj_dropout": 0.1,
45
+ "final_dropout": 0.0,
46
+ "gradient_checkpointing": false,
47
+ "hidden_act": "gelu",
48
+ "hidden_dropout": 0.1,
49
+ "hidden_size": 1024,
50
+ "initializer_range": 0.02,
51
+ "intermediate_size": 4096,
52
+ "layer_norm_eps": 1e-05,
53
+ "layerdrop": 0.1,
54
+ "mask_channel_length": 10,
55
+ "mask_channel_min_space": 1,
56
+ "mask_channel_other": 0.0,
57
+ "mask_channel_prob": 0.0,
58
+ "mask_channel_selection": "static",
59
+ "mask_feature_length": 10,
60
+ "mask_feature_prob": 0.0,
61
+ "mask_time_length": 10,
62
+ "mask_time_min_space": 1,
63
+ "mask_time_other": 0.0,
64
+ "mask_time_prob": 0.075,
65
+ "mask_time_selection": "static",
66
+ "model_type": "hubert",
67
+ "num_attention_heads": 16,
68
+ "num_conv_pos_embedding_groups": 16,
69
+ "num_conv_pos_embeddings": 128,
70
+ "num_feat_extract_layers": 7,
71
+ "num_hidden_layers": 24,
72
+ "pad_token_id": 0,
73
+ "transformers_version": "4.9.0.dev0",
74
+ "vocab_size": 32
75
+ }
flax_model.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d9182ceb19f342e64595b56f8cefbe3aad87fc906c46824798741c3221f55956
3
+ size 417861048
hubert_for_sequence_classification.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ from transformers.models.wav2vec2.modeling_flax_wav2vec2 import FlaxWav2Vec2Module, FlaxWav2Vec2PreTrainedModel
3
+ from typing import Union
4
+ from transformers import HubertConfig
5
+ from transformers.modeling_flax_outputs import FlaxSequenceClassifierOutput
6
+ import flax.linen as nn
7
+ import jax.numpy as jnp
8
+ import jax
9
+
10
+
11
+ class FlaxHubertForSequenceClassificationModule(nn.Module):
12
+ config: HubertConfig
13
+ dtype: jnp.dtype = jnp.float32
14
+
15
+ def setup(self):
16
+ self.hubert = FlaxWav2Vec2Module(self.config, dtype=self.dtype)
17
+ self.dropout = nn.Dropout(rate=self.config.final_dropout)
18
+ self.reduce = "mean"
19
+
20
+ # binary classification
21
+ self.lm_head = nn.Dense(
22
+ 2,
23
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
24
+ dtype=self.dtype,
25
+ )
26
+
27
+ def __call__(
28
+ self,
29
+ input_values,
30
+ attention_mask=None,
31
+ mask_time_indices=None,
32
+ deterministic=True,
33
+ output_attentions=None,
34
+ output_hidden_states=None,
35
+ return_dict=None,
36
+ ):
37
+ outputs = self.hubert(
38
+ input_values,
39
+ attention_mask=attention_mask,
40
+ mask_time_indices=mask_time_indices,
41
+ deterministic=deterministic,
42
+ output_attentions=output_attentions,
43
+ output_hidden_states=output_hidden_states,
44
+ return_dict=return_dict,
45
+ )
46
+
47
+ hidden_states = outputs[0]
48
+ if self.reduce == "mean":
49
+ hidden_states = jnp.mean(hidden_states, axis=1)
50
+
51
+ hidden_states = jax.nn.relu(hidden_states)
52
+ logits = self.lm_head(hidden_states)
53
+
54
+ if not return_dict:
55
+ return (logits,) + outputs[2:]
56
+
57
+ return FlaxSequenceClassifierOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
58
+
59
+ def _get_feat_extract_output_lengths(self, input_lengths: Union[jnp.ndarray, int]):
60
+ """
61
+ Computes the output length of the convolutional layers
62
+ """
63
+
64
+ def _conv_out_length(input_length, kernel_size, stride):
65
+ # 1D convolutional layer output length formula taken
66
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
67
+ return (input_length - kernel_size) // stride + 1
68
+
69
+ for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
70
+ input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
71
+
72
+ return input_lengths
73
+
74
+
75
+ class FlaxHubertPreTrainedModel(FlaxWav2Vec2PreTrainedModel):
76
+ config_class = HubertConfig
77
+ base_model_prefix: str = "hubert"
78
+ module_class: nn.Module = None
79
+
80
+ def _get_feat_extract_output_lengths(self, input_lengths: Union[jnp.ndarray, int]):
81
+ return self.module._get_feat_extract_output_lengths(input_lengths)
82
+
83
+
84
+ class FlaxHubertModel(FlaxHubertPreTrainedModel):
85
+ module_class = FlaxWav2Vec2Module
86
+
87
+
88
+ class FlaxHubertForSequenceClassification(FlaxHubertPreTrainedModel):
89
+ module_class = FlaxHubertForSequenceClassificationModule
run_hubert_classifier.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ from hubert_for_sequence_classification import FlaxHubertForSequenceClassification, FlaxHubertModel
3
+ import numpy as np
4
+
5
+ # need to do some ugly save/reload because of a bug
6
+ model = FlaxHubertModel.from_pretrained("facebook/hubert-large-ll60k", from_pt=True)
7
+ model.save_pretrained("./")
8
+ model = FlaxHubertForSequenceClassification.from_pretrained("./")
9
+
10
+ dummy_input = np.array(2 * [1024 * [1.0]], dtype=np.float32)
11
+
12
+ logits = model(dummy_input).logits
13
+
14
+ # output shape is (batch_size, 2)
15
+ print("output shape", logits.shape)