|
import sys, os, json |
|
root = os.sep + os.sep.join(__file__.split(os.sep)[1:__file__.split(os.sep).index("Recurrent-Parameter-Generation")+1]) |
|
sys.path.append(root) |
|
os.chdir(root) |
|
with open("./workspace/config.json", "r") as f: |
|
additional_config = json.load(f) |
|
USE_WANDB = additional_config["use_wandb"] |
|
|
|
|
|
import math |
|
import random |
|
import warnings |
|
from _thread import start_new_thread |
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
if USE_WANDB: import wandb |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.nn import functional as F |
|
from torch.cuda.amp import autocast |
|
|
|
from bitsandbytes import optim |
|
from model import ClassConditionMambaDiffusion as Model |
|
from model.diffusion import DDPMSampler, DDIMSampler |
|
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR |
|
from accelerate.utils import DistributedDataParallelKwargs |
|
from accelerate.utils import AutocastKwargs |
|
from accelerate import Accelerator |
|
|
|
from dataset import ClassInput_ViTTiny |
|
from torch.utils.data import DataLoader |
|
|
|
|
|
class ClassInput_ViTTiny_Dataset(ClassInput_ViTTiny): |
|
data_path = "./dataset/condition_classinput_inference/checkpoint_test" |
|
generated_path = "./workspace/classinput/generated.pth" |
|
test_command = f"python ./dataset/condition_classinput_inference/test.py " |
|
|
|
|
|
|
|
|
|
config = { |
|
|
|
"dataset": None, |
|
"dim_per_token": 8192, |
|
"sequence_length": 'auto', |
|
|
|
"batch_size": 16, |
|
"num_workers": 16, |
|
"total_steps": 120000, |
|
"learning_rate": 0.00003, |
|
"weight_decay": 0.0, |
|
"save_every": 120000//50, |
|
"print_every": 50, |
|
"autocast": lambda i: 5000 < i < 90000, |
|
"checkpoint_save_path": "./checkpoint", |
|
|
|
"test_batch_size": 1, |
|
"generated_path": ClassInput_ViTTiny_Dataset.generated_path, |
|
"test_command": ClassInput_ViTTiny_Dataset.test_command, |
|
|
|
"model_config": { |
|
"num_permutation": "auto", |
|
|
|
"d_condition": 1024, |
|
"d_model": 8192, |
|
"d_state": 128, |
|
"d_conv": 4, |
|
"expand": 2, |
|
"num_layers": 2, |
|
|
|
"diffusion_batch": 512, |
|
"layer_channels": [1, 32, 64, 128, 64, 32, 1], |
|
"model_dim": "auto", |
|
"condition_dim": "auto", |
|
"kernel_size": 7, |
|
"sample_mode": DDPMSampler, |
|
"beta": (0.0001, 0.02), |
|
"T": 1000, |
|
"forward_once": True, |
|
}, |
|
"tag": "generalization", |
|
} |
|
|
|
|
|
|
|
|
|
|
|
print('==> Preparing data..') |
|
train_set = ClassInput_ViTTiny_Dataset(dim_per_token=config["dim_per_token"]) |
|
test_set = ClassInput_ViTTiny_Dataset(dim_per_token=config["dim_per_token"]) |
|
|
|
print("checkpoint number:", train_set.real_length) |
|
|
|
|
|
|
|
if config["model_config"]["num_permutation"] == "auto": |
|
config["model_config"]["num_permutation"] = train_set.max_permutation_state |
|
if config["model_config"]["condition_dim"] == "auto": |
|
config["model_config"]["condition_dim"] = config["model_config"]["d_model"] |
|
if config["model_config"]["model_dim"] == "auto": |
|
config["model_config"]["model_dim"] = config["dim_per_token"] |
|
if config["sequence_length"] == "auto": |
|
config["sequence_length"] = train_set.sequence_length |
|
print(f"sequence length: {config['sequence_length']}") |
|
else: |
|
assert train_set.sequence_length == config["sequence_length"], f"sequence_length={train_set.sequence_length}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print('==> Building model..') |
|
Model.config = config["model_config"] |
|
model = Model( |
|
sequence_length=config["sequence_length"], |
|
positional_embedding=train_set.get_position_embedding( |
|
positional_embedding_dim=config["model_config"]["d_model"], |
|
), |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate(save_path=config["generated_path"], need_test=True): |
|
print("\n==> Generating..") |
|
model.eval() |
|
_, condition = test_set[random.randint(0, len(test_set)-1)] |
|
class_index = str(int("".join([str(int(i)) for i in condition]), 2)).zfill(4) |
|
with torch.no_grad(): |
|
prediction = model(sample=True, condition=condition[None], permutation_state=False) |
|
generated_norm = torch.nanmean((prediction.cpu() * mask).abs()) |
|
print("Generated_norm:", generated_norm.item()) |
|
if USE_WANDB and accelerator.is_main_process: |
|
wandb.log({"generated_norm": generated_norm.item()}) |
|
if accelerator.is_main_process: |
|
train_set.save_params(prediction, save_path=save_path.format(class_index)) |
|
if need_test: |
|
start_new_thread(os.system, (config["test_command"].format(class_index),)) |
|
model.train() |
|
return prediction |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|