import torch import torch.nn as nn import torchvision import torch.backends.cudnn as cudnn import torch.optim import os import sys import argparse import time import dataloader import model import Myloss import numpy as np from torchvision import transforms def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: m.weight.data.normal_(0.0, 0.02) elif classname.find('BatchNorm') != -1: m.weight.data.normal_(1.0, 0.02) m.bias.data.fill_(0) def train(config): os.environ['CUDA_VISIBLE_DEVICES']='0' DCE_net = model.enhance_net_nopool().cuda() DCE_net.apply(weights_init) if config.load_pretrain == True: DCE_net.load_state_dict(torch.load(config.pretrain_dir)) train_dataset = dataloader.lowlight_loader(config.lowlight_images_path) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.train_batch_size, shuffle=True, num_workers=config.num_workers, pin_memory=True) L_color = Myloss.L_color() L_spa = Myloss.L_spa() L_exp = Myloss.L_exp(16,0.6) L_TV = Myloss.L_TV() optimizer = torch.optim.Adam(DCE_net.parameters(), lr=config.lr, weight_decay=config.weight_decay) DCE_net.train() for epoch in range(config.num_epochs): for iteration, img_lowlight in enumerate(train_loader): img_lowlight = img_lowlight.cuda() enhanced_image_1,enhanced_image,A = DCE_net(img_lowlight) Loss_TV = 200*L_TV(A) loss_spa = torch.mean(L_spa(enhanced_image, img_lowlight)) loss_col = 5*torch.mean(L_color(enhanced_image)) loss_exp = 10*torch.mean(L_exp(enhanced_image)) # best_loss loss = Loss_TV + loss_spa + loss_col + loss_exp # optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm(DCE_net.parameters(),config.grad_clip_norm) optimizer.step() if ((iteration+1) % config.display_iter) == 0: print("Loss at iteration", iteration+1, ":", loss.item()) if ((iteration+1) % config.snapshot_iter) == 0: torch.save(DCE_net.state_dict(), config.snapshots_folder + "Epoch" + str(epoch) + '.pth') if __name__ == "__main__": parser = argparse.ArgumentParser() # Input Parameters parser.add_argument('--lowlight_images_path', type=str, default="data/train_data/") parser.add_argument('--lr', type=float, default=0.0001) parser.add_argument('--weight_decay', type=float, default=0.0001) parser.add_argument('--grad_clip_norm', type=float, default=0.1) parser.add_argument('--num_epochs', type=int, default=200) parser.add_argument('--train_batch_size', type=int, default=8) parser.add_argument('--val_batch_size', type=int, default=4) parser.add_argument('--num_workers', type=int, default=4) parser.add_argument('--display_iter', type=int, default=10) parser.add_argument('--snapshot_iter', type=int, default=10) parser.add_argument('--snapshots_folder', type=str, default="snapshots/") parser.add_argument('--load_pretrain', type=bool, default= False) parser.add_argument('--pretrain_dir', type=str, default= "snapshots/Epoch99.pth") config = parser.parse_args() if not os.path.exists(config.snapshots_folder): os.mkdir(config.snapshots_folder) train(config)