|
from . import * |
|
|
|
import sys |
|
import os |
|
|
|
|
|
current_directory = os.path.dirname(os.path.abspath(__file__)) |
|
external_directory = os.path.abspath(os.path.join(current_directory, '../data')) |
|
|
|
|
|
sys.path.append(external_directory) |
|
|
|
|
|
from data_utils import load_datasets, create_train_valid_dataloaders |
|
from model import init_ldm_model, init_diff_pro_sdf |
|
|
|
|
|
class LdmTrainConfig(TrainConfig): |
|
|
|
def __init__(self, params, output_dir, mode, |
|
mask_background, multi_phrase_label, random_pitch_aug, debug_mode=False) -> None: |
|
super().__init__(params, None, output_dir) |
|
self.debug_mode = debug_mode |
|
|
|
|
|
self.mask_background = mask_background |
|
self.multi_phrase_label = multi_phrase_label |
|
self.random_pitch_aug = random_pitch_aug |
|
|
|
|
|
self.ldm_model = init_ldm_model(mode, params, debug_mode) |
|
self.model = init_diff_pro_sdf(self.ldm_model, params, self.device) |
|
|
|
|
|
load_first_n = 10 if self.debug_mode else None |
|
train_set, valid_set = load_datasets( |
|
mode, multi_phrase_label, random_pitch_aug, |
|
mask_background, load_first_n |
|
) |
|
self.train_dl, self.val_dl = create_train_valid_dataloaders(params.batch_size, train_set, valid_set) |
|
|
|
|
|
self.optimizer = torch.optim.Adam( |
|
self.model.parameters(), lr=params.learning_rate |
|
) |
|
|