|
''' |
|
Transform augmented model back to normal hf supported version. |
|
i.e remove first module |
|
''' |
|
from augmentation import AUG |
|
|
|
from torch import nn |
|
import transformers |
|
from transformers import Wav2Vec2ForCTC, AutoModelForPreTraining |
|
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2NoLayerNormConvLayer, Wav2Vec2LayerNormConvLayer, Wav2Vec2GroupNormConvLayer |
|
|
|
|
|
def patch_init(cls): |
|
__class__ = cls |
|
|
|
def new_init(self, config): |
|
if config.feat_extract_norm == "group": |
|
conv_layers = [Wav2Vec2GroupNormConvLayer(config, layer_id=0)] + [ |
|
Wav2Vec2NoLayerNormConvLayer(config, layer_id=i + 1) for i in range(config.num_feat_extract_layers - 1) |
|
] |
|
elif config.feat_extract_norm == "layer": |
|
conv_layers = [ |
|
Wav2Vec2LayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers) |
|
] |
|
else: |
|
raise ValueError( |
|
f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']" |
|
) |
|
aug = AUG() |
|
from IPython import embed |
|
embed() |
|
conv_layers.insert(0, aug) |
|
self.conv_layers = nn.ModuleList(conv_layers) |
|
cls.__init__ = new_init |
|
|
|
patch_init(transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureExtractor) |
|
|
|
|
|
|
|
model = Wav2Vec2ForCTC.from_pretrained(".") |
|
from IPython import embed |
|
embed() |
|
|
|
model.wav2vec2.feature_extractor.conv_layers = nn.Sequential(*list(model.wav2vec2.feature_extractor.conv_layers.children())[1:]) |
|
|
|
|
|
model.save_pretrained(".") |
|
|
|
""" |
|
replace with temprarily then save model, patching didn't work, get loaded so far for some reason. |
|
"conv_dim": [ |
|
1, |
|
512, |
|
512, |
|
512, |
|
512, |
|
512, |
|
512, |
|
512 |
|
], |
|
"conv_kernel": [ |
|
10, |
|
10, |
|
3, |
|
3, |
|
3, |
|
3, |
|
2, |
|
2 |
|
], |
|
"conv_stride": [ |
|
5, |
|
5, |
|
2, |
|
2, |
|
2, |
|
2, |
|
2, |
|
2 |
|
""" |