adain / train.py
Chengkai Yang
init
7930ce0
raw
history blame
3.53 kB
import argparse
import numpy as np
import torch
from torch.utils.data import DataLoader
from utils import TrainSet
from AdaIN import AdaINNet
from tqdm import tqdm
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--content_dir', type=str, required=True, help='content images folder path')
parser.add_argument('--style_dir', type=str, required=True, help='style images folder path')
parser.add_argument('--epochs', type=int, default=1, help='Number of epoch')
parser.add_argument('--batch_size', type=int, default=8, help='Batch size')
parser.add_argument('--resume', type=int, default=0, help='Continue training from epoch')
parser.add_argument('--cuda', action='store_true', help='Use CUDA')
args = parser.parse_args()
device = torch.device('cuda' if args.cuda and torch.cuda.is_available() else 'cpu')
check_point_dir = './check_point/'
weights_dir = './weights/'
train_set = TrainSet(args.content_dir, args.style_dir)
train_loader = DataLoader(dataset=train_set, batch_size=args.batch_size, shuffle=True)
vgg_model = torch.load('vgg_normalized.pth')
model = AdaINNet(vgg_model).to(device)
decoder_optimizer = torch.optim.Adam(model.decoder.parameters(), lr=1e-6)
total_loss, content_loss, style_loss = 0.0, 0.0, 0.0
losses = []
iteration = 0
if args.resume > 0:
states = torch.load(check_point_dir + "epoch_" + str(args.resume)+'.pth')
model.decoder.load_state_dict(states['decoder'])
decoder_optimizer.load_state_dict(states['decoder_optimizer'])
losses = states['losses']
iteration = states['iteration']
for epoch in range(args.resume + 1, args.epochs + 1):
print("Begin epoch: %i/%i" % (epoch, int(args.epochs)))
train_tqdm = tqdm(train_loader)
train_tqdm.set_description('Loss: %.4f, Content loss: %.4f, Style loss: %.4f' % (total_loss, content_loss, style_loss))
losses.append((iteration, total_loss, content_loss, style_loss))
total_loss, content_loss, style_loss = 0.0, 0.0, 0.0
for content_batch, style_batch in train_tqdm:
content_batch = content_batch.to(device)
style_batch = style_batch.to(device)
loss_content, loss_style = model(content_batch, style_batch)
loss_scaled = loss_content + 10 * loss_style
loss_scaled.backward()
decoder_optimizer.step()
total_loss += loss_scaled.item() * style_batch.size(0)
decoder_optimizer.zero_grad()
total_num += style_batch.size(0)
if iteration % 100 == 0 and iteration > 0:
total_loss /= total_num
content_loss /= total_num
style_loss /= total_num
print('')
train_tqdm.set_description('Loss: %.4f, Content loss: %.4f, Style loss: %.4f' % (total_loss, content_loss, style_loss))
losses.append((iteration, total_loss, content_loss, style_loss))
total_loss, content_loss, style_loss = 0.0, 0.0, 0.0
total_num = 0
if iteration % np.ceil(len(train_loader.dataset)/args.batch_size) == 0 and iteration > 0:
total_loss /= total_num
content_loss /= total_num
style_loss /= total_num
total_num = 0
iteration += 1
print('Finished epoch: %i/%i' % (epoch, int(args.epochs)))
states = {'decoder': model.decoder.state_dict(), 'decoder_optimizer': decoder_optimizer.state_dict(),
'losses': losses, 'iteration': iteration}
torch.save(states, check_point_dir + 'epoch_%i.pth' % (epoch))
torch.save(model.decoder.state_dict(), weights_dir + 'decoder_epoch_%i.pth' % (epoch))
np.savetxt("losses", losses, fmt='%i,%.4f,%.4f,%.4f')
if __name__ == '__main__':
main()