| import argparse | |
| import numpy as np | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from utils import TrainSet | |
| from AdaIN import AdaINNet | |
| from tqdm import tqdm | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--content_dir', type=str, required=True, help='content images folder path') | |
| parser.add_argument('--style_dir', type=str, required=True, help='style images folder path') | |
| parser.add_argument('--epochs', type=int, default=1, help='Number of epoch') | |
| parser.add_argument('--batch_size', type=int, default=8, help='Batch size') | |
| parser.add_argument('--resume', type=int, default=0, help='Continue training from epoch') | |
| parser.add_argument('--cuda', action='store_true', help='Use CUDA') | |
| args = parser.parse_args() | |
| device = torch.device('cuda' if args.cuda and torch.cuda.is_available() else 'cpu') | |
| check_point_dir = './check_point/' | |
| weights_dir = './weights/' | |
| train_set = TrainSet(args.content_dir, args.style_dir) | |
| train_loader = DataLoader(dataset=train_set, batch_size=args.batch_size, shuffle=True) | |
| vgg_model = torch.load('vgg_normalized.pth') | |
| model = AdaINNet(vgg_model).to(device) | |
| decoder_optimizer = torch.optim.Adam(model.decoder.parameters(), lr=1e-6) | |
| total_loss, content_loss, style_loss = 0.0, 0.0, 0.0 | |
| losses = [] | |
| iteration = 0 | |
| if args.resume > 0: | |
| states = torch.load(check_point_dir + "epoch_" + str(args.resume)+'.pth') | |
| model.decoder.load_state_dict(states['decoder']) | |
| decoder_optimizer.load_state_dict(states['decoder_optimizer']) | |
| losses = states['losses'] | |
| iteration = states['iteration'] | |
| for epoch in range(args.resume + 1, args.epochs + 1): | |
| print("Begin epoch: %i/%i" % (epoch, int(args.epochs))) | |
| train_tqdm = tqdm(train_loader) | |
| train_tqdm.set_description('Loss: %.4f, Content loss: %.4f, Style loss: %.4f' % (total_loss, content_loss, style_loss)) | |
| losses.append((iteration, total_loss, content_loss, style_loss)) | |
| total_loss, content_loss, style_loss = 0.0, 0.0, 0.0 | |
| for content_batch, style_batch in train_tqdm: | |
| content_batch = content_batch.to(device) | |
| style_batch = style_batch.to(device) | |
| loss_content, loss_style = model(content_batch, style_batch) | |
| loss_scaled = loss_content + 10 * loss_style | |
| loss_scaled.backward() | |
| decoder_optimizer.step() | |
| total_loss += loss_scaled.item() * style_batch.size(0) | |
| decoder_optimizer.zero_grad() | |
| total_num += style_batch.size(0) | |
| if iteration % 100 == 0 and iteration > 0: | |
| total_loss /= total_num | |
| content_loss /= total_num | |
| style_loss /= total_num | |
| print('') | |
| train_tqdm.set_description('Loss: %.4f, Content loss: %.4f, Style loss: %.4f' % (total_loss, content_loss, style_loss)) | |
| losses.append((iteration, total_loss, content_loss, style_loss)) | |
| total_loss, content_loss, style_loss = 0.0, 0.0, 0.0 | |
| total_num = 0 | |
| if iteration % np.ceil(len(train_loader.dataset)/args.batch_size) == 0 and iteration > 0: | |
| total_loss /= total_num | |
| content_loss /= total_num | |
| style_loss /= total_num | |
| total_num = 0 | |
| iteration += 1 | |
| print('Finished epoch: %i/%i' % (epoch, int(args.epochs))) | |
| states = {'decoder': model.decoder.state_dict(), 'decoder_optimizer': decoder_optimizer.state_dict(), | |
| 'losses': losses, 'iteration': iteration} | |
| torch.save(states, check_point_dir + 'epoch_%i.pth' % (epoch)) | |
| torch.save(model.decoder.state_dict(), weights_dir + 'decoder_epoch_%i.pth' % (epoch)) | |
| np.savetxt("losses", losses, fmt='%i,%.4f,%.4f,%.4f') | |
| if __name__ == '__main__': | |
| main() | |