|  | import sys, os, json | 
					
						
						|  | root = os.sep + os.sep.join(__file__.split(os.sep)[1:__file__.split(os.sep).index("Recurrent-Parameter-Generation")+1]) | 
					
						
						|  | sys.path.append(root) | 
					
						
						|  | os.chdir(root) | 
					
						
						|  | with open("./workspace/config.json", "r") as f: | 
					
						
						|  | additional_config = json.load(f) | 
					
						
						|  | USE_WANDB = additional_config["use_wandb"] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | import math | 
					
						
						|  | import random | 
					
						
						|  | import warnings | 
					
						
						|  | from _thread import start_new_thread | 
					
						
						|  | warnings.filterwarnings("ignore", category=UserWarning) | 
					
						
						|  | if USE_WANDB: import wandb | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | import torch.nn as nn | 
					
						
						|  | from torch.nn import functional as F | 
					
						
						|  | from torch.cuda.amp import autocast | 
					
						
						|  |  | 
					
						
						|  | from bitsandbytes import optim | 
					
						
						|  | from model import ClassConditionMambaDiffusion as Model | 
					
						
						|  | from model.diffusion import DDPMSampler, DDIMSampler | 
					
						
						|  | from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR | 
					
						
						|  | from accelerate.utils import DistributedDataParallelKwargs | 
					
						
						|  | from accelerate.utils import AutocastKwargs | 
					
						
						|  | from accelerate import Accelerator | 
					
						
						|  |  | 
					
						
						|  | from dataset import ClassInput_ViTTiny | 
					
						
						|  | from torch.utils.data import DataLoader | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ClassInput_ViTTiny_Dataset(ClassInput_ViTTiny): | 
					
						
						|  | data_path = "./dataset/condition_classinput_inference/checkpoint_test" | 
					
						
						|  | generated_path = "./workspace/classinput/generated.pth" | 
					
						
						|  | test_command = f"python ./dataset/condition_classinput_inference/test.py " | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | config = { | 
					
						
						|  |  | 
					
						
						|  | "dataset": None, | 
					
						
						|  | "dim_per_token": 8192, | 
					
						
						|  | "sequence_length": 'auto', | 
					
						
						|  |  | 
					
						
						|  | "batch_size": 16, | 
					
						
						|  | "num_workers": 16, | 
					
						
						|  | "total_steps": 120000, | 
					
						
						|  | "learning_rate": 0.00003, | 
					
						
						|  | "weight_decay": 0.0, | 
					
						
						|  | "save_every": 120000//50, | 
					
						
						|  | "print_every": 50, | 
					
						
						|  | "autocast": lambda i: 5000 < i < 90000, | 
					
						
						|  | "checkpoint_save_path": "./checkpoint", | 
					
						
						|  |  | 
					
						
						|  | "test_batch_size": 1, | 
					
						
						|  | "generated_path": ClassInput_ViTTiny_Dataset.generated_path, | 
					
						
						|  | "test_command": ClassInput_ViTTiny_Dataset.test_command, | 
					
						
						|  |  | 
					
						
						|  | "model_config": { | 
					
						
						|  | "num_permutation": "auto", | 
					
						
						|  |  | 
					
						
						|  | "d_condition": 1024, | 
					
						
						|  | "d_model": 8192, | 
					
						
						|  | "d_state": 128, | 
					
						
						|  | "d_conv": 4, | 
					
						
						|  | "expand": 2, | 
					
						
						|  | "num_layers": 2, | 
					
						
						|  |  | 
					
						
						|  | "diffusion_batch": 512, | 
					
						
						|  | "layer_channels": [1, 32, 64, 128, 64, 32, 1], | 
					
						
						|  | "model_dim": "auto", | 
					
						
						|  | "condition_dim": "auto", | 
					
						
						|  | "kernel_size": 7, | 
					
						
						|  | "sample_mode": DDPMSampler, | 
					
						
						|  | "beta": (0.0001, 0.02), | 
					
						
						|  | "T": 1000, | 
					
						
						|  | "forward_once": True, | 
					
						
						|  | }, | 
					
						
						|  | "tag": "generalization", | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | print('==> Preparing data..') | 
					
						
						|  | train_set = ClassInput_ViTTiny_Dataset(dim_per_token=config["dim_per_token"]) | 
					
						
						|  | test_set = ClassInput_ViTTiny_Dataset(dim_per_token=config["dim_per_token"]) | 
					
						
						|  |  | 
					
						
						|  | print("checkpoint number:", train_set.real_length) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if config["model_config"]["num_permutation"] == "auto": | 
					
						
						|  | config["model_config"]["num_permutation"] = train_set.max_permutation_state | 
					
						
						|  | if config["model_config"]["condition_dim"] == "auto": | 
					
						
						|  | config["model_config"]["condition_dim"] = config["model_config"]["d_model"] | 
					
						
						|  | if config["model_config"]["model_dim"] == "auto": | 
					
						
						|  | config["model_config"]["model_dim"] = config["dim_per_token"] | 
					
						
						|  | if config["sequence_length"] == "auto": | 
					
						
						|  | config["sequence_length"] = train_set.sequence_length | 
					
						
						|  | print(f"sequence length: {config['sequence_length']}") | 
					
						
						|  | else: | 
					
						
						|  | assert train_set.sequence_length == config["sequence_length"], f"sequence_length={train_set.sequence_length}" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | print('==> Building model..') | 
					
						
						|  | Model.config = config["model_config"] | 
					
						
						|  | model = Model( | 
					
						
						|  | sequence_length=config["sequence_length"], | 
					
						
						|  | positional_embedding=train_set.get_position_embedding( | 
					
						
						|  | positional_embedding_dim=config["model_config"]["d_model"], | 
					
						
						|  | ), | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def generate(save_path=config["generated_path"], need_test=True): | 
					
						
						|  | print("\n==> Generating..") | 
					
						
						|  | model.eval() | 
					
						
						|  | _, condition = test_set[random.randint(0, len(test_set)-1)] | 
					
						
						|  | class_index = str(int("".join([str(int(i)) for i in condition]), 2)).zfill(4) | 
					
						
						|  | with torch.no_grad(): | 
					
						
						|  | prediction = model(sample=True, condition=condition[None], permutation_state=False) | 
					
						
						|  | generated_norm = torch.nanmean((prediction.cpu() * mask).abs()) | 
					
						
						|  | print("Generated_norm:", generated_norm.item()) | 
					
						
						|  | if USE_WANDB and accelerator.is_main_process: | 
					
						
						|  | wandb.log({"generated_norm": generated_norm.item()}) | 
					
						
						|  | if accelerator.is_main_process: | 
					
						
						|  | train_set.save_params(prediction, save_path=save_path.format(class_index)) | 
					
						
						|  | if need_test: | 
					
						
						|  | start_new_thread(os.system, (config["test_command"].format(class_index),)) | 
					
						
						|  | model.train() | 
					
						
						|  | return prediction | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  |