Spaces:
Runtime error
Runtime error
import pdb | |
import pprint | |
from tqdm import tqdm, trange | |
import numpy as np | |
import os | |
from collections import OrderedDict, defaultdict | |
from utils.basic_utils import AverageMeter | |
import torch | |
import torch.nn.functional as F | |
import torch.backends.cudnn as cudnn | |
from torch.utils.data import DataLoader | |
from main.config import TestOptions, setup_model | |
from main.dataset import DatasetMR, start_end_collate_mr, prepare_batch_inputs_mr | |
from eval.eval import eval_submission | |
from eval.postprocessing import PostProcessorDETR | |
from utils.basic_utils import save_jsonl, save_json | |
from utils.temporal_nms import temporal_nms | |
from utils.span_utils import span_cxw_to_xx | |
import logging | |
import importlib | |
logger = logging.getLogger(__name__) | |
logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s", | |
datefmt="%Y-%m-%d %H:%M:%S", | |
level=logging.INFO) | |
def post_processing_mr_nms(mr_res, nms_thd, max_before_nms, max_after_nms): | |
mr_res_after_nms = [] | |
for e in mr_res: | |
e["pred_relevant_windows"] = temporal_nms( | |
e["pred_relevant_windows"][:max_before_nms], | |
nms_thd=nms_thd, | |
max_after_nms=max_after_nms | |
) | |
mr_res_after_nms.append(e) | |
return mr_res_after_nms | |
def eval_epoch_post_processing(submission, opt, gt_data, save_submission_filename): | |
# IOU_THDS = (0.5, 0.7) | |
logger.info("Saving/Evaluating before nms results") | |
submission_path = os.path.join(opt.results_dir, save_submission_filename) | |
save_jsonl(submission, submission_path) | |
if opt.eval_split_name in ["val", "test"]: # since test_public has no GT | |
metrics = eval_submission( | |
submission, gt_data, | |
verbose=opt.debug, match_number=not opt.debug, | |
) | |
save_metrics_path = submission_path.replace(".jsonl", "_metrics.json") | |
save_json(metrics, save_metrics_path, save_pretty=True, sort_keys=False) | |
latest_file_paths = [submission_path, save_metrics_path] | |
else: | |
metrics = None | |
latest_file_paths = [submission_path, ] | |
if opt.nms_thd != -1: | |
logger.info("[MR] Performing nms with nms_thd {}".format(opt.nms_thd)) | |
submission_after_nms = post_processing_mr_nms( | |
submission, nms_thd=opt.nms_thd, | |
max_before_nms=opt.max_before_nms, max_after_nms=opt.max_after_nms | |
) | |
logger.info("Saving/Evaluating nms results") | |
submission_nms_path = submission_path.replace(".jsonl", "_nms_thd_{}.jsonl".format(opt.nms_thd)) | |
save_jsonl(submission_after_nms, submission_nms_path) | |
if opt.eval_split_name == "val": | |
metrics_nms = eval_submission( | |
submission_after_nms, gt_data, | |
verbose=opt.debug, match_number=not opt.debug | |
) | |
save_metrics_nms_path = submission_nms_path.replace(".jsonl", "_metrics.json") | |
save_json(metrics_nms, save_metrics_nms_path, save_pretty=True, sort_keys=False) | |
latest_file_paths += [submission_nms_path, save_metrics_nms_path] | |
else: | |
metrics_nms = None | |
latest_file_paths = [submission_nms_path, ] | |
else: | |
metrics_nms = None | |
return metrics, metrics_nms, latest_file_paths | |
def compute_mr_results(model, eval_loader, opt, epoch_i=None, criterion=None, tb_writer=None): | |
model.eval() | |
if criterion: | |
assert eval_loader.dataset.load_labels | |
criterion.eval() | |
loss_meters = defaultdict(AverageMeter) | |
write_tb = tb_writer is not None and epoch_i is not None | |
mr_res = [] | |
for batch in tqdm(eval_loader, desc="compute st ed scores"): | |
query_meta = batch[0] | |
model_inputs, targets = prepare_batch_inputs_mr(batch[1], opt.device, non_blocking=opt.pin_memory) | |
outputs = model(**model_inputs) | |
prob = outputs["pred_logits"] # the last channel may be 1 or 2. | |
# if opt.eval_mode == 'v1': | |
# prob = prob * outputs["saliency_scores"].unsqueeze(-1) # v1 | |
# if opt.eval_mode == 'v2': | |
# prob = F.softmax(prob, dim=1) * outputs["saliency_scores"].unsqueeze(-1) # v2 | |
# if opt.eval_mode == 'v3': | |
# prob = outputs["saliency_scores"].unsqueeze(-1) | |
if outputs["pred_logits"].shape[-1] > 1: | |
prob = F.softmax(outputs["pred_logits"], -1) # (batch_size, #queries, #classes=2) | |
if opt.span_loss_type == "l1": | |
scores = prob[..., 0] # * (batch_size, #queries) foreground label is 0, we directly take it | |
pred_spans = outputs["pred_spans"] # (bsz, #queries, 2) | |
if opt.model_id not in ['moment_detr']: # dense regression. | |
start_spans = targets['timestamp'] | |
pred_spans = start_spans + pred_spans | |
mask = targets['timestamp_mask'].bool() | |
scores[~mask] = 0 | |
# if opt.eval_mode == 'v4': | |
# _mask = targets['timestamp_window'].bool() | |
# scores[~_mask] = 0 | |
if opt.eval_mode == 'add': | |
# pdb.set_trace() | |
_saliency_scores = outputs["saliency_scores"].half() + prob.squeeze(-1) | |
else: | |
_saliency_scores = outputs["saliency_scores"].half() # (bsz, L) | |
if opt.eval_mode == 'add_mr': | |
prob = outputs["saliency_scores"].half().unsqueeze(-1) + prob | |
saliency_scores = [] | |
valid_vid_lengths = model_inputs["src_vid_mask"].sum(1).cpu().tolist() | |
for j in range(len(valid_vid_lengths)): | |
saliency_scores.append(_saliency_scores[j, :int(valid_vid_lengths[j])].tolist()) | |
else: | |
bsz, n_queries = outputs["pred_spans"].shape[:2] # # (bsz, #queries, max_v_l *2) | |
pred_spans_logits = outputs["pred_spans"].view(bsz, n_queries, 2, opt.max_v_l) | |
# TODO use more advanced decoding method with st_ed product | |
pred_span_scores, pred_spans = F.softmax(pred_spans_logits, dim=-1).max(-1) # 2 * (bsz, #queries, 2) | |
scores = torch.prod(pred_span_scores, 2) # (bsz, #queries) | |
pred_spans[:, 1] += 1 | |
pred_spans *= opt.clip_length | |
# compose predictions | |
for idx, (meta, spans, score) in enumerate(zip(query_meta, pred_spans.cpu(), scores.cpu())): | |
if opt.span_loss_type == "l1": | |
if opt.model_id in ['moment_detr']: | |
spans = span_cxw_to_xx(spans) * meta["duration"] | |
else: | |
spans = spans * meta["duration"] | |
spans = torch.clamp(spans, 0, meta["duration"]) # added by Kevin, since window cannot be longer than video duration. | |
# (#queries, 3), [st(float), ed(float), score(float)] | |
cur_ranked_preds = torch.cat([spans, score[:, None]], dim=1).tolist() | |
if not opt.no_sort_results: | |
cur_ranked_preds = sorted(cur_ranked_preds, key=lambda x: x[2], reverse=True) | |
cur_ranked_preds = [[float(f"{e:.4f}") for e in row] for row in cur_ranked_preds] | |
cur_query_pred = dict( | |
qid=meta["qid"], | |
query=meta["query"], | |
vid=meta["vid"], | |
pred_relevant_windows=cur_ranked_preds, | |
pred_saliency_scores=saliency_scores[idx] | |
) | |
mr_res.append(cur_query_pred) | |
if criterion: | |
loss_dict = criterion(outputs, targets) | |
weight_dict = criterion.weight_dict | |
losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) | |
loss_dict["loss_overall"] = float(losses) # for logging only | |
for k, v in loss_dict.items(): | |
loss_meters[k].update(float(v) * weight_dict[k] if k in weight_dict else float(v)) | |
if opt.debug: | |
break | |
if write_tb and criterion: | |
for k, v in loss_meters.items(): | |
tb_writer.add_scalar("Eval/{}".format(k), v.avg, epoch_i + 1) | |
post_processor = PostProcessorDETR( | |
clip_length=opt.clip_length, min_ts_val=0, max_ts_val=150, | |
min_w_l=2, max_w_l=150, move_window_method="left", | |
# process_func_names=("clip_ts", "round_multiple") | |
process_func_names=["round_multiple"] # have added `clamp' op on line 147, thus we do not need `clip_ts' again; | |
) | |
# todo: are we need round_multiple? | |
if opt.round_multiple > 0: | |
mr_res = post_processor(mr_res) | |
return mr_res, loss_meters | |
def get_eval_res(model, eval_loader, opt, epoch_i, criterion, tb_writer): | |
"""compute and save query and video proposal embeddings""" | |
eval_res, eval_loss_meters = compute_mr_results(model, eval_loader, opt, epoch_i, criterion, tb_writer) # list(dict) | |
return eval_res, eval_loss_meters | |
def eval_epoch(model, eval_dataset, opt, save_submission_filename, epoch_i=None, criterion=None, tb_writer=None): | |
logger.info("Generate submissions") | |
model.eval() | |
if criterion is not None and eval_dataset.load_labels: | |
criterion.eval() | |
else: | |
criterion = None | |
eval_loader = DataLoader( | |
eval_dataset, | |
collate_fn=start_end_collate_mr, | |
batch_size=opt.eval_bsz, | |
num_workers=opt.num_workers, | |
shuffle=False, | |
pin_memory=opt.pin_memory | |
) | |
submission, eval_loss_meters = get_eval_res(model, eval_loader, opt, epoch_i, criterion, tb_writer) | |
if opt.no_sort_results: | |
save_submission_filename = save_submission_filename.replace(".jsonl", "_unsorted.jsonl") | |
metrics, metrics_nms, latest_file_paths = eval_epoch_post_processing( | |
submission, opt, eval_dataset.data, save_submission_filename) | |
return metrics, metrics_nms, eval_loss_meters, latest_file_paths | |
def start_inference(): | |
logger.info("Setup config, data and model...") | |
opt = TestOptions().parse() | |
# pdb.set_trace() | |
cudnn.benchmark = True | |
cudnn.deterministic = False | |
assert opt.eval_path is not None | |
eval_dataset = DatasetMR( | |
dset_name=opt.dset_name, | |
data_path=opt.eval_path, | |
v_feat_dirs=opt.v_feat_dirs, | |
q_feat_dir=opt.t_feat_dir, | |
v_feat_dim=opt.v_feat_dim, | |
q_feat_dim=opt.t_feat_dim, | |
q_feat_type="last_hidden_state", | |
max_q_l=opt.max_q_l, | |
max_v_l=opt.max_v_l, | |
ctx_mode=opt.ctx_mode, | |
data_ratio=opt.data_ratio, | |
normalize_v=not opt.no_norm_vfeat, | |
normalize_t=not opt.no_norm_tfeat, | |
clip_len=opt.clip_length, | |
max_windows=opt.max_windows, | |
load_labels=True, # opt.eval_split_name == "val", | |
span_loss_type=opt.span_loss_type, | |
txt_drop_ratio=0, | |
use_cache=opt.use_cache, | |
) | |
if opt.lr_warmup > 0: | |
# total_steps = opt.n_epoch * len(train_dataset) // opt.bsz | |
total_steps = opt.n_epoch | |
warmup_steps = opt.lr_warmup if opt.lr_warmup > 1 else int(opt.lr_warmup * total_steps) | |
opt.lr_warmup = [warmup_steps, total_steps] | |
model, criterion, _, _ = setup_model(opt) | |
save_submission_filename = "inference_{}_{}_{}_preds.jsonl".format( | |
opt.dset_name, opt.eval_split_name, opt.eval_id) | |
logger.info("Starting inference...") | |
with torch.no_grad(): | |
metrics_no_nms, metrics_nms, eval_loss_meters, latest_file_paths = \ | |
eval_epoch(model, eval_dataset, opt, save_submission_filename, criterion=criterion) | |
logger.info("metrics_no_nms {}".format(pprint.pformat(metrics_no_nms["brief"], indent=4))) | |
if metrics_nms is not None: | |
logger.info("metrics_nms {}".format(pprint.pformat(metrics_nms["brief"], indent=4))) | |
if __name__ == '__main__': | |
start_inference() | |