KevinGeng's picture
Upload lightning_module.py
f498134
raw
history blame
1.37 kB
import pytorch_lightning as pl
import torch
import torch.nn as nn
import os
import numpy as np
import hydra
from model import load_ssl_model, PhonemeEncoder, DomainEmbedding, LDConditioner, Projection
class BaselineLightningModule(pl.LightningModule):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.construct_model()
self.save_hyperparameters()
def construct_model(self):
self.feature_extractors = nn.ModuleList([
load_ssl_model(cp_path='wav2vec_small.pt'),
DomainEmbedding(3,128),
])
output_dim = sum([ feature_extractor.get_output_dim() for feature_extractor in self.feature_extractors])
output_layers = [
LDConditioner(judge_dim=128,num_judges=3000,input_dim=output_dim)
]
output_dim = output_layers[-1].get_output_dim()
output_layers.append(
Projection(hidden_dim=2048,activation=torch.nn.ReLU(),range_clipping=False,input_dim=output_dim)
)
self.output_layers = nn.ModuleList(output_layers)
def forward(self, inputs):
outputs = {}
for feature_extractor in self.feature_extractors:
outputs.update(feature_extractor(inputs))
x = outputs
for output_layer in self.output_layers:
x = output_layer(x,inputs)
return x