File size: 30,738 Bytes
ebf5d87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
"""
Dataset for clip model
"""
import logging
import torch
from torch.utils.data import Dataset
import numpy as np
import h5py
import math
import random
from utils.basic_utils import load_jsonl, load_json, l2_normalize_np_array
from utils.tensor_utils import pad_sequences_1d
from baselines.clip_alignment_with_language.local_utils.proposal import get_proposal_interface
from baselines.clip_alignment_with_language.local_utils.compute_proposal_upper_bound import \
    get_didemo_agreed_ts
from standalone_eval.eval import compute_temporal_iou_batch

logger = logging.getLogger(__name__)


class ProposalRetrievalDataset(Dataset):
    """
    Args:
        dset_name, str, ["tvr"]
        ctx_mode: str,
        pos_iou_thd: float, in [0, 1], >= pos_iou_thd are defined as positive
        neg_iou_thd: float, in [0, 1], < neg_iou_thd are defined as negative
    Return:
        a dict: {
            "meta": {
                "desc_id": int,
                "desc": str,
                "vid_name": str,
                "duration": float,
                "ts": [st (float), ed (float)], seconds, ground_truth timestamps
                "pos_moment": [st (float), ed (float)], seconds, IoU with "ts" >= pos_iou_thd
                "intra_neg_moment": [st (float), ed (float)], seconds, IoU with "ts" < neg_iou_thd
                "inter_neg_vid_name": str,
                "inter_neg_duration": float,
                "inter_neg_moment": [st (float), ed (float)], seconds, IoU with "ts" < neg_iou_thd
            }
            "model_inputs": {
                "desc_feat": torch.tensor, (L, D_t)
                "pos_moment_feat": torch.tensor, (n_clip_in_moment, D)
                "intra_neg_moment_feat": torch.tensor, (n_clip_in_moment, D)
                "inter_neg_moment_feat": torch.tensor, (n_clip_in_moment, D)
            }
        }
    """
    def __init__(self, dset_name, data_path, desc_bert_path, sub_bert_path, max_desc_len,
                 vid_feat_path, clip_length, vid_feat_size, sub_feat_size=0, ctx_mode="video_tef",
                 pos_iou_thd=0.7, neg_iou_thd=0.3, h5driver=None, data_ratio=1.0,
                 normalize_vfeat=True, normalize_tfeat=True, model_type="cal",
                 external_train_vr_res_path=None, corpus_path=None):
        self.dset_name = dset_name
        self.model_type = model_type
        self.pool_local = model_type == "mcn"  # pool local feature
        self.data_path = data_path
        self.data_ratio = data_ratio

        self.desc_bert_path = desc_bert_path
        self.max_desc_len = max_desc_len
        self.sub_bert_path = sub_bert_path

        self.vid_feat_path = vid_feat_path
        self.clip_length = clip_length
        self.ctx_mode = ctx_mode

        self.pos_iou_thd = pos_iou_thd
        self.neg_iou_thd = neg_iou_thd

        self.vid_feat_output_size = 2 * vid_feat_size * ("video" in ctx_mode) + 2 * ("tef" in ctx_mode)
        self.sub_feat_output_size = 2 * sub_feat_size * ("sub" in ctx_mode) + 2 * ("tef" in ctx_mode)

        # prepare desc data
        self.data = load_jsonl(data_path)
        if self.data_ratio != 1:
            n_examples = int(len(self.data) * data_ratio)
            self.data = self.data[:n_examples]
            logger.info("Using {}% of the data: {} examples".format(data_ratio * 100, n_examples))

        self.proposal_fn = get_proposal_interface(dset_name)
        if self.ctx_mode != "tef":
            self.vid_feat_h5 = h5py.File(self.vid_feat_path, "r", driver=h5driver)
        self.desc_bert_h5 = h5py.File(self.desc_bert_path, "r", driver=h5driver)
        if "sub" in self.ctx_mode:
            self.sub_bert_h5 = h5py.File(self.sub_bert_path, "r", driver=h5driver)
        self.normalize_vfeat = normalize_vfeat
        self.normalize_tfeat = normalize_tfeat
        self.use_video = "video" in self.ctx_mode
        self.use_sub = "sub" in self.ctx_mode
        self.use_tef = "tef" in self.ctx_mode

        if external_train_vr_res_path is not None:
            video_data = load_json(corpus_path)["train"]
            # {video_idx: [vid_name, vid_duration]}
            video_idx2name_dur_pair = {v[1]: [k, v[0]] for k, v in video_data.items()}
            external_vr_res = load_json(external_train_vr_res_path)
            # {desc_id: [(vid_name, vid_duration), ...]}
            self.desc_id2video_names_dur_pairs = \
                {e["desc_id"]: [video_idx2name_dur_pair[int(sub_e[0])] for sub_e in e["predictions"]]
                 for e in external_vr_res["VR"]}  # ordered

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        raw_data = self.data[index]

        # initialize with basic data
        meta = dict(
            desc_id=raw_data["desc_id"],
            desc=raw_data["desc"],
            vid_name=raw_data["vid_name"],
            duration=raw_data["duration"],
            ts=raw_data["ts"] if self.dset_name != "didemo" else get_didemo_agreed_ts(raw_data["ts"]),
        )
        model_inputs = dict()
        query_feat = self.desc_bert_h5[str(raw_data["desc_id"])][:self.max_desc_len]
        if self.normalize_tfeat:
            query_feat = l2_normalize_np_array(query_feat)
        model_inputs["query_feat"] = torch.from_numpy(query_feat)

        # sample positive and negative moments
        meta["pos_moment"] = self.align_ts_to_clip_boundaries(meta["duration"], meta["ts"])
        meta["intra_neg_moment"] = self.sample_intra_neg_moment(meta["duration"], meta["ts"])
        meta["inter_neg_moment"], meta["inter_neg_vid_name"], meta["inter_neg_duration"] = \
            self.sample_inter_video_negative(meta["vid_name"], meta["pos_moment"] / meta["duration"],
                                             desc_id=meta["desc_id"])

        pos_tef, intra_neg_tef, inter_neg_tef = (None,) * 3
        if self.use_tef:
            pos_tef = meta["pos_moment"] / meta["duration"]  # temporal endpoint feature, (2, )
            intra_neg_tef = meta["intra_neg_moment"] / meta["duration"]
            inter_neg_tef = meta["inter_neg_moment"] / meta["inter_neg_duration"]

        if self.use_video:
            pos_v_feat = self.vid_feat_h5[meta["vid_name"]]  # (N_frm, D)
            neg_v_feat = self.vid_feat_h5[meta["inter_neg_vid_name"]]
            pos_v_ctx_feat = np.mean(pos_v_feat, axis=0)
            neg_v_ctx_feat = np.mean(neg_v_feat, axis=0)
            if self.normalize_vfeat:
                pos_v_ctx_feat = l2_normalize_np_array(pos_v_ctx_feat)
                neg_v_ctx_feat = l2_normalize_np_array(neg_v_ctx_feat)
            pos_moment_v_feat = self.get_moment_feat(pos_v_feat, meta["pos_moment"],
                                                     normalize=self.normalize_vfeat,
                                                     fix_outbound=True, pool_local=self.pool_local)
            intra_neg_moment_v_feat = self.get_moment_feat(pos_v_feat, meta["intra_neg_moment"],
                                                           normalize=self.normalize_vfeat,
                                                           fix_outbound=True, pool_local=self.pool_local)
            inter_neg_moment_v_feat = self.get_moment_feat(neg_v_feat, meta["inter_neg_moment"],
                                                           normalize=self.normalize_vfeat,
                                                           fix_outbound=True, pool_local=self.pool_local)

            # concat features, [video_clip_feat; video_context_feat; temporal_endpoint_feat]
            model_inputs["pos_moment_video_feat"] = self.concat_feat_adv(
                moment_feats=[pos_moment_v_feat, pos_v_ctx_feat], tef=pos_tef, ctx_mode=self.ctx_mode)
            model_inputs["intra_neg_moment_video_feat"] = self.concat_feat_adv(
                moment_feats=[intra_neg_moment_v_feat, pos_v_ctx_feat], tef=intra_neg_tef, ctx_mode=self.ctx_mode)
            model_inputs["inter_neg_moment_video_feat"] = self.concat_feat_adv(
                moment_feats=[inter_neg_moment_v_feat, neg_v_ctx_feat], tef=inter_neg_tef, ctx_mode=self.ctx_mode)
        else:
            for k in ["pos_moment_video_feat", "intra_neg_moment_video_feat", "inter_neg_moment_video_feat"]:
                model_inputs[k] = torch.zeros((2, 2))

        if self.use_sub:  # no need for ctx feature, as the features are already contextulized
            pos_s_feat = self.sub_bert_h5[meta["vid_name"]]  # (N_words, D_t)
            neg_s_feat = self.sub_bert_h5[meta["inter_neg_vid_name"]]
            pos_s_ctx_feat = np.mean(pos_s_feat, axis=0)
            neg_s_ctx_feat = np.mean(neg_s_feat, axis=0)
            if self.normalize_tfeat:
                pos_s_ctx_feat = l2_normalize_np_array(pos_s_ctx_feat)
                neg_s_ctx_feat = l2_normalize_np_array(neg_s_ctx_feat)
            pos_moment_s_feat = self.get_moment_feat(pos_s_feat, meta["pos_moment"],
                                                     normalize=self.normalize_tfeat,
                                                     fix_outbound=True, pool_local=self.pool_local)
            intra_neg_moment_s_feat = self.get_moment_feat(pos_s_feat, meta["intra_neg_moment"],
                                                           normalize=self.normalize_tfeat,
                                                           fix_outbound=True, pool_local=self.pool_local)
            inter_neg_moment_s_feat = self.get_moment_feat(neg_s_feat, meta["inter_neg_moment"],
                                                           normalize=self.normalize_tfeat,
                                                           fix_outbound=True, pool_local=self.pool_local)

            # concat features, [sub_clip_feat; sub_context_feat; temporal_endpoint_feat]
            model_inputs["pos_moment_sub_feat"] = self.concat_feat_adv(
                moment_feats=[pos_moment_s_feat, pos_s_ctx_feat], tef=pos_tef, ctx_mode=self.ctx_mode)
            model_inputs["intra_neg_moment_sub_feat"] = self.concat_feat_adv(
                moment_feats=[intra_neg_moment_s_feat, pos_s_ctx_feat], tef=intra_neg_tef, ctx_mode=self.ctx_mode)
            model_inputs["inter_neg_moment_sub_feat"] = self.concat_feat_adv(
                moment_feats=[inter_neg_moment_s_feat, neg_s_ctx_feat], tef=inter_neg_tef, ctx_mode=self.ctx_mode)
        else:
            for k in ["pos_moment_sub_feat", "intra_neg_moment_sub_feat", "inter_neg_moment_sub_feat"]:
                model_inputs[k] = torch.zeros((2, 2))

        if not self.use_sub and not self.use_video and self.use_tef:  # use video stream
            model_inputs["pos_moment_video_feat"] = \
                self.concat_feat_adv(tef=pos_tef, ctx_mode=self.ctx_mode)
            model_inputs["intra_neg_moment_video_feat"] = \
                self.concat_feat_adv(tef=intra_neg_tef, ctx_mode=self.ctx_mode)
            model_inputs["inter_neg_moment_video_feat"] = \
                self.concat_feat_adv(tef=inter_neg_tef, ctx_mode=self.ctx_mode)
        return dict(meta=meta, model_inputs=model_inputs)

    def align_ts_to_clip_boundaries(self, duration, ts):
        """  # TODO Do we really need this???
        Generate a moment [st, ed] that is most close to a clip boundary,
        st and ed must be a multiple of self.clip_length, and ed <= duration
        duration: float,
        ts: [st (float), ed (float)], ground_truth ts
        """
        clip_aligned_ts = np.array([math.floor(ts[0] / self.clip_length),
                                    math.ceil(ts[1] / self.clip_length)]) * self.clip_length
        clip_aligned_ts[1] = min(clip_aligned_ts[1], duration)
        return clip_aligned_ts

    def sample_intra_neg_moment(self, duration, ts):
        """ Generate a intra negative moment given the video duration and the GT ts.
        The returned moment will be aligned to clip boundaries.
        1) neg_moment has at least 2 clips
        2) its iou with ts should be < self.neg_iou_thd
        Args:
            duration: float
            ts: [st (float), ed (float)], ground_truth ts

        Returns:

        """
        max_n_search = 5  # search at most max_n_search times, so the program will not be stuck in infinite loops.
        sampled_moments = self.sample_ts_at_clip_boundaries(duration, n_pairs=max_n_search)  # (n_pairs, 2)
        sampled_moments_ious = compute_temporal_iou_batch(sampled_moments, ts)  # (n_pairs, )
        smallest_iou_idx = np.argmin(sampled_moments_ious)
        sampled_moment = sampled_moments[smallest_iou_idx]
        # only a small number (<20 with max_n_search==10) of samples are wrong,
        # usually when the video_duration is too short.
        # if sampled_moments_ious[smallest_iou_idx] >= self.neg_iou_thd:
        #     logger.warning("the sampled intra-neg might be wrong. "
        #                    "v_dur {}, ts {}, sampled neg moment {}, iou {}"
        #                    .format(duration, ts, sampled_moment, sampled_moments_ious[smallest_iou_idx]))
        return sampled_moment

    def sample_ts_at_clip_boundaries(self, duration, n_pairs=1):
        """sample n_pairs moment at clip boundaries, each has at least two clips."""
        # '+ self.clip_length' since we assume indexing using [clip_st_idx, clip_ed_idx),
        moments = np.random.randint(0, np.ceil(duration / self.clip_length), size=(n_pairs, 2))
        moments = np.sort(moments, axis=1) * self.clip_length
        less_equal = moments[:, 1] - moments[:, 0] <= self.clip_length
        start_zero = moments[:, 0] == 0
        moments[:, 1][less_equal * start_zero] += self.clip_length
        moments[:, 0][less_equal * (start_zero == False)] -= self.clip_length  # keep as bool!!!
        return moments

    def sample_inter_video_negative(self, pos_vid_name, normalized_pos_moment, desc_id=None):
        """Sample a negative moment --> negative video + similar normalized moment.
        1) they are not from the same video
        Args:
            pos_vid_name: str,
            normalized_pos_moment: np.ndarray, (2, ), value in [0, 1], normalized by duration.
            desc_id: str
        Returns:
            moment: np.ndarray, (2, ), ts aligned to clip boundaries.

        """
        use_guided_negative = hasattr(self, "desc_id2video_names_dur_pairs")
        if use_guided_negative:
            top_videos = self.desc_id2video_names_dur_pairs[desc_id]
            max_idx = len(top_videos) - 1

        while True:  # usually only run once.
            if use_guided_negative:
                sampled_idx = min(max_idx, int(random.expovariate(0.1)))
                sampled_video_name, sampled_video_dur = top_videos[sampled_idx]
            else:
                neg_vid_data = self.data[int(random.random() * len(self))]
                sampled_video_name, sampled_video_dur = neg_vid_data["vid_name"], neg_vid_data["duration"]
            if sampled_video_name != pos_vid_name:
                inter_neg_moment = self.align_ts_to_clip_boundaries(
                    sampled_video_dur, sampled_video_dur * normalized_pos_moment)
                break

        return inter_neg_moment, sampled_video_name, sampled_video_dur

    @classmethod
    def get_clip_indices_from_moments(cls, moment, clip_length):
        clip_st_ed_indices = moment / clip_length
        return math.floor(clip_st_ed_indices[0]), math.ceil(clip_st_ed_indices[1])

    def get_moment_feat(self, vid_feat, moment, normalize=True, fix_outbound=False, pool_local=False):
        """Each moment contains multiple clips.
        Inside means [moment[0], moment[1]] (seconds)
        Args:
            vid_feat: np.ndarray, (N_clips, D)
            moment: [st (float), ed (float)], np.ndarray
            normalize: L2 normalize features
            fix_outbound: bool,
            pool_local: whether to mean pool the features
        Returns:
            moment_feature: np.ndarray, ((moment[1] - moment[0]) / clip_length, D) or (D, )
        """
        clip_st_idx, clip_ed_idx = self.get_clip_indices_from_moments(moment, self.clip_length)
        if fix_outbound:
            vid_feat_len = len(vid_feat)
            if clip_st_idx >= vid_feat_len:
                clip_st_idx = vid_feat_len - 2
        moment_feat = vid_feat[clip_st_idx:clip_ed_idx]  # indexed as [st, ed)
        if pool_local:
            moment_feat = np.mean(moment_feat, axis=0, keepdims=True)
        if normalize:
            moment_feat = l2_normalize_np_array(moment_feat)
        return moment_feat  # (n_clip_in_moment, D) or (D, )

    @classmethod
    def concat_feat_adv(cls, moment_feats=None, tef=None, to_torch=True, ctx_mode="tef"):
        """ Concat moment_feat with other_feats and tef. All the features should be L2 normalized before concatenating
        Args:
            moment_feats: list of feats, one of them might be None. Other possible values are
                ctx_feat (D, ) or sub(vid)_moment_feat (N_p, N_clips, D_t) or (N_clips, D_t).
                The first non-None feature array is used as base for the rest to concatenate with.
            tef: (N_p, 2) or (2, ), np.ndarray
            to_torch: convert resulting np.ndarray to torch.tensor
            ctx_mode:
        """
        if ctx_mode == "tef":
            assembled_feat = np.expand_dims(tef, axis=-2)
        else:  # concat moment_feat with all other_feats
            moment_feats = [e for e in moment_feats if e is not None]  # remove possible None (placeholder)
            extra_dims = moment_feats[0].shape[:-1]  # all others will need to broadcast to match it.
            if isinstance(extra_dims, int):  # happens when len(moment_feat.shape) == 2
                extra_dims = (extra_dims, )
            last_dim_lengths = [0, ] + [e.shape[-1] for e in moment_feats]
            if "tef" in ctx_mode:  # add tef
                last_dim_lengths += [2, ]
                moment_feats += [np.expand_dims(tef, axis=-2), ]

            if len(moment_feats) > 1:
                assembled_feat = np.empty(extra_dims + (sum(last_dim_lengths), ), dtype=np.float32)
                last_dim_lengths_cumsum = [sum(last_dim_lengths[0:idx+1]) for idx in range(len(last_dim_lengths))]
                for idx, feat in enumerate(moment_feats):
                    assembled_feat[..., last_dim_lengths_cumsum[idx]:last_dim_lengths_cumsum[idx+1]] = feat
            else:
                assembled_feat = moment_feats[0]

        if to_torch:
            return torch.from_numpy(assembled_feat)
        else:
            return assembled_feat  # (N_prop, N_clips, D_concat) or (N_clips, D_concat)


class ProposalRetrievalEvalDataset(Dataset):
    """
    init_data_mode: `video_query` or `video_only` or `query_only`,
        it indicates which data to load when initialize the Dataset object.
    data_mode: `context` or `query`, it indicates which data to return for self.__get_item__()
    desc_bert_path_or_handler: h5py.File object or str path
    vid_feat_path_or_handler: h5py.File object or str path
    eval_proposal_bsz: the proposals for a single video will be sorted in length and batched here with
        max batch size to be eval_proposal_bsz. A single video might have multiple batches of proposals.
    load_gt_video: load GroundTruth Video, useful when evaluating single video moment retrieval.
    data_ratio: percentage of query data to use.
    """
    def __init__(self, dset_name, eval_split_name, data_path=None,
                 desc_bert_path_or_handler=None, max_desc_len=None,
                 sub_bert_path_or_handler=None, vid_feat_path_or_handler=None,
                 corpus_path=None, clip_length=None,
                 eval_proposal_bsz=None, ctx_mode="tef", data_mode="context",
                 h5driver=None, data_ratio=1.0, normalize_vfeat=True,
                 normalize_tfeat=True, max_n_proposals=90, model_type="cal"):
        self.dset_name = dset_name
        self.model_type = model_type
        self.pool_local = model_type == "mcn"  # pool local feature
        self.eval_split_name = eval_split_name
        self.ctx_mode = ctx_mode
        self.load_gt_video = False
        self.data_ratio = data_ratio  # only affect query data
        self.normalize_vfeat = normalize_vfeat
        self.normalize_tfeat = normalize_tfeat
        self.max_n_proposals = max_n_proposals

        self.data_mode = None
        self.set_data_mode(data_mode)

        self.max_desc_len = max_desc_len
        self.data_path = data_path
        self.query_data = load_jsonl(data_path)
        if data_ratio != 1:
            n_examples = int(len(self.query_data) * data_ratio)
            self.query_data = self.query_data[:n_examples]
            logger.info("Using {}% of the data: {} examples".format(data_ratio * 100, n_examples))
        if isinstance(desc_bert_path_or_handler, h5py.File):
            self.desc_bert_h5 = desc_bert_path_or_handler
        else:
            self.desc_bert_h5 = h5py.File(desc_bert_path_or_handler, "r", driver=h5driver)

        video_data = load_json(corpus_path)[self.eval_split_name]
        self.video_data = [{"vid_name": k, "duration": v[0]} for k, v in video_data.items()]
        self.video2idx = {k: v[1] for k, v in video_data.items()}
        self.eval_proposal_bsz = eval_proposal_bsz
        self.clip_length = clip_length
        self.proposal_fn = get_proposal_interface(dset_name)

        self.use_video = "video" in self.ctx_mode
        self.use_sub = "sub" in self.ctx_mode
        self.use_tef = "tef" in self.ctx_mode

        if self.use_video:
            if isinstance(vid_feat_path_or_handler, h5py.File):
                self.vid_feat_h5 = vid_feat_path_or_handler
            else:  # str path
                self.vid_feat_h5 = h5py.File(vid_feat_path_or_handler, "r", driver=h5driver)

        if self.use_sub:
            if isinstance(sub_bert_path_or_handler, h5py.File):
                self.sub_bert_h5 = sub_bert_path_or_handler
            else:  # str path
                self.sub_bert_h5 = h5py.File(sub_bert_path_or_handler, "r", driver=h5driver)

    def set_data_mode(self, data_mode):
        """context or query"""
        assert data_mode in ["context", "query"]
        self.data_mode = data_mode

    def load_gt_vid_name_for_query(self, load_gt_video):
        """load_gt_video: bool, affect the returned value of self._get_item_query"""
        assert "vid_name" in self.query_data[0]
        self.load_gt_video = load_gt_video

    def __len__(self):
        if self.data_mode == "context":
            return len(self.video_data)
        else:
            return len(self.query_data)

    def __getitem__(self, index):
        if self.data_mode == "context":
            return self._get_item_context(index)
        else:
            return self._get_item_query(index)

    def _get_item_query(self, index):
        """Need to batch"""
        raw_data = self.query_data[index]

        meta = dict(
            desc_id=raw_data["desc_id"],
            desc=raw_data["desc"],
            vid_name=raw_data["vid_name"] if self.load_gt_video else None
        )

        model_inputs = dict()
        query_feat = self.desc_bert_h5[str(raw_data["desc_id"])][:self.max_desc_len]
        if self.normalize_tfeat:
            query_feat = l2_normalize_np_array(query_feat)
        model_inputs["query_feat"] = torch.from_numpy(query_feat)
        return dict(meta=meta, model_inputs=model_inputs)

    def _get_item_context(self, index):
        """No need to batch, since it has already been batched here"""
        raw_data = self.video_data[index]

        # get proposals and sort in ascending order, to get more efficient batching
        proposals = self.proposal_fn(
            video_id="", metadata={"duration": raw_data["duration"]})  # np.ndarray (N_p, 2)
        proposals_lengths = proposals[:, 1] - proposals[:, 0]  # seconds
        sorted_proposal_indices = np.argsort(proposals_lengths)[:self.max_n_proposals]
        sorted_proposals = proposals[sorted_proposal_indices]

        # initialize with basic data
        meta = dict(
            vid_name=raw_data["vid_name"],
            duration=raw_data["duration"],
            proposals=sorted_proposals
        )
        model_inputs = dict()

        n_proposal_batches = math.ceil(1.0 * len(sorted_proposals) / self.eval_proposal_bsz)

        tef_batched_list = [None, ] * n_proposal_batches
        t_moments_mask_list = [None, ] * n_proposal_batches
        if self.use_tef:
            tef_array = sorted_proposals / meta["duration"]  # (N_p, 2)
            for batch_idx in range(n_proposal_batches):
                st_m_idx = batch_idx * self.eval_proposal_bsz
                ed_m_idx = (batch_idx + 1) * self.eval_proposal_bsz
                tef_batched_list[batch_idx] = tef_array[st_m_idx:ed_m_idx]
                t_moments_mask_list[batch_idx] = \
                    np.ones((len(tef_batched_list[batch_idx]), 1), dtype=np.float32)
            if not self.use_video and not self.use_sub:  # use video stream
                model_inputs["video_moment_features_list"] = [
                    ProposalRetrievalDataset.concat_feat_adv(tef=t, ctx_mode=self.ctx_mode) for t in tef_batched_list]
                model_inputs["video_moment_mask_list"] = [torch.from_numpy(e) for e in t_moments_mask_list]

        # extract/group/pad
        if self.use_video:
            v_feat = self.vid_feat_h5[meta["vid_name"]]  # (N_frm, D)
            v_ctx_feat = np.mean(v_feat, axis=0)  # (D, )
            if self.normalize_vfeat:
                v_ctx_feat = l2_normalize_np_array(v_ctx_feat)
            v_padded_moments_features_list, v_moments_mask_list = \
                self.get_batched_moment_feat_for_all_proposals(v_feat, sorted_proposals,
                                                               pool_local=self.pool_local,
                                                               normalize=self.normalize_vfeat)

            model_inputs["video_moment_features_list"] = [ProposalRetrievalDataset.concat_feat_adv(
                moment_feats=[v, v_ctx_feat], tef=t, ctx_mode=self.ctx_mode)
                for v, t in zip(v_padded_moments_features_list, tef_batched_list)]
            model_inputs["video_moment_mask_list"] = [torch.from_numpy(e) for e in v_moments_mask_list]

        if self.use_sub:
            s_feat = self.sub_bert_h5[meta["vid_name"]]  # (N_frm, D)
            s_ctx_feat = np.mean(s_feat, axis=0)  # (D, )
            if self.normalize_tfeat:
                s_ctx_feat = l2_normalize_np_array(s_ctx_feat)
            s_padded_moments_features_list, s_moments_mask_list = \
                self.get_batched_moment_feat_for_all_proposals(s_feat, sorted_proposals,
                                                               pool_local=self.pool_local,
                                                               normalize=self.normalize_tfeat)
            model_inputs["sub_moment_features_list"] = [ProposalRetrievalDataset.concat_feat_adv(
                moment_feats=[s, s_ctx_feat], tef=t, ctx_mode=self.ctx_mode)
                for s, t in zip(s_padded_moments_features_list, tef_batched_list)]
            model_inputs["sub_moment_mask_list"] = [torch.from_numpy(e) for e in s_moments_mask_list]
        return dict(meta=meta, model_inputs=model_inputs)

    def get_batched_moment_feat_for_all_proposals(self, feature, moments, pool_local=False, normalize=True):
        """proposals of the same video wil be segmented into multiple batches to accomodate GPU memory
        pool_local: pool local feature into a single vector
        """
        n_proposal_batches = math.ceil(1.0 * len(moments) / self.eval_proposal_bsz)
        padded_moments_features_list = [None, ] * n_proposal_batches
        moments_mask_list = [None, ] * n_proposal_batches
        moments_features = self.get_moment_feat_for_all_proposals(
            feature, moments, normalize=normalize, pool_local=pool_local)  # N_p * [(N_clips, D), ]
        for batch_idx in range(n_proposal_batches):
            st_m_idx = batch_idx * self.eval_proposal_bsz
            ed_m_idx = (batch_idx + 1) * self.eval_proposal_bsz
            padded_moments_features, moments_mask = \
                pad_sequences_1d(moments_features[st_m_idx:ed_m_idx], dtype=np.float32)
            padded_moments_features_list[batch_idx] = padded_moments_features
            moments_mask_list[batch_idx] = moments_mask
            assert np.sum(np.sum(moments_mask, axis=1) == 0) == 0, " err {}".format(moments_mask)
        assert np.sum(np.sum(moments_mask_list[0], axis=1) == 0) == 0, " err {}".format(moments_mask_list)
        return padded_moments_features_list, moments_mask_list

    def get_moment_feat_for_all_proposals(self, vid_feat, moments, normalize=True, pool_local=False):
        """Each moment is comprised of multiple clips
        Args:
            vid_feat: np.ndarray, (N_clips, D)
            moments: np.ndarray, (N_p, 2), each row is [st (float), ed (float)],
            normalize: L2 normalize
            pool_local:
        Returns:
            moments_features: list(np.ndarray), [(N_clips, D), ] * N_p, N_clips is changing.
        """
        if normalize and not pool_local:
            vid_feat = l2_normalize_np_array(vid_feat)
        vid_feat_len = len(vid_feat)
        moments_st_clip_indices = np.floor(moments[:, 0] / self.clip_length).astype(np.int64).clip(0, vid_feat_len-1)
        moments_ed_clip_indices = np.ceil(moments[:, 1] / self.clip_length).astype(np.int64).clip(1, vid_feat_len)
        moments_features = []
        for st_idx, ed_idx, m in zip(moments_st_clip_indices, moments_ed_clip_indices, moments):
            feat = vid_feat[st_idx:ed_idx]
            if pool_local:
                feat = np.mean(feat, axis=0, keepdims=True)
                if normalize:
                    feat = l2_normalize_np_array(feat)
            moments_features.append(feat)
        return moments_features


def proposal_retrieval_collate(batch):
    batch_meta = [e["meta"] for e in batch]  # seems no need to collate ?

    model_inputs_keys = batch[0]["model_inputs"].keys()
    batched_data = {k: pad_sequences_1d([e["model_inputs"][k] for e in batch], dtype=torch.float32)
                    for k in model_inputs_keys}
    return batch_meta, batched_data


def prepare_batch_inputs(batched_model_inputs, device, non_blocking=False):
    model_inputs = {}
    for k, v in batched_model_inputs.items():
        model_inputs[k] = v[0].to(device, non_blocking=non_blocking)
        model_inputs[k.replace("feat", "mask")] = v[1].to(device, non_blocking=non_blocking)
    return model_inputs


if __name__ == '__main__':
    from baselines.clip_alignment_with_language.config import BaseOptions
    options = BaseOptions().parse()