import os import numpy as np import torch import torch.optim as optim import torch.nn as nn import logging from src.ddpm.diffusion import Diffusion from src.ddpm.modules import UNet from src.ddpm.dataset import create_dataloader from pathlib import Path import argparse import datetime from src.gan.gankits import process_onehot, get_decoder from src.smb.level import MarioLevel, lvlhcat, save_batch from src.utils.filesys import getpath from src.utils.img import make_img_sheet sprite_counts = np.power(np.array([ 74977, 15252, 572591, 5826, 1216, 7302, 237, 237, 2852, 1074, 235, 304, 48, 96, 160, 1871, 936, 186, 428, 80, 428 ]), 1/4 ) min_count = np.min(sprite_counts) def setup_logging(run_name, beta_schedule): model_path = os.path.join("models", beta_schedule, run_name) result_path = os.path.join("results", beta_schedule, run_name) os.makedirs(model_path, exist_ok=True) os.makedirs(result_path, exist_ok=True) return model_path, result_path # 测试DDPM的模型训练 def train(args): path = getpath(args.res_path) os.makedirs(path, exist_ok=True) dataloader = create_dataloader(batch_size=args.batch_size, shuffle=True, num_workers=0) device = 'cpu' if args.gpuid < 0 else f'cuda:{args.gpuid}' model = UNet().to(device) optimizer = optim.AdamW(model.parameters(), lr=args.lr) mse = nn.MSELoss() diffusion = Diffusion(device=device, schedule=args.beta_schedule) temperatures = torch.tensor(min_count / sprite_counts, dtype=torch.float32).to(device) l = len(dataloader) for epoch in range(args.resume_from+1, args.resume_from+args.epochs+1): logging.info(f"Starting epoch {epoch}:") epoch_loss = {'rec_loss': 0, 'mse': 0, 'loss': 0} for i, images in enumerate(dataloader): images = images.to(device) t = diffusion.sample_timesteps(images.shape[0]).to(device) # random int from 1~1000 x_t, noise = diffusion.noise_images(images, t) # x_t: image with noise at t, noise: gaussian noise predicted_noise = model(x_t.float(), t.float()) # returns predicted noise eps_theta original_img = images.argmax(dim=1) # batch x 14 x 14 reconstructed_img = diffusion.sample_only_final(x_t, t, predicted_noise, temperatures) rec_loss = -reconstructed_img.log_prob(original_img).sum(dim=(1,2)).mean() # batch mse_loss = mse(noise.float(), predicted_noise.float()) loss = 0.001 * rec_loss + mse_loss epoch_loss['rec_loss'] += rec_loss.item() epoch_loss['mse'] += mse_loss.item() epoch_loss['loss'] += loss.item() optimizer.zero_grad() loss.backward() optimizer.step() print( '\nIteration: %d' % epoch, 'rec_loss: %.5g' % (epoch_loss['rec_loss']/l), 'mse: %.5g' % (epoch_loss['mse']/l) ) if epoch % 1000 == 0: itpath = getpath(path, f'it{epoch}') os.makedirs(itpath, exist_ok=True) model.save(getpath(path, itpath, 'ddpm.pth')) lvls = [] init_lateves = torch.tensor(np.load(getpath('analysis/initial_seg.npy'))) gan = get_decoder() init_seg_onhots = gan(torch.tensor(init_lateves).view(*init_lateves.shape, 1, 1)) i = 0 for init_seg_onehot in init_seg_onhots: seg_onehots = diffusion.sample(model, n=25)[-1] a = init_seg_onehot.view(1, *init_seg_onehot.shape) b = seg_onehots.detach().cpu() print(a.shape, b.shape) segs = process_onehot(torch.cat([a, b], dim=0)) level = lvlhcat(segs) lvls.append(level) save_batch(lvls, getpath(path, 'samples.lvls')) model.save(getpath(path, 'ddpm.pth')) def launch(): parser = argparse.ArgumentParser() parser.add_argument("--epochs", type=int, default=10000) parser.add_argument("--batch_size", type=int, default=256) parser.add_argument("--res_path", type=str, default='exp_data/DDPM') parser.add_argument("--gpuid", type=int, default=0) parser.add_argument("--lr", type=float, default=3e-4) parser.add_argument("--beta_schedule", type=str, default="quadratic", choices=['linear', 'quadratic', 'sigmoid']) parser.add_argument("--run_name", type=str, default=f"{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}") parser.add_argument("--resume_from", type=int, default=0) args = parser.parse_args() train(args) if __name__ == "__main__": launch()