IanNathaniel commited on
Commit
1db5083
1 Parent(s): a98f4ad

Delete lowlight_train.py

Browse files
Files changed (1) hide show
  1. lowlight_train.py +0 -124
lowlight_train.py DELETED
@@ -1,124 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torchvision
4
- import torch.backends.cudnn as cudnn
5
- import torch.optim
6
- import os
7
- import sys
8
- import argparse
9
- import time
10
- import dataloader
11
- import model
12
- import Myloss
13
- import numpy as np
14
- from torchvision import transforms
15
-
16
-
17
- def weights_init(m):
18
- classname = m.__class__.__name__
19
- if classname.find('Conv') != -1:
20
- m.weight.data.normal_(0.0, 0.02)
21
- elif classname.find('BatchNorm') != -1:
22
- m.weight.data.normal_(1.0, 0.02)
23
- m.bias.data.fill_(0)
24
-
25
-
26
-
27
-
28
-
29
- def train(config):
30
-
31
- os.environ['CUDA_VISIBLE_DEVICES']='0'
32
-
33
- DCE_net = model.enhance_net_nopool().cuda()
34
-
35
- DCE_net.apply(weights_init)
36
- if config.load_pretrain == True:
37
- DCE_net.load_state_dict(torch.load(config.pretrain_dir))
38
- train_dataset = dataloader.lowlight_loader(config.lowlight_images_path)
39
-
40
- train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.train_batch_size, shuffle=True, num_workers=config.num_workers, pin_memory=True)
41
-
42
-
43
-
44
- L_color = Myloss.L_color()
45
- L_spa = Myloss.L_spa()
46
-
47
- L_exp = Myloss.L_exp(16,0.6)
48
- L_TV = Myloss.L_TV()
49
-
50
-
51
- optimizer = torch.optim.Adam(DCE_net.parameters(), lr=config.lr, weight_decay=config.weight_decay)
52
-
53
- DCE_net.train()
54
-
55
- for epoch in range(config.num_epochs):
56
- for iteration, img_lowlight in enumerate(train_loader):
57
-
58
- img_lowlight = img_lowlight.cuda()
59
-
60
- enhanced_image_1,enhanced_image,A = DCE_net(img_lowlight)
61
-
62
- Loss_TV = 200*L_TV(A)
63
-
64
- loss_spa = torch.mean(L_spa(enhanced_image, img_lowlight))
65
-
66
- loss_col = 5*torch.mean(L_color(enhanced_image))
67
-
68
- loss_exp = 10*torch.mean(L_exp(enhanced_image))
69
-
70
-
71
- # best_loss
72
- loss = Loss_TV + loss_spa + loss_col + loss_exp
73
- #
74
-
75
-
76
- optimizer.zero_grad()
77
- loss.backward()
78
- torch.nn.utils.clip_grad_norm(DCE_net.parameters(),config.grad_clip_norm)
79
- optimizer.step()
80
-
81
- if ((iteration+1) % config.display_iter) == 0:
82
- print("Loss at iteration", iteration+1, ":", loss.item())
83
- if ((iteration+1) % config.snapshot_iter) == 0:
84
-
85
- torch.save(DCE_net.state_dict(), config.snapshots_folder + "Epoch" + str(epoch) + '.pth')
86
-
87
-
88
-
89
-
90
- if __name__ == "__main__":
91
-
92
- parser = argparse.ArgumentParser()
93
-
94
- # Input Parameters
95
- parser.add_argument('--lowlight_images_path', type=str, default="data/train_data/")
96
- parser.add_argument('--lr', type=float, default=0.0001)
97
- parser.add_argument('--weight_decay', type=float, default=0.0001)
98
- parser.add_argument('--grad_clip_norm', type=float, default=0.1)
99
- parser.add_argument('--num_epochs', type=int, default=200)
100
- parser.add_argument('--train_batch_size', type=int, default=8)
101
- parser.add_argument('--val_batch_size', type=int, default=4)
102
- parser.add_argument('--num_workers', type=int, default=4)
103
- parser.add_argument('--display_iter', type=int, default=10)
104
- parser.add_argument('--snapshot_iter', type=int, default=10)
105
- parser.add_argument('--snapshots_folder', type=str, default="snapshots/")
106
- parser.add_argument('--load_pretrain', type=bool, default= False)
107
- parser.add_argument('--pretrain_dir', type=str, default= "snapshots/Epoch99.pth")
108
-
109
- config = parser.parse_args()
110
-
111
- if not os.path.exists(config.snapshots_folder):
112
- os.mkdir(config.snapshots_folder)
113
-
114
-
115
- train(config)
116
-
117
-
118
-
119
-
120
-
121
-
122
-
123
-
124
-