Doven
update code.
f7009b3
import sys, os, json
root = os.sep + os.sep.join(__file__.split(os.sep)[1:__file__.split(os.sep).index("Recurrent-Parameter-Generation")+1])
sys.path.append(root)
os.chdir(root)
with open("./workspace/config.json", "r") as f:
additional_config = json.load(f)
USE_WANDB = additional_config["use_wandb"]
# other
import math
import random
import warnings
from _thread import start_new_thread
warnings.filterwarnings("ignore", category=UserWarning)
if USE_WANDB: import wandb
# torch
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.cuda.amp import autocast
# model
from bitsandbytes import optim
from model import ClassConditionMambaDiffusion as Model
from model.diffusion import DDPMSampler, DDIMSampler
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
from accelerate.utils import DistributedDataParallelKwargs
from accelerate.utils import AutocastKwargs
from accelerate import Accelerator
# dataset
from dataset import ClassInput_ViTTiny
from torch.utils.data import DataLoader
class ClassInput_ViTTiny_Dataset(ClassInput_ViTTiny):
data_path = "./dataset/condition_classinput_inference/checkpoint_test"
generated_path = "./workspace/classinput/generated.pth"
test_command = f"python ./dataset/condition_classinput_inference/test.py "
config = {
# dataset setting
"dataset": None,
"dim_per_token": 8192,
"sequence_length": 'auto',
# train setting
"batch_size": 16,
"num_workers": 16,
"total_steps": 120000,
"learning_rate": 0.00003,
"weight_decay": 0.0,
"save_every": 120000//50,
"print_every": 50,
"autocast": lambda i: 5000 < i < 90000,
"checkpoint_save_path": "./checkpoint",
# test setting
"test_batch_size": 1, # fixed, don't change this
"generated_path": ClassInput_ViTTiny_Dataset.generated_path,
"test_command": ClassInput_ViTTiny_Dataset.test_command,
# to log
"model_config": {
"num_permutation": "auto",
# mamba config
"d_condition": 1024,
"d_model": 8192,
"d_state": 128,
"d_conv": 4,
"expand": 2,
"num_layers": 2,
# diffusion config
"diffusion_batch": 512,
"layer_channels": [1, 32, 64, 128, 64, 32, 1],
"model_dim": "auto",
"condition_dim": "auto",
"kernel_size": 7,
"sample_mode": DDPMSampler,
"beta": (0.0001, 0.02),
"T": 1000,
"forward_once": True,
},
"tag": "generalization",
}
# Data
print('==> Preparing data..')
train_set = ClassInput_ViTTiny_Dataset(dim_per_token=config["dim_per_token"])
test_set = ClassInput_ViTTiny_Dataset(dim_per_token=config["dim_per_token"])
# sample = train_set[0][0]
print("checkpoint number:", train_set.real_length)
# print("input shape:", sample.shape)
# print("useful ratio:", torch.where(torch.isnan(sample), 0., 1.).mean())
# mask = torch.where(torch.isnan(sample), torch.nan, 1.)
if config["model_config"]["num_permutation"] == "auto":
config["model_config"]["num_permutation"] = train_set.max_permutation_state
if config["model_config"]["condition_dim"] == "auto":
config["model_config"]["condition_dim"] = config["model_config"]["d_model"]
if config["model_config"]["model_dim"] == "auto":
config["model_config"]["model_dim"] = config["dim_per_token"]
if config["sequence_length"] == "auto":
config["sequence_length"] = train_set.sequence_length
print(f"sequence length: {config['sequence_length']}")
else: # set fixed sequence_length
assert train_set.sequence_length == config["sequence_length"], f"sequence_length={train_set.sequence_length}"
# train_loader = DataLoader(
# dataset=train_set,
# batch_size=config["batch_size"],
# num_workers=config["num_workers"],
# persistent_workers=True,
# drop_last=True,
# shuffle=True,
# )
#
# Model
print('==> Building model..')
Model.config = config["model_config"]
model = Model(
sequence_length=config["sequence_length"],
positional_embedding=train_set.get_position_embedding(
positional_embedding_dim=config["model_config"]["d_model"],
), # positional_embedding
) # model setting is in model
#
# # Optimizer
# print('==> Building optimizer..')
# optimizer = optim.AdamW8bit(
# params=model.parameters(),
# lr=config["learning_rate"],
# weight_decay=config["weight_decay"],
# ) # optimizer
# scheduler = CosineAnnealingLR(
# optimizer=optimizer,
# T_max=config["total_steps"],
# ) # scheduler
#
# # accelerator
# if __name__ == "__main__":
# kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
# accelerator = Accelerator(kwargs_handlers=[kwargs,])
# model, optimizer, train_loader = accelerator.prepare(model, optimizer, train_loader)
#
#
# # wandb
# if __name__ == "__main__" and USE_WANDB and accelerator.is_main_process:
# wandb.login(key=additional_config["wandb_api_key"])
# wandb.init(project="Recurrent-Parameter-Generation", name=config['tag'], config=config,)
# Training
# print('==> Defining training..')
# def train():
# if not USE_WANDB:
# train_loss = 0
# this_steps = 0
# print("==> Start training..")
# model.train()
# for batch_idx, (param, condition) in enumerate(train_loader):
# optimizer.zero_grad()
# # train
# # noinspection PyArgumentList
# with accelerator.autocast(autocast_handler=AutocastKwargs(enabled=config["autocast"](batch_idx))):
# loss = model(
# output_shape=param.shape,
# x_0=param,
# condition=condition,
# permutation_state=None,
# )
# accelerator.backward(loss)
# optimizer.step()
# if accelerator.is_main_process:
# scheduler.step()
# # to logging losses and print and save
# if USE_WANDB and accelerator.is_main_process:
# wandb.log({"train_loss": loss.item()})
# elif USE_WANDB:
# pass # don't print
# else: # not use wandb
# train_loss += loss.item()
# this_steps += 1
# if this_steps % config["print_every"] == 0:
# print('Loss: %.6f' % (train_loss/this_steps))
# this_steps = 0
# train_loss = 0
# if batch_idx % config["save_every"] == 0 and accelerator.is_main_process:
# os.makedirs(config["checkpoint_save_path"], exist_ok=True)
# state = accelerator.unwrap_model(model).state_dict()
# torch.save(state, os.path.join(config["checkpoint_save_path"],
# f"{__file__.split('/')[-1].split('.')[0]}.pth"))
# generate(save_path=config["generated_path"], need_test=True)
# if batch_idx >= config["total_steps"]:
# break
def generate(save_path=config["generated_path"], need_test=True):
print("\n==> Generating..")
model.eval()
_, condition = test_set[random.randint(0, len(test_set)-1)]
class_index = str(int("".join([str(int(i)) for i in condition]), 2)).zfill(4)
with torch.no_grad():
prediction = model(sample=True, condition=condition[None], permutation_state=False)
generated_norm = torch.nanmean((prediction.cpu() * mask).abs())
print("Generated_norm:", generated_norm.item())
if USE_WANDB and accelerator.is_main_process:
wandb.log({"generated_norm": generated_norm.item()})
if accelerator.is_main_process:
train_set.save_params(prediction, save_path=save_path.format(class_index))
if need_test:
start_new_thread(os.system, (config["test_command"].format(class_index),))
model.train()
return prediction
# if __name__ == '__main__':
# train()
# del train_loader # deal problems by dataloader
# print("Finished Training!")
# exit(0)