import sys, os, json root = os.sep + os.sep.join(__file__.split(os.sep)[1:__file__.split(os.sep).index("Recurrent-Parameter-Generation")+1]) sys.path.append(root) os.chdir(root) with open("./workspace/config.json", "r") as f: additional_config = json.load(f) USE_WANDB = additional_config["use_wandb"] # other import math import random import warnings from _thread import start_new_thread warnings.filterwarnings("ignore", category=UserWarning) if USE_WANDB: import wandb # torch import torch import torch.nn as nn from torch.nn import functional as F from torch.cuda.amp import autocast # model from bitsandbytes import optim from model import ClassConditionMambaDiffusion as Model from model.diffusion import DDPMSampler, DDIMSampler from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR from accelerate.utils import DistributedDataParallelKwargs from accelerate.utils import AutocastKwargs from accelerate import Accelerator # dataset from dataset import ClassInput_ViTTiny from torch.utils.data import DataLoader class ClassInput_ViTTiny_Dataset(ClassInput_ViTTiny): data_path = "./dataset/condition_classinput_inference/checkpoint_test" generated_path = "./workspace/classinput/generated.pth" test_command = f"python ./dataset/condition_classinput_inference/test.py " config = { # dataset setting "dataset": None, "dim_per_token": 8192, "sequence_length": 'auto', # train setting "batch_size": 16, "num_workers": 16, "total_steps": 120000, "learning_rate": 0.00003, "weight_decay": 0.0, "save_every": 120000//50, "print_every": 50, "autocast": lambda i: 5000 < i < 90000, "checkpoint_save_path": "./checkpoint", # test setting "test_batch_size": 1, # fixed, don't change this "generated_path": ClassInput_ViTTiny_Dataset.generated_path, "test_command": ClassInput_ViTTiny_Dataset.test_command, # to log "model_config": { "num_permutation": "auto", # mamba config "d_condition": 1024, "d_model": 8192, "d_state": 128, "d_conv": 4, "expand": 2, "num_layers": 2, # diffusion config "diffusion_batch": 512, "layer_channels": [1, 32, 64, 128, 64, 32, 1], "model_dim": "auto", "condition_dim": "auto", "kernel_size": 7, "sample_mode": DDPMSampler, "beta": (0.0001, 0.02), "T": 1000, "forward_once": True, }, "tag": "generalization", } # Data print('==> Preparing data..') train_set = ClassInput_ViTTiny_Dataset(dim_per_token=config["dim_per_token"]) test_set = ClassInput_ViTTiny_Dataset(dim_per_token=config["dim_per_token"]) # sample = train_set[0][0] print("checkpoint number:", train_set.real_length) # print("input shape:", sample.shape) # print("useful ratio:", torch.where(torch.isnan(sample), 0., 1.).mean()) # mask = torch.where(torch.isnan(sample), torch.nan, 1.) if config["model_config"]["num_permutation"] == "auto": config["model_config"]["num_permutation"] = train_set.max_permutation_state if config["model_config"]["condition_dim"] == "auto": config["model_config"]["condition_dim"] = config["model_config"]["d_model"] if config["model_config"]["model_dim"] == "auto": config["model_config"]["model_dim"] = config["dim_per_token"] if config["sequence_length"] == "auto": config["sequence_length"] = train_set.sequence_length print(f"sequence length: {config['sequence_length']}") else: # set fixed sequence_length assert train_set.sequence_length == config["sequence_length"], f"sequence_length={train_set.sequence_length}" # train_loader = DataLoader( # dataset=train_set, # batch_size=config["batch_size"], # num_workers=config["num_workers"], # persistent_workers=True, # drop_last=True, # shuffle=True, # ) # # Model print('==> Building model..') Model.config = config["model_config"] model = Model( sequence_length=config["sequence_length"], positional_embedding=train_set.get_position_embedding( positional_embedding_dim=config["model_config"]["d_model"], ), # positional_embedding ) # model setting is in model # # # Optimizer # print('==> Building optimizer..') # optimizer = optim.AdamW8bit( # params=model.parameters(), # lr=config["learning_rate"], # weight_decay=config["weight_decay"], # ) # optimizer # scheduler = CosineAnnealingLR( # optimizer=optimizer, # T_max=config["total_steps"], # ) # scheduler # # # accelerator # if __name__ == "__main__": # kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) # accelerator = Accelerator(kwargs_handlers=[kwargs,]) # model, optimizer, train_loader = accelerator.prepare(model, optimizer, train_loader) # # # # wandb # if __name__ == "__main__" and USE_WANDB and accelerator.is_main_process: # wandb.login(key=additional_config["wandb_api_key"]) # wandb.init(project="Recurrent-Parameter-Generation", name=config['tag'], config=config,) # Training # print('==> Defining training..') # def train(): # if not USE_WANDB: # train_loss = 0 # this_steps = 0 # print("==> Start training..") # model.train() # for batch_idx, (param, condition) in enumerate(train_loader): # optimizer.zero_grad() # # train # # noinspection PyArgumentList # with accelerator.autocast(autocast_handler=AutocastKwargs(enabled=config["autocast"](batch_idx))): # loss = model( # output_shape=param.shape, # x_0=param, # condition=condition, # permutation_state=None, # ) # accelerator.backward(loss) # optimizer.step() # if accelerator.is_main_process: # scheduler.step() # # to logging losses and print and save # if USE_WANDB and accelerator.is_main_process: # wandb.log({"train_loss": loss.item()}) # elif USE_WANDB: # pass # don't print # else: # not use wandb # train_loss += loss.item() # this_steps += 1 # if this_steps % config["print_every"] == 0: # print('Loss: %.6f' % (train_loss/this_steps)) # this_steps = 0 # train_loss = 0 # if batch_idx % config["save_every"] == 0 and accelerator.is_main_process: # os.makedirs(config["checkpoint_save_path"], exist_ok=True) # state = accelerator.unwrap_model(model).state_dict() # torch.save(state, os.path.join(config["checkpoint_save_path"], # f"{__file__.split('/')[-1].split('.')[0]}.pth")) # generate(save_path=config["generated_path"], need_test=True) # if batch_idx >= config["total_steps"]: # break def generate(save_path=config["generated_path"], need_test=True): print("\n==> Generating..") model.eval() _, condition = test_set[random.randint(0, len(test_set)-1)] class_index = str(int("".join([str(int(i)) for i in condition]), 2)).zfill(4) with torch.no_grad(): prediction = model(sample=True, condition=condition[None], permutation_state=False) generated_norm = torch.nanmean((prediction.cpu() * mask).abs()) print("Generated_norm:", generated_norm.item()) if USE_WANDB and accelerator.is_main_process: wandb.log({"generated_norm": generated_norm.item()}) if accelerator.is_main_process: train_set.save_params(prediction, save_path=save_path.format(class_index)) if need_test: start_new_thread(os.system, (config["test_command"].format(class_index),)) model.train() return prediction # if __name__ == '__main__': # train() # del train_loader # deal problems by dataloader # print("Finished Training!") # exit(0)