UniVTG / main /dataset_qfvs.py
KevinQHLin's picture
Upload 60 files
9d0a4ae
raw
history blame
No virus
12.7 kB
import os
import pdb
import h5py
import nncore
import torch
from torch.utils.data import Dataset
import numpy as np
from tqdm import tqdm
import random
import logging
from os.path import join, exists
from nncore.dataset import DATASETS
from nncore.parallel import DataContainer
from main.config_hl import TVSUM_SPLITS, YOUTUBE_SPLITS
from utils.basic_utils import load_jsonl, load_pickle, l2_normalize_np_array
from utils.tensor_utils import pad_sequences_1d
from utils.span_utils import span_xx_to_cxw
logger = logging.getLogger(__name__)
class DatasetQFVS(Dataset):
def __init__(self,config, use_tef=True):
# pdb.set_trace()
self.config=config
self.dataset=[]
self.use_tef=use_tef
self.embedding=load_pickle(f"./data/qfvs/txt_clip/{self.config['txt_feature']}.pkl")
self.transfer={"Cupglass":"Glass",
"Musicalinstrument":"Instrument",
"Petsanimal":"Animal"}
self.f_dict = {}
feat_type = self.config['vid_feature']
for video_id in self.config["train_videos"]:
self.f_dict[str(video_id)] = h5py.File(f'./data/qfvs/processed/P0{video_id}_{feat_type}.h5','r')
for _ , _, files in os.walk("./data/qfvs/metadata/origin_data/Query-Focused_Summaries/Oracle_Summaries/P0"+str(video_id)):
for file in files:
self.dataset.append(['Oracle', file[:file.find("_oracle.txt")]+"_"+str(video_id)])
if self.config['qfvs_dense_shot'] > 0:
dense_concept = {}
feat_type = self.config['vid_feature']
feat=h5py.File(f'./data/qfvs/processed/P0{video_id}_{feat_type}.h5','r')
features=feat['features'][()]
seg_len=feat['seg_len'][()]
with open("./data/qfvs/metadata/origin_data/Dense_per_shot_tags/P0"+str(video_id)+"/P0"+str(video_id)+".txt","r") as f:
lines=f.readlines()
for index,line in enumerate(lines):
concepts=line.strip().split(',')
for concept in concepts:
if concept in self.transfer:
concept= self.transfer[concept]
if concept not in dense_concept:
# dense_concept[concept] = torch.zeros(seg_len.sum())
dense_concept[concept] = torch.zeros(self.config["max_segment_num"]*self.config["max_frame_num"])
else:
dense_concept[concept][index] = 1
for key, value in dense_concept.items():
if value.sum().item() > 0:
self.dataset.append([video_id, key, value])
def __getitem__(self, index):
if self.dataset[index][0] == 'Oracle':
return self.get_oracle(index)
else:
return self.get_dense(index)
def get_dense(self,index):
video_id=str(self.dataset[index][0])
f = self.f_dict[video_id]
# feat_type = self.config['vid_feature']
# f=h5py.File(f'./data/qfvs/processed/P0{video_id}_{feat_type}.h5','r')
features=f['features'][()]
seg_len=f['seg_len'][()]
dim = features.shape[-1]
mask_GT = torch.zeros(self.config["max_segment_num"], self.config["max_frame_num"], dtype=torch.bool)
for j in range(len(seg_len)):
for k in range(seg_len[j]):
mask_GT[j][k] = 1
features = torch.from_numpy(features)
concept1 = concept2 = self.dataset[index][1]
concept1_GT = concept2_GT = oracle_summary = self.dataset[index][2]
if concept1 in self.transfer:
concept1=self.transfer[concept1]
if concept2 in self.transfer:
concept2=self.transfer[concept2]
concept1=self.embedding[concept1]
concept2=self.embedding[concept2]
concept1 = l2_normalize_np_array(concept1)
concept2 = l2_normalize_np_array(concept2)
try:
saliency_pos_labels_1 = torch.Tensor([random.choice(torch.where(concept1_GT> 0)[0].tolist())])
except:
saliency_pos_labels_1 = torch.Tensor(0)
try:
saliency_pos_labels_2 = torch.Tensor([random.choice(torch.where(concept2_GT> 0)[0].tolist())])
except:
saliency_pos_labels_2 = torch.Tensor(0)
try:
saliency_pos_labels_oracle = torch.Tensor([random.choice(torch.where(oracle_summary> 0)[0].tolist())])
except:
saliency_pos_labels_oracle = torch.Tensor(0)
return {
'features':features,
'seg_len':torch.from_numpy(seg_len),
'concept1_GT':concept1_GT,
'concept2_GT':concept2_GT,
'mask_GT':mask_GT,
'oracle_summary':oracle_summary,
'tokens_pad1':torch.from_numpy(concept1),
'tokens_pad2':torch.from_numpy(concept2),
'saliency_pos_labels_1': saliency_pos_labels_1,
'saliency_pos_labels_2': saliency_pos_labels_2,
'saliency_pos_labels_oracle': saliency_pos_labels_oracle,
}
def get_oracle(self,index):
video_id=self.dataset[index][1].split('_')[2]
f = self.f_dict[video_id]
# video_id=self.dataset[index][1].split('_')[2]
# feat_type = self.config['vid_feature']
# f=h5py.File(f'./data/qfvs/processed/P0{video_id}_{feat_type}.h5','r')
features=f['features'][()]
seg_len=f['seg_len'][()]
dim = features.shape[-1]
mask_GT = torch.zeros(self.config["max_segment_num"], self.config["max_frame_num"], dtype=torch.bool)
for j in range(len(seg_len)):
for k in range(seg_len[j]):
mask_GT[j][k] = 1
features = torch.from_numpy(features)
concept1,concept2=self.dataset[index][1].split('_')[0:2]
concept1_GT=torch.zeros(self.config["max_segment_num"]*self.config["max_frame_num"])
concept2_GT=torch.zeros(self.config["max_segment_num"]*self.config["max_frame_num"])
# concept1_GT=torch.zeros(seg_len.sum())
# concept2_GT= torch.zeros(seg_len.sum())
with open("./data/qfvs/metadata/origin_data/Dense_per_shot_tags/P0"+video_id+"/P0"+video_id+".txt","r") as f:
lines=f.readlines()
for index,line in enumerate(lines):
concepts=line.strip().split(',')
if concept1 in concepts:
concept1_GT[index]=1
if concept2 in concepts:
concept2_GT[index]=1
# oracle_summary =torch.zeros(seg_len.sum())
oracle_summary = torch.zeros(self.config["max_segment_num"]*self.config["max_frame_num"])
GT_summary_shots = []
with open("./data/qfvs/metadata/origin_data/Query-Focused_Summaries/Oracle_Summaries/P0"+str(video_id)+"/"+str(concept1)+"_"+str(concept2)+"_"+"oracle.txt","r") as f:
for line in f.readlines():
GT_summary_shots.append(int(line.strip()))
GT_summary_shots = [x - 1 for x in GT_summary_shots]
for element in GT_summary_shots:
oracle_summary[element] = 1
if concept1 in self.transfer:
concept1=self.transfer[concept1]
if concept2 in self.transfer:
concept2=self.transfer[concept2]
concept1=self.embedding[concept1]
concept2=self.embedding[concept2]
concept1 = l2_normalize_np_array(concept1)
concept2 = l2_normalize_np_array(concept2)
try:
saliency_pos_labels_1 = torch.Tensor([random.choice(torch.where(concept1_GT> 0)[0].tolist())])
except:
saliency_pos_labels_1 = torch.Tensor(0)
try:
saliency_pos_labels_2 = torch.Tensor([random.choice(torch.where(concept2_GT> 0)[0].tolist())])
except:
saliency_pos_labels_2 = torch.Tensor(0)
try:
saliency_pos_labels_oracle = torch.Tensor([random.choice(torch.where(oracle_summary> 0)[0].tolist())])
except:
saliency_pos_labels_oracle = torch.Tensor(0)
return {
'features':features,
'seg_len':torch.from_numpy(seg_len),
'concept1_GT':concept1_GT,
'concept2_GT':concept2_GT,
'mask_GT':mask_GT,
'oracle_summary':oracle_summary,
'tokens_pad1':torch.from_numpy(concept1),
'tokens_pad2':torch.from_numpy(concept2),
'saliency_pos_labels_1': saliency_pos_labels_1,
'saliency_pos_labels_2': saliency_pos_labels_2,
'saliency_pos_labels_oracle': saliency_pos_labels_oracle,
}
def __len__(self):
return len(self.dataset)
def start_end_collate_qfvs(batch):
model_inputs_keys = batch[0].keys()
batched_data = dict()
for k in model_inputs_keys:
batched_data[k] = pad_sequences_1d([e[k].data for e in batch], dtype=torch.float32, fixed_length=None)
return batched_data
def prepare_batch_inputs_qfvs(data, config, eval=False):
if not eval:
features, mask, seg_len, \
concept1_GT, concept2_GT, mask_GT, oracle_summary_GT, \
src_txt_1, src_txt_2, src_txt_mask_1, src_txt_mask_2,\
saliency_pos_labels_1, saliency_pos_labels_2, saliency_pos_labels_oracle = \
data['features'][0], data['mask_GT'][0], data['seg_len'][0],\
data['concept1_GT'][0], data['concept2_GT'][0], data['mask_GT'][0], data['oracle_summary'][0],\
data['tokens_pad1'][0], data['tokens_pad2'][0], data['tokens_pad1'][1], data['tokens_pad2'][1], \
data['saliency_pos_labels_1'][0], data['saliency_pos_labels_2'][0], data['saliency_pos_labels_oracle'][0],
else:
features, mask, seg_len, \
src_txt_1, src_txt_2, src_txt_mask_1, src_txt_mask_2 = \
data['features'][0], data['mask_GT'][0], data['seg_len'][0],\
data['tokens_pad1'][0], data['tokens_pad2'][0], data['tokens_pad1'][1], data['tokens_pad2'][1]
# preprocess for vid input.
mask_GT = mask.to('cuda').reshape(1, -1).bool()
seq = features.to('cuda').squeeze(0)
mask = mask.to('cuda').squeeze(0)
num_seg = seq.shape[0]
ctx_l = seq.shape[1]
tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l
tef_ed = tef_st + 1.0 / ctx_l
tef = torch.stack([tef_st, tef_ed], dim=1).to('cuda') # (Lv, 2)
tef = tef.squeeze(0).repeat(seq.shape[0], 1, 1)
seq = torch.cat([seq, tef], dim=-1)
# for txt input.
src_txt_1 = src_txt_1.to(torch.float32).to('cuda').repeat(num_seg, 1, 1)
src_txt_2 = src_txt_2.to(torch.float32).to('cuda').repeat(num_seg, 1, 1)
src_txt_mask_1 = src_txt_mask_1.to('cuda').repeat(num_seg, 1)
src_txt_mask_2 = src_txt_mask_2.to('cuda').repeat(num_seg, 1)
src_txt_oracle = torch.cat((src_txt_1, src_txt_2), dim=1).to('cuda')
src_txt_mask_oracle = torch.cat((src_txt_mask_1, src_txt_mask_2), dim=1).to('cuda')
model_inputs_1 = dict(src_vid=seq, src_vid_mask=mask, src_txt=src_txt_1, src_txt_mask=src_txt_mask_1)
model_inputs_2 = dict(src_vid=seq, src_vid_mask=mask, src_txt=src_txt_2, src_txt_mask=src_txt_mask_2)
model_inputs_oracle = dict(src_vid=seq, src_vid_mask=mask, src_txt=src_txt_oracle, src_txt_mask=src_txt_mask_oracle)
# concept1_GT = concept1_GT.squeeze().reshape(config['max_segment_num'], config['max_frame_num'])
# concept2_GT = concept2_GT.squeeze().reshape(config['max_segment_num'], config['max_frame_num'])
# oracle_summary_GT = oracle_summary_GT.squeeze().reshape(config['max_segment_num'], config['max_frame_num'])
if not eval:
targets_1 = dict(saliency_scores=concept1_GT.to('cuda'), saliency_pos_labels=saliency_pos_labels_1.to('cuda'))
targets_2 = dict(saliency_scores=concept2_GT.to('cuda'), saliency_pos_labels=saliency_pos_labels_2.to('cuda'))
targets_oracle = dict(saliency_scores=oracle_summary_GT.to('cuda'), saliency_pos_labels=saliency_pos_labels_oracle.to('cuda'))
targets_1['timestamp_mask'] = mask; targets_1['timestamp_window'] = concept1_GT.to('cuda')
targets_2['timestamp_mask'] = mask; targets_2['timestamp_window'] = concept2_GT.to('cuda')
targets_oracle['timestamp_mask'] = mask; targets_oracle['timestamp_window'] = oracle_summary_GT.to('cuda')
return model_inputs_1, model_inputs_2, model_inputs_oracle, \
targets_1, targets_2, targets_oracle, mask_GT
else:
return model_inputs_1, model_inputs_2, model_inputs_oracle, mask_GT