Spaces:
Running
on
T4
Running
on
T4
File size: 7,165 Bytes
d4b77ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
import random
import pickle
import logging
import torch
import cv2
import os
from torch.utils.data.dataset import Dataset
import numpy as np
import cvbase
from .util.STTN_mask import create_random_shape_with_random_motion
import imageio
from .util.flow_utils import region_fill as rf
logger = logging.getLogger('base')
class VideoBasedDataset(Dataset):
def __init__(self, opt, dataInfo):
self.opt = opt
self.sampleMethod = opt['sample']
self.dataInfo = dataInfo
self.height, self.width = self.opt['input_resolution']
self.frame_path = dataInfo['frame_path']
self.flow_path = dataInfo['flow_path'] # The path of the optical flows
self.train_list = os.listdir(self.frame_path)
self.name2length = self.dataInfo['name2len']
with open(self.name2length, 'rb') as f:
self.name2length = pickle.load(f)
self.sequenceLen = self.opt['num_frames']
self.flow2rgb = opt['flow2rgb'] # whether to change flow to rgb domain
self.flow_direction = opt[
'flow_direction'] # The direction must be in ['for', 'back', 'bi'], indicating forward, backward and bidirectional flows
def __len__(self):
return len(self.train_list)
def __getitem__(self, idx):
try:
item = self.load_item(idx)
except:
print('Loading error: ' + self.train_list[idx])
item = self.load_item(0)
return item
def frameSample(self, frameLen, sequenceLen):
if self.sampleMethod == 'random':
indices = [i for i in range(frameLen)]
sampleIndices = random.sample(indices, sequenceLen)
elif self.sampleMethod == 'seq':
pivot = random.randint(0, sequenceLen - 1 - frameLen)
sampleIndices = [i for i in range(pivot, pivot + frameLen)]
else:
raise ValueError('Cannot determine the sample method {}'.format(self.sampleMethod))
return sampleIndices
def load_item(self, idx):
video = self.train_list[idx]
frame_dir = os.path.join(self.frame_path, video)
forward_flow_dir = os.path.join(self.flow_path, video, 'forward_flo')
backward_flow_dir = os.path.join(self.flow_path, video, 'backward_flo')
frameLen = self.name2length[video]
flowLen = frameLen - 1
assert frameLen > self.sequenceLen, 'Frame length {} is less than sequence length'.format(frameLen)
sampledIndices = self.frameSample(frameLen, self.sequenceLen)
# generate random masks for these sampled frames
candidateMasks = create_random_shape_with_random_motion(frameLen, 0.9, 1.1, 1, 10)
# read the frames and masks
frames, masks, forward_flows, backward_flows = [], [], [], []
for i in range(len(sampledIndices)):
frame = self.read_frame(os.path.join(frame_dir, '{:05d}.jpg'.format(sampledIndices[i])), self.height,
self.width)
mask = self.read_mask(candidateMasks[sampledIndices[i]], self.height, self.width)
frames.append(frame)
masks.append(mask)
if self.flow_direction == 'for':
forward_flow = self.read_forward_flow(forward_flow_dir, sampledIndices[i], flowLen)
forward_flow = self.diffusion_flow(forward_flow, mask)
forward_flows.append(forward_flow)
elif self.flow_direction == 'back':
backward_flow = self.read_backward_flow(backward_flow_dir, sampledIndices[i])
backward_flow = self.diffusion_flow(backward_flow, mask)
backward_flows.append(backward_flow)
elif self.flow_direction == 'bi':
forward_flow = self.read_forward_flow(forward_flow_dir, sampledIndices[i], flowLen)
forward_flow = self.diffusion_flow(forward_flow, mask)
forward_flows.append(forward_flow)
backward_flow = self.read_backward_flow(backward_flow_dir, sampledIndices[i])
backward_flow = self.diffusion_flow(backward_flow, mask)
backward_flows.append(backward_flow)
else:
raise ValueError('Unknown flow direction mode: {}'.format(self.flow_direction))
inputs = {'frames': frames, 'masks': masks, 'forward_flo': forward_flows, 'backward_flo': backward_flows}
inputs = self.to_tensor(inputs)
inputs['frames'] = (inputs['frames'] / 255.) * 2 - 1
return inputs
def diffusion_flow(self, flow, mask):
flow_filled = np.zeros(flow.shape)
flow_filled[:, :, 0] = rf.regionfill(flow[:, :, 0] * (1 - mask), mask)
flow_filled[:, :, 1] = rf.regionfill(flow[:, :, 1] * (1 - mask), mask)
return flow_filled
def read_frame(self, path, height, width):
frame = imageio.imread(path)
frame = cv2.resize(frame, (width, height), cv2.INTER_LINEAR)
return frame
def read_mask(self, mask, height, width):
mask = np.array(mask)
mask = mask / 255.
raw_mask = (mask > 0.5).astype(np.uint8)
raw_mask = cv2.resize(raw_mask, dsize=(width, height), interpolation=cv2.INTER_NEAREST)
return raw_mask
def read_forward_flow(self, forward_flow_dir, sampledIndex, flowLen):
if sampledIndex >= flowLen:
sampledIndex = flowLen - 1
flow = cvbase.read_flow(os.path.join(forward_flow_dir, '{:05d}.flo'.format(sampledIndex)))
height, width = flow.shape[:2]
flow = cv2.resize(flow, (self.width, self.height), cv2.INTER_LINEAR)
flow[:, :, 0] = flow[:, :, 0] / width * self.width
flow[:, :, 1] = flow[:, :, 1] / height * self.height
return flow
def read_backward_flow(self, backward_flow_dir, sampledIndex):
if sampledIndex == 0:
sampledIndex = 0
else:
sampledIndex -= 1
flow = cvbase.read_flow(os.path.join(backward_flow_dir, '{:05d}.flo'.format(sampledIndex)))
height, width = flow.shape[:2]
flow = cv2.resize(flow, (self.width, self.height), cv2.INTER_LINEAR)
flow[:, :, 0] = flow[:, :, 0] / width * self.width
flow[:, :, 1] = flow[:, :, 1] / height * self.height
return flow
def to_tensor(self, data_list):
"""
Args:
data_list: A list contains multiple numpy arrays
Returns: The stacked tensor list
"""
keys = list(data_list.keys())
for key in keys:
if data_list[key] is None or data_list[key] == []:
data_list.pop(key)
else:
item = data_list[key]
if not isinstance(item, list):
item = torch.from_numpy(np.transpose(item, (2, 0, 1))).float() # [c, h, w]
else:
item = np.stack(item, axis=0)
if len(item.shape) == 3: # [t, h, w]
item = item[:, :, :, np.newaxis]
item = torch.from_numpy(np.transpose(item, (0, 3, 1, 2))).float() # [t, c, h, w]
data_list[key] = item
return data_list
|