File size: 4,640 Bytes
0eef6aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134

import torch
import torchaudio
import torch.nn as nn

import modules.p1cupe.model_utils as model_utils



class FinetuneXLSR(nn.Module):
    def __init__(self, hp, input_wav_length, freeze_feature_encoder=False):
        super().__init__()

        self.hp = hp
        self.noise_level = hp.noise_level
        self.xls_dim = 1024
        self.output_dim = hp.phoneme_classes + 1  # 66 phonemes + blank token for CTC
        
        # Load pre-trained XLSR model
        #bundle = torchaudio.pipelines.WAV2VEC2_XLSR53  # referenced as 'xa' in yaml
        bundle = torchaudio.pipelines.WAV2VEC2_XLSR_300M    # referenced as 'xb' in yaml
        
        self.XLSR = bundle.get_model()

        # reset paramters for XLSR:

        # reset parameters for XLSR:
        def reset_parameters(module):
            if hasattr(module, 'reset_parameters'):
                module.reset_parameters()
                #print("reset_parameters for ", module)

        if hasattr(self.hp , 'xlrs_reset'):
            if (self.hp.xlrs_reset):
                print("reset_parameters for XLSR")
                self.XLSR.apply(reset_parameters)
        
        self.freeze_feature_encoder = freeze_feature_encoder
        # Optionally freeze only the feature encoder
        # It's common practice to keep the feature encoder frozen while fine-tuning the rest
        if self.freeze_feature_encoder:
            # Freeze the feature extractor part
            for param in self.XLSR.model.feature_extractor.parameters():
                param.requires_grad = False
            
            # Optionally, you might also want to freeze the feature projection
            # for param in self.XLSR.model.encoder.feature_projection.parameters():
            #    param.requires_grad = False
            
        # Final classifier
        self.classifier = nn.Sequential(
            nn.Linear(self.xls_dim, self.xls_dim),
            nn.ReLU(),
            nn.Dropout(0.25),
            nn.Linear(self.xls_dim, self.output_dim)
        )

        
        #self.layer_dims = self.make_layer_sizer()
        self.layer_dims = model_utils.ModelUtils.extract_layer_dims(self)
        self.frames_per_window = model_utils.ModelUtils.calculate_layer_sizes(self.layer_dims, torch.tensor([input_wav_length]), -1)[0].int()
        self.frames_per_window = torch.ceil((self.frames_per_window-1)).int()
        self.model_utils = model_utils.ModelUtils(self.layer_dims, input_wav_length, self.frames_per_window)
        #self.model_utils.print_model_info()
        

    
    def update_frames_per_window(self, input_wav_length):
        self.frames_per_window = self.model_utils.calculate_layer_sizes(self.layer_dims, torch.tensor([input_wav_length]), -1)[0].int()
        self.frames_per_window = torch.ceil((self.frames_per_window-1)).int()
        print("frames_per_window (frames per clip if disable_windowing):", self.frames_per_window.item())
        return self.frames_per_window

    def forward(self, x):
        if self.training and self.noise_level > 0:
            x = x + torch.randn_like(x) * self.noise_level
        
        #x = x.unsqueeze(1)
        
        # Remove torch.no_grad() to allow gradients to flow through XLSR
        #if self.freeze_feature_encoder:
        #    with torch.no_grad():
        #        features, _ = self.XLSR.extract_features(x)
        #else:
        features, _ = self.XLSR.extract_features(x)
        features = features[-1]
        
        logits = self.classifier(features)
        return logits




def test_model():
    

    #bundle = torchaudio.pipelines.WAV2VEC2_XLSR53
    #model = bundle.get_model().to(device)

    device = torch.device("cuda:0")
    
    print(torch.__version__)
    print(torchaudio.__version__)
    torch.random.manual_seed(0)
    print(device)

    # Initialize model without freezing any layers
    model = FinetuneXLSR(noise_level=0.01, freeze_feature_encoder=False).to(device)

    print(model.__class__)

    # Resample audio to the expected sampling rate
    sample_path = "tmp/data/audio_samples/9860_8338_000010.flac.wav"

    sample_waveform, sample_samplerate = torchaudio.load(sample_path)
    sample_waveform = sample_waveform.to(device)
    waveform = torchaudio.functional.resample(sample_waveform, sample_samplerate, 16000)
    
    batch = waveform.to(device)
    print(waveform.shape)
    
    # Extract acoustic features
    with torch.inference_mode():
        logits = model(batch)

    print(len(logits), logits[0].shape)
    for x, element in enumerate(logits):
        print(x, element)

def main():
    test_model()

if __name__ == "__main__":
    main()