|
import pytorch_lightning as pl |
|
import torch |
|
import torch.nn as nn |
|
import os |
|
import numpy as np |
|
import hydra |
|
from model import load_ssl_model, PhonemeEncoder, DomainEmbedding, LDConditioner, Projection |
|
|
|
|
|
class BaselineLightningModule(pl.LightningModule): |
|
def __init__(self, cfg): |
|
super().__init__() |
|
self.cfg = cfg |
|
self.construct_model() |
|
self.save_hyperparameters() |
|
|
|
def construct_model(self): |
|
self.feature_extractors = nn.ModuleList([ |
|
load_ssl_model(cp_path='wav2vec_small.pt'), |
|
DomainEmbedding(3,128), |
|
]) |
|
output_dim = sum([ feature_extractor.get_output_dim() for feature_extractor in self.feature_extractors]) |
|
output_layers = [ |
|
LDConditioner(judge_dim=128,num_judges=3000,input_dim=output_dim) |
|
] |
|
output_dim = output_layers[-1].get_output_dim() |
|
output_layers.append( |
|
Projection(hidden_dim=2048,activation=torch.nn.ReLU(),range_clipping=False,input_dim=output_dim) |
|
|
|
) |
|
|
|
self.output_layers = nn.ModuleList(output_layers) |
|
|
|
def forward(self, inputs): |
|
outputs = {} |
|
for feature_extractor in self.feature_extractors: |
|
outputs.update(feature_extractor(inputs)) |
|
x = outputs |
|
for output_layer in self.output_layers: |
|
x = output_layer(x,inputs) |
|
return x |
|
|