Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	| import pytorch_lightning as pl | |
| import torch | |
| import torch.nn as nn | |
| import os | |
| import numpy as np | |
| import hydra | |
| from model import load_ssl_model, PhonemeEncoder, DomainEmbedding, LDConditioner, Projection | |
| class BaselineLightningModule(pl.LightningModule): | |
| def __init__(self, cfg): | |
| super().__init__() | |
| self.cfg = cfg | |
| self.construct_model() | |
| self.save_hyperparameters() | |
| def construct_model(self): | |
| self.feature_extractors = nn.ModuleList([ | |
| load_ssl_model(cp_path='wav2vec_small.pt'), | |
| DomainEmbedding(3,128), | |
| ]) | |
| output_dim = sum([ feature_extractor.get_output_dim() for feature_extractor in self.feature_extractors]) | |
| output_layers = [ | |
| LDConditioner(judge_dim=128,num_judges=3000,input_dim=output_dim) | |
| ] | |
| output_dim = output_layers[-1].get_output_dim() | |
| output_layers.append( | |
| Projection(hidden_dim=2048,activation=torch.nn.ReLU(),range_clipping=False,input_dim=output_dim) | |
| ) | |
| self.output_layers = nn.ModuleList(output_layers) | |
| def forward(self, inputs): | |
| outputs = {} | |
| for feature_extractor in self.feature_extractors: | |
| outputs.update(feature_extractor(inputs)) | |
| x = outputs | |
| for output_layer in self.output_layers: | |
| x = output_layer(x,inputs) | |
| return x | |
