|
import os |
|
import argparse |
|
from transformers import set_seed |
|
from src.scripts.mytokenizers import Tokenizer |
|
from src.improved_diffusion import gaussian_diffusion as gd |
|
from src.improved_diffusion.respace import SpacedDiffusion |
|
from src.improved_diffusion import dist_util |
|
from src.improved_diffusion.transformer_model import TransformerNetModel |
|
from src.improved_diffusion.resample import create_named_schedule_sampler |
|
from src.improved_diffusion.script_util import model_and_diffusion_defaults |
|
from src.improved_diffusion.script_util import add_dict_to_argparser |
|
from src.improved_diffusion.train_util import TrainLoop |
|
import torch.distributed as dist |
|
import wandb |
|
from src.scripts.mydatasets import get_dataloader, Lang2molDataset_train |
|
import warnings |
|
import torch.multiprocessing as mp |
|
|
|
|
|
def main_worker(rank, world_size): |
|
args = create_argparser().parse_args() |
|
set_seed(42) |
|
|
|
wandb.login(key=args.wandb_token) |
|
wandb.init( |
|
project="ACL_Lang2Mol", |
|
config=args.__dict__, |
|
) |
|
|
|
dist_util.setup_dist(rank, world_size) |
|
tokenizer = Tokenizer() |
|
model = TransformerNetModel( |
|
in_channels=args.model_in_channels, |
|
model_channels=args.model_model_channels, |
|
dropout=args.model_dropout, |
|
vocab_size=len(tokenizer), |
|
hidden_size=args.model_hidden_size, |
|
num_attention_heads=args.model_num_attention_heads, |
|
num_hidden_layers=args.model_num_hidden_layers, |
|
) |
|
if args.model_path != "": |
|
model.load_state_dict( |
|
dist_util.load_state_dict(args.model_path, map_location="cpu") |
|
) |
|
|
|
model.train() |
|
|
|
print("Total params:", sum(p.numel() for p in model.parameters())) |
|
print( |
|
"Total trainable params:", |
|
sum(p.numel() for p in model.parameters() if p.requires_grad), |
|
) |
|
print("Tokenizer vocab length:", len(tokenizer)) |
|
|
|
diffusion = SpacedDiffusion( |
|
use_timesteps=[i for i in range(args.diffusion_steps)], |
|
betas=gd.get_named_beta_schedule("sqrt", args.diffusion_steps), |
|
model_mean_type=(gd.ModelMeanType.START_X), |
|
model_var_type=((gd.ModelVarType.FIXED_LARGE)), |
|
loss_type=gd.LossType.E2E_MSE, |
|
rescale_timesteps=True, |
|
model_arch="transformer", |
|
training_mode="e2e", |
|
) |
|
|
|
schedule_sampler = create_named_schedule_sampler("uniform", diffusion) |
|
|
|
print("Loading data...") |
|
train_dataset = Lang2molDataset_train( |
|
dir=args.dataset_path, |
|
tokenizer=tokenizer, |
|
split="train", |
|
corrupt_prob=0.0, |
|
token_max_length=512, |
|
dataset_name=args.dataset_name, |
|
) |
|
dataloader = get_dataloader(train_dataset, args.batch_size, rank, world_size) |
|
print("Finish loading data") |
|
|
|
TrainLoop( |
|
model=model, |
|
diffusion=diffusion, |
|
data=dataloader, |
|
batch_size=args.batch_size, |
|
microbatch=args.microbatch, |
|
lr=args.lr, |
|
ema_rate=args.ema_rate, |
|
log_interval=args.log_interval, |
|
save_interval=args.save_interval, |
|
resume_checkpoint=args.resume_checkpoint, |
|
use_fp16=args.use_fp16, |
|
fp16_scale_growth=args.fp16_scale_growth, |
|
schedule_sampler=schedule_sampler, |
|
weight_decay=args.weight_decay, |
|
lr_anneal_steps=args.lr_anneal_steps, |
|
checkpoint_path=args.checkpoint_path, |
|
gradient_clipping=args.gradient_clipping, |
|
eval_data=None, |
|
eval_interval=args.eval_interval, |
|
).run_loop() |
|
dist.destroy_process_group() |
|
|
|
|
|
def create_argparser(): |
|
defaults = dict() |
|
text_defaults = dict( |
|
wandb_token="", |
|
batch_size=16, |
|
cache_mode="no", |
|
checkpoint_path="checkpoints", |
|
class_cond=False, |
|
config="ll", |
|
config_name="QizhiPei/biot5-base-text2mol", |
|
dataset_path="dataset", |
|
diffusion_steps=2000, |
|
dropout=0.01, |
|
e2e_train="", |
|
ema_rate="0.9999", |
|
emb_scale_factor=1.0, |
|
eval_interval=2000, |
|
experiment="random", |
|
experiment_mode="lm", |
|
fp16_scale_growth=0.001, |
|
gradient_clipping=2.4, |
|
image_size=8, |
|
in_channel=16, |
|
learn_sigma=False, |
|
log_interval=1000, |
|
logits_mode=1, |
|
lr=0.00005, |
|
lr_anneal_steps=500000, |
|
microbatch=-1, |
|
modality="e2e-tgt", |
|
model_arch="transformer", |
|
noise_level=0.0, |
|
noise_schedule="sqrt", |
|
num_channels=128, |
|
num_heads=4, |
|
num_heads_upsample=-1, |
|
num_res_blocks=2, |
|
out_channel=16, |
|
padding_mode="pad", |
|
predict_xstart=True, |
|
preprocessing_num_workers=1, |
|
rescale_learned_sigmas=True, |
|
rescale_timesteps=True, |
|
resume_checkpoint="", |
|
save_interval=50000, |
|
schedule_sampler="uniform", |
|
seed=42, |
|
timestep_respacing="", |
|
training_mode="e2e", |
|
use_bert_tokenizer="no", |
|
use_checkpoint=False, |
|
use_fp16=False, |
|
use_kl=False, |
|
use_scale_shift_norm=True, |
|
weight_decay=0.0, |
|
model_in_channels=32, |
|
model_model_channels=128, |
|
model_dropout=0.01, |
|
model_hidden_size=1024, |
|
model_num_attention_heads=16, |
|
model_num_hidden_layers=12, |
|
dataset_name="", |
|
model_path="", |
|
) |
|
defaults.update(model_and_diffusion_defaults()) |
|
defaults.update(text_defaults) |
|
parser = argparse.ArgumentParser() |
|
add_dict_to_argparser(parser, defaults) |
|
return parser |
|
|
|
|
|
if __name__ == "__main__": |
|
world_size = 1 |
|
mp.spawn(main_worker, args=(world_size,), nprocs=world_size, join=True) |
|
|