IanNathaniel commited on
Commit
80e7256
1 Parent(s): 5f7fa96

Upload lowlight_train.py

Browse files
Files changed (1) hide show
  1. lowlight_train.py +124 -0
lowlight_train.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision
4
+ import torch.backends.cudnn as cudnn
5
+ import torch.optim
6
+ import os
7
+ import sys
8
+ import argparse
9
+ import time
10
+ import dataloader
11
+ import model
12
+ import Myloss
13
+ import numpy as np
14
+ from torchvision import transforms
15
+
16
+
17
+ def weights_init(m):
18
+ classname = m.__class__.__name__
19
+ if classname.find('Conv') != -1:
20
+ m.weight.data.normal_(0.0, 0.02)
21
+ elif classname.find('BatchNorm') != -1:
22
+ m.weight.data.normal_(1.0, 0.02)
23
+ m.bias.data.fill_(0)
24
+
25
+
26
+
27
+
28
+
29
+ def train(config):
30
+
31
+ os.environ['CUDA_VISIBLE_DEVICES']='0'
32
+
33
+ DCE_net = model.enhance_net_nopool().cuda()
34
+
35
+ DCE_net.apply(weights_init)
36
+ if config.load_pretrain == True:
37
+ DCE_net.load_state_dict(torch.load(config.pretrain_dir))
38
+ train_dataset = dataloader.lowlight_loader(config.lowlight_images_path)
39
+
40
+ train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.train_batch_size, shuffle=True, num_workers=config.num_workers, pin_memory=True)
41
+
42
+
43
+
44
+ L_color = Myloss.L_color()
45
+ L_spa = Myloss.L_spa()
46
+
47
+ L_exp = Myloss.L_exp(16,0.6)
48
+ L_TV = Myloss.L_TV()
49
+
50
+
51
+ optimizer = torch.optim.Adam(DCE_net.parameters(), lr=config.lr, weight_decay=config.weight_decay)
52
+
53
+ DCE_net.train()
54
+
55
+ for epoch in range(config.num_epochs):
56
+ for iteration, img_lowlight in enumerate(train_loader):
57
+
58
+ img_lowlight = img_lowlight.cuda()
59
+
60
+ enhanced_image_1,enhanced_image,A = DCE_net(img_lowlight)
61
+
62
+ Loss_TV = 200*L_TV(A)
63
+
64
+ loss_spa = torch.mean(L_spa(enhanced_image, img_lowlight))
65
+
66
+ loss_col = 5*torch.mean(L_color(enhanced_image))
67
+
68
+ loss_exp = 10*torch.mean(L_exp(enhanced_image))
69
+
70
+
71
+ # best_loss
72
+ loss = Loss_TV + loss_spa + loss_col + loss_exp
73
+ #
74
+
75
+
76
+ optimizer.zero_grad()
77
+ loss.backward()
78
+ torch.nn.utils.clip_grad_norm(DCE_net.parameters(),config.grad_clip_norm)
79
+ optimizer.step()
80
+
81
+ if ((iteration+1) % config.display_iter) == 0:
82
+ print("Loss at iteration", iteration+1, ":", loss.item())
83
+ if ((iteration+1) % config.snapshot_iter) == 0:
84
+
85
+ torch.save(DCE_net.state_dict(), config.snapshots_folder + "Epoch" + str(epoch) + '.pth')
86
+
87
+
88
+
89
+
90
+ if __name__ == "__main__":
91
+
92
+ parser = argparse.ArgumentParser()
93
+
94
+ # Input Parameters
95
+ parser.add_argument('--lowlight_images_path', type=str, default="data/train_data/")
96
+ parser.add_argument('--lr', type=float, default=0.0001)
97
+ parser.add_argument('--weight_decay', type=float, default=0.0001)
98
+ parser.add_argument('--grad_clip_norm', type=float, default=0.1)
99
+ parser.add_argument('--num_epochs', type=int, default=200)
100
+ parser.add_argument('--train_batch_size', type=int, default=8)
101
+ parser.add_argument('--val_batch_size', type=int, default=4)
102
+ parser.add_argument('--num_workers', type=int, default=4)
103
+ parser.add_argument('--display_iter', type=int, default=10)
104
+ parser.add_argument('--snapshot_iter', type=int, default=10)
105
+ parser.add_argument('--snapshots_folder', type=str, default="snapshots/")
106
+ parser.add_argument('--load_pretrain', type=bool, default= False)
107
+ parser.add_argument('--pretrain_dir', type=str, default= "snapshots/Epoch99.pth")
108
+
109
+ config = parser.parse_args()
110
+
111
+ if not os.path.exists(config.snapshots_folder):
112
+ os.mkdir(config.snapshots_folder)
113
+
114
+
115
+ train(config)
116
+
117
+
118
+
119
+
120
+
121
+
122
+
123
+
124
+