GMC-IQA / models /gc_loss.py
Zevin2023's picture
MoC-IQA
07e1105
raw
history blame contribute delete
No virus
6.16 kB
import torch.nn as nn
import torch
import numpy as np
# class GC_Loss(nn.Module):
# def __init__(self, queue_len=800):
# super(GC_Loss, self).__init__()
# self.pred_queue = list()
# self.gt_queue = list()
# self.queue_len = 0
# self.queue_max_len = queue_len
# print('CCWD Length: ', queue_len)
# self.l1_loss = torch.nn.L1Loss().cuda()
# self.l2_loss = torch.nn.MSELoss().cuda()
# def enqueue(self, pred, gt):
# bs = pred.shape[0]
# self.queue_len = self.queue_len + bs
# self.pred_queue = self.pred_queue + pred.cpu().detach().numpy().tolist()
# self.gt_queue = self.gt_queue + gt.cpu().detach().numpy().tolist()
# if self.queue_len > self.queue_max_len:
# self.dequeue(self.queue_len - self.queue_max_len)
# self.queue_len = self.queue_max_len
# def dequeue(self, n):
# for index in range(n):
# self.pred_queue.pop(0)
# self.gt_queue.pop(0)
# def clear(self):
# self.pred_queue.clear()
# self.gt_queue.clear()
# def forward(self, x, y):
# x_queue = self.pred_queue.copy()
# y_queue = self.gt_queue.copy()
# # 获取队列中的所有值
# x_all = torch.cat((x, torch.tensor(x_queue).cuda()), dim=0)
# y_all = torch.cat((y, torch.tensor(y_queue).cuda()), dim=0)
# # 估计均值和方差
# x_bar = torch.mean(x_all, dim=0)
# x_std = torch.std(x_all, dim=0)
# y_bar = torch.mean(y_all, dim=0)
# y_std = torch.std(y_all, dim=0)
# # 估计预测值在整体值中的PLCC
# diff_x_plcc = (x - x_bar) # [bs, 1]
# diff_y_plcc = (y - y_bar) # [bs, 1]
# x1 = torch.sum(torch.mul(diff_x_plcc, diff_y_plcc))
# x2_1 = torch.sqrt(torch.sum(torch.mul(diff_x_plcc, diff_x_plcc)))
# x2_2 = torch.sqrt(torch.sum(torch.mul(diff_y_plcc, diff_y_plcc)))
# # 对所有值标准化
# diff_x = (x_all - x_bar) / x_std # [bs, 1]
# diff_y = (y_all - y_bar) / y_std # [bs, 1]
# rank_x = diff_x.reshape(-1, 1)
# rank_y = diff_y.reshape(-1, 1)
# rank_x = rank_x - rank_x.transpose(1, 0)
# rank_y = rank_y - rank_y.transpose(1, 0)
# # 对所有值估计排序
# rank_x = torch.sum(1 / 2 * (1 + torch.erf(rank_x)), dim=1)
# rank_y = torch.sum(1 / 2 * (1 + torch.erf(rank_y)), dim=1)
# # 计算排序后的均值和方差
# rank_x_bar = torch.mean(rank_x, dim=0)
# rank_x_std = torch.std(rank_x, dim=0)
# rank_y_bar = torch.mean(rank_y, dim=0)
# rank_y_std = torch.std(rank_y, dim=0)
# # 估计预测值在整体值中的SROCC
# rank_x_ = (x - rank_x_bar) / rank_x_std # [bs, 1]
# rank_y_ = (y - rank_y_bar) / rank_y_std # [bs, 1]
# x1_rank = torch.sum(torch.mul(rank_x_, rank_y_))
# x2_1_rank = torch.sqrt(torch.sum(torch.mul(rank_x_, rank_x_)))
# x2_2_rank = torch.sqrt(torch.sum(torch.mul(rank_y_, rank_y_)))
# self.enqueue(x, y)
# return (0.5 * ((1 - x1 / (x2_1 * x2_2)) + (1 - (x1_rank / (x2_1_rank * x2_2_rank)))) + 1) * self.l2_loss(x, y)
class GC_Loss(nn.Module):
def __init__(self, queue_len=800, alpha=0.5, beta=0.5, gamma=1):
super(GC_Loss, self).__init__()
self.pred_queue = list()
self.gt_queue = list()
self.queue_len = 0
self.queue_max_len = queue_len
print('The queue length is: ', self.queue_max_len)
self.mse = torch.nn.MSELoss().cuda()
self.alpha, self.beta, self.gamma = alpha, beta, gamma
def consistency(self, pred_data, gt_data):
pred_one_batch, pred_queue = pred_data
gt_one_batch, gt_queue = gt_data
pred_mean = torch.mean(pred_queue)
gt_mean = torch.mean(gt_queue)
diff_pred = pred_one_batch - pred_mean
diff_gt = gt_one_batch - gt_mean
x1 = torch.sum(torch.mul(diff_pred, diff_gt))
x2_1 = torch.sqrt(torch.sum(torch.mul(diff_pred, diff_pred)))
x2_2 = torch.sqrt(torch.sum(torch.mul(diff_gt, diff_gt)))
return x1 / (x2_1 * x2_2)
def ppra(self, x):
"""
Pairwise Preference-based Rank Approximation
"""
x_bar, x_std = torch.mean(x), torch.std(x)
x_n = (x - x_bar) / x_std
x_n_T = x_n.reshape(-1, 1)
rank_x = x_n_T - x_n_T.transpose(1, 0)
rank_x = torch.sum(1 / 2 * (1 + torch.erf(rank_x / torch.sqrt(torch.tensor(2, dtype=torch.float)))), dim=1)
return rank_x
@torch.no_grad()
def enqueue(self, pred, gt):
bs = pred.shape[0]
self.queue_len = self.queue_len + bs
self.pred_queue = self.pred_queue + pred.tolist()
self.gt_queue = self.gt_queue + gt.cpu().detach().numpy().tolist()
if self.queue_len > self.queue_max_len:
self.dequeue(self.queue_len - self.queue_max_len)
self.queue_len = self.queue_max_len
@torch.no_grad()
def dequeue(self, n):
for _ in range(n):
self.pred_queue.pop(0)
self.gt_queue.pop(0)
def clear(self):
self.pred_queue.clear()
self.gt_queue.clear()
def forward(self, x, y):
x_queue = self.pred_queue.copy()
y_queue = self.gt_queue.copy()
x_all = torch.cat((x, torch.tensor(x_queue).cuda()), dim=0)
y_all = torch.cat((y, torch.tensor(y_queue).cuda()), dim=0)
PLCC = self.consistency((x, x_all), (y, y_all))
PGC = 1 - PLCC
rank_x = self.ppra(x_all)
rank_y = self.ppra(y_all)
SROCC = self.consistency((rank_x[:x.shape[0]], rank_x), (rank_y[:y.shape[0]], rank_y))
SGC = 1 - SROCC
GC = (self.alpha * PGC + self.beta * SGC + self.gamma) * self.mse(x, y)
self.enqueue(x, y)
return GC
if __name__ == '__main__':
gc = GC_Loss().cuda()
x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float).cuda()
y = torch.tensor([6, 7, 8, 9, 15], dtype=torch.float).cuda()
res = gc(x, y)
print(res)