白鹭先生 commited on
Commit
905cd18
1 Parent(s): 305fd71
app.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Author: Egrt
3
+ Date: 2022-01-13 13:34:10
4
+ LastEditors: Egrt
5
+ LastEditTime: 2022-01-13 13:48:57
6
+ FilePath: \LicenseGAN\app.py
7
+ '''
8
+ import os
9
+ os.system('pip install pytorch')
10
+ os.system('pip install gradio==2.5.3')
11
+ from PIL import Image
12
+ from esrgan import ESRGAN
13
+ import gradio as gr
14
+
15
+ esrgan = ESRGAN()
16
+
17
+ # --------模型推理---------- #
18
+ def inference(img):
19
+ lr_shape = [12, 24]
20
+ img = img.resize((lr_shape[1], lr_shape[0]), Image.BICUBIC)
21
+ r_image = esrgan.generate_1x1_image(img)
22
+ return r_image
23
+
24
+ # --------网页信息---------- #
25
+ title = "车牌超分辨率重建"
26
+ description = "使用生成对抗网络对低分辨率车牌图片进行八倍的超分辨率重建,能够有效的恢复出车牌号。 @西南科技大学智能控制与图像处理研究室"
27
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2108.10257' target='_blank'>LicenseGAN: Image Restoration Using Swin Transformer</a> | <a href='https://github.com/JingyunLiang/SwinIR' target='_blank'>Github Repo</a></p>"
28
+ example_img_dir = 'img'
29
+ example_img_name = os.listdir(example_img_dir)
30
+ examples=[[os.path.join(example_img_dir, image_path)] for image_path in example_img_name if image_path.endswith('.jpg')]
31
+ gr.Interface(
32
+ inference,
33
+ [gr.inputs.Image(type="pil", label="Input")],
34
+ gr.outputs.Image(type="pil", label="Output"),
35
+ title=title,
36
+ description=description,
37
+ article=article,
38
+ enable_queue=True,
39
+ examples=examples
40
+ ).launch(debug=True)
esrgan.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.backends.cudnn as cudnn
4
+ from PIL import Image
5
+ import cv2
6
+ from nets.esrgan import Generator
7
+ from utils.utils import cvtColor, preprocess_input
8
+
9
+
10
+ class ESRGAN(object):
11
+ #-----------------------------------------#
12
+ # 注意修改model_path
13
+ #-----------------------------------------#
14
+ _defaults = {
15
+ #-----------------------------------------------#
16
+ # model_path指向logs文件夹下的权值文件
17
+ #-----------------------------------------------#
18
+ "model_path" : 'model_data/Generator_ESRGAN.pth',
19
+ #-----------------------------------------------#
20
+ # 上采样的倍数,和训练时一样
21
+ #-----------------------------------------------#
22
+ "scale_factor" : 8,
23
+ #-------------------------------#
24
+ # 是否使用Cuda
25
+ # 没有GPU可以设置成False
26
+ #-------------------------------#
27
+ "cuda" : False,
28
+ }
29
+
30
+ #---------------------------------------------------#
31
+ # 初始化SRGAN
32
+ #---------------------------------------------------#
33
+ def __init__(self, **kwargs):
34
+ self.__dict__.update(self._defaults)
35
+ for name, value in kwargs.items():
36
+ setattr(self, name, value)
37
+ self.generate()
38
+
39
+ def generate(self):
40
+ self.net = Generator(self.scale_factor)
41
+
42
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
43
+ self.net.load_state_dict(torch.load(self.model_path, map_location=device))
44
+ self.net = self.net.eval()
45
+ print('{} model, and classes loaded.'.format(self.model_path))
46
+
47
+ if self.cuda:
48
+ self.net = torch.nn.DataParallel(self.net)
49
+ cudnn.benchmark = True
50
+ self.net = self.net.cuda()
51
+
52
+ def generate_1x1_image(self, image):
53
+ #---------------------------------------------------------#
54
+ # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
55
+ # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
56
+ #---------------------------------------------------------#
57
+ image = cvtColor(image)
58
+ #---------------------------------------------------------#
59
+ # 添加上batch_size维度,并进行归一化
60
+ #---------------------------------------------------------#
61
+ image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image, dtype=np.float32), [0.5,0.5,0.5], [0.5,0.5,0.5]), [2,0,1]), 0)
62
+
63
+ with torch.no_grad():
64
+ image_data = torch.from_numpy(image_data).type(torch.FloatTensor)
65
+ if self.cuda:
66
+ image_data = image_data.cuda()
67
+
68
+ #---------------------------------------------------------#
69
+ # 将图像输入网络当中进行预测!
70
+ #---------------------------------------------------------#
71
+ hr_image = self.net(image_data)[0]
72
+ #---------------------------------------------------------#
73
+ # 将归一化的结果再转成rgb格式
74
+ #---------------------------------------------------------#
75
+ hr_image = (hr_image.cpu().data.numpy().transpose(1, 2, 0) * 0.5 + 0.5)
76
+ hr_image = (hr_image-np.min(hr_image))/(np.max(hr_image)-np.min(hr_image)) * 255
77
+
78
+ hr_image = Image.fromarray(np.uint8(hr_image))
79
+ return hr_image
img/0095-1_0-302&358_450&412-450&408_304&412_302&362_448&358-0_0_27_10_33_29_29-80-45.jpg ADDED
img/015-90_87-254&546_483&616-484&622_252&620_255&542_487&544-0_0_18_33_19_30_30-100-38.jpg ADDED
img/015-90_90-187&518_421&597-435&595_192&600_191&520_434&515-0_0_23_27_27_26_19-96-79.jpg ADDED
img/0158984375-90_268-245&462_467&535-467&535_245&529_247&462_467&465-0_0_3_24_27_25_30_32-161-162.jpg ADDED
img/0166796875-89_267-242&423_486&492-483&492_245&492_242&430_486&423-0_0_3_26_26_27_30_29-179-318.jpg ADDED
img/0210546875-92_269-233&488_485&572-482&572_233&559_236&488_485&499-0_0_3_26_33_30_33_32-143-226.jpg ADDED
model_data/Generator_ESRGAN.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4c137b3da480f7ad251641ace39d73b2adb30ec0c40662cadce2bf0e80b8fca8
3
+ size 28697247
nets/__pycache__/esrgan.cpython-38.pyc ADDED
Binary file (4.91 kB). View file
nets/__pycache__/srgan.cpython-38.pyc ADDED
Binary file (3.78 kB). View file
nets/esrgan.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+
5
+
6
+ class DenseResidualBlock(nn.Module):
7
+ """
8
+ 密集连接型残差网络
9
+ """
10
+
11
+ def __init__(self, filters, res_scale=0.2):
12
+ super(DenseResidualBlock, self).__init__()
13
+ self.res_scale = res_scale
14
+
15
+ def block(in_features, non_linearity=True):
16
+ layers = [nn.Conv2d(in_features, filters, 3, 1, 1, bias=True)]
17
+ if non_linearity:
18
+ layers += [nn.GELU()]
19
+ return nn.Sequential(*layers)
20
+
21
+ self.b1 = block(in_features=1 * filters)
22
+ self.b2 = block(in_features=2 * filters)
23
+ self.b3 = block(in_features=3 * filters)
24
+ self.b4 = block(in_features=4 * filters)
25
+ self.b5 = block(in_features=5 * filters, non_linearity=False)
26
+ self.blocks = [self.b1, self.b2, self.b3, self.b4, self.b5]
27
+
28
+ def forward(self, x):
29
+ inputs = x
30
+ for block in self.blocks:
31
+ out = block(inputs)
32
+ inputs = torch.cat([inputs, out], 1)
33
+ return out.mul(self.res_scale) + x
34
+
35
+ class ResidualInResidualDenseBlock(nn.Module):
36
+ def __init__(self, filters, res_scale=0.2):
37
+ super(ResidualInResidualDenseBlock, self).__init__()
38
+ self.res_scale = res_scale
39
+ self.dense_blocks = nn.Sequential(
40
+ DenseResidualBlock(filters), DenseResidualBlock(filters), DenseResidualBlock(filters)
41
+ )
42
+
43
+ def forward(self, x):
44
+ return self.dense_blocks(x).mul(self.res_scale) + x
45
+
46
+ class UpsampleBLock(nn.Module):
47
+ def __init__(self, in_channels, up_scale):
48
+ super(UpsampleBLock, self).__init__()
49
+ self.conv = nn.Conv2d(in_channels, in_channels * up_scale ** 2, kernel_size=3, padding=1)
50
+ self.pixel_shuffle = nn.PixelShuffle(up_scale)
51
+ self.gelu = nn.GELU()
52
+
53
+ def forward(self, x):
54
+ x = self.conv(x)
55
+ x = self.pixel_shuffle(x)
56
+ x = self.gelu(x)
57
+ return x
58
+
59
+ class Generator(nn.Module):
60
+ def __init__(self, scale_factor, channels=3, filters=64, num_res_blocks=4):
61
+ super(Generator, self).__init__()
62
+ upsample_block_num = int(math.log(scale_factor, 2))
63
+ # 第一个卷积层
64
+ self.conv1 = nn.Conv2d(channels, filters, kernel_size=3, stride=1, padding=1)
65
+ # 密集残差连接块
66
+ self.res_blocks = nn.Sequential(*[ResidualInResidualDenseBlock(filters) for _ in range(num_res_blocks)])
67
+ # 第二个卷积层
68
+ self.conv2 = nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1)
69
+ self.upsample = [UpsampleBLock(filters, 2) for _ in range(upsample_block_num)]
70
+ self.upsample = nn.Sequential(*self.upsample)
71
+ # 输出卷积层
72
+ self.conv3 = nn.Sequential(
73
+ nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1),
74
+ nn.GELU(),
75
+ nn.Conv2d(filters, channels, kernel_size=3, stride=1, padding=1)
76
+ )
77
+
78
+ def forward(self, x):
79
+ out1 = self.conv1(x)
80
+ out = self.res_blocks(out1)
81
+ out2 = self.conv2(out)
82
+ out = torch.add(out1, out2)
83
+ upsample = self.upsample(out)
84
+ out = self.conv3(upsample)
85
+ return out
86
+
87
+
88
+ class Discriminator(nn.Module):
89
+ def __init__(self):
90
+ super(Discriminator, self).__init__()
91
+ self.net = nn.Sequential(
92
+ nn.Conv2d(3, 64, kernel_size=3, padding=1),
93
+ nn.GELU(),
94
+
95
+ nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
96
+ nn.BatchNorm2d(64),
97
+ nn.GELU(),
98
+
99
+ nn.Conv2d(64, 128, kernel_size=3, padding=1),
100
+ nn.BatchNorm2d(128),
101
+ nn.GELU(),
102
+
103
+ nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
104
+ nn.BatchNorm2d(128),
105
+ nn.GELU(),
106
+
107
+ nn.Conv2d(128, 256, kernel_size=3, padding=1),
108
+ nn.BatchNorm2d(256),
109
+ nn.GELU(),
110
+
111
+ nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
112
+ nn.BatchNorm2d(256),
113
+ nn.GELU(),
114
+
115
+ nn.Conv2d(256, 512, kernel_size=3, padding=1),
116
+ nn.BatchNorm2d(512),
117
+ nn.GELU(),
118
+
119
+ nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
120
+ nn.BatchNorm2d(512),
121
+ nn.GELU(),
122
+
123
+ nn.AdaptiveAvgPool2d(1),
124
+ nn.Conv2d(512, 1024, kernel_size=1),
125
+ nn.GELU(),
126
+ nn.Conv2d(1024, 1, kernel_size=1)
127
+ )
128
+
129
+ def forward(self, x):
130
+ batch_size = x.size(0)
131
+ return torch.sigmoid(self.net(x).view(batch_size))
132
+
133
+ if __name__ == "__main__":
134
+ from torchsummary import summary
135
+
136
+ # 需要使用device来指定网络在GPU还是CPU运行
137
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
138
+ model = Generator(8).to(device)
139
+ summary(model, input_size=(3,12,24))
140
+
utils/__init__.py ADDED
@@ -0,0 +1 @@
 
1
+ from .utils import *
utils/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (153 Bytes). View file
utils/__pycache__/dataloader.cpython-38.pyc ADDED
Binary file (4.19 kB). View file
utils/__pycache__/utils.cpython-38.pyc ADDED
Binary file (2.15 kB). View file
utils/__pycache__/utils_fit.cpython-38.pyc ADDED
Binary file (2.15 kB). View file
utils/__pycache__/utils_metrics.cpython-38.pyc ADDED
Binary file (2.34 kB). View file
utils/dataloader.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from random import randint
2
+
3
+ import cv2
4
+ import numpy as np
5
+ from PIL import Image
6
+ from torch.utils.data.dataset import Dataset
7
+
8
+ from utils import cvtColor, preprocess_input
9
+ from torch.utils.data import DataLoader
10
+
11
+ def get_new_img_size(width, height, img_min_side=600):
12
+ if width <= height:
13
+ f = float(img_min_side) / width
14
+ resized_height = int(f * height)
15
+ resized_width = int(img_min_side)
16
+ else:
17
+ f = float(img_min_side) / height
18
+ resized_width = int(f * width)
19
+ resized_height = int(img_min_side)
20
+
21
+ return resized_width, resized_height
22
+
23
+ class SRGANDataset(Dataset):
24
+ def __init__(self, train_lines, lr_shape, hr_shape):
25
+ super(SRGANDataset, self).__init__()
26
+
27
+ self.train_lines = train_lines
28
+ self.train_batches = len(train_lines)
29
+
30
+ self.lr_shape = lr_shape
31
+ self.hr_shape = hr_shape
32
+
33
+ def __len__(self):
34
+ return self.train_batches
35
+
36
+ def __getitem__(self, index):
37
+ index = index % self.train_batches
38
+
39
+ image_origin = Image.open(self.train_lines[index].split()[0])
40
+ if self.rand()<.5:
41
+ img_h = self.get_random_data(image_origin, self.hr_shape)
42
+ else:
43
+ img_h = self.random_crop(image_origin, self.hr_shape[1], self.hr_shape[0])
44
+ img_l = img_h.resize((self.lr_shape[1], self.lr_shape[0]), Image.BICUBIC)
45
+
46
+ img_h = np.transpose(preprocess_input(np.array(img_h, dtype=np.float32), [0.5,0.5,0.5], [0.5,0.5,0.5]), [2,0,1])
47
+ img_l = np.transpose(preprocess_input(np.array(img_l, dtype=np.float32), [0.5,0.5,0.5], [0.5,0.5,0.5]), [2,0,1])
48
+ return np.array(img_l), np.array(img_h)
49
+
50
+ def rand(self, a=0, b=1):
51
+ return np.random.rand()*(b-a) + a
52
+
53
+ def get_random_data(self, image, input_shape, jitter=.3, hue=.1, sat=1.5, val=1.5, random=True):
54
+ #------------------------------#
55
+ # 读取图像并转换成RGB图像
56
+ #------------------------------#
57
+ image = cvtColor(image)
58
+ #------------------------------#
59
+ # 获得图像的高宽与目标高宽
60
+ #------------------------------#
61
+ iw, ih = image.size
62
+ h, w = input_shape
63
+
64
+ if not random:
65
+ scale = min(w/iw, h/ih)
66
+ nw = int(iw*scale)
67
+ nh = int(ih*scale)
68
+ dx = (w-nw)//2
69
+ dy = (h-nh)//2
70
+
71
+ #---------------------------------#
72
+ # 将图像多余的部分加上灰条
73
+ #---------------------------------#
74
+ image = image.resize((nw,nh), Image.BICUBIC)
75
+ new_image = Image.new('RGB', (w,h), (128,128,128))
76
+ new_image.paste(image, (dx, dy))
77
+ image_data = np.array(new_image, np.float32)
78
+
79
+ return image_data
80
+
81
+ #------------------------------------------#
82
+ # 对图像进行缩放并且进行长和宽的扭曲
83
+ #------------------------------------------#
84
+ new_ar = w/h * self.rand(1-jitter,1+jitter)/self.rand(1-jitter,1+jitter)
85
+ scale = self.rand(1, 1.5)
86
+ if new_ar < 1:
87
+ nh = int(scale*h)
88
+ nw = int(nh*new_ar)
89
+ else:
90
+ nw = int(scale*w)
91
+ nh = int(nw/new_ar)
92
+ image = image.resize((nw,nh), Image.BICUBIC)
93
+
94
+ #------------------------------------------#
95
+ # 将图像多余的部分加上灰条
96
+ #------------------------------------------#
97
+ dx = int(self.rand(0, w-nw))
98
+ dy = int(self.rand(0, h-nh))
99
+ new_image = Image.new('RGB', (w,h), (128,128,128))
100
+ new_image.paste(image, (dx, dy))
101
+ image = new_image
102
+
103
+ #------------------------------------------#
104
+ # 翻转图像
105
+ #------------------------------------------#
106
+ flip = self.rand()<.5
107
+ if flip: image = image.transpose(Image.FLIP_LEFT_RIGHT)
108
+
109
+ rotate = self.rand()<.5
110
+ if rotate:
111
+ angle = np.random.randint(-15,15)
112
+ a,b = w/2,h/2
113
+ M = cv2.getRotationMatrix2D((a,b),angle,1)
114
+ image = cv2.warpAffine(np.array(image), M, (w,h), borderValue=[128,128,128])
115
+
116
+ #------------------------------------------#
117
+ # 色域扭曲
118
+ #------------------------------------------#
119
+ hue = self.rand(-hue, hue)
120
+ sat = self.rand(1, sat) if self.rand()<.5 else 1/self.rand(1, sat)
121
+ val = self.rand(1, val) if self.rand()<.5 else 1/self.rand(1, val)
122
+ x = cv2.cvtColor(np.array(image,np.float32)/255, cv2.COLOR_RGB2HSV)
123
+ x[..., 1] *= sat
124
+ x[..., 2] *= val
125
+ x[x[:,:, 0]>360, 0] = 360
126
+ x[:, :, 1:][x[:, :, 1:]>1] = 1
127
+ x[x<0] = 0
128
+ image_data = cv2.cvtColor(x, cv2.COLOR_HSV2RGB)*255
129
+ return Image.fromarray(np.uint8(image_data))
130
+
131
+ def random_crop(self, image, width, height):
132
+ #--------------------------------------------#
133
+ # 如果图像过小无法截取,先对图像进行放大
134
+ #--------------------------------------------#
135
+ if image.size[0] < self.hr_shape[1] or image.size[1] < self.hr_shape[0]:
136
+ resized_width, resized_height = get_new_img_size(width, height, img_min_side=np.max(self.hr_shape))
137
+ image = image.resize((resized_width, resized_height), Image.BICUBIC)
138
+
139
+ #--------------------------------------------#
140
+ # 随机截取一部分
141
+ #--------------------------------------------#
142
+ width1 = randint(0, image.size[0] - width)
143
+ height1 = randint(0, image.size[1] - height)
144
+
145
+ width2 = width1 + width
146
+ height2 = height1 + height
147
+
148
+ image = image.crop((width1, height1, width2, height2))
149
+ return image
150
+
151
+ def SRGAN_dataset_collate(batch):
152
+ images_l = []
153
+ images_h = []
154
+ for img_l, img_h in batch:
155
+ images_l.append(img_l)
156
+ images_h.append(img_h)
157
+ return np.array(images_l), np.array(images_h)
utils/preprocess.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ import matplotlib.image as mpimage
4
+ import argparse
5
+ import functools
6
+ from utils import add_arguments, print_arguments
7
+ from dask.distributed import LocalCluster
8
+ from dask import bag as dbag
9
+ from dask.diagnostics import ProgressBar
10
+ from typing import Tuple
11
+ from PIL import Image
12
+
13
+
14
+
15
+ # Dataset statistics that I gathered in development
16
+ #-----------------------------------#
17
+ # 用于过滤感知质量较低的不良图片
18
+ #-----------------------------------#
19
+ IMAGE_MEAN = 0.5
20
+ IMAGE_MEAN_STD = 0.028
21
+
22
+ IMG_STD = 0.28
23
+ IMG_STD_STD = 0.01
24
+
25
+
26
+ def readImage(fileName: str) -> np.ndarray:
27
+ image = mpimage.imread(fileName)
28
+ return image
29
+
30
+
31
+ #-----------------------------------#
32
+ # 从文件名中提取车牌的坐标
33
+ #-----------------------------------#
34
+
35
+
36
+ def parseLabel(label: str) -> Tuple[np.ndarray, np.ndarray]:
37
+ annotation = label.split('-')[3].split('_')
38
+ coor1 = [int(i) for i in annotation[0].split('&')]
39
+ coor2 = [int(i) for i in annotation[1].split('&')]
40
+ coor3 = [int(i) for i in annotation[2].split('&')]
41
+ coor4 = [int(i) for i in annotation[3].split('&')]
42
+ coor = np.array([coor1, coor2, coor3, coor4])
43
+ center = np.mean(coor, axis=0)
44
+ return coor, center.astype(int)
45
+
46
+
47
+ #-----------------------------------#
48
+ # 根据车牌坐标裁剪出车牌图像
49
+ #-----------------------------------#
50
+
51
+
52
+ def cropImage(image: np.ndarray, coor: np.ndarray, center: np.ndarray) -> np.ndarray:
53
+ maxW = np.max(coor[:, 0] - center[0]) # max plate width
54
+ maxH = np.max(coor[:, 1] - center[1]) # max plate height
55
+
56
+ xWanted = [64, 128, 192, 256]
57
+ yWanted = [32, 64, 96, 128]
58
+
59
+ found = False
60
+ for w, h in zip(xWanted, yWanted):
61
+ if maxW < w//2 and maxH < h//2:
62
+ maxH = h//2
63
+ maxW = w//2
64
+ found = True
65
+ break
66
+ if not found: # 车牌太大则丢弃
67
+ return np.array([])
68
+ elif center[1]-maxH < 0 or center[1]+maxH >= image.shape[1] or \
69
+ center[0]-maxW < 0 or center[0] + maxW >= image.shape[0]:
70
+ return np.array([])
71
+ else:
72
+ return image[center[1]-maxH:center[1]+maxH, center[0]-maxW:center[0]+maxW]
73
+
74
+ #-----------------------------------#
75
+ # 保存车牌图片
76
+ #-----------------------------------#
77
+
78
+
79
+ def saveImage(image: np.ndarray, fileName: str, outDir: str) -> int:
80
+ if image.shape[0] == 0:
81
+ return 0
82
+ else:
83
+ imgShape = image.shape
84
+ if imgShape[1] == 64:
85
+ mpimage.imsave(os.path.join(outDir, '64_32', fileName), image)
86
+ elif imgShape[1] == 128:
87
+ mpimage.imsave(os.path.join(outDir, '128_64', fileName), image)
88
+ elif imgShape[1] == 208:
89
+ mpimage.imsave(os.path.join(outDir, '192_96', fileName), image)
90
+ else: #resize large images
91
+ image = Image.fromarray(image).resize((192, 96))
92
+ image = np.asarray(image) # back to numpy array
93
+ mpimage.imsave(os.path.join(outDir, '192_96', fileName), image)
94
+ return 1
95
+
96
+
97
+ #-----------------------------------#
98
+ # 包装成一个函数,以便将处理区分到不同目录
99
+ #-----------------------------------#
100
+
101
+ def processImage(file: str, inputDir: str, outputDir: str, subFolder: str) -> int:
102
+ result = parseLabel(file)
103
+ filePath = os.path.join(inputDir,subFolder, file)
104
+ image = readImage(filePath)
105
+ plate = cropImage(image, result[0], result[1])
106
+ if plate.shape[0] == 0:
107
+ return 0
108
+ mean = np.mean(plate/255.0)
109
+ std = np.std(plate/255.0)
110
+ # 亮度不好的
111
+ if mean <= IMAGE_MEAN - 10*IMAGE_MEAN_STD or mean >= IMAGE_MEAN + 10*IMAGE_MEAN_STD:
112
+ return 0
113
+ # 低对比度的
114
+ if std <= IMG_STD - 10*IMG_STD_STD:
115
+ return 0
116
+ status = saveImage(plate, file, outputDir)
117
+ return status
118
+
119
+
120
+ def main(argv):
121
+ jobNum = int(argv.jobNum)
122
+ outputDir = argv.outputDir
123
+ inputDir = argv.inputDir
124
+ try:
125
+ os.mkdir(outputDir)
126
+ for shape in ['64_32', '128_64', '192_96']:
127
+ os.mkdir(os.path.join(outputDir, shape))
128
+ except OSError:
129
+ pass # 地址已经存在
130
+ client = LocalCluster(n_workers=jobNum, threads_per_worker=5) # 开启多线程
131
+ for subFolder in ['ccpd_base', 'ccpd_db', 'ccpd_fn', 'ccpd_rotate', 'ccpd_tilt', 'ccpd_weather']:
132
+ fileList = os.listdir(os.path.join(inputDir, subFolder))
133
+ print('* {} images found in {}. Start processing ...'.format(len(fileList), subFolder))
134
+ toDo = dbag.from_sequence(fileList, npartitions=jobNum*30).persist() # persist the bag in memory
135
+ toDo = toDo.map(processImage, inputDir, outputDir, subFolder)
136
+ pbar = ProgressBar(minimum=2.0)
137
+ pbar.register() # 登记所有的计算,以便更好地跟踪
138
+ result = toDo.compute()
139
+ print('* image cropped: {}. Done ...'.format(sum(result)))
140
+ client.close() # 关闭集群
141
+
142
+
143
+ if __name__ == "__main__":
144
+ parser = argparse.ArgumentParser(description=__doc__)
145
+ add_arg = functools.partial(add_arguments, argparser=parser)
146
+ add_arg('jobNum', int, 4, '处理图片的线程数')
147
+ add_arg('inputDir', str, 'datasets/CCPD2019', '输入图片目录')
148
+ add_arg('outputDir', str, 'datasets/CCPD2019_new', '保存图片目录')
149
+ args = parser.parse_args()
150
+ print_arguments(args)
151
+ main(args)
utils/utils.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import torch
5
+ import distutils.util
6
+
7
+ def show_result(num_epoch, G_net, imgs_lr, imgs_hr):
8
+ with torch.no_grad():
9
+ test_images = G_net(imgs_lr)
10
+
11
+ fig, ax = plt.subplots(1, 2)
12
+
13
+ for j in itertools.product(range(2)):
14
+ ax[j].get_xaxis().set_visible(False)
15
+ ax[j].get_yaxis().set_visible(False)
16
+
17
+ ax[0].cla()
18
+ ax[0].imshow(np.transpose(test_images.cpu().numpy()[0] * 0.5 + 0.5, [1,2,0]))
19
+
20
+ ax[1].cla()
21
+ ax[1].imshow(np.transpose(imgs_hr.cpu().numpy()[0] * 0.5 + 0.5, [1,2,0]))
22
+
23
+ label = 'Epoch {0}'.format(num_epoch)
24
+ fig.text(0.5, 0.04, label, ha='center')
25
+ plt.savefig("results/train_out/epoch_" + str(num_epoch) + "_results.png")
26
+ plt.close('all') #避免内存泄漏
27
+
28
+ #---------------------------------------------------------#
29
+ # 将图像转换成RGB图像,防止灰度图在预测时报错。
30
+ # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
31
+ #---------------------------------------------------------#
32
+ def cvtColor(image):
33
+ if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
34
+ return image
35
+ else:
36
+ image = image.convert('RGB')
37
+ return image
38
+
39
+ def preprocess_input(image, mean, std):
40
+ image = (image/255 - mean)/std
41
+ return image
42
+
43
+ def get_lr(optimizer):
44
+ for param_group in optimizer.param_groups:
45
+ return param_group['lr']
46
+
47
+ def print_arguments(args):
48
+ print("----------- Configuration Arguments -----------")
49
+ for arg, value in sorted(vars(args).items()):
50
+ print("%s: %s" % (arg, value))
51
+ print("------------------------------------------------")
52
+
53
+
54
+ def add_arguments(argname, type, default, help, argparser, **kwargs):
55
+ type = distutils.util.strtobool if type == bool else type
56
+ argparser.add_argument("--" + argname,
57
+ default=default,
58
+ type=type,
59
+ help=help + ' 默认: %(default)s.',
60
+ **kwargs)
utils/utils_fit.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from tqdm import tqdm
3
+
4
+ from .utils import show_result, get_lr
5
+ from .utils_metrics import PSNR, SSIM
6
+
7
+
8
+ def fit_one_epoch(G_model_train, D_model_train, G_model, D_model, VGG_feature_model, G_optimizer, D_optimizer, BCE_loss, MSE_loss, epoch, epoch_size, gen, Epoch, cuda, batch_size, save_interval):
9
+ G_total_loss = 0
10
+ D_total_loss = 0
11
+ G_total_PSNR = 0
12
+ G_total_SSIM = 0
13
+
14
+ with tqdm(total=epoch_size,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3) as pbar:
15
+ for iteration, batch in enumerate(gen):
16
+ if iteration >= epoch_size:
17
+ break
18
+
19
+ with torch.no_grad():
20
+ lr_images, hr_images = batch
21
+ lr_images, hr_images = torch.from_numpy(lr_images).type(torch.FloatTensor), torch.from_numpy(hr_images).type(torch.FloatTensor)
22
+ y_real, y_fake = torch.ones(batch_size), torch.zeros(batch_size)
23
+ if cuda:
24
+ lr_images, hr_images, y_real, y_fake = lr_images.cuda(), hr_images.cuda(), y_real.cuda(), y_fake.cuda()
25
+
26
+ #-------------------------------------------------#
27
+ # 训练判别器
28
+ #-------------------------------------------------#
29
+ D_optimizer.zero_grad()
30
+
31
+ D_result = D_model_train(hr_images)
32
+ D_real_loss = BCE_loss(D_result, y_real)
33
+ D_real_loss.backward()
34
+
35
+ G_result = G_model_train(lr_images)
36
+ D_result = D_model_train(G_result).squeeze()
37
+ D_fake_loss = BCE_loss(D_result, y_fake)
38
+ D_fake_loss.backward()
39
+
40
+ D_optimizer.step()
41
+
42
+ D_train_loss = D_real_loss + D_fake_loss
43
+
44
+ #-------------------------------------------------#
45
+ # 训练生成器
46
+ #-------------------------------------------------#
47
+ G_optimizer.zero_grad()
48
+
49
+ G_result = G_model_train(lr_images)
50
+ image_loss = MSE_loss(G_result, hr_images)
51
+
52
+ D_result = D_model_train(G_result).squeeze()
53
+ adversarial_loss = BCE_loss(D_result, y_real)
54
+
55
+ perception_loss = MSE_loss(VGG_feature_model(G_result), VGG_feature_model(hr_images))
56
+
57
+ G_train_loss = image_loss + 1e-3 * adversarial_loss + 2e-6 * perception_loss
58
+
59
+ G_train_loss.backward()
60
+ G_optimizer.step()
61
+
62
+ G_total_loss += G_train_loss.item()
63
+ D_total_loss += D_train_loss.item()
64
+
65
+ with torch.no_grad():
66
+ G_total_PSNR += PSNR(G_result, hr_images).item()
67
+ G_total_SSIM += SSIM(G_result, hr_images).item()
68
+
69
+ pbar.set_postfix(**{'G_loss' : G_total_loss / (iteration + 1),
70
+ 'D_loss' : D_total_loss / (iteration + 1),
71
+ 'G_PSNR' : G_total_PSNR / (iteration + 1),
72
+ 'G_SSIM' : G_total_SSIM / (iteration + 1),
73
+ 'lr' : get_lr(G_optimizer)})
74
+ pbar.update(1)
75
+
76
+ if iteration % save_interval == 0:
77
+ show_result(epoch + 1, G_model_train, lr_images, hr_images)
78
+
79
+ print('Epoch:'+ str(epoch + 1) + '/' + str(Epoch))
80
+ print('G Loss: %.4f || D Loss: %.4f ' % (G_total_loss / epoch_size, D_total_loss / epoch_size))
81
+ print('Saving state, iter:', str(epoch+1))
82
+
83
+ if (epoch + 1) % 10==0:
84
+ torch.save(G_model.state_dict(), 'logs/G_Epoch%d-GLoss%.4f-DLoss%.4f.pth'%((epoch + 1), G_total_loss / epoch_size, D_total_loss / epoch_size))
85
+ torch.save(D_model.state_dict(), 'logs/D_Epoch%d-GLoss%.4f-DLoss%.4f.pth'%((epoch + 1), G_total_loss / epoch_size, D_total_loss / epoch_size))
utils/utils_metrics.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from math import exp
4
+ import numpy as np
5
+
6
+ def gaussian(window_size, sigma):
7
+ gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
8
+ return gauss/gauss.sum()
9
+
10
+ def create_window(window_size, channel=1):
11
+ _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
12
+ _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
13
+ window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
14
+ return window
15
+
16
+ def SSIM(img1, img2, window_size=11, window=None, size_average=True, full=False):
17
+ img1 = (img1 * 0.5 + 0.5) * 255
18
+ img2 = (img2 * 0.5 + 0.5) * 255
19
+ min_val = 0
20
+ max_val = 255
21
+ L = max_val - min_val
22
+ img2 = torch.clamp(img2, 0.0, 255.0)
23
+
24
+ padd = 0
25
+ (_, channel, height, width) = img1.size()
26
+ if window is None:
27
+ real_size = min(window_size, height, width)
28
+ window = create_window(real_size, channel=channel).to(img1.device)
29
+
30
+ mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
31
+ mu2 = F.conv2d(img2, window, padding=padd, groups=channel)
32
+
33
+ mu1_sq = mu1.pow(2)
34
+ mu2_sq = mu2.pow(2)
35
+ mu1_mu2 = mu1 * mu2
36
+
37
+ sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq
38
+ sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq
39
+ sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2
40
+
41
+ C1 = (0.01 * L) ** 2
42
+ C2 = (0.03 * L) ** 2
43
+
44
+ v1 = 2.0 * sigma12 + C2
45
+ v2 = sigma1_sq + sigma2_sq + C2
46
+ cs = torch.mean(v1 / v2) # contrast sensitivity
47
+
48
+ ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
49
+
50
+ if size_average:
51
+ ret = ssim_map.mean()
52
+ else:
53
+ ret = ssim_map.mean(1).mean(1).mean(1)
54
+
55
+ if full:
56
+ return ret, cs
57
+ return ret
58
+
59
+ def tf_log10(x):
60
+ numerator = torch.log(x)
61
+ denominator = torch.log(torch.tensor(10.0))
62
+ return numerator / denominator
63
+
64
+ def PSNR(img1, img2):
65
+ img1 = (img1 * 0.5 + 0.5) * 255
66
+ img2 = (img2 * 0.5 + 0.5) * 255
67
+ max_pixel = 255.0
68
+ img2 = torch.clamp(img2, 0.0, 255.0)
69
+ return 10.0 * tf_log10((max_pixel ** 2) / (torch.mean(torch.pow(img2 - img1, 2))))