LicenseGAN / utils /preprocess.py
白鹭先生
新增SwinIR模型
db5513e
import numpy as np
import os
import matplotlib.image as mpimage
import argparse
import functools
from utils import add_arguments, print_arguments
from dask.distributed import LocalCluster
from dask import bag as dbag
from dask.diagnostics import ProgressBar
from typing import Tuple
from PIL import Image
import cv2
#-----------------------------------#
# 对四个点坐标排序
#-----------------------------------#
def order_points(pts):
# 一共4个坐标点
rect = np.zeros((4, 2), dtype = "float32")
# 按顺序找到对应坐标0123分别是 左上,右上,右下,左下
# 计算左上,右下
s = pts.sum(axis = 1)
rect[0] = pts[np.argmin(s)]
rect[2] = pts[np.argmax(s)]
# 计算右上和左下
diff = np.diff(pts, axis = 1)
rect[1] = pts[np.argmin(diff)]
rect[3] = pts[np.argmax(diff)]
return rect
#-----------------------------------#
# 透射变换纠正车牌图片
#-----------------------------------#
def four_point_transform(image, pts):
# 获取输入坐标点
rect = order_points(pts)
(tl, tr, br, bl) = rect
# 计算输入的w和h值
widthA = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2))
widthB = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2))
maxWidth = max(int(widthA), int(widthB))
heightA = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2))
heightB = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2))
maxHeight = max(int(heightA), int(heightB))
# 变换后对应坐标位置
dst = np.array([
[0, 0],
[maxWidth - 1, 0],
[maxWidth - 1, maxHeight - 1],
[0, maxHeight - 1]], dtype = "float32")
# 计算变换矩阵
M = cv2.getPerspectiveTransform(rect, dst)
warped = cv2.warpPerspective(image, M, (maxWidth, maxHeight))
# 返回变换后结果
return warped
# Dataset statistics that I gathered in development
#-----------------------------------#
# 用于过滤感知质量较低的不良图片
#-----------------------------------#
IMAGE_MEAN = 0.5
IMAGE_MEAN_STD = 0.028
IMG_STD = 0.28
IMG_STD_STD = 0.01
def readImage(fileName: str) -> np.ndarray:
image = mpimage.imread(fileName)
return image
#-----------------------------------#
# 从文件名中提取车牌的坐标
#-----------------------------------#
def parseLabel(label: str) -> Tuple[np.ndarray, np.ndarray]:
annotation = label.split('-')[3].split('_')
coor1 = [int(i) for i in annotation[0].split('&')]
coor2 = [int(i) for i in annotation[1].split('&')]
coor3 = [int(i) for i in annotation[2].split('&')]
coor4 = [int(i) for i in annotation[3].split('&')]
coor = np.array([coor1, coor2, coor3, coor4])
center = np.mean(coor, axis=0)
return coor, center.astype(int)
#-----------------------------------#
# 根据车牌坐标裁剪出车牌图像
#-----------------------------------#
# def cropImage(image: np.ndarray, coor: np.ndarray, center: np.ndarray) -> np.ndarray:
# image = four_point_transform(image, coor)
# return image
def cropImage(image: np.ndarray, coor: np.ndarray, center: np.ndarray) -> np.ndarray:
maxW = np.max(coor[:, 0] - center[0]) # max plate width
maxH = np.max(coor[:, 1] - center[1]) # max plate height
xWanted = [64, 128, 192, 256]
yWanted = [32, 64, 96, 128]
found = False
for w, h in zip(xWanted, yWanted):
if maxW < w//2 and maxH < h//2:
maxH = h//2
maxW = w//2
found = True
break
if not found: # plate too large, discard
return np.array([])
elif center[1]-maxH < 0 or center[1]+maxH >= image.shape[1] or \
center[0]-maxW < 0 or center[0] + maxW >= image.shape[0]:
return np.array([])
else:
return image[center[1]-maxH:center[1]+maxH, center[0]-maxW:center[0]+maxW]
#-----------------------------------#
# 保存车牌图片
#-----------------------------------#
def saveImage(image: np.ndarray, fileName: str, outDir: str) -> int:
if image.shape[0] == 0:
return 0
else:
imgShape = image.shape
if imgShape[1] == 64:
mpimage.imsave(os.path.join(outDir, '64_32', fileName), image)
elif imgShape[1] == 128:
mpimage.imsave(os.path.join(outDir, '128_64', fileName), image)
elif imgShape[1] == 208:
mpimage.imsave(os.path.join(outDir, '192_96', fileName), image)
else: #resize large images
image = Image.fromarray(image).resize((192, 96))
image = np.asarray(image) # back to numpy array
mpimage.imsave(os.path.join(outDir, '192_96', fileName), image)
return 1
#-----------------------------------#
# 包装成一个函数,以便将处理区分到不同目录
#-----------------------------------#
def processImage(file: str, inputDir: str, outputDir: str, subFolder: str) -> int:
result = parseLabel(file)
filePath = os.path.join(inputDir,subFolder, file)
image = readImage(filePath)
plate = cropImage(image, result[0], result[1])
if plate.shape[0] == 0:
return 0
mean = np.mean(plate/255.0)
std = np.std(plate/255.0)
# bad brightness
if mean <= IMAGE_MEAN - 10*IMAGE_MEAN_STD or mean >= IMAGE_MEAN + 10*IMAGE_MEAN_STD:
return 0
# low contrast
if std <= IMG_STD - 10*IMG_STD_STD:
return 0
status = saveImage(plate, file, outputDir)
return status
def main(argv):
jobNum = int(argv.jobNum)
outputDir = argv.outputDir
inputDir = argv.inputDir
try:
os.mkdir(outputDir)
for shape in ['64_32', '128_64', '192_96']:
os.mkdir(os.path.join(outputDir, shape))
except OSError:
pass # path already exists
client = LocalCluster(n_workers=jobNum, threads_per_worker=5) # IO intensive, more threads
print('* number of workers:{}, \n* input dir:{}, \n* output dir:{}\n\n'.format(jobNum, inputDir, outputDir))
for subFolder in ['ccpd_green', 'ccpd_base', 'ccpd_db', 'ccpd_fn', 'ccpd_rotate', 'ccpd_tilt', 'ccpd_weather']:
fileList = os.listdir(os.path.join(inputDir, subFolder))
print('* {} images found in {}. Start processing ...'.format(len(fileList), subFolder))
toDo = dbag.from_sequence(fileList, npartitions=jobNum*30).persist() # persist the bag in memory
toDo = toDo.map(processImage, inputDir, outputDir, subFolder)
pbar = ProgressBar(minimum=2.0)
pbar.register() # register all computations for better tracking
result = toDo.compute()
print('* image cropped: {}. Done ...'.format(sum(result)))
client.close() # shut down the cluster
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('jobNum', int, 4, '处理图片的线程数')
add_arg('inputDir', str, 'datasets/CCPD2020', '输入图片目录')
add_arg('outputDir', str, 'datasets/CCPD2020_new', '保存图片目录')
args = parser.parse_args()
print_arguments(args)
main(args)