Spaces:
Runtime error
Runtime error
import torch | |
from torch.utils.data import Dataset | |
import numpy as np | |
from tqdm import tqdm | |
import random | |
import logging | |
from os.path import join, exists | |
from utils.basic_utils import load_jsonl, l2_normalize_np_array | |
from utils.tensor_utils import pad_sequences_1d | |
from moment_detr.span_utils import span_xx_to_cxw | |
logger = logging.getLogger(__name__) | |
class StartEndDataset(Dataset): | |
Q_FEAT_TYPES = ["pooler_output", "last_hidden_state"] | |
"""One line in data loaded from data_path." | |
{ | |
"qid": 7803, | |
"query": "Man in gray top walks from outside to inside.", | |
"duration": 150, | |
"vid": "RoripwjYFp8_360.0_510.0", | |
"relevant_clip_ids": [13, 14, 15, 16, 17], | |
"relevant_windows": [[26, 36]] | |
} | |
""" | |
def __init__(self, dset_name, data_path, v_feat_dirs, q_feat_dir, | |
q_feat_type="last_hidden_state", | |
max_q_l=32, max_v_l=75, data_ratio=1.0, ctx_mode="video", | |
normalize_v=True, normalize_t=True, load_labels=True, | |
clip_len=2, max_windows=5, span_loss_type="l1", txt_drop_ratio=0): | |
self.dset_name = dset_name | |
self.data_path = data_path | |
self.data_ratio = data_ratio | |
self.v_feat_dirs = v_feat_dirs \ | |
if isinstance(v_feat_dirs, list) else [v_feat_dirs] | |
self.q_feat_dir = q_feat_dir | |
self.q_feat_type = q_feat_type | |
self.max_q_l = max_q_l | |
self.max_v_l = max_v_l | |
self.ctx_mode = ctx_mode | |
self.use_tef = "tef" in ctx_mode | |
self.use_video = "video" in ctx_mode | |
self.normalize_t = normalize_t | |
self.normalize_v = normalize_v | |
self.load_labels = load_labels | |
self.clip_len = clip_len | |
self.max_windows = max_windows # maximum number of windows to use as labels | |
self.span_loss_type = span_loss_type | |
self.txt_drop_ratio = txt_drop_ratio | |
if "val" in data_path or "test" in data_path: | |
assert txt_drop_ratio == 0 | |
# checks | |
assert q_feat_type in self.Q_FEAT_TYPES | |
# data | |
self.data = self.load_data() | |
def load_data(self): | |
datalist = load_jsonl(self.data_path) | |
if self.data_ratio != 1: | |
n_examples = int(len(datalist) * self.data_ratio) | |
datalist = datalist[:n_examples] | |
logger.info("Using {}% of the data: {} examples" | |
.format(self.data_ratio * 100, n_examples)) | |
return datalist | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, index): | |
meta = self.data[index] | |
model_inputs = dict() | |
model_inputs["query_feat"] = self._get_query_feat_by_qid(meta["qid"]) # (Dq, ) or (Lq, Dq) | |
if self.use_video: | |
model_inputs["video_feat"] = self._get_video_feat_by_vid(meta["vid"]) # (Lv, Dv) | |
ctx_l = len(model_inputs["video_feat"]) | |
else: | |
ctx_l = self.max_v_l | |
if self.use_tef: | |
tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l | |
tef_ed = tef_st + 1.0 / ctx_l | |
tef = torch.stack([tef_st, tef_ed], dim=1) # (Lv, 2) | |
if self.use_video: | |
model_inputs["video_feat"] = torch.cat( | |
[model_inputs["video_feat"], tef], dim=1) # (Lv, Dv+2) | |
else: | |
model_inputs["video_feat"] = tef | |
if self.load_labels: | |
model_inputs["span_labels"] = self.get_span_labels(meta["relevant_windows"], ctx_l) # (#windows, 2) | |
if "subs_train" not in self.data_path: | |
model_inputs["saliency_pos_labels"], model_inputs["saliency_neg_labels"] = \ | |
self.get_saliency_labels(meta["relevant_clip_ids"], meta["saliency_scores"], ctx_l) | |
else: | |
model_inputs["saliency_pos_labels"], model_inputs["saliency_neg_labels"] = \ | |
self.get_saliency_labels_sub_as_query(meta["relevant_windows"][0], ctx_l) # only one gt | |
return dict(meta=meta, model_inputs=model_inputs) | |
def get_saliency_labels_sub_as_query(self, gt_window, ctx_l, max_n=2): | |
gt_st = int(gt_window[0] / self.clip_len) | |
gt_ed = max(0, min(int(gt_window[1] / self.clip_len), ctx_l) - 1) | |
if gt_st > gt_ed: | |
gt_st = gt_ed | |
if gt_st != gt_ed: | |
pos_clip_indices = random.sample(range(gt_st, gt_ed+1), k=max_n) | |
else: | |
pos_clip_indices = [gt_st, gt_st] | |
neg_pool = list(range(0, gt_st)) + list(range(gt_ed+1, ctx_l)) | |
neg_clip_indices = random.sample(neg_pool, k=max_n) | |
return pos_clip_indices, neg_clip_indices | |
def get_saliency_labels(self, rel_clip_ids, scores, ctx_l, max_n=1, add_easy_negative=True): | |
"""Sum the scores from the three annotations, then take the two clips with the | |
maximum scores as positive, and two with the minimum scores as negative. | |
Args: | |
rel_clip_ids: list(int), list of relevant clip ids | |
scores: list([anno1_score, anno2_score, anno3_score]), | |
ctx_l: int | |
max_n: int, #clips to use as positive and negative, for easy and hard negative, respectively. | |
add_easy_negative: bool, if True, sample eay negative outside the relevant_clip_ids. | |
""" | |
# indices inside rel_clip_ids | |
scores = np.array(scores) # (#rel_clips, 3) | |
agg_scores = np.sum(scores, 1) # (#rel_clips, ) | |
sort_indices = np.argsort(agg_scores) # increasing | |
# indices in the whole video | |
# the min(_, ctx_l-1) here is incorrect, but should not cause | |
# much troubles since this should be rarely used. | |
hard_pos_clip_indices = [min(rel_clip_ids[idx], ctx_l-1) for idx in sort_indices[-max_n:]] | |
hard_neg_clip_indices = [min(rel_clip_ids[idx], ctx_l-1) for idx in sort_indices[:max_n]] | |
easy_pos_clip_indices = [] | |
easy_neg_clip_indices = [] | |
if add_easy_negative: | |
easy_neg_pool = list(set(range(ctx_l)) - set(rel_clip_ids)) | |
if len(easy_neg_pool) >= max_n: | |
easy_pos_clip_indices = random.sample(rel_clip_ids, k=max_n) | |
easy_neg_clip_indices = random.sample(easy_neg_pool, k=max_n) | |
else: # copy the hard ones | |
easy_pos_clip_indices = hard_pos_clip_indices | |
easy_neg_clip_indices = hard_neg_clip_indices | |
pos_clip_indices = hard_pos_clip_indices + easy_pos_clip_indices | |
neg_clip_indices = hard_neg_clip_indices + easy_neg_clip_indices | |
return pos_clip_indices, neg_clip_indices | |
def get_span_labels(self, windows, ctx_l): | |
""" | |
windows: list([st, ed]) in seconds. E.g. [[26, 36]], corresponding st_ed clip_indices [[13, 17]] (inclusive) | |
Note a maximum of `self.max_windows` windows are used. | |
returns Tensor of shape (#windows, 2), each row is [center, width] normalized by video length | |
""" | |
if len(windows) > self.max_windows: | |
random.shuffle(windows) | |
windows = windows[:self.max_windows] | |
if self.span_loss_type == "l1": | |
windows = torch.Tensor(windows) / (ctx_l * self.clip_len) # normalized windows in xx | |
windows = span_xx_to_cxw(windows) # normalized windows in cxw | |
elif self.span_loss_type == "ce": | |
windows = torch.Tensor([ | |
[int(w[0] / self.clip_len), min(int(w[1] / self.clip_len), ctx_l) - 1] | |
for w in windows]).long() # inclusive | |
else: | |
raise NotImplementedError | |
return windows | |
def _get_query_feat_by_qid(self, qid): | |
q_feat_path = join(self.q_feat_dir, f"qid{qid}.npz") | |
q_feat = np.load(q_feat_path)[self.q_feat_type].astype(np.float32) | |
if self.q_feat_type == "last_hidden_state": | |
q_feat = q_feat[:self.max_q_l] | |
if self.normalize_t: | |
q_feat = l2_normalize_np_array(q_feat) | |
if self.txt_drop_ratio > 0: | |
q_feat = self.random_drop_rows(q_feat) | |
return torch.from_numpy(q_feat) # (D, ) or (Lq, D) | |
def random_drop_rows(self, embeddings): | |
"""randomly mask num_drop rows in embeddings to be zero. | |
Args: | |
embeddings: np.ndarray (L, D) | |
""" | |
num_drop_rows = round(len(embeddings) * self.txt_drop_ratio) | |
if num_drop_rows > 0: | |
row_indices = np.random.choice( | |
len(embeddings), size=num_drop_rows, replace=False) | |
embeddings[row_indices] = 0 | |
return embeddings | |
def _get_video_feat_by_vid(self, vid): | |
v_feat_list = [] | |
for _feat_dir in self.v_feat_dirs: | |
_feat_path = join(_feat_dir, f"{vid}.npz") | |
_feat = np.load(_feat_path)["features"][:self.max_v_l].astype(np.float32) | |
if self.normalize_v: | |
_feat = l2_normalize_np_array(_feat) | |
v_feat_list.append(_feat) | |
# some features are slightly longer than the others | |
min_len = min([len(e) for e in v_feat_list]) | |
v_feat_list = [e[:min_len] for e in v_feat_list] | |
v_feat = np.concatenate(v_feat_list, axis=1) | |
return torch.from_numpy(v_feat) # (Lv, D) | |
def start_end_collate(batch): | |
batch_meta = [e["meta"] for e in batch] # seems no need to collate ? | |
model_inputs_keys = batch[0]["model_inputs"].keys() | |
batched_data = dict() | |
for k in model_inputs_keys: | |
if k == "span_labels": | |
batched_data[k] = [dict(spans=e["model_inputs"]["span_labels"]) for e in batch] | |
continue | |
if k in ["saliency_pos_labels", "saliency_neg_labels"]: | |
batched_data[k] = torch.LongTensor([e["model_inputs"][k] for e in batch]) | |
continue | |
batched_data[k] = pad_sequences_1d( | |
[e["model_inputs"][k] for e in batch], dtype=torch.float32, fixed_length=None) | |
return batch_meta, batched_data | |
def prepare_batch_inputs(batched_model_inputs, device, non_blocking=False): | |
model_inputs = dict( | |
src_txt=batched_model_inputs["query_feat"][0].to(device, non_blocking=non_blocking), | |
src_txt_mask=batched_model_inputs["query_feat"][1].to(device, non_blocking=non_blocking), | |
src_vid=batched_model_inputs["video_feat"][0].to(device, non_blocking=non_blocking), | |
src_vid_mask=batched_model_inputs["video_feat"][1].to(device, non_blocking=non_blocking), | |
) | |
targets = {} | |
if "span_labels" in batched_model_inputs: | |
targets["span_labels"] = [ | |
dict(spans=e["spans"].to(device, non_blocking=non_blocking)) | |
for e in batched_model_inputs["span_labels"] | |
] | |
if "saliency_pos_labels" in batched_model_inputs: | |
for name in ["saliency_pos_labels", "saliency_neg_labels"]: | |
targets[name] = batched_model_inputs[name].to(device, non_blocking=non_blocking) | |
targets = None if len(targets) == 0 else targets | |
return model_inputs, targets | |