Spaces:
Running
Running
# This code uses the decoder loss directly. | |
# | |
# | |
# Deep learning | |
import torch | |
from torch_optimizer.lamb import Lamb | |
from trainer import TrainerDirectDecoder | |
# Parallel | |
from torch.utils.data.distributed import DistributedSampler | |
from torch.distributed import init_process_group, destroy_process_group | |
# Data | |
from utils import MoleculeModule, get_optim_groups | |
from torch.utils.data import DataLoader | |
# Standard library | |
import os | |
import args | |
def ddp_setup(): | |
init_process_group(backend="nccl") | |
torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) | |
def load_train_objs(config): | |
# load data | |
train_loader = MoleculeModule( | |
config.max_len, | |
config.train_load, | |
config.data_root | |
) | |
train_loader.setup() | |
loader = DataLoader( | |
train_loader.pubchem, | |
batch_size=config.n_batch, | |
pin_memory=True, | |
shuffle=False, | |
collate_fn=train_loader.text_encoder.process, | |
sampler=DistributedSampler(train_loader.pubchem), | |
num_workers=config.n_workers | |
) | |
# load model | |
if config.smi_ted_version == 'v1': | |
from smi_ted_light.load import Smi_ted | |
elif config.smi_ted_version == 'v2': | |
from smi_ted_large.load import Smi_ted | |
model = Smi_ted(config, train_loader.get_vocab()).to('cuda') | |
model.apply(model._init_weights) | |
# load optimizer | |
optim_groups = get_optim_groups(model) | |
optimizer = torch.optim.AdamW(optim_groups, lr=config.lr_decoder, betas=(0.9, 0.99), fused=True) | |
return loader, model, optimizer | |
def main( | |
config, | |
save_every: int, | |
total_epochs: int, | |
save_checkpoint_path: str, | |
load_checkpoint_path: str | |
): | |
ddp_setup() | |
# training objects | |
train_data, model, optimizer = load_train_objs(config) | |
# init trainer | |
trainer = TrainerDirectDecoder( | |
model, | |
train_data, | |
optimizer, | |
save_every, | |
save_checkpoint_path, | |
load_checkpoint_path, | |
config | |
) | |
trainer.train(total_epochs) | |
destroy_process_group() | |
if __name__ == '__main__': | |
parser = args.get_parser() | |
args = parser.parse_args() | |
main( | |
args, | |
args.checkpoint_every, | |
args.max_epochs, | |
save_checkpoint_path=args.save_checkpoint_path, | |
load_checkpoint_path=args.load_checkpoint_path, | |
) | |