GMC-IQA / utils /iqa_solver.py
Zevin2023's picture
MoC-IQA
07e1105
raw
history blame
No virus
4.58 kB
import torch
from scipy import stats
import numpy as np
from models import monet as MoNet
from models import gc_loss as GC_Loss
from utils.dataset import data_loader
import json
import random
import os
from tqdm import tqdm
def get_data(dataset, data_path='./utils/dataset/dataset_info.json'):
with open(data_path, 'r') as data_info:
data_info = json.load(data_info)
path, img_num = data_info[dataset]
img_num = list(range(img_num))
random.shuffle(img_num)
train_index = img_num[0:int(round(0.8 * len(img_num)))]
test_index = img_num[int(round(0.8 * len(img_num))):len(img_num)]
return path, train_index, test_index
def cal_srocc_plcc(pred_score, gt_score):
srocc, _ = stats.spearmanr(pred_score, gt_score)
plcc, _ = stats.pearsonr(pred_score, gt_score)
return srocc, plcc
class Solver:
def __init__(self, config):
path, train_index, test_index = get_data(dataset=config.dataset)
train_loader = data_loader.Data_Loader(config, path, train_index, istrain=True)
test_loader = data_loader.Data_Loader(config, path, test_index, istrain=False)
self.train_data = train_loader.get_data()
self.test_data = test_loader.get_data()
print('Traning data number: ', len(train_index))
print('Testing data number: ', len(test_index))
if config.loss == 'MAE':
self.loss = torch.nn.L1Loss().cuda()
elif config.loss == 'MSE':
self.loss = torch.nn.MSELoss().cuda()
elif config.loss == 'GC':
self.loss = GC_Loss.GC_Loss(queue_len=int(len(train_index) * config.queue_ratio))
else:
raise 'Only Support MAE, MSE and GC loss.'
print('Loading MoNet...')
self.MoNet = MoNet.MoNet(config).cuda()
self.MoNet.train(True)
self.epochs = config.epochs
self.optimizer = torch.optim.Adam(self.MoNet.parameters(), lr=config.lr, weight_decay=config.weight_decay)
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=config.T_max, eta_min=config.eta_min)
self.model_save_path = os.path.join(config.save_path, 'best_model.pkl')
def train(self):
"""Training"""
best_srocc = 0.0
best_plcc = 0.0
print('----------------------------------')
print('Epoch\tTrain_Loss\tTrain_SROCC\tTrain_PLCC\tTest_SROCC\tTest_PLCC')
for t in range(self.epochs):
epoch_loss = []
pred_scores = []
gt_scores = []
for img, label in tqdm(self.train_data):
img = img.cuda()
label = label.view(-1).cuda()
self.optimizer.zero_grad()
pred = self.MoNet(img) # 'paras' contains the network weights conveyed to target network
pred_scores = pred_scores + pred.cpu().tolist()
gt_scores = gt_scores + label.cpu().tolist()
loss = self.loss(pred.squeeze(), label.float().detach())
epoch_loss.append(loss.item())
loss.backward()
self.optimizer.step()
self.scheduler.step()
train_srocc, train_plcc = cal_srocc_plcc(pred_scores, gt_scores)
test_srocc, test_plcc = self.test()
if test_srocc + test_plcc > best_srocc + best_plcc:
best_srocc = test_srocc
best_plcc = test_plcc
torch.save(self.MoNet.state_dict(), self.model_save_path)
print('Model saved in: ', self.model_save_path)
print('{}\t{}\t{}\t{}\t{}\t{}'.format(t + 1, round(np.mean(epoch_loss), 4), round(train_srocc, 4),
round(train_plcc, 4), round(test_srocc, 4), round(test_plcc, 4)))
print('Best test SROCC {}, PLCC {}'.format(round(best_srocc, 4), round(best_plcc, 4)))
return best_srocc, best_plcc
def test(self):
"""Testing"""
self.MoNet.train(False)
pred_scores = []
gt_scores = []
with torch.no_grad():
for img, label in tqdm(self.test_data):
# Data.
img = img.cuda()
label = label.view(-1).cuda()
pred = self.MoNet(img)
pred_scores = pred_scores + pred.cpu().tolist()
gt_scores = gt_scores + label.cpu().tolist()
test_srocc, test_plcc = cal_srocc_plcc(pred_scores, gt_scores)
self.MoNet.train(True)
return test_srocc, test_plcc