mazpie's picture
Initial commit
2d9a728
raw
history blame
55.2 kB
import datetime
import logging
import time
import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
from einops import rearrange
from models.criterions import get_sim
from utils.basic_utils import MetricLogger
from utils.distributed import get_rank, get_world_size
logger = logging.getLogger(__name__)
def extract_text_feats(texts, max_txt_l, tokenizer, model, device, return_ids=False):
num_text = len(texts)
text_bs = 256
text_feats = []
text_atts = []
if return_ids:
text_ids = []
for i in range(0, num_text, text_bs):
text = texts[i : min(num_text, i + text_bs)]
text_input = tokenizer(
text,
padding="max_length",
truncation=True,
max_length=max_txt_l,
return_tensors="pt",
).to(device) # NOTE not need to cast
text_feat = model.encode_text(text_input)[0]
text_feats.append(text_feat)
text_atts.append(text_input.attention_mask)
if return_ids:
text_ids.append(text_input.input_ids)
text_feats = torch.cat(text_feats, dim=0)
text_atts = torch.cat(text_atts, dim=0)
if return_ids:
text_ids = torch.cat(text_ids, dim=0)
return text_feats, text_atts, text_ids
else:
return text_feats, text_atts
def extract_vision_feats(data_loader, model, device, config):
if config.use_half_precision:
if config.get('use_bf16', False):
cast_dtype = torch.bfloat16
else:
cast_dtype = torch.float16
else:
cast_dtype = None
image_feats_all = []
pooled_image_feats_all = []
metric_logger = MetricLogger(delimiter=" ")
header = "extracting image feats"
iterator = metric_logger.log_every(data_loader, 100, header)
for image, img_id in iterator:
image = image.to(device, dtype=cast_dtype, non_blocking=True)
image_feat, pooled_image_feat = model.encode_vision(image, test=True)
if len(pooled_image_feat.shape) == 2:
pooled_image_feat = pooled_image_feat.unsqueeze(1) # make av_fusion happy
if config.evaluation.eval_frame_ensemble == "concat":
if len(image_feat.shape) == 4:
image_feat = rearrange(image_feat, "b t l c -> b (t l) c").contiguous()
image_feat = image_feat.unsqueeze(1) # (bsz, 1, #frm*L, d)
else:
assert config.video_input.num_frames == 1, "only support single-frame"
assert config.evaluation.eval_frame_ensemble in ["mean", "max", "lse"]
if config.evaluation.eval_offload:
image_feats_all.append(image_feat.cpu())
pooled_image_feats_all.append(pooled_image_feat.cpu())
else:
image_feats_all.append(image_feat)
pooled_image_feats_all.append(pooled_image_feat)
image_feats_all = torch.cat(image_feats_all, dim=0)
pooled_image_feats_all = torch.cat(pooled_image_feats_all, dim=0)
return image_feats_all, pooled_image_feats_all
def extract_audio_feats(data_loader, model, device, config):
if config.use_half_precision:
if config.get('use_bf16', False):
cast_dtype = torch.bfloat16
else:
cast_dtype = torch.float16
else:
cast_dtype = None
audio_feats_all = []
pooled_audio_feats_all = []
metric_logger = MetricLogger(delimiter=" ")
header = "extracting audio feats"
iterator = metric_logger.log_every(data_loader, 100, header)
for audio, _ in iterator:
audio = audio.to(device, dtype=cast_dtype, non_blocking=True)
audio_feat, pooled_audio_feat = model.encode_audio(audio, test=True)
audio_feat = audio_feat.unsqueeze(1) # make deep_fusion happy
pooled_audio_feat = pooled_audio_feat.unsqueeze(1)
if config.evaluation.eval_offload:
audio_feats_all.append(audio_feat.cpu())
pooled_audio_feats_all.append(pooled_audio_feat.cpu())
else:
audio_feats_all.append(audio_feat)
pooled_audio_feats_all.append(pooled_audio_feat)
audio_feats_all = torch.cat(audio_feats_all, dim=0)
pooled_audio_feats_all = torch.cat(pooled_audio_feats_all, dim=0)
return audio_feats_all, pooled_audio_feats_all
def extract_audio_vision_feats(data_loader, model, device, config):
if config.use_half_precision:
if config.get('use_bf16', False):
cast_dtype = torch.bfloat16
else:
cast_dtype = torch.float16
else:
cast_dtype = None
audio_feats_all = []
pooled_audio_feats_all = []
image_feats_all = []
pooled_image_feats_all = []
metric_logger = MetricLogger(delimiter=" ")
header = "extracting audio and vision feats"
iterator = metric_logger.log_every(data_loader, 100, header)
for media, _ in iterator:
audio = media[0]
image = media[1]
audio = audio.to(device, dtype=cast_dtype, non_blocking=True)
image = image.to(device, dtype=cast_dtype, non_blocking=True)
audio_feat, pooled_audio_feat = model.encode_audio(audio, test=True)
audio_feat = audio_feat.unsqueeze(1) # make deep_fusion happy
pooled_audio_feat = pooled_audio_feat.unsqueeze(1)
image_feat, pooled_image_feat = model.encode_vision(image, test=True)
if len(pooled_image_feat.shape) == 2:
pooled_image_feat = pooled_image_feat.unsqueeze(1) # make av_fusion happy
if config.evaluation.eval_frame_ensemble == "concat":
if len(image_feat.shape) == 4:
image_feat = rearrange(image_feat, "b t l c -> b (t l) c").contiguous()
image_feat = image_feat.unsqueeze(1) # (bsz, 1, #frm*L, d)
else:
assert config.video_input.num_frames == 1, "only support single-frame"
assert config.evaluation.eval_frame_ensemble in ["mean", "max", "lse"]
if config.evaluation.eval_offload:
audio_feats_all.append(audio_feat.cpu())
pooled_audio_feats_all.append(pooled_audio_feat.cpu())
image_feats_all.append(image_feat.cpu())
pooled_image_feats_all.append(pooled_image_feat.cpu())
else:
audio_feats_all.append(audio_feat)
pooled_audio_feats_all.append(pooled_audio_feat)
image_feats_all.append(image_feat)
pooled_image_feats_all.append(pooled_image_feat)
audio_feats_all = torch.cat(audio_feats_all, dim=0)
pooled_audio_feats_all = torch.cat(pooled_audio_feats_all, dim=0)
image_feats_all = torch.cat(image_feats_all, dim=0)
pooled_image_feats_all = torch.cat(pooled_image_feats_all, dim=0)
return audio_feats_all, pooled_audio_feats_all, image_feats_all, pooled_image_feats_all
@torch.no_grad()
def evaluation_wrapper(model, data_loader, tokenizer, device, config, prefix=""):
amp_eval_enabled = config.use_half_precision and not (hasattr(config, "deepspeed") and config.deepspeed.enable)
logger.info(f"Begin to eval, model_without_ddp.dtype={model.dtype if hasattr(model, 'dtype') else None}, amp_eval_enabled={amp_eval_enabled}, dtype={torch.bfloat16 if config.get('use_bf16', False) else torch.float16}")
with torch.cuda.amp.autocast(enabled=amp_eval_enabled, dtype=torch.bfloat16 if config.get('use_bf16', False) else torch.float16):
i2t_match, t2i_match = None, None
if "qformer" in config.model.model_cls.lower():
i2t_match, t2i_match, i2t_sim, t2i_sim, i2t_dsl, t2i_dsl = evaluation_qformer(
model, data_loader, tokenizer, device, config
)
elif "blip" in config.model.model_cls.lower():
raise NotImplementedError
elif "clip" in config.model.model_cls.lower() or 'coca' in config.model.model_cls.lower():
# raise NotImplementedError
i2t_sim, t2i_sim, i2t_dsl, t2i_dsl = evaluation_clip(
model, data_loader, tokenizer, device, config
)
else:
i2t_match, t2i_match, i2t_sim, t2i_sim, i2t_dsl, t2i_dsl = evaluation(
model, data_loader, tokenizer, device, config
)
if hasattr(data_loader.dataset, "num_prompts"):
np = data_loader.dataset.num_prompts
logger.info(f"Using {np} prompts, we need reshape and mean!!!")
nt = len(data_loader.dataset.text) // np
if i2t_match is not None:
i2t_match = i2t_match.reshape((i2t_match.shape[0], nt, np)).mean(axis=-1)
t2i_match = t2i_match.reshape((nt, np, t2i_match.shape[1])).mean(axis=1)
i2t_sim = i2t_sim.reshape((i2t_sim.shape[0], nt, np)).mean(axis=-1)
t2i_sim = t2i_sim.reshape((nt, np, t2i_sim.shape[1])).mean(axis=1)
i2t_dsl = i2t_dsl.reshape((i2t_dsl.shape[0], nt, np)).mean(axis=-1)
t2i_dsl = t2i_dsl.reshape((nt, np, t2i_dsl.shape[1])).mean(axis=1)
score_pairs = [
(prefix + "_sim", i2t_sim, t2i_sim),
(prefix + "_dsl", i2t_dsl, t2i_dsl),
]
if i2t_match is not None:
if config.evaluation.get('use_dsl_for_match', False):
score_pairs.append((prefix + "_match (use_dsl)", i2t_match, t2i_match))
else:
score_pairs.append((prefix + "_match", i2t_match, t2i_match))
res = dict()
for name, i2t, t2i in score_pairs:
if i2t is not None:
txt2img_ids = data_loader.dataset.txt2img
img2txt_ids = data_loader.dataset.img2txt
res[name] = itm_eval(i2t, t2i, txt2img_ids, img2txt_ids)
return res
@torch.no_grad()
def evaluation(model, data_loader, tokenizer, device, config):
model.eval()
use_dsl_for_match = config.evaluation.get('use_dsl_for_match', False)
metric_logger = MetricLogger(delimiter=" ")
header = "Evaluation:"
dtype = torch.half if config.use_half_precision else torch.float
media_type = data_loader.dataset.media_type
use_subtitle = hasattr(data_loader.dataset, "use_subtitle") and data_loader.dataset.use_subtitle
if use_subtitle:
assert media_type in ["video", "audio_video"], f"Not support media_type: {media_type}."
assert hasattr(data_loader.dataset, "subtitle") and data_loader.dataset.subtitle is not None, "You don't have subtitle to use."
logger.info(f"Start evaluation for media_type={media_type}")
assert media_type in ['audio', 'video', 'audio_video'], f"Not implement evaluation of {media_type}"
logger.info("Computing dual encoder features...")
start_time = time.time()
# this computes all features in each GPU
texts = data_loader.dataset.text
# max_txt_l of eval depends on data_cofig
max_txt_l = data_loader.dataset.max_txt_l
text_feats, text_atts = extract_text_feats(
texts, max_txt_l, tokenizer, model, device
) # (bsz, Lt, d), (bsz, Lt)
if use_subtitle:
subtitle_feats, _ = extract_text_feats(
data_loader.dataset.subtitle, max_txt_l, tokenizer, model, device
) # (bsz, Lt, d), (bsz, Lt)
subtitle_proj = model.text_proj(subtitle_feats[:, 0]).unsqueeze(1)
subtitle_feats = subtitle_feats.unsqueeze(1)
if media_type == 'video':
image_feats, pooled_image_feats = extract_vision_feats(
data_loader, model, device, config
) # (bsz, 1, #frm*Li, d) or (bsz, #frm, Li, d), (bsz, #frm, d)
logger.info("Finished vision feature extraction")
logger.info("Computing ITC scores [dot-product]")
if config.evaluation.eval_offload:
# image_feats = image_feats.to(device, non_blocking=True) image_feats will cause OOM!!!
pooled_image_feats = pooled_image_feats.to(device, non_blocking=True)
if use_subtitle:
# print(subtitle_proj.shape, pooled_image_feats.shape)
i2t_scores, t2i_scores = get_sim(
model.vs_fusion(torch.concat([subtitle_proj, model.vision_proj(pooled_image_feats)], dim=-1)), model.text_proj(text_feats[:, 0])
)
else:
i2t_scores, t2i_scores = get_sim(
model.vision_proj(pooled_image_feats), model.text_proj(text_feats[:, 0])
)
if use_dsl_for_match:
logger.info("use_dsl_for_match!!!")
old_i2t_scores, old_t2i_scores = i2t_scores, t2i_scores
i2t_scores = old_i2t_scores * old_i2t_scores.softmax(dim=0)
t2i_scores = old_i2t_scores.T * old_i2t_scores.T.softmax(dim=0)
num_medias = len(data_loader.dataset.image)
# pooled_media_feats = pooled_image_feats
if use_subtitle:
media_feats = torch.concat([subtitle_feats, image_feats], dim=-2)
if hasattr(model, "vstm_head"):
match_head = model.vstm_head
else:
match_head = None
else:
media_feats = image_feats
if hasattr(model, "itm_head"):
match_head = model.itm_head
else:
match_head = None
elif media_type == 'audio':
audio_feats, pooled_audio_feats = extract_audio_feats(
data_loader, model, device, config
)
logger.info("Finished audio feature extraction")
logger.info("Computing ITC scores [dot-product]")
if config.evaluation.eval_offload:
pooled_audio_feats = pooled_audio_feats.to(device, non_blocking=True)
i2t_scores, t2i_scores = get_sim(
model.audio_proj(pooled_audio_feats), model.text_proj(text_feats[:, 0])
)
num_medias = len(data_loader.dataset.audio)
media_feats = audio_feats
# pooled_media_feats = pooled_audio_feats
if hasattr(model, "atm_head"):
match_head = model.atm_head
else:
match_head = None
elif media_type == 'audio_video':
audio_feats, pooled_audio_feats, image_feats, pooled_image_feats = extract_audio_vision_feats(
data_loader, model, device, config
)
logger.info("Finished audio and vision feature extraction")
logger.info("Computing ITC scores [dot-product]")
if config.evaluation.eval_offload:
pooled_audio_feats = pooled_audio_feats.to(device, non_blocking=True)
pooled_image_feats = pooled_image_feats.to(device, non_blocking=True)
if use_subtitle:
i2t_scores, t2i_scores = get_sim(
model.avs_fusion(torch.concat([model.audio_proj(pooled_audio_feats), subtitle_proj, model.vision_proj(pooled_image_feats)], dim=-1)), model.text_proj(text_feats[:, 0])
)
else:
i2t_scores, t2i_scores = get_sim(
model.av_fusion(torch.concat([model.audio_proj(pooled_audio_feats), model.vision_proj(pooled_image_feats)], dim=-1)), model.text_proj(text_feats[:, 0])
)
num_medias = len(data_loader.dataset.image)
if use_subtitle:
media_feats = torch.concat([audio_feats, subtitle_feats, image_feats], dim=-2)
# pooled_media_feats = pooled_audio_feats
if hasattr(model, "avstm_head"):
match_head = model.avstm_head
else:
match_head = None
else:
media_feats = torch.concat([audio_feats, image_feats], dim=-2)
# pooled_media_feats = pooled_audio_feats
if hasattr(model, "avtm_head"):
match_head = model.avtm_head
else:
match_head = None
else:
raise NotImplementedError(media_type)
logger.info("Computing ITC scores [dot-product], done!")
if match_head is not None:
i2t_scores_x = torch.full((num_medias, len(texts)), -100.0).to(
device, torch.float, non_blocking=True
)
# computes only part of the scores at each GPU, gather at the end
logger.info("Rerank dual-encoder results with cross-encoder...")
num_tasks = get_world_size()
rank = get_rank()
# only uses the part associated with the raw eval set
# compute media2text #
step = num_medias // num_tasks + 1
start = rank * step
end = min(num_medias, start + step)
text_encoder = model.get_text_encoder()
iterator = metric_logger.log_every(i2t_scores[start:end], 100, header)
logger.info(f"i2t_scores.shape {i2t_scores[start:end].shape}")
# generate score for each clip, and aggregate all clip scores for a video
n_clip_per_video = (
media_feats.shape[1] if not config.deep_fusion else media_feats[0].shape[1]
)
assert not config.deep_fusion and n_clip_per_video == 1, f"Not implemented for config.deep_fusion={config.deep_fusion} n_clip_per_video={n_clip_per_video}"
logger.info(
f"n_clip_per_video={n_clip_per_video}, with eval_frame_ensemble={config.evaluation.eval_frame_ensemble}"
)
for i, sims in enumerate(iterator):
k = min(len(sims), config.evaluation.k_test)
topk_sim, topk_idx = sims.topk(k=k, dim=0)
clip_scores = []
for clip_idx in range(n_clip_per_video):
if config.deep_fusion:
encoder_output = [
feat[start + i, clip_idx].to(device, non_blocking=True)
if config.evaluation.eval_offload
else feat[start + i, clip_idx]
for feat in media_feats
]
else:
encoder_output = (
media_feats[start + i, clip_idx].to(device, non_blocking=True)
if config.evaluation.eval_offload
else media_feats[start + i, clip_idx]
) # (#frm*Li, d)
# new
bs = 32
# bs = config.batch_size_test.video
itm_embeds = []
if config.deep_fusion:
if len(topk_idx) % bs != 0:
left = len(topk_idx) % bs
left_encoder_output = [feat.repeat(left, 1, 1) for feat in encoder_output]
left_encoder_att = [
torch.ones(feat.size()[:-1], dtype=torch.long).to(
device, non_blocking=True
)
for feat in left_encoder_output
]
encoder_output = [feat.repeat(bs, 1, 1) for feat in encoder_output]
encoder_att = [
torch.ones(feat.size()[:-1], dtype=torch.long).to(
device, non_blocking=True
)
for feat in encoder_output
]
else:
if len(topk_idx) % bs != 0:
left = len(topk_idx) % bs
left_encoder_output = encoder_output.repeat(left, 1, 1) # (k=128, #frm*Li, d)
left_encoder_att = torch.ones(left_encoder_output.size()[:-1], dtype=torch.long).to(
device, non_blocking=True
)
encoder_output = encoder_output.repeat(bs, 1, 1) # (k=128, #frm*Li, d)
encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to(
device, non_blocking=True
)
for j in range(0, len(topk_idx), bs):
if j + bs > len(topk_idx):
output = text_encoder(
encoder_embeds=text_feats[topk_idx[j:]],
attention_mask=text_atts[topk_idx[j:]],
encoder_hidden_states=left_encoder_output,
encoder_attention_mask=left_encoder_att,
return_dict=True,
mode="fusion",
)
else:
output = text_encoder(
encoder_embeds=text_feats[topk_idx[j : j + bs]],
attention_mask=text_atts[topk_idx[j : j + bs]],
encoder_hidden_states=encoder_output,
encoder_attention_mask=encoder_att,
return_dict=True,
mode="fusion",
)
batch_itm_embeds = output.last_hidden_state[:, 0]
itm_embeds.append(batch_itm_embeds)
itm_embeds = torch.cat(itm_embeds, dim=0)
# end new
score = match_head(itm_embeds)[:, 1]
clip_scores.append(score)
if len(clip_scores) == 1:
score = clip_scores[0]
else:
raise NotImplementedError(f"len(clip_scores) == {len(clip_scores)}")
i2t_scores_x[start + i, topk_idx] = score.to(i2t_scores_x.dtype)
# compute text2media #
num_text = len(data_loader.dataset.text)
t2i_scores_x = torch.full((num_text, num_medias), -100.0).to(
device, torch.float, non_blocking=True
)
step = num_text // num_tasks + 1
start = rank * step
end = min(num_text, start + step)
iterator = metric_logger.log_every(t2i_scores[start:end], 100, header)
logger.info(f"t2i_scores.shape {t2i_scores[start:end].shape}")
# generate score for each clip, and aggregate all clip scores for a video
n_clip_per_video = (
media_feats.shape[1] if not config.deep_fusion else media_feats[0].shape[1]
)
for i, sims in enumerate(iterator):
k = min(len(sims), config.evaluation.k_test)
topk_sim, topk_idx = sims.topk(k=k, dim=0)
clip_scores = []
for clip_idx in range(n_clip_per_video):
# new
bs = 32
# bs = config.batch_size_test.video
itm_embeds = []
for j in range(0, len(topk_idx), bs):
if config.deep_fusion:
encoder_output = [
feat[topk_idx[j : j + bs].cpu(), clip_idx].to(device, non_blocking=True)
if config.evaluation.eval_offload
else feat[topk_idx[j : j + bs], clip_idx]
for feat in media_feats
]
encoder_att = [
torch.ones(feat.size()[:-1], dtype=torch.long).to(
device, non_blocking=True
)
for feat in encoder_output
]
else:
encoder_output = (
media_feats[topk_idx[j : j + bs].cpu(), clip_idx].to(
device, non_blocking=True
)
if config.evaluation.eval_offload
else media_feats[topk_idx[j : j + bs], clip_idx]
)
encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to(
device, non_blocking=True
)
repeat_n = (
encoder_output.shape[0]
if not config.deep_fusion
else encoder_output[0].shape[0]
)
output = text_encoder(
encoder_embeds=text_feats[start + i].repeat(repeat_n, 1, 1),
attention_mask=text_atts[start + i].repeat(repeat_n, 1),
encoder_hidden_states=encoder_output,
encoder_attention_mask=encoder_att,
return_dict=True,
mode="fusion",
)
batch_itm_embeds = output.last_hidden_state[:, 0]
itm_embeds.append(batch_itm_embeds)
itm_embeds = torch.cat(itm_embeds, dim=0)
# end new
score = match_head(itm_embeds)[:, 1]
clip_scores.append(score)
if len(clip_scores) == 1:
score = clip_scores[0]
else:
raise NotImplementedError(f"len(clip_scores) == {len(clip_scores)}")
t2i_scores_x[start + i, topk_idx] = score.to(t2i_scores_x.dtype)
logger.info("Compute over!!!")
if config.distributed:
logger.info("Gather across GPUs!!!")
# gather across GPUs
dist.barrier()
logger.info("dist.barrier()!!!")
dist.all_reduce(i2t_scores_x, op=dist.ReduceOp.SUM)
logger.info("dist.all_reduce(i2t_scores_x, op=dist.ReduceOp.SUM) over!!!")
dist.all_reduce(t2i_scores_x, op=dist.ReduceOp.SUM)
logger.info("dist.all_reduce(t2i_scores_x, op=dist.ReduceOp.SUM) over!!!")
if use_dsl_for_match:
i2t_scores_dsl = i2t_scores
i2t_scores_dsl_T = t2i_scores
i2t_scores = old_i2t_scores
t2i_scores = old_t2i_scores
else:
i2t_scores_dsl = i2t_scores.float() * i2t_scores.float().softmax(dim=0)
i2t_scores_dsl_T = i2t_scores.float().T * i2t_scores.float().T.softmax(dim=0)
else:
i2t_scores_dsl = i2t_scores.float() * i2t_scores.float().softmax(dim=0)
i2t_scores_dsl_T = i2t_scores.float().T * i2t_scores.float().T.softmax(dim=0)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
logger.info(f"Evaluation time {total_time_str}")
if match_head is not None:
return (
i2t_scores_x.softmax(dim=1).cpu().float().numpy(),
t2i_scores_x.softmax(dim=1).cpu().float().numpy(),
i2t_scores.softmax(dim=1).cpu().float().numpy(),
i2t_scores.T.softmax(dim=1).cpu().float().numpy(),
i2t_scores_dsl.softmax(dim=1).cpu().float().numpy(),
i2t_scores_dsl_T.softmax(dim=1).cpu().float().numpy()
)
else:
return (
None,
None,
i2t_scores.softmax(dim=1).cpu().float().numpy(),
i2t_scores.T.softmax(dim=1).cpu().float().numpy(),
i2t_scores_dsl.softmax(dim=1).cpu().float().numpy(),
i2t_scores_dsl_T.softmax(dim=1).cpu().float().numpy()
)
@torch.no_grad()
def evaluation_simple(model, data_loader, tokenizer, device, config):
model.eval()
metric_logger = MetricLogger(delimiter=" ")
header = "Evaluation:"
media_type = data_loader.dataset.media_type
logger.info(f"Start evaluation for media_type={media_type}")
assert media_type in ['video'], f"Not implement evaluation of {media_type}"
logger.info("Computing dual encoder features...")
start_time = time.time()
# this computes all features in each GPU
texts = data_loader.dataset.text
# max_txt_l of eval depends on data_cofig
max_txt_l = data_loader.dataset.max_txt_l
text_feats, text_atts = extract_text_feats(
texts, max_txt_l, tokenizer, model, device
) # (bsz, Lt, d), (bsz, Lt)
if media_type == 'video':
image_feats, pooled_image_feats = extract_vision_feats(
data_loader, model, device, config
) # (bsz, 1, #frm*Li, d) or (bsz, #frm, Li, d), (bsz, #frm, d)
logger.info("Finished vision feature extraction")
logger.info("Computing ITC scores [dot-product]")
if config.evaluation.eval_offload:
# image_feats = image_feats.to(device, non_blocking=True) image_feats will cause OOM!!!
pooled_image_feats = pooled_image_feats.to(device, non_blocking=True)
i2t_scores, t2i_scores = get_sim(
model.vision_proj(pooled_image_feats), model.text_proj(text_feats[:, 0])
)
num_medias = len(data_loader.dataset.image)
media_feats = image_feats
if hasattr(model, "itm_head"):
match_head = model.itm_head
else:
match_head = None
else:
raise NotImplementedError(media_type)
logger.info("Computing ITC scores [dot-product], done!")
if match_head is not None:
i2t_scores_x = torch.full((num_medias, len(texts)), -100.0).to(
device, torch.float, non_blocking=True
)
# computes only part of the scores at each GPU, gather at the end
logger.info("Rerank dual-encoder results with cross-encoder...")
num_tasks = get_world_size()
rank = get_rank()
# only uses the part associated with the raw eval set
# compute media2text #
step = num_medias // num_tasks + 1
start = rank * step
end = min(num_medias, start + step)
text_encoder = model.get_text_encoder()
iterator = metric_logger.log_every(i2t_scores[start:end], 100, header)
logger.info(f"i2t_scores.shape {i2t_scores[start:end].shape}")
# generate score for each clip, and aggregate all clip scores for a video
n_clip_per_video = (
media_feats.shape[1] if not config.deep_fusion else media_feats[0].shape[1]
)
assert not config.deep_fusion and n_clip_per_video == 1, f"Not implemented for config.deep_fusion={config.deep_fusion} n_clip_per_video={n_clip_per_video}"
logger.info(
f"n_clip_per_video={n_clip_per_video}, with eval_frame_ensemble={config.evaluation.eval_frame_ensemble}"
)
for i, sims in enumerate(iterator):
k = min(len(sims), config.evaluation.k_test)
topk_sim, topk_idx = sims.topk(k=k, dim=0)
clip_scores = []
for clip_idx in range(n_clip_per_video):
if config.deep_fusion:
encoder_output = [
feat[start + i, clip_idx].to(device, non_blocking=True)
if config.evaluation.eval_offload
else feat[start + i, clip_idx]
for feat in media_feats
]
else:
encoder_output = (
media_feats[start + i, clip_idx].to(device, non_blocking=True)
if config.evaluation.eval_offload
else media_feats[start + i, clip_idx]
) # (#frm*Li, d)
# new
bs = 32
# bs = config.batch_size_test.video
itm_embeds = []
if config.deep_fusion:
if len(topk_idx) % bs != 0:
left = len(topk_idx) % bs
left_encoder_output = [feat.repeat(left, 1, 1) for feat in encoder_output]
left_encoder_att = [
torch.ones(feat.size()[:-1], dtype=torch.long).to(
device, non_blocking=True
)
for feat in left_encoder_output
]
encoder_output = [feat.repeat(bs, 1, 1) for feat in encoder_output]
encoder_att = [
torch.ones(feat.size()[:-1], dtype=torch.long).to(
device, non_blocking=True
)
for feat in encoder_output
]
else:
if len(topk_idx) % bs != 0:
left = len(topk_idx) % bs
left_encoder_output = encoder_output.repeat(left, 1, 1) # (k=128, #frm*Li, d)
left_encoder_att = torch.ones(left_encoder_output.size()[:-1], dtype=torch.long).to(
device, non_blocking=True
)
encoder_output = encoder_output.repeat(bs, 1, 1) # (k=128, #frm*Li, d)
encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to(
device, non_blocking=True
)
for j in range(0, len(topk_idx), bs):
if j + bs > len(topk_idx):
output = text_encoder(
encoder_embeds=text_feats[topk_idx[j:]],
attention_mask=text_atts[topk_idx[j:]],
encoder_hidden_states=left_encoder_output,
encoder_attention_mask=left_encoder_att,
return_dict=True,
mode="fusion",
)
else:
output = text_encoder(
encoder_embeds=text_feats[topk_idx[j : j + bs]],
attention_mask=text_atts[topk_idx[j : j + bs]],
encoder_hidden_states=encoder_output,
encoder_attention_mask=encoder_att,
return_dict=True,
mode="fusion",
)
batch_itm_embeds = output.last_hidden_state[:, 0]
itm_embeds.append(batch_itm_embeds)
itm_embeds = torch.cat(itm_embeds, dim=0)
# end new
score = match_head(itm_embeds)[:, 1]
clip_scores.append(score)
if len(clip_scores) == 1:
score = clip_scores[0]
else:
raise NotImplementedError(f"len(clip_scores) == {len(clip_scores)}")
i2t_scores_x[start + i, topk_idx] = score.to(i2t_scores_x.dtype)
# compute text2media #
num_text = len(data_loader.dataset.text)
t2i_scores_x = torch.full((num_text, num_medias), -100.0).to(
device, torch.float, non_blocking=True
)
step = num_text // num_tasks + 1
start = rank * step
end = min(num_text, start + step)
iterator = metric_logger.log_every(t2i_scores[start:end], 100, header)
logger.info(f"t2i_scores.shape {t2i_scores[start:end].shape}")
# generate score for each clip, and aggregate all clip scores for a video
n_clip_per_video = (
media_feats.shape[1] if not config.deep_fusion else media_feats[0].shape[1]
)
for i, sims in enumerate(iterator):
k = min(len(sims), config.evaluation.k_test)
topk_sim, topk_idx = sims.topk(k=k, dim=0)
clip_scores = []
for clip_idx in range(n_clip_per_video):
# new
bs = 32
# bs = config.batch_size_test.video
itm_embeds = []
for j in range(0, len(topk_idx), bs):
if config.deep_fusion:
encoder_output = [
feat[topk_idx[j : j + bs].cpu(), clip_idx].to(device, non_blocking=True)
if config.evaluation.eval_offload
else feat[topk_idx[j : j + bs], clip_idx]
for feat in media_feats
]
encoder_att = [
torch.ones(feat.size()[:-1], dtype=torch.long).to(
device, non_blocking=True
)
for feat in encoder_output
]
else:
encoder_output = (
media_feats[topk_idx[j : j + bs].cpu(), clip_idx].to(
device, non_blocking=True
)
if config.evaluation.eval_offload
else media_feats[topk_idx[j : j + bs], clip_idx]
)
encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to(
device, non_blocking=True
)
repeat_n = (
encoder_output.shape[0]
if not config.deep_fusion
else encoder_output[0].shape[0]
)
output = text_encoder(
encoder_embeds=text_feats[start + i].repeat(repeat_n, 1, 1),
attention_mask=text_atts[start + i].repeat(repeat_n, 1),
encoder_hidden_states=encoder_output,
encoder_attention_mask=encoder_att,
return_dict=True,
mode="fusion",
)
batch_itm_embeds = output.last_hidden_state[:, 0]
itm_embeds.append(batch_itm_embeds)
itm_embeds = torch.cat(itm_embeds, dim=0)
# end new
score = match_head(itm_embeds)[:, 1]
clip_scores.append(score)
if len(clip_scores) == 1:
score = clip_scores[0]
else:
raise NotImplementedError(f"len(clip_scores) == {len(clip_scores)}")
t2i_scores_x[start + i, topk_idx] = score.to(t2i_scores_x.dtype)
logger.info("Compute over!!!")
if config.distributed:
logger.info("Gather across GPUs!!!")
# gather across GPUs
dist.barrier()
logger.info("dist.barrier()!!!")
dist.all_reduce(i2t_scores_x, op=dist.ReduceOp.SUM)
logger.info("dist.all_reduce(i2t_scores_x, op=dist.ReduceOp.SUM) over!!!")
dist.all_reduce(t2i_scores_x, op=dist.ReduceOp.SUM)
logger.info("dist.all_reduce(t2i_scores_x, op=dist.ReduceOp.SUM) over!!!")
i2t_scores_dsl = i2t_scores.float() * i2t_scores.float().softmax(dim=0)
i2t_scores_dsl_T = i2t_scores.float().T * i2t_scores.float().T.softmax(dim=0)
else:
i2t_scores_dsl = i2t_scores.float() * i2t_scores.float().softmax(dim=0)
i2t_scores_dsl_T = i2t_scores.float().T * i2t_scores.float().T.softmax(dim=0)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
logger.info(f"Evaluation time {total_time_str}")
if match_head is not None:
return (
i2t_scores_x.softmax(dim=1).cpu().float().numpy(),
t2i_scores_x.softmax(dim=1).cpu().float().numpy(),
i2t_scores.softmax(dim=1).cpu().float().numpy(),
i2t_scores.T.softmax(dim=1).cpu().float().numpy(),
i2t_scores_dsl.softmax(dim=1).cpu().float().numpy(),
i2t_scores_dsl_T.softmax(dim=1).cpu().float().numpy()
)
else:
return (
None,
None,
i2t_scores.softmax(dim=1).cpu().float().numpy(),
i2t_scores.T.softmax(dim=1).cpu().float().numpy(),
i2t_scores_dsl.softmax(dim=1).cpu().float().numpy(),
i2t_scores_dsl_T.softmax(dim=1).cpu().float().numpy()
)
@torch.no_grad()
def evaluation_qformer(model, data_loader, tokenizer, device, config):
model.eval()
metric_logger = MetricLogger(delimiter=" ")
header = "Evaluation:"
dtype = torch.half if config.use_half_precision else torch.float
media_type = data_loader.dataset.media_type
logger.info(f"Start evaluation_qformer for media_type={media_type}")
assert media_type == 'video', f"Not implement evaluation of {media_type}"
logger.info("Computing dual encoder features...")
start_time = time.time()
# this computes all features in each GPU
texts = data_loader.dataset.text
# max_txt_l of eval depends on data_cofig
max_txt_l = data_loader.dataset.max_txt_l
text_feats, text_atts, text_ids = extract_text_feats(
texts, max_txt_l, tokenizer, model, device, return_ids=True
) # (bsz, Lt, d), (bsz, Lt)
if media_type == 'video':
image_feats, pooled_image_feats = extract_vision_feats(
data_loader, model, device, config
) # (bsz, 1, #frm*Li, d) or (bsz, #frm, Li, d), (bsz, #frm, d)
logger.info("Finished vision feature extraction")
logger.info("Computing ITC scores [dot-product]")
if config.evaluation.eval_offload:
# image_feats = image_feats.to(device, non_blocking=True) image_feats will cause OOM!!!
pooled_image_feats = pooled_image_feats.to(device, non_blocking=True)
if hasattr(model, "q_vision_proj"):
i2t_scores, t2i_scores = get_sim(
model.q_vision_proj(pooled_image_feats), model.q_text_proj(text_feats[:, 0])
)
else:
i2t_scores, t2i_scores = get_sim(
model.vision_proj(pooled_image_feats), model.text_proj(text_feats[:, 0])
)
num_medias = len(data_loader.dataset.image)
media_feats = image_feats
if hasattr(model, "itm_head"):
match_head = model.itm_head
elif hasattr(model, "q_itm_head"):
match_head = model.q_itm_head
else:
raise NotImplementedError("you must have a match head in qformer!!!")
logger.info("Computing ITC scores [dot-product], done!")
if match_head is not None:
i2t_scores_x = torch.full((num_medias, len(texts)), -100.0).to(
device, torch.float, non_blocking=True
)
# computes only part of the scores at each GPU, gather at the end
logger.info("Rerank dual-encoder results with cross-encoder...")
num_tasks = get_world_size()
rank = get_rank()
# only uses the part associated with the raw eval set
# compute image2text #
step = num_medias // num_tasks + 1
start = rank * step
end = min(num_medias, start + step)
iterator = metric_logger.log_every(i2t_scores[start:end], 100, header)
logger.info(f"i2t_scores.shape {i2t_scores[start:end].shape}")
# generate score for each clip, and aggregate all clip scores for a video
n_clip_per_video = (
image_feats.shape[1] if not config.deep_fusion else image_feats[0].shape[1]
)
assert not config.deep_fusion and n_clip_per_video == 1, f"Not implemented for config.deep_fusion={config.deep_fusion} n_clip_per_video={n_clip_per_video}"
logger.info(
f"n_clip_per_video={n_clip_per_video}, with eval_frame_ensemble={config.evaluation.eval_frame_ensemble}"
)
for i, sims in enumerate(iterator):
k = min(len(sims), config.evaluation.k_test)
topk_sim, topk_idx = sims.topk(k=k, dim=0)
clip_scores = []
for clip_idx in range(n_clip_per_video):
if config.deep_fusion:
encoder_output = [
feat[start + i, clip_idx].to(device, non_blocking=True)
if config.evaluation.eval_offload
else feat[start + i, clip_idx]
for feat in media_feats
]
else:
encoder_output = (
image_feats[start + i, clip_idx].to(device, non_blocking=True)
if config.evaluation.eval_offload
else image_feats[start + i, clip_idx]
) # (#frm*Li, d)
# new
bs = 32
# bs = config.batch_size_test.video
itm_embeds = []
if not config.deep_fusion: # Create fake list
encoder_output = [encoder_output]
encoder_output = [feat.repeat(bs, 1, 1) for feat in encoder_output]
encoder_att = [
torch.ones(feat.size()[:-1], dtype=torch.long).to(device, non_blocking=True)
for feat in encoder_output
]
for j in range(0, len(topk_idx), bs):
cur_bs = min(bs, len(topk_idx) - j)
encoder_output = [feat[:cur_bs] for feat in encoder_output]
encoder_att = [att[:cur_bs] for att in encoder_att]
batch_encoder_output = encoder_output if config.deep_fusion else encoder_output[0]
batch_encoder_att = encoder_att if config.deep_fusion else encoder_att[0]
output = model.vtm_embed(
text_ids=text_ids[topk_idx[j:j+bs]],
text_atts=text_atts[topk_idx[j:j+bs]],
vision_embeds=batch_encoder_output,
vision_atts=batch_encoder_att,
)
itm_embeds.append(output)
itm_embeds = torch.cat(itm_embeds, dim=0)
score = match_head(itm_embeds)[:, 1]
clip_scores.append(score)
if len(clip_scores) == 1:
score = clip_scores[0]
else:
raise NotImplementedError(f"len(clip_scores) == {len(clip_scores)}")
i2t_scores_x[start + i, topk_idx] = score.to(i2t_scores_x.dtype)
# compute text2image #
num_text = len(data_loader.dataset.text)
t2i_scores_x = torch.full((num_text, len(data_loader.dataset.image)), -100.0).to(
device, torch.float, non_blocking=True
)
step = num_text // num_tasks + 1
start = rank * step
end = min(num_text, start + step)
iterator = metric_logger.log_every(t2i_scores[start:end], 100, header)
logger.info(f"t2i_scores.shape {t2i_scores[start:end].shape}")
# generate score for each clip, and aggregate all clip scores for a video
n_clip_per_video = (
image_feats.shape[1] if not config.deep_fusion else image_feats[0].shape[1]
)
k = config.evaluation.k_test
logger.info(f"Top-{k} matching")
for i, sims in enumerate(iterator):
k = min(len(sims), config.evaluation.k_test)
topk_sim, topk_idx = sims.topk(k=k, dim=0)
clip_scores = []
for clip_idx in range(n_clip_per_video):
# new
bs = 32
# bs = config.batch_size_test.video
itm_embeds = []
for j in range(0, len(topk_idx), bs):
fake_image_feats = [image_feats] if not config.deep_fusion else image_feats
encoder_output = [
feat[topk_idx[j : j + bs].cpu(), clip_idx].to(device, non_blocking=True)
if config.evaluation.eval_offload
else feat[topk_idx[j : j + bs], clip_idx]
for feat in fake_image_feats
]
encoder_att = [
torch.ones(feat.size()[:-1], dtype=torch.long).to(
device, non_blocking=True
)
for feat in encoder_output
]
cur_bs = min(bs, len(topk_idx) - j)
batch_encoder_output = encoder_output if config.deep_fusion else encoder_output[0]
batch_encoder_att = encoder_att if config.deep_fusion else encoder_att[0]
output = model.vtm_embed(
text_ids=text_ids[start + i].repeat(cur_bs, 1),
text_atts=text_atts[start + i].repeat(cur_bs, 1),
vision_embeds=batch_encoder_output,
vision_atts=batch_encoder_att,
)
itm_embeds.append(output)
itm_embeds = torch.cat(itm_embeds, dim=0)
# end new
score = match_head(itm_embeds)[:, 1]
clip_scores.append(score)
if len(clip_scores) == 1:
score = clip_scores[0]
else:
raise NotImplementedError(f"len(clip_scores) == {len(clip_scores)}")
t2i_scores_x[start + i, topk_idx] = score.to(t2i_scores_x.dtype)
logger.info("Compute over!!!")
if config.distributed:
logger.info("Gather across GPUs!!!")
# gather across GPUs
dist.barrier()
logger.info("dist.barrier()!!!")
dist.all_reduce(i2t_scores_x, op=dist.ReduceOp.SUM)
logger.info("dist.all_reduce(i2t_scores_x, op=dist.ReduceOp.SUM) over!!!")
dist.all_reduce(t2i_scores_x, op=dist.ReduceOp.SUM)
logger.info("dist.all_reduce(t2i_scores_x, op=dist.ReduceOp.SUM) over!!!")
i2t_scores_dsl = i2t_scores.float() * i2t_scores.float().softmax(dim=0)
i2t_scores_dsl_T = i2t_scores.float().T * i2t_scores.float().T.softmax(dim=0)
else:
i2t_scores_dsl = i2t_scores.float() * i2t_scores.float().softmax(dim=0)
i2t_scores_dsl_T = i2t_scores.float().T * i2t_scores.float().T.softmax(dim=0)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
logger.info(f"Evaluation time {total_time_str}")
i2t_scores_dsl = i2t_scores * i2t_scores.softmax(dim=0)
i2t_scores_dsl_T = i2t_scores.T * i2t_scores.T.softmax(dim=0)
if match_head is not None:
return (
i2t_scores_x.softmax(dim=1).cpu().float().numpy(),
t2i_scores_x.softmax(dim=1).cpu().float().numpy(),
i2t_scores.softmax(dim=1).cpu().float().numpy(),
i2t_scores.T.softmax(dim=1).cpu().float().numpy(),
i2t_scores_dsl.softmax(dim=1).cpu().float().numpy(),
i2t_scores_dsl_T.softmax(dim=1).cpu().float().numpy()
)
else:
return (
None,
None,
i2t_scores.softmax(dim=1).cpu().float().numpy(),
i2t_scores.T.softmax(dim=1).cpu().float().numpy(),
i2t_scores_dsl.softmax(dim=1).cpu().float().numpy(),
i2t_scores_dsl_T.softmax(dim=1).cpu().float().numpy()
)
@torch.no_grad()
def evaluation_clip(model, data_loader, tokenizer, device, config):
model.eval()
metric_logger = MetricLogger(delimiter=" ")
header = "Evaluation:"
dtype = torch.half if config.use_half_precision else torch.float
media_type = data_loader.dataset.media_type
logger.info(f"Start evaluation_clip for media_type={media_type}")
logger.info("Computing dual encoder features...")
# this computes all features in each GPU
texts = data_loader.dataset.text
num_text = len(texts)
text_bs = 256
text_feats = []
for i in range(0, num_text, text_bs):
text = texts[i : min(num_text, i + text_bs)]
if "internvideo2" in config.model.model_cls.lower():
text_feat = model.encode_text(tokenizer(text).to(device))
else:
raise NotImplementedError
text_feat = model.encode_text(text)
text_feats.append(text_feat.cpu())
text_feats = torch.cat(text_feats, dim=0)
logger.info("Finished computing text features")
media_feats = []
metric_logger = MetricLogger(delimiter=" ")
header = f"extracting {media_type} feats!!!"
iterator = metric_logger.log_every(data_loader, 100, header)
for media, _ in iterator:
if media_type in ['image', 'video']:
media = media.to(device, non_blocking=True)
media_feat = model.encode_vision(media, test=True)
elif media_type == 'audio':
media = media.to(device, non_blocking=True)
media_feat = model.encode_audio(media, test=True)
elif media_type == 'audio_video':
raise NotImplementedError(f"Not implement media_type:{media_type}")
else:
raise NotImplementedError(f"Not implement media_type:{media_type}")
media_feats.append(media_feat.cpu())
media_feats = torch.cat(media_feats, dim=0)
logger.info("Finished feature extraction")
logger.info("Computing ITC scores [dot-product]")
# print(media_feats.dtype, text_feats.dtype)
# print(media_feats.device, text_feats.device)
i2t_scores, t2i_scores = get_sim(media_feats.float(), text_feats.float())
del media_feats, text_feats
logger.info("Computing ITC scores [dot-product], done!")
i2t_scores_dsl = i2t_scores * i2t_scores.softmax(dim=0)
i2t_scores_dsl_T = i2t_scores.T * i2t_scores.T.softmax(dim=0)
return (
i2t_scores.cpu().float().numpy(),
i2t_scores.T.cpu().float().numpy(),
i2t_scores_dsl.cpu().float().numpy(),
i2t_scores_dsl_T.cpu().float().numpy(),
)
@torch.no_grad()
def itm_eval(scores_i2t, scores_t2i, txt2img, img2txt):
# Images->Text
ranks = np.zeros(scores_i2t.shape[0])
for index, score in enumerate(scores_i2t):
inds = np.argsort(score)[::-1]
# Score
gt_txt_ids = img2txt[index]
if isinstance(gt_txt_ids, int):
ranks[index] = np.where(inds == gt_txt_ids)[0][0]
else:
rank = 1e20
for i in gt_txt_ids:
tmp = np.where(inds == i)[0][0]
if tmp < rank:
rank = tmp
ranks[index] = rank
# Compute metrics
tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
# Text->Images
ranks = np.zeros(scores_t2i.shape[0])
for index, score in enumerate(scores_t2i):
inds = np.argsort(score)[::-1]
gt_img_ids = txt2img[index]
if isinstance(gt_img_ids, int):
ranks[index] = np.where(inds == gt_img_ids)[0][0]
else: # list, used in the case each caption has multiple GT images
# Score
rank = 1e20
for i in gt_img_ids:
tmp = np.where(inds == i)[0][0]
if tmp < rank:
rank = tmp
ranks[index] = rank
# Compute metrics
ir1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
ir5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
ir10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
tr_mean = (tr1 + tr5 + tr10) / 3
ir_mean = (ir1 + ir5 + ir10) / 3
r_mean = (tr_mean + ir_mean) / 2
eval_result = {
"v2t_r1": tr1,
"v2t_r5": tr5,
"v2t_r10": tr10,
"v2t_r_mean": tr_mean,
"t2v_r1": ir1,
"t2v_r5": ir5,
"t2v_r10": ir10,
"t2v_r_mean": ir_mean,
"r_mean": r_mean,
}
eval_result = {k: round(v, 2) for k, v in eval_result.items()}
return eval_result