Spaces:
Runtime error
Runtime error
from matplotlib import artist | |
import cv2 | |
import numpy as np | |
import torch | |
import torch.utils.data as data | |
import pandas as pd | |
import os | |
from PIL import Image | |
from .utils import cvtColor, preprocess_input | |
from .utils_aug import CenterCrop, ImageNetPolicy, RandomResizedCrop, Resize | |
class DataGenerator(data.Dataset): | |
def __init__(self, annotation_lines, input_shape, random=True, autoaugment_flag=True): | |
self.artwork_data = annotation_lines | |
self.input_shape = input_shape | |
self.random = random | |
#------------------------------# | |
# 是否使用数据增强 | |
#------------------------------# | |
self.autoaugment_flag = autoaugment_flag | |
if self.autoaugment_flag: | |
self.resize_crop = RandomResizedCrop(input_shape) | |
self.policy = ImageNetPolicy() | |
self.resize = Resize(input_shape[0] if input_shape[0] == input_shape[1] else input_shape) | |
self.center_crop = CenterCrop(input_shape) | |
self.all_features = self.get_all_features(self.artwork_data) | |
def __len__(self): | |
return self.artwork_data.shape[0] | |
def __getitem__(self, index): | |
# 从数据集中获取图像地址 | |
annotation_path = self.artwork_data['Duration (s)'][index] | |
annotation_path = os.path.join('datasets/archive/Dataset', annotation_path) | |
image = Image.open(annotation_path) | |
#------------------------------# | |
# 读取图像并转换成RGB图像 | |
#------------------------------# | |
image = cvtColor(image) | |
if self.autoaugment_flag: | |
image = self.AutoAugment(image, random=self.random) | |
else: | |
image = self.get_random_data(image, self.input_shape, random=self.random) | |
# 去除价格特征 | |
other_features = self.all_features.drop(labels='Prices', axis=1) | |
# 取其它特征作为输入 | |
other_features = other_features.iloc[index].values | |
other_features = np.resize(np.array(other_features, dtype=np.float32), self.input_shape) | |
other_features = np.expand_dims(other_features, 0) | |
image = np.transpose(preprocess_input(np.array(image, dtype=np.float32)), [2, 0, 1]) | |
all_features = np.concatenate((image, other_features), axis=0) | |
# 从数据集中获取价格标签 | |
y = np.expand_dims(np.array(self.all_features['Prices'][index]), axis=-1) | |
return all_features, y | |
def rand(self, a=0, b=1): | |
return np.random.rand()*(b-a) + a | |
''' | |
@description: | |
@param {*} self | |
@param {*} all_features | |
@return {*} 数据集预处理 | |
''' | |
def get_all_features(self, all_features): | |
all_features = all_features.iloc[:, [2,4,5,7,8,9,10,12,16]] | |
all_features = all_features.copy() | |
numeric_features = all_features.dtypes[all_features.dtypes != 'object'].index | |
all_features[numeric_features] = all_features[numeric_features].apply( | |
lambda x: (x - x.mean()) / (x.std())) | |
# 标准化后,每个特征的均值变为0,所以可以直接用0来替换缺失值 | |
all_features[numeric_features] = all_features[numeric_features].fillna(0) | |
# dummy_na=True将缺失值也当作合法的特征值并为其创建指示特征 | |
all_features = pd.get_dummies(all_features, dummy_na=True) | |
return all_features | |
def get_random_data(self, image, input_shape, jitter=.3, hue=.1, sat=1.5, val=1.5, random=True): | |
#------------------------------# | |
# 获得图像的高宽与目标高宽 | |
#------------------------------# | |
iw, ih = image.size | |
h, w = input_shape | |
if not random: | |
scale = min(w/iw, h/ih) | |
nw = int(iw*scale) | |
nh = int(ih*scale) | |
dx = (w-nw)//2 | |
dy = (h-nh)//2 | |
#---------------------------------# | |
# 将图像多余的部分加上灰条 | |
#---------------------------------# | |
image = image.resize((nw,nh), Image.BICUBIC) | |
new_image = Image.new('RGB', (w,h), (128,128,128)) | |
new_image.paste(image, (dx, dy)) | |
image_data = np.array(new_image, np.float32) | |
return image_data | |
#------------------------------------------# | |
# 对图像进行缩放并且进行长和宽的扭曲 | |
#------------------------------------------# | |
new_ar = iw/ih * self.rand(1-jitter,1+jitter) / self.rand(1-jitter,1+jitter) | |
scale = self.rand(.75, 1.5) | |
if new_ar < 1: | |
nh = int(scale*h) | |
nw = int(nh*new_ar) | |
else: | |
nw = int(scale*w) | |
nh = int(nw/new_ar) | |
image = image.resize((nw,nh), Image.BICUBIC) | |
#------------------------------------------# | |
# 将图像多余的部分加上灰条 | |
#------------------------------------------# | |
dx = int(self.rand(0, w-nw)) | |
dy = int(self.rand(0, h-nh)) | |
new_image = Image.new('RGB', (w,h), (128,128,128)) | |
new_image.paste(image, (dx, dy)) | |
image = new_image | |
#------------------------------------------# | |
# 翻转图像 | |
#------------------------------------------# | |
flip = self.rand()<.5 | |
if flip: image = image.transpose(Image.FLIP_LEFT_RIGHT) | |
rotate = self.rand()<.5 | |
if rotate: | |
angle = np.random.randint(-15,15) | |
a,b = w/2,h/2 | |
M = cv2.getRotationMatrix2D((a,b),angle,1) | |
image = cv2.warpAffine(np.array(image), M, (w,h), borderValue=[128, 128, 128]) | |
image_data = np.array(image, np.uint8) | |
#---------------------------------# | |
# 对图像进行色域变换 | |
# 计算色域变换的参数 | |
#---------------------------------# | |
r = np.random.uniform(-1, 1, 3) * [hue, sat, val] + 1 | |
#---------------------------------# | |
# 将图像转到HSV上 | |
#---------------------------------# | |
hue, sat, val = cv2.split(cv2.cvtColor(image_data, cv2.COLOR_RGB2HSV)) | |
dtype = image_data.dtype | |
#---------------------------------# | |
# 应用变换 | |
#---------------------------------# | |
x = np.arange(0, 256, dtype=r.dtype) | |
lut_hue = ((x * r[0]) % 180).astype(dtype) | |
lut_sat = np.clip(x * r[1], 0, 255).astype(dtype) | |
lut_val = np.clip(x * r[2], 0, 255).astype(dtype) | |
image_data = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val))) | |
image_data = cv2.cvtColor(image_data, cv2.COLOR_HSV2RGB) | |
return image_data | |
def AutoAugment(self, image, random=True): | |
if not random: | |
image = self.resize(image) | |
image = self.center_crop(image) | |
return image | |
#------------------------------------------# | |
# resize并且随即裁剪 | |
#------------------------------------------# | |
image = self.resize_crop(image) | |
#------------------------------------------# | |
# 翻转图像 | |
#------------------------------------------# | |
flip = self.rand()<.5 | |
if flip: image = image.transpose(Image.FLIP_LEFT_RIGHT) | |
#------------------------------------------# | |
# 随机增强 | |
#------------------------------------------# | |
image = self.policy(image) | |
return image | |
def detection_collate(batch): | |
images = [] | |
targets = [] | |
for image, y in batch: | |
images.append(image) | |
targets.append(y) | |
images = torch.from_numpy(np.array(images)).type(torch.FloatTensor) | |
targets = torch.from_numpy(np.array(targets)).type(torch.FloatTensor) | |
return images, targets | |