KevinGeng commited on
Commit
f498134
1 Parent(s): 6d26a9c

Upload lightning_module.py

Browse files
Files changed (1) hide show
  1. lightning_module.py +41 -0
lightning_module.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import torch
3
+ import torch.nn as nn
4
+ import os
5
+ import numpy as np
6
+ import hydra
7
+ from model import load_ssl_model, PhonemeEncoder, DomainEmbedding, LDConditioner, Projection
8
+
9
+
10
+ class BaselineLightningModule(pl.LightningModule):
11
+ def __init__(self, cfg):
12
+ super().__init__()
13
+ self.cfg = cfg
14
+ self.construct_model()
15
+ self.save_hyperparameters()
16
+
17
+ def construct_model(self):
18
+ self.feature_extractors = nn.ModuleList([
19
+ load_ssl_model(cp_path='wav2vec_small.pt'),
20
+ DomainEmbedding(3,128),
21
+ ])
22
+ output_dim = sum([ feature_extractor.get_output_dim() for feature_extractor in self.feature_extractors])
23
+ output_layers = [
24
+ LDConditioner(judge_dim=128,num_judges=3000,input_dim=output_dim)
25
+ ]
26
+ output_dim = output_layers[-1].get_output_dim()
27
+ output_layers.append(
28
+ Projection(hidden_dim=2048,activation=torch.nn.ReLU(),range_clipping=False,input_dim=output_dim)
29
+
30
+ )
31
+
32
+ self.output_layers = nn.ModuleList(output_layers)
33
+
34
+ def forward(self, inputs):
35
+ outputs = {}
36
+ for feature_extractor in self.feature_extractors:
37
+ outputs.update(feature_extractor(inputs))
38
+ x = outputs
39
+ for output_layer in self.output_layers:
40
+ x = output_layer(x,inputs)
41
+ return x