GMC-IQA / utils /dataset /folders.py
Zevin2023's picture
MoC-IQA
07e1105
raw
history blame
No virus
6.1 kB
import torch.utils.data as data
import torch
from PIL import Image
import os
import scipy.io
import numpy as np
import csv
from openpyxl import load_workbook
import cv2
class LIVEC(data.Dataset):
def __init__(self, root, index, transform):
imgpath = scipy.io.loadmat(os.path.join(root, 'Data', 'AllImages_release.mat'))
imgpath = imgpath['AllImages_release']
imgpath = imgpath[7:1169]
mos = scipy.io.loadmat(os.path.join(root, 'Data', 'AllMOS_release.mat'))
labels = mos['AllMOS_release'].astype(np.float32)
labels = labels[0][7:1169]
sample, gt = [], []
for i, item in enumerate(index):
sample.append(os.path.join(root, 'Images', imgpath[item][0][0]))
gt.append(labels[item])
gt = normalization(gt)
self.samples, self.gt = sample, gt
self.transform = transform
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
img_tensor, gt_tensor = get_item(self.samples, self.gt, index, self.transform)
return img_tensor, gt_tensor
def __len__(self):
length = len(self.samples)
return length
class Koniq10k(data.Dataset):
def __init__(self, root, index, transform):
imgname = []
mos_all = []
csv_file = os.path.join(root, 'koniq10k_distributions_sets.csv')
with open(csv_file) as f:
reader = csv.DictReader(f)
for row in reader:
imgname.append(row['image_name'])
mos = np.array(float(row['MOS'])).astype(np.float32)
mos_all.append(mos)
sample, gt = [], []
for i, item in enumerate(index):
sample.append(os.path.join(root, '1024x768', imgname[item]))
gt.append(mos_all[item])
gt = normalization(gt)
self.samples, self.gt = sample, gt
self.transform = transform
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
img_tensor, gt_tensor = get_item(self.samples, self.gt, index, self.transform)
return img_tensor, gt_tensor
def __len__(self):
length = len(self.samples)
return length
class SPAQ(data.Dataset):
def __init__(self, root, index, transform):
imgname = []
mos_all = []
csv_file = os.path.join(root, 'koniq10k_scores_and_distributions.csv')
with open(csv_file) as f:
reader = csv.DictReader(f)
for row in reader:
imgname.append(row['image_name'])
mos = np.array(float(row['MOS_zscore'])).astype(np.float32)
mos_all.append(mos)
sample, gt = [], []
for i, item in enumerate(index):
sample.append(os.path.join(root, '1024x768', imgname[item]))
gt.append(labels[item])
gt = norm_target(gt)
self.samples, self.gt = sample, gt
self.samples, self.gt = sample, gt
self.transform = transform
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
path, target = self.samples[index], self.gt[index]
sample = pil_loader(path)
sample = self.transform(sample)
return sample, target
def __len__(self):
length = len(self.samples)
return length
class BID(data.Dataset):
def __init__(self, root, index, transform):
imgname = []
mos_all = []
xls_file = os.path.join(root, 'DatabaseGrades.xlsx')
workbook = load_workbook(xls_file)
booksheet = workbook.active
rows = booksheet.rows
count = 1
for row in rows:
count += 1
img_num = booksheet.cell(row=count, column=1).value
img_name = "DatabaseImage%04d.JPG" % (img_num)
imgname.append(img_name)
mos = booksheet.cell(row=count, column=2).value
mos = np.array(mos)
mos = mos.astype(np.float32)
mos_all.append(mos)
if count == 587:
break
sample, gt = [], []
for i, item in enumerate(index):
sample.append(os.path.join(root, imgname[item]))
gt.append(mos_all[item])
gt = normalization(gt)
self.samples, self.gt = sample, gt
self.transform = transform
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
img_tensor, gt_tensor = get_item(self.samples, self.gt, index, self.transform)
return img_tensor, gt_tensor
def __len__(self):
length = len(self.samples)
return length
def get_item(samples, gt, index, transform):
path, target = samples[index], gt[index]
sample = load_image(path)
samples = {'img': sample, 'gt': target }
samples = transform(samples)
return samples['img'], samples['gt'].type(torch.FloatTensor)
def getFileName(path, suffix):
filename = []
f_list = os.listdir(path)
for i in f_list:
if os.path.splitext(i)[1] == suffix:
filename.append(i)
return filename
def load_image(img_path):
d_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
d_img = cv2.resize(d_img, (224, 224), interpolation=cv2.INTER_CUBIC)
d_img = cv2.cvtColor(d_img, cv2.COLOR_BGR2RGB)
d_img = np.array(d_img).astype('float32') / 255
d_img = np.transpose(d_img, (2, 0, 1))
return d_img
def normalization(data):
data = np.array(data)
range = np.max(data) - np.min(data)
data = (data - np.min(data)) / range
data = list(data.astype('float').reshape(-1, 1))
return data