Spaces:
Sleeping
Sleeping
import os | |
import csv | |
import time | |
import torch | |
import random | |
import torch.nn.functional as F | |
from copy import deepcopy | |
from src.smb.level import * | |
from torch.optim import Adam | |
from src.utils.mymath import crowdivs | |
from src.utils.filesys import auto_dire | |
from src.utils.img import make_img_sheet | |
from src.utils.datastruct import batched_iter | |
from src.gan.gans import SAGenerator, SADiscriminator | |
from src.gan.gankits import process_onehot, sample_latvec | |
def get_gan_train_data(): | |
H, W = MarioLevel.height, MarioLevel.seg_width | |
data = [] | |
for lvl, _ in traverse_level_files('smb/levels'): | |
num_lvl = lvl.to_num_arr() | |
_, length = num_lvl.shape | |
for s in range(length - W): | |
seg = num_lvl[:, s: s+W] | |
onehot = np.zeros([MarioLevel.n_types, H, W]) | |
xs = [seg[i, j] for i, j in product(range(H), range(W))] | |
ys = [k // W for k in range(H * W)] | |
zs = [k % W for k in range(H * W)] | |
onehot[xs, ys, zs] = 1 | |
data.append(onehot) | |
return data | |
def set_GAN_parser(parser): | |
parser.add_argument('--batch_size', type=int, default=128, help='input batch size') | |
parser.add_argument('--niter', type=int, default=2000, help='number of iterations to training GAN') | |
parser.add_argument('--eval_itv', type=int, default=10, help='Interval (in unit of iteration) of evaluating and logging') | |
parser.add_argument('--save_itv', type=int, default=100, help='Interval (in unit of iteration) of saving agent and samples') | |
parser.add_argument('--repeatD', type=int, default=5, help='repeatly training D for how many time for each iteration') | |
parser.add_argument('--repeatG', type=int, default=1, help='repeatly training G for how many time for each iteration') | |
parser.add_argument('--lrD', type=float, default=4e-4, help='learning rate for D, default=4e-4') | |
parser.add_argument('--lrG', type=float, default=1e-4, help='learning rate for G, default=1e-4') | |
parser.add_argument('--regD', type=float, default=3e-4, help='weight_decay for D, default=1e-3') | |
parser.add_argument('--regG', type=float, default=0., help='weight_decay for G, default=0') | |
parser.add_argument('--beta1', type=float, default=0., help='beta1 parameter for Adam optimiser, default=0.') | |
parser.add_argument('--beta2', type=float, default=0.9, help='beta2 parameter for Adam optimiser, default=0.9') | |
parser.add_argument('--gpuid', type=int, default=0, help='id of gpu. If smaller than 0, use cpu') | |
parser.add_argument('--res_path', type=str, default='', help='root_folder to store training data') | |
parser.add_argument('--weight_clip', type=float, default=0., help='clip weight of dicriminator into [-this, this] if this > 0') | |
parser.add_argument('--noise', type=str, default='uniform', help='Type of noise distribution') | |
parser.add_argument('--base_channels', type=int, default=32, help='Number of channels of the layer with least channels') | |
def train_GAN(args): | |
def evaluate_diversity(levels_): | |
hamming_tab = np.array([[hamming_dis(l1, l2) for l1 in levels_] for l2 in levels_]) | |
tpjs_tab = np.array([[tile_pattern_js_div(l1, l2) for l1 in levels_] for l2 in levels_]) | |
hamming_divs_ = crowdivs(hamming_tab) | |
tpjs_divs_ = crowdivs(tpjs_tab) | |
return hamming_divs_, tpjs_divs_ | |
device = 'cpu' if args.gpuid < 0 or not torch.cuda.is_available() else f'cuda:{args.gpuid}' | |
netG = SAGenerator(args.base_channels).to(device) | |
netD = SADiscriminator(args.base_channels).to(device) | |
optG = Adam(netG.parameters(), lr=args.lrG, betas=(args.beta1, args.beta2), weight_decay=args.regG) | |
optD = Adam(netD.parameters(), lr=args.lrD, betas=(args.beta1, args.beta2), weight_decay=args.regD) | |
data = get_gan_train_data() | |
data = [torch.tensor(item, device=device, dtype=torch.float) for item in data] | |
if args.res_path == '': | |
res_path = auto_dire('training_data', name='GAN') | |
else: | |
res_path = getpath('training_data/' + args.res_path) | |
try: | |
os.makedirs(res_path) | |
except FileExistsError: | |
print(f'Training cancelled due to decoder.pth already exists in {res_path}') | |
return | |
with open(getpath(f'{res_path}/NN_architectures.txt'), 'w') as f: | |
f.write('=' * 24 +' Generator ' + '=' * 24 + '\n') | |
f.write(str(netG)) | |
f.write('\n' + '=' * 22 +' Discriminator ' + '=' * 22 + '\n') | |
f.write(str(netD)) | |
cfgs = deepcopy(vars(args)) | |
cfgs.pop('entry') | |
cfgs['start-time'] = time.strftime('%Y-%m-%plc %H:%M:%S', time.localtime()) | |
with open(f'{res_path}/cfgs.csv', 'w') as f: | |
w = csv.writer(f) | |
w.writerow(['key', 'value', '']) | |
w.writerows(list(cfgs.items())) | |
start_time = time.time() | |
log_target = open(f'{res_path}/logs.csv', 'w') | |
log_writer = csv.writer(log_target) | |
log_writer.writerow(['Iterations', 'D-real', 'D-fake', 'Divs-hamming', 'Divs-tpjs', 'Time', '']) | |
# log_data = [] | |
for t in range(args.niter): | |
random.shuffle(data) | |
# Train | |
for item, n in batched_iter(data, args.batch_size): | |
real = torch.stack(item) | |
for _ in range(args.repeatD): | |
if args.weight_clip > 0: | |
for p in netD.parameters(): | |
p.data.clamp_(-args.weight_clip, args.weight_clip) | |
with torch.no_grad(): | |
z = sample_latvec(n, device=device, distribuion=args.noise) | |
fake = netG(z) | |
l_real = F.relu(1 - netD(real)).mean() | |
l_fake = F.relu(netD(fake) + 1).mean() | |
optD.zero_grad() | |
l_real.backward() | |
l_fake.backward() | |
optD.step() | |
for _ in range(args.repeatG): | |
sample_latvec(n, device=device, distribuion=args.noise) | |
fake = netG(z) | |
optG.zero_grad() | |
loss_G = -netD(fake).mean() | |
loss_G.backward() | |
optG.step() | |
if t % args.save_itv == (args.save_itv - 1): | |
netG.eval() | |
netD.eval() | |
with torch.no_grad(): | |
z = sample_latvec(54, device=device, distribuion=args.noise) | |
fake = netG(z) | |
levels = process_onehot(fake) | |
iteration_path = res_path + f'/iteration{t+1}' | |
os.makedirs(iteration_path, exist_ok=True) | |
imgs = [lvl.to_img() for lvl in levels] | |
make_img_sheet(imgs, 9, save_path=f'{iteration_path}/samplesheet.png') | |
torch.save(netG, getpath(iteration_path + '/decoder.pth')) | |
# pds.DataFrame(log_data, columns=log_keys).to_csv(f'{path_}/log.csv') | |
netD.train() | |
netG.train() | |
netG.eval() | |
netD.eval() | |
with torch.no_grad(): | |
z = sample_latvec(54, device=device, distribuion=args.noise) | |
fake = netG(z) | |
levels = process_onehot(fake) | |
imgs = [lvl.to_img() for lvl in levels] | |
torch.save(netG, f'{res_path}/decoder.pth') | |
make_img_sheet(imgs, 9, save_path=f'{res_path}/samplesheet.png') | |
log_target.close() |