''' Transform augmented model back to normal hf supported version. i.e remove first module ''' from augmentation import AUG from torch import nn import transformers from transformers import Wav2Vec2ForCTC, AutoModelForPreTraining from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2NoLayerNormConvLayer, Wav2Vec2LayerNormConvLayer, Wav2Vec2GroupNormConvLayer def patch_init(cls): __class__ = cls # provide closure cell for super() def new_init(self, config): if config.feat_extract_norm == "group": conv_layers = [Wav2Vec2GroupNormConvLayer(config, layer_id=0)] + [ Wav2Vec2NoLayerNormConvLayer(config, layer_id=i + 1) for i in range(config.num_feat_extract_layers - 1) ] elif config.feat_extract_norm == "layer": conv_layers = [ Wav2Vec2LayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers) ] else: raise ValueError( f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']" ) aug = AUG() from IPython import embed embed() conv_layers.insert(0, aug) self.conv_layers = nn.ModuleList(conv_layers) cls.__init__ = new_init patch_init(transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureExtractor) # model_path = "pytorch_model.bin" # Wav2Vec2ForPreTraining model = Wav2Vec2ForCTC.from_pretrained(".") from IPython import embed embed() # monkey patching from augmentation return model to normal state model.wav2vec2.feature_extractor.conv_layers = nn.Sequential(*list(model.wav2vec2.feature_extractor.conv_layers.children())[1:]) model.save_pretrained(".") """ replace with temprarily then save model, patching didn't work, get loaded so far for some reason. "conv_dim": [ 1, 512, 512, 512, 512, 512, 512, 512 ], "conv_kernel": [ 10, 10, 3, 3, 3, 3, 2, 2 ], "conv_stride": [ 5, 5, 2, 2, 2, 2, 2, 2 """