Spaces:
vztu
/
Runtime error

COVER / cover /models /evaluator.py
nanushio
- [MINOR] [SOURCE] [UPDATE] 1. update app.py
a877246
import time
from functools import partial, reduce
import torch
import torch.nn as nn
from torch.nn.functional import adaptive_avg_pool3d
from .conv_backbone import convnext_3d_small, convnext_3d_tiny, convnextv2_3d_pico, convnextv2_3d_femto, clip_vitL14
from .head import IQAHead, VARHead, VQAHead
from .swin_backbone import SwinTransformer2D as ImageBackbone
from .swin_backbone import SwinTransformer3D as VideoBackbone
from .swin_backbone import swin_3d_small, swin_3d_tiny
class BaseEvaluator(nn.Module):
def __init__(
self, backbone=dict(), vqa_head=dict(),
):
super().__init__()
self.backbone = VideoBackbone(**backbone)
self.vqa_head = VQAHead(**vqa_head)
def forward(self, vclip, inference=True, **kwargs):
if inference:
self.eval()
with torch.no_grad():
feat = self.backbone(vclip)
score = self.vqa_head(feat)
self.train()
return score
else:
feat = self.backbone(vclip)
score = self.vqa_head(feat)
return score
def forward_with_attention(self, vclip):
self.eval()
with torch.no_grad():
feat, avg_attns = self.backbone(vclip, require_attn=True)
score = self.vqa_head(feat)
return score, avg_attns
class COVER(nn.Module):
def __init__(
self,
backbone_size="divided",
backbone_preserve_keys="fragments,resize",
multi=False,
layer=-1,
backbone=dict(
resize={"window_size": (4, 4, 4)}, fragments={"window_size": (4, 4, 4)}
),
divide_head=False,
vqa_head=dict(in_channels=768),
var=False,
):
self.backbone_preserve_keys = backbone_preserve_keys.split(",")
self.multi = multi
self.layer = layer
super().__init__()
for key, hypers in backbone.items():
print(backbone_size)
if key not in self.backbone_preserve_keys:
continue
if backbone_size == "divided":
t_backbone_size = hypers["type"]
else:
t_backbone_size = backbone_size
if t_backbone_size == "swin_tiny":
b = swin_3d_tiny(**backbone[key])
elif t_backbone_size == "swin_tiny_grpb":
# to reproduce fast-vqa
b = VideoBackbone()
elif t_backbone_size == "swin_tiny_grpb_m":
# to reproduce fast-vqa-m
b = VideoBackbone(window_size=(4, 4, 4), frag_biases=[0, 0, 0, 0])
elif t_backbone_size == "swin_small":
b = swin_3d_small(**backbone[key])
elif t_backbone_size == "conv_tiny":
b = convnext_3d_tiny(pretrained=True)
elif t_backbone_size == "conv_small":
b = convnext_3d_small(pretrained=True)
elif t_backbone_size == "conv_femto":
b = convnextv2_3d_femto(pretrained=True)
elif t_backbone_size == "conv_pico":
b = convnextv2_3d_pico(pretrained=True)
elif t_backbone_size == "xclip":
raise NotImplementedError
elif t_backbone_size == "clip_iqa+":
b = clip_vitL14(pretrained=True)
else:
raise NotImplementedError
print("Setting backbone:", key + "_backbone")
setattr(self, key + "_backbone", b)
if divide_head:
for key in backbone:
pre_pool = False #if key == "technical" else True
if key not in self.backbone_preserve_keys:
continue
b = VQAHead(pre_pool=pre_pool, **vqa_head)
print("Setting head:", key + "_head")
setattr(self, key + "_head", b)
else:
if var:
self.vqa_head = VARHead(**vqa_head)
print(b)
else:
self.vqa_head = VQAHead(**vqa_head)
self.smtc_gate_tech = CrossGatingBlock(x_features=768, num_channels=768, block_size=1,
grid_size=1, upsample_y=False, dropout_rate=0.1, use_bias=True, use_global_mlp=False)
self.smtc_gate_aesc = CrossGatingBlock(x_features=768, num_channels=768, block_size=1,
grid_size=1, upsample_y=False, dropout_rate=0.1, use_bias=True, use_global_mlp=False)
def forward(
self,
vclips,
inference=True,
return_pooled_feats=False,
return_raw_feats=False,
reduce_scores=False,
pooled=False,
**kwargs
):
assert (return_pooled_feats & return_raw_feats) == False, "Please only choose one kind of features to return"
if inference:
self.eval()
with torch.no_grad():
scores = []
feats = {}
for key in vclips:
if key == 'technical' or key == 'aesthetic':
feat = getattr(self, key.split("_")[0] + "_backbone")(
vclips[key], multi=self.multi, layer=self.layer, **kwargs
)
if key == 'technical':
feat_gated = self.smtc_gate_tech(feats['semantic'], feat)
elif key == 'aesthetic':
feat_gated = self.smtc_gate_aesc(feats['semantic'], feat)
if hasattr(self, key.split("_")[0] + "_head"):
scores += [getattr(self, key.split("_")[0] + "_head")(feat_gated)]
else:
scores += [getattr(self, "vqa_head")(feat_gated)]
elif key == 'semantic':
x = vclips[key].squeeze(0)
x = x.permute(1,0,2,3)
feat, _ = getattr(self, key.split("_")[0] + "_backbone")(
x, multi=self.multi, layer=self.layer, **kwargs
)
# for image feature from clipiqa+ VIT14
# image feature shape (t, c) -> (16, 768)
feat = feat.permute(1,0).contiguous() # (c, t) -> (768, 16)
feat = feat.unsqueeze(-1).unsqueeze(-1) # (c, t, w, h) -> (768, 16, 1, 1)
feat_expand = feat.expand(-1, -1, 7, 7) # (c, t, w, h) -> (768, 16, 7, 7)
feat_expand = feat_expand.unsqueeze(0) # (b, c, t, w, h) -> (1, 768, 16, 7, 7)
if hasattr(self, key.split("_")[0] + "_head"):
score = getattr(self, key.split("_")[0] + "_head")(feat_expand)
else:
score = getattr(self, "vqa_head")(feat_expand)
scores += [score]
feats[key] = feat_expand
if reduce_scores:
if len(scores) > 1:
scores = reduce(lambda x, y: x + y, scores)
else:
scores = scores[0]
if pooled:
scores = torch.mean(scores, (1, 2, 3, 4))
self.train()
if return_pooled_feats or return_raw_feats:
return scores, feats
return scores
else:
self.train()
scores = []
feats = {}
for key in vclips:
if key == 'technical' or key == 'aesthetic':
feat = getattr(self, key.split("_")[0] + "_backbone")(
vclips[key], multi=self.multi, layer=self.layer, **kwargs
)
if key == 'technical':
feat_gated = self.smtc_gate_tech(feats['semantic'], feat)
elif key == 'aesthetic':
feat_gated = self.smtc_gate_aesc(feats['semantic'], feat)
if hasattr(self, key.split("_")[0] + "_head"):
scores += [getattr(self, key.split("_")[0] + "_head")(feat_gated)]
else:
scores += [getattr(self, "vqa_head")(feat_gated)]
feats[key] = feat
elif key == 'semantic':
scores_semantic_list = []
feats_semantic_list = []
for batch_idx in range(vclips[key].shape[0]):
x = vclips[key][batch_idx].squeeze()
x = x.permute(1,0,2,3)
feat, _ = getattr(self, key.split("_")[0] + "_backbone")(
x, multi=self.multi, layer=self.layer, **kwargs
)
# for image feature from clipiqa+ VIT14
# image feature shape (t, c) -> (16, 768)
feat = feat.permute(1,0).contiguous() # (c, t) -> (768, 16)
feat = feat.unsqueeze(-1).unsqueeze(-1) # (c, t, w, h) -> (768, 16, 1, 1)
feat_expand = feat.expand(-1, -1, 7, 7) # (c, t, w, h) -> (768, 16, 7, 7)
feats_semantic_list.append(feat_expand)
if hasattr(self, key.split("_")[0] + "_head"):
feat_expand = feat_expand.unsqueeze(0) # (b, c, t, w, h) -> (1, 768, 16, 7, 7)
score = getattr(self, key.split("_")[0] + "_head")(feat_expand)
score = score.squeeze(0)
scores_semantic_list.append(score)
else:
feat_expand = feat_expand.unsqueeze(0) # (b, c, t, w, h) -> (1, 768, 16, 7, 7)
score = getattr(self, "vqa_head")(feat_expand)
score = score.squeeze(0)
scores_semantic_list.append(score)
scores_semantic_tensor = torch.stack(scores_semantic_list)
feats[key] = torch.stack(feats_semantic_list)
scores += [scores_semantic_tensor]
if return_pooled_feats:
feats[key] = feat.mean((-3, -2, -1))
if reduce_scores:
if len(scores) > 1:
scores = reduce(lambda x, y: x + y, scores)
else:
scores = scores[0]
if pooled:
print(scores.shape)
scores = torch.mean(scores, (1, 2, 3, 4))
print(scores.shape)
if return_pooled_feats:
return scores, feats
return scores
def forward_head(
self,
feats,
inference=True,
reduce_scores=False,
pooled=False,
**kwargs
):
if inference:
self.eval()
with torch.no_grad():
scores = []
feats = {}
for key in feats:
feat = feats[key]
if hasattr(self, key.split("_")[0] + "_head"):
scores += [getattr(self, key.split("_")[0] + "_head")(feat)]
else:
scores += [getattr(self, "vqa_head")(feat)]
if reduce_scores:
if len(scores) > 1:
scores = reduce(lambda x, y: x + y, scores)
else:
scores = scores[0]
if pooled:
scores = torch.mean(scores, (1, 2, 3, 4))
self.train()
return scores
else:
self.train()
scores = []
feats = {}
for key in vclips:
feat = getattr(self, key.split("_")[0] + "_backbone")(
vclips[key], multi=self.multi, layer=self.layer, **kwargs
)
if hasattr(self, key.split("_")[0] + "_head"):
scores += [getattr(self, key.split("_")[0] + "_head")(feat)]
else:
scores += [getattr(self, "vqa_head")(feat)]
if return_pooled_feats:
feats[key] = feat
if reduce_scores:
if len(scores) > 1:
scores = reduce(lambda x, y: x + y, scores)
else:
scores = scores[0]
if pooled:
print(scores.shape)
scores = torch.mean(scores, (1, 2, 3, 4))
print(scores.shape)
if return_pooled_feats:
return scores, feats
return scores
class MinimumCOVER(nn.Module):
def __init__(self):
super().__init__()
self.technical_backbone = VideoBackbone()
self.aesthetic_backbone = convnext_3d_tiny(pretrained=True)
self.technical_head = VQAHead(pre_pool=False, in_channels=768)
self.aesthetic_head = VQAHead(pre_pool=False, in_channels=768)
def forward(self,aesthetic_view, technical_view):
self.eval()
with torch.no_grad():
aesthetic_score = self.aesthetic_head(self.aesthetic_backbone(aesthetic_view))
technical_score = self.technical_head(self.technical_backbone(technical_view))
aesthetic_score_pooled = torch.mean(aesthetic_score, (1,2,3,4))
technical_score_pooled = torch.mean(technical_score, (1,2,3,4))
return [aesthetic_score_pooled, technical_score_pooled]
class BaseImageEvaluator(nn.Module):
def __init__(
self, backbone=dict(), iqa_head=dict(),
):
super().__init__()
self.backbone = ImageBackbone(**backbone)
self.iqa_head = IQAHead(**iqa_head)
def forward(self, image, inference=True, **kwargs):
if inference:
self.eval()
with torch.no_grad():
feat = self.backbone(image)
score = self.iqa_head(feat)
self.train()
return score
else:
feat = self.backbone(image)
score = self.iqa_head(feat)
return score
def forward_with_attention(self, image):
self.eval()
with torch.no_grad():
feat, avg_attns = self.backbone(image, require_attn=True)
score = self.iqa_head(feat)
return score, avg_attns
class CrossGatingBlock(nn.Module): #input shape: n, c, h, w
"""Cross-gating MLP block."""
def __init__(self, x_features, num_channels, block_size, grid_size, cin_y=0,upsample_y=True, use_bias=True, use_global_mlp=True, dropout_rate=0):
super().__init__()
self.cin_y = cin_y
self.x_features = x_features
self.num_channels = num_channels
self.block_size = block_size
self.grid_size = grid_size
self.upsample_y = upsample_y
self.use_bias = use_bias
self.use_global_mlp = use_global_mlp
self.drop = dropout_rate
self.Conv_0 = nn.Linear(self.x_features, self.num_channels)
self.Conv_1 = nn.Linear(self.num_channels, self.num_channels)
self.in_project_x = nn.Linear(self.num_channels, self.num_channels, bias=self.use_bias)
self.gelu1 = nn.GELU(approximate='tanh')
self.out_project_y = nn.Linear(self.num_channels, self.num_channels, bias=self.use_bias)
self.dropout1 = nn.Dropout(self.drop)
def forward(self, x,y): #n,c,t,h,w
# Upscale Y signal, y is the gating signal.
assert y.shape == x.shape
x = x.permute(0,2,3,4,1).contiguous() #n,t,h,w,c
y = y.permute(0,2,3,4,1).contiguous() #n,t,h,w,c
x = self.Conv_0(x)
y = self.Conv_1(y)
shortcut_y = y
x = self.in_project_x(x)
gx = self.gelu1(x)
# Apply cross gating
y = y * gx # gating y using x
y = self.out_project_y(y)
y = self.dropout1(y)
y = y + shortcut_y # y = y * x + y
return y.permute(0,4,1,2,3).contiguous() #n,c,t,h,w