oguzakif commited on
Commit
d4b77ac
1 Parent(s): f6248c8
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. FGT_codes/FGT/checkpoint/config.yaml +34 -0
  3. FGT_codes/FGT/checkpoint/fgt.pth.tar +3 -0
  4. FGT_codes/FGT/config/data_info.yaml +11 -0
  5. FGT_codes/FGT/config/davis_name2len.pkl +3 -0
  6. FGT_codes/FGT/config/davis_name2len_train.pkl +3 -0
  7. FGT_codes/FGT/config/davis_name2len_val.pkl +3 -0
  8. FGT_codes/FGT/config/train.yaml +93 -0
  9. FGT_codes/FGT/config/valid_config.yaml +8 -0
  10. FGT_codes/FGT/config/youtubevos_name2len.pkl +3 -0
  11. FGT_codes/FGT/data/__init__.py +49 -0
  12. FGT_codes/FGT/data/train_dataset.py +165 -0
  13. FGT_codes/FGT/data/util/MaskModel.py +123 -0
  14. FGT_codes/FGT/data/util/STTN_mask.py +244 -0
  15. FGT_codes/FGT/data/util/__init__.py +28 -0
  16. FGT_codes/FGT/data/util/flow_utils/__init__.py +0 -0
  17. FGT_codes/FGT/data/util/flow_utils/flow_reversal.py +77 -0
  18. FGT_codes/FGT/data/util/flow_utils/region_fill.py +142 -0
  19. FGT_codes/FGT/data/util/freeform_masks.py +266 -0
  20. FGT_codes/FGT/data/util/mask_generators.py +217 -0
  21. FGT_codes/FGT/data/util/readers.py +527 -0
  22. FGT_codes/FGT/data/util/util.py +259 -0
  23. FGT_codes/FGT/data/util/utils.py +158 -0
  24. FGT_codes/FGT/flowCheckPoint/config.yaml +11 -0
  25. FGT_codes/FGT/flowCheckPoint/lafc_single.pth.tar +3 -0
  26. FGT_codes/FGT/inputs.py +83 -0
  27. FGT_codes/FGT/metrics/__init__.py +31 -0
  28. FGT_codes/FGT/metrics/psnr.py +10 -0
  29. FGT_codes/FGT/metrics/ssim.py +46 -0
  30. FGT_codes/FGT/models/BaseNetwork.py +46 -0
  31. FGT_codes/FGT/models/__init__.py +0 -0
  32. FGT_codes/FGT/models/__pycache__/BaseNetwork.cpython-39.pyc +0 -0
  33. FGT_codes/FGT/models/__pycache__/__init__.cpython-39.pyc +0 -0
  34. FGT_codes/FGT/models/__pycache__/model.cpython-39.pyc +0 -0
  35. FGT_codes/FGT/models/lafc_single.py +114 -0
  36. FGT_codes/FGT/models/model.py +284 -0
  37. FGT_codes/FGT/models/temporal_patch_gan.py +76 -0
  38. FGT_codes/FGT/models/transformer_base/__init__.py +0 -0
  39. FGT_codes/FGT/models/transformer_base/__pycache__/__init__.cpython-39.pyc +0 -0
  40. FGT_codes/FGT/models/transformer_base/__pycache__/attention_base.cpython-39.pyc +0 -0
  41. FGT_codes/FGT/models/transformer_base/__pycache__/attention_flow.cpython-39.pyc +0 -0
  42. FGT_codes/FGT/models/transformer_base/__pycache__/ffn_base.cpython-39.pyc +0 -0
  43. FGT_codes/FGT/models/transformer_base/attention_base.py +106 -0
  44. FGT_codes/FGT/models/transformer_base/attention_flow.py +171 -0
  45. FGT_codes/FGT/models/transformer_base/ffn_base.py +114 -0
  46. FGT_codes/FGT/models/utils/RAFT/utils/__init__.py +0 -0
  47. FGT_codes/FGT/models/utils/RAFT/utils/utils.py +82 -0
  48. FGT_codes/FGT/models/utils/__init__.py +0 -0
  49. FGT_codes/FGT/models/utils/__pycache__/__init__.cpython-39.pyc +0 -0
  50. FGT_codes/FGT/models/utils/__pycache__/network_blocks_2d.cpython-39.pyc +0 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
  *.pth.tar filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
  *.pth.tar filter=lfs diff=lfs merge=lfs -text
36
+ *.o filter=lfs diff=lfs merge=lfs -text
FGT_codes/FGT/checkpoint/config.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PASSMASK: 1
2
+ alpha: 0.3
3
+ ape: 1
4
+ cnum: 64
5
+ conv_type: vanilla
6
+ dist_cnum: 32
7
+ drop: 0
8
+ frame_hidden: 512
9
+ gd: 4
10
+ in_channel: 4
11
+ init_weights: 1
12
+ input_resolution: !!python/tuple
13
+ - 240
14
+ - 432
15
+ flow_inChannel: 2
16
+ flow_cnum: 64
17
+ flow_hidden: 256
18
+ kernel_size: !!python/tuple
19
+ - 7
20
+ - 7
21
+ mlp_ratio: 40
22
+ numBlocks: 8
23
+ num_head: 4
24
+ padding: !!python/tuple
25
+ - 3
26
+ - 3
27
+ stride: !!python/tuple
28
+ - 3
29
+ - 3
30
+ sw: 8
31
+ tw: 2
32
+ use_bias: 1
33
+ norm: None
34
+ model: model
FGT_codes/FGT/checkpoint/fgt.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:41352263b2d14aec73f0dcf75c4bf5155ddb23404aba6f023a0300aadfd7672f
3
+ size 157341393
FGT_codes/FGT/config/data_info.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # dataset general info
2
+ frame_path: youtubevos_frames
3
+ flow_path: youtubevos_flows
4
+ name2len: config/youtubevos_name2len.pkl
5
+
6
+ flow:
7
+ flow_height: 240
8
+ flow_width: 432
9
+ augments: False
10
+ colors: RGB
11
+ ext: .jpg
FGT_codes/FGT/config/davis_name2len.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6607939cc02910f5badaebff46242f299597e93c07d77b6d740a3004f179f50c
3
+ size 1621
FGT_codes/FGT/config/davis_name2len_train.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1ad5e89d5486b38f74ac62d08924a4ff7caa445d34df827385457e8516d4763f
3
+ size 1073
FGT_codes/FGT/config/davis_name2len_val.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:30b2a23f943f40f2a09e98b474b88b07271e46a1224cb415650432d491cc1896
3
+ size 188
FGT_codes/FGT/config/train.yaml ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### General settings
2
+ name: FGT_train
3
+ use_tb_logger: true
4
+ outputdir: /myData/ret/experiments
5
+ datadir: /myData
6
+ record_iter: 16
7
+
8
+ ### Calling definition
9
+ model: model
10
+ datasetName_train: train_dataset
11
+ network: network
12
+
13
+ ### datasets
14
+ datasets:
15
+ train:
16
+ name: youtubevos
17
+ type: video
18
+ mode: train
19
+ dataInfo_config: ./config/data_info.yaml
20
+ use_shuffle: True
21
+ n_workers: 0
22
+ batch_size: 2
23
+
24
+ val:
25
+ name: youtubevos
26
+ type: video
27
+ mode: val
28
+ use_shuffle: False
29
+ n_workers: 1
30
+ batch_size: 1
31
+ val_config: ./config/valid_config.yaml
32
+
33
+ ### train settings
34
+ train:
35
+ lr: 0.0001
36
+ lr_decay: 0.1
37
+ manual_seed: 10
38
+ BETA1: 0.9
39
+ BETA2: 0.999
40
+ MAX_ITERS: 500000
41
+ UPDATE_INTERVAL: 300000 # 400000 is also OK
42
+ WARMUP: ~
43
+ val_freq: 1 # Set to 1 is for debug, you can enlarge it to 50 in regular training
44
+ TEMPORAL_GAN: ~ # without temporal GAN
45
+
46
+ ### logger
47
+ logger:
48
+ PRINT_FREQ: 16
49
+ SAVE_CHECKPOINT_FREQ: 4000 # 100 is for debug consideration
50
+
51
+ ### Data related parameters
52
+ flow2rgb: 1
53
+ flow_direction: for
54
+ num_frames: 5
55
+ sample: random
56
+ max_val: 0.01
57
+
58
+ ### Model related parameters
59
+ res_h: 240
60
+ res_w: 432
61
+ in_channel: 4
62
+ cnum: 64
63
+ flow_inChannel: 2
64
+ flow_cnum: 64
65
+ dist_cnum: 32
66
+ frame_hidden: 512
67
+ flow_hidden: 256
68
+ PASSMASK: 1
69
+ num_blocks: 8
70
+ kernel_size_w: 7
71
+ kernel_size_h: 7
72
+ stride_h: 3
73
+ stride_w: 3
74
+ num_head: 4
75
+ conv_type: vanilla
76
+ norm: None
77
+ use_bias: 1
78
+ ape: 1
79
+ pos_mode: single
80
+ mlp_ratio: 40
81
+ drop: 0
82
+ init_weights: 1
83
+ tw: 2
84
+ sw: 8
85
+ gd: 4
86
+
87
+ ### Loss weights
88
+ L1M: 1
89
+ L1V: 1
90
+ adv: 0.01
91
+
92
+ ### inference parameters
93
+ ref_length: 10
FGT_codes/FGT/config/valid_config.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ flow_height: 240
2
+ flow_width: 432
3
+ data_root: davis_valid_flows
4
+ mask_root: rectMask_96
5
+ frame_root: JPEGImages/480p
6
+ flow_root: davis_test_flows
7
+ batch_size: 1
8
+ name2len: config/davis_name2len_val.pkl
FGT_codes/FGT/config/youtubevos_name2len.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:60410308d4a0e780a531290d8bddc7f204bc0e8a500eab7c01c563b8efce9753
3
+ size 75501
FGT_codes/FGT/data/__init__.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import torch
3
+ import torch.utils.data
4
+ from importlib import import_module
5
+
6
+
7
+ def create_dataloader(phase, dataset, dataset_opt, opt=None, sampler=None):
8
+ logger = logging.getLogger('base')
9
+ if phase == 'train':
10
+ num_workers = dataset_opt['n_workers'] * opt['world_size']
11
+ batch_size = dataset_opt['batch_size']
12
+ if sampler is not None:
13
+ logger.info('N_workers: {}, batch_size: {} DDP train dataloader has been established'.format(num_workers,
14
+ batch_size))
15
+ return torch.utils.data.DataLoader(dataset, batch_size=batch_size,
16
+ num_workers=num_workers, sampler=sampler,
17
+ pin_memory=True)
18
+ else:
19
+ logger.info('N_workers: {}, batch_size: {} train dataloader has been established'.format(num_workers,
20
+ batch_size))
21
+ return torch.utils.data.DataLoader(dataset, batch_size=batch_size,
22
+ num_workers=num_workers, shuffle=True,
23
+ pin_memory=True)
24
+
25
+ else:
26
+ logger.info(
27
+ 'N_workers: {}, batch_size: {} validate/test dataloader has been established'.format(
28
+ dataset_opt['n_workers'],
29
+ dataset_opt['batch_size']))
30
+ return torch.utils.data.DataLoader(dataset, batch_size=dataset_opt['batch_size'], shuffle=False,
31
+ num_workers=dataset_opt['n_workers'],
32
+ pin_memory=False)
33
+
34
+
35
+ def create_dataset(dataset_opt, dataInfo, phase, dataset_name):
36
+ if phase == 'train':
37
+ dataset_package = import_module('data.{}'.format(dataset_name))
38
+ dataset = dataset_package.VideoBasedDataset(dataset_opt, dataInfo)
39
+
40
+ mode = dataset_opt['mode']
41
+ logger = logging.getLogger('base')
42
+ logger.info(
43
+ '{} train dataset [{:s} - {:s} - {:s}] is created.'.format(dataset_opt['type'].upper(),
44
+ dataset.__class__.__name__,
45
+ dataset_opt['name'], mode))
46
+ else: # validate and test dataset
47
+ return ValueError('No dataset initialized for valdataset')
48
+
49
+ return dataset
FGT_codes/FGT/data/train_dataset.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import pickle
3
+
4
+ import logging
5
+ import torch
6
+ import cv2
7
+ import os
8
+
9
+ from torch.utils.data.dataset import Dataset
10
+ import numpy as np
11
+ import cvbase
12
+ from .util.STTN_mask import create_random_shape_with_random_motion
13
+ import imageio
14
+ from .util.flow_utils import region_fill as rf
15
+
16
+ logger = logging.getLogger('base')
17
+
18
+
19
+ class VideoBasedDataset(Dataset):
20
+ def __init__(self, opt, dataInfo):
21
+ self.opt = opt
22
+ self.sampleMethod = opt['sample']
23
+ self.dataInfo = dataInfo
24
+ self.height, self.width = self.opt['input_resolution']
25
+ self.frame_path = dataInfo['frame_path']
26
+ self.flow_path = dataInfo['flow_path'] # The path of the optical flows
27
+ self.train_list = os.listdir(self.frame_path)
28
+ self.name2length = self.dataInfo['name2len']
29
+ with open(self.name2length, 'rb') as f:
30
+ self.name2length = pickle.load(f)
31
+ self.sequenceLen = self.opt['num_frames']
32
+ self.flow2rgb = opt['flow2rgb'] # whether to change flow to rgb domain
33
+ self.flow_direction = opt[
34
+ 'flow_direction'] # The direction must be in ['for', 'back', 'bi'], indicating forward, backward and bidirectional flows
35
+
36
+ def __len__(self):
37
+ return len(self.train_list)
38
+
39
+ def __getitem__(self, idx):
40
+ try:
41
+ item = self.load_item(idx)
42
+ except:
43
+ print('Loading error: ' + self.train_list[idx])
44
+ item = self.load_item(0)
45
+ return item
46
+
47
+ def frameSample(self, frameLen, sequenceLen):
48
+ if self.sampleMethod == 'random':
49
+ indices = [i for i in range(frameLen)]
50
+ sampleIndices = random.sample(indices, sequenceLen)
51
+ elif self.sampleMethod == 'seq':
52
+ pivot = random.randint(0, sequenceLen - 1 - frameLen)
53
+ sampleIndices = [i for i in range(pivot, pivot + frameLen)]
54
+ else:
55
+ raise ValueError('Cannot determine the sample method {}'.format(self.sampleMethod))
56
+ return sampleIndices
57
+
58
+ def load_item(self, idx):
59
+ video = self.train_list[idx]
60
+ frame_dir = os.path.join(self.frame_path, video)
61
+ forward_flow_dir = os.path.join(self.flow_path, video, 'forward_flo')
62
+ backward_flow_dir = os.path.join(self.flow_path, video, 'backward_flo')
63
+ frameLen = self.name2length[video]
64
+ flowLen = frameLen - 1
65
+ assert frameLen > self.sequenceLen, 'Frame length {} is less than sequence length'.format(frameLen)
66
+ sampledIndices = self.frameSample(frameLen, self.sequenceLen)
67
+
68
+ # generate random masks for these sampled frames
69
+ candidateMasks = create_random_shape_with_random_motion(frameLen, 0.9, 1.1, 1, 10)
70
+
71
+ # read the frames and masks
72
+ frames, masks, forward_flows, backward_flows = [], [], [], []
73
+ for i in range(len(sampledIndices)):
74
+ frame = self.read_frame(os.path.join(frame_dir, '{:05d}.jpg'.format(sampledIndices[i])), self.height,
75
+ self.width)
76
+ mask = self.read_mask(candidateMasks[sampledIndices[i]], self.height, self.width)
77
+ frames.append(frame)
78
+ masks.append(mask)
79
+ if self.flow_direction == 'for':
80
+ forward_flow = self.read_forward_flow(forward_flow_dir, sampledIndices[i], flowLen)
81
+ forward_flow = self.diffusion_flow(forward_flow, mask)
82
+ forward_flows.append(forward_flow)
83
+ elif self.flow_direction == 'back':
84
+ backward_flow = self.read_backward_flow(backward_flow_dir, sampledIndices[i])
85
+ backward_flow = self.diffusion_flow(backward_flow, mask)
86
+ backward_flows.append(backward_flow)
87
+ elif self.flow_direction == 'bi':
88
+ forward_flow = self.read_forward_flow(forward_flow_dir, sampledIndices[i], flowLen)
89
+ forward_flow = self.diffusion_flow(forward_flow, mask)
90
+ forward_flows.append(forward_flow)
91
+ backward_flow = self.read_backward_flow(backward_flow_dir, sampledIndices[i])
92
+ backward_flow = self.diffusion_flow(backward_flow, mask)
93
+ backward_flows.append(backward_flow)
94
+ else:
95
+ raise ValueError('Unknown flow direction mode: {}'.format(self.flow_direction))
96
+ inputs = {'frames': frames, 'masks': masks, 'forward_flo': forward_flows, 'backward_flo': backward_flows}
97
+ inputs = self.to_tensor(inputs)
98
+ inputs['frames'] = (inputs['frames'] / 255.) * 2 - 1
99
+ return inputs
100
+
101
+ def diffusion_flow(self, flow, mask):
102
+ flow_filled = np.zeros(flow.shape)
103
+ flow_filled[:, :, 0] = rf.regionfill(flow[:, :, 0] * (1 - mask), mask)
104
+ flow_filled[:, :, 1] = rf.regionfill(flow[:, :, 1] * (1 - mask), mask)
105
+ return flow_filled
106
+
107
+ def read_frame(self, path, height, width):
108
+ frame = imageio.imread(path)
109
+ frame = cv2.resize(frame, (width, height), cv2.INTER_LINEAR)
110
+ return frame
111
+
112
+ def read_mask(self, mask, height, width):
113
+ mask = np.array(mask)
114
+ mask = mask / 255.
115
+ raw_mask = (mask > 0.5).astype(np.uint8)
116
+ raw_mask = cv2.resize(raw_mask, dsize=(width, height), interpolation=cv2.INTER_NEAREST)
117
+ return raw_mask
118
+
119
+ def read_forward_flow(self, forward_flow_dir, sampledIndex, flowLen):
120
+ if sampledIndex >= flowLen:
121
+ sampledIndex = flowLen - 1
122
+ flow = cvbase.read_flow(os.path.join(forward_flow_dir, '{:05d}.flo'.format(sampledIndex)))
123
+ height, width = flow.shape[:2]
124
+ flow = cv2.resize(flow, (self.width, self.height), cv2.INTER_LINEAR)
125
+ flow[:, :, 0] = flow[:, :, 0] / width * self.width
126
+ flow[:, :, 1] = flow[:, :, 1] / height * self.height
127
+ return flow
128
+
129
+ def read_backward_flow(self, backward_flow_dir, sampledIndex):
130
+ if sampledIndex == 0:
131
+ sampledIndex = 0
132
+ else:
133
+ sampledIndex -= 1
134
+ flow = cvbase.read_flow(os.path.join(backward_flow_dir, '{:05d}.flo'.format(sampledIndex)))
135
+ height, width = flow.shape[:2]
136
+ flow = cv2.resize(flow, (self.width, self.height), cv2.INTER_LINEAR)
137
+ flow[:, :, 0] = flow[:, :, 0] / width * self.width
138
+ flow[:, :, 1] = flow[:, :, 1] / height * self.height
139
+ return flow
140
+
141
+ def to_tensor(self, data_list):
142
+ """
143
+
144
+ Args:
145
+ data_list: A list contains multiple numpy arrays
146
+
147
+ Returns: The stacked tensor list
148
+
149
+ """
150
+ keys = list(data_list.keys())
151
+ for key in keys:
152
+ if data_list[key] is None or data_list[key] == []:
153
+ data_list.pop(key)
154
+ else:
155
+ item = data_list[key]
156
+ if not isinstance(item, list):
157
+ item = torch.from_numpy(np.transpose(item, (2, 0, 1))).float() # [c, h, w]
158
+ else:
159
+ item = np.stack(item, axis=0)
160
+ if len(item.shape) == 3: # [t, h, w]
161
+ item = item[:, :, :, np.newaxis]
162
+ item = torch.from_numpy(np.transpose(item, (0, 3, 1, 2))).float() # [t, c, h, w]
163
+ data_list[key] = item
164
+ return data_list
165
+
FGT_codes/FGT/data/util/MaskModel.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import numpy as np
3
+
4
+ class RandomMask():
5
+ def __init__(self, videoLength, dataInfo):
6
+ self.videoLength = videoLength
7
+ self.imageHeight, self.imageWidth = dataInfo['image']['image_height'], \
8
+ dataInfo['image']['image_width']
9
+ self.maskHeight, self.maskWidth = dataInfo['mask']['mask_height'], \
10
+ dataInfo['mask']['mask_width']
11
+ try:
12
+ self.maxDeltaHeight, self.maxDeltaWidth = dataInfo['mask']['max_delta_height'], \
13
+ dataInfo['mask']['max_delta_width']
14
+ except KeyError:
15
+ self.maxDeltaHeight, self.maxDeltaWidth = 0, 0
16
+
17
+ try:
18
+ self.verticalMargin, self.horizontalMargin = dataInfo['mask']['vertical_margin'], \
19
+ dataInfo['mask']['horizontal_margin']
20
+ except KeyError:
21
+ self.verticalMargin, self.horizontalMargin = 0, 0
22
+
23
+ def __call__(self):
24
+ from .utils import random_bbox
25
+ from .utils import bbox2mask
26
+ masks = []
27
+ bbox = random_bbox(self.imageHeight, self.imageWidth, self.verticalMargin, self.horizontalMargin,
28
+ self.maskHeight, self.maskWidth)
29
+ if random.uniform(0, 1) > 0.5:
30
+ mask = bbox2mask(self.imageHeight, self.imageWidth, 0, 0, bbox)
31
+ for frame in range(self.videoLength):
32
+ masks.append(mask)
33
+ else:
34
+ for frame in range(self.videoLength):
35
+ delta_h, delta_w = random.randint(-3, 3), random.randint(-3, 3) # 每次向四个方向移动三个像素以内
36
+ bbox = list(bbox)
37
+ bbox[0] = min(max(self.verticalMargin, bbox[0] + delta_h), self.imageHeight - self.verticalMargin - bbox[2])
38
+ bbox[1] = min(max(self.horizontalMargin, bbox[1] + delta_w), self.imageWidth - self.horizontalMargin - bbox[3])
39
+ mask = bbox2mask(self.imageHeight, self.imageWidth, 0, 0, bbox)
40
+ masks.append(mask)
41
+ masks = np.stack(masks, axis=0)
42
+ if len(masks.shape) == 3:
43
+ masks = masks[:, :, :, np.newaxis]
44
+ assert len(masks.shape) == 4, 'Wrong mask dimension {}'.format(len(masks.shape))
45
+ return masks
46
+
47
+
48
+ class MidRandomMask():
49
+ ### This mask is considered without random motion
50
+ def __init__(self, videoLength, dataInfo):
51
+ self.videoLength = videoLength
52
+ self.imageHeight, self.imageWidth = dataInfo['image']['image_height'], \
53
+ dataInfo['image']['image_width']
54
+ self.maskHeight, self.maskWidth = dataInfo['mask']['mask_height'], \
55
+ dataInfo['mask']['mask_width']
56
+
57
+ def __call__(self):
58
+ from .utils import mid_bbox_mask
59
+ mask = mid_bbox_mask(self.imageHeight, self.imageWidth, self.maskHeight, self.maskWidth)
60
+ masks = []
61
+ for _ in range(self.videoLength):
62
+ masks.append(mask)
63
+ return mask
64
+
65
+
66
+ class MatrixMask():
67
+ ### This mask is considered without random motion
68
+ def __init__(self, videoLength, dataInfo):
69
+ self.videoLength = videoLength
70
+ self.imageHeight, self.imageWidth = dataInfo['image']['image_height'], \
71
+ dataInfo['image']['image_width']
72
+ self.maskHeight, self.maskWidth = dataInfo['mask']['mask_height'], \
73
+ dataInfo['mask']['mask_width']
74
+ try:
75
+ self.row, self.column = dataInfo['mask']['row'], \
76
+ dataInfo['mask']['column']
77
+ except KeyError:
78
+ self.row, self.column = 5, 4
79
+
80
+ def __call__(self):
81
+ from .utils import matrix2bbox
82
+ mask = matrix2bbox(self.imageHeight, self.imageWidth, self.maskHeight,
83
+ self.maskWidth, self.row, self.column)
84
+ masks = []
85
+ for video in range(self.videoLength):
86
+ masks.append(mask)
87
+ return mask
88
+
89
+
90
+ class FreeFormMask():
91
+ def __init__(self, videoLength, dataInfo):
92
+ self.videoLength = videoLength
93
+ self.imageHeight, self.imageWidth = dataInfo['image']['image_height'], \
94
+ dataInfo['image']['image_width']
95
+ self.maxVertex = dataInfo['mask']['max_vertex']
96
+ self.maxLength = dataInfo['mask']['max_length']
97
+ self.maxBrushWidth = dataInfo['mask']['max_brush_width']
98
+ self.maxAngle = dataInfo['mask']['max_angle']
99
+
100
+ def __call__(self):
101
+ from .utils import freeFormMask
102
+ mask = freeFormMask(self.imageHeight, self.imageWidth,
103
+ self.maxVertex, self.maxLength,
104
+ self.maxBrushWidth, self.maxAngle)
105
+ return mask
106
+
107
+
108
+ class StationaryMask():
109
+ def __init__(self, videoLength, dataInfo):
110
+ self.videoLength = videoLength
111
+ self.imageHeight, self.imageWidth = dataInfo['image']['image_height'], \
112
+ dataInfo['image']['image_width']
113
+ # self.maxPointNum = dataInfo['mask']['max_point_num']
114
+ # self.maxLength = dataInfo['mask']['max_length']
115
+
116
+ def __call__(self):
117
+ from .STTN_mask import create_random_shape_with_random_motion
118
+ masks = create_random_shape_with_random_motion(self.videoLength, 0.9, 1.1, 1, 10, self.imageHeight, self.imageWidth)
119
+ masks = np.stack(masks, axis=0)
120
+ if len(masks.shape) == 3:
121
+ masks = masks[:, :, :, np.newaxis]
122
+ assert len(masks.shape) == 4, 'Your masks with a wrong shape {}'.format(len(masks.shape))
123
+ return masks
FGT_codes/FGT/data/util/STTN_mask.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.patches as patches
2
+ from matplotlib.path import Path
3
+ import os
4
+ import sys
5
+ import io
6
+ import cv2
7
+ import time
8
+ import math
9
+ import argparse
10
+ import shutil
11
+ import random
12
+ import zipfile
13
+ from glob import glob
14
+ import math
15
+ import numpy as np
16
+ import torch.nn.functional as F
17
+ import torchvision.transforms as transforms
18
+ from PIL import Image, ImageOps, ImageDraw, ImageFilter
19
+
20
+ import torch
21
+ import torchvision
22
+ import torch.nn as nn
23
+ import torch.distributed as dist
24
+
25
+ import matplotlib
26
+ from matplotlib import pyplot as plt
27
+ matplotlib.use('agg')
28
+
29
+
30
+ class GroupRandomHorizontalFlip(object):
31
+ """Randomly horizontally flips the given PIL.Image with a probability of 0.5
32
+ """
33
+
34
+ def __init__(self, is_flow=False):
35
+ self.is_flow = is_flow
36
+
37
+ def __call__(self, img_group, is_flow=False):
38
+ v = random.random()
39
+ if v < 0.5:
40
+ ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group]
41
+ if self.is_flow:
42
+ for i in range(0, len(ret), 2):
43
+ # invert flow pixel values when flipping
44
+ ret[i] = ImageOps.invert(ret[i])
45
+ return ret
46
+ else:
47
+ return img_group
48
+
49
+
50
+ class Stack(object):
51
+ def __init__(self, roll=False):
52
+ self.roll = roll
53
+
54
+ def __call__(self, img_group):
55
+ mode = img_group[0].mode
56
+ if mode == '1':
57
+ img_group = [img.convert('L') for img in img_group]
58
+ mode = 'L'
59
+ if mode == 'L':
60
+ return np.stack([np.expand_dims(x, 2) for x in img_group], axis=2)
61
+ elif mode == 'RGB':
62
+ if self.roll:
63
+ return np.stack([np.array(x)[:, :, ::-1] for x in img_group], axis=2)
64
+ else:
65
+ return np.stack(img_group, axis=2)
66
+ else:
67
+ raise NotImplementedError("Image mode {}".format(mode))
68
+
69
+
70
+ class ToTorchFormatTensor(object):
71
+ """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
72
+ to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """
73
+
74
+ def __init__(self, div=True):
75
+ self.div = div
76
+
77
+ def __call__(self, pic):
78
+ if isinstance(pic, np.ndarray):
79
+ # numpy img: [L, C, H, W]
80
+ img = torch.from_numpy(pic).permute(2, 3, 0, 1).contiguous()
81
+ else:
82
+ # handle PIL Image
83
+ img = torch.ByteTensor(
84
+ torch.ByteStorage.from_buffer(pic.tobytes()))
85
+ img = img.view(pic.size[1], pic.size[0], len(pic.mode))
86
+ # put it from HWC to CHW format
87
+ # yikes, this transpose takes 80% of the loading time/CPU
88
+ img = img.transpose(0, 1).transpose(0, 2).contiguous()
89
+ img = img.float().div(255) if self.div else img.float()
90
+ return img
91
+
92
+
93
+ # ##########################################
94
+ # ##########################################
95
+
96
+ def create_random_shape_with_random_motion(video_length, zoomin, zoomout, rotmin, rotmax, imageHeight=240, imageWidth=432):
97
+ # get a random shape
98
+ assert zoomin < 1, "Zoom-in parameter must be smaller than 1"
99
+ assert zoomout > 1, "Zoom-out parameter must be larger than 1"
100
+ assert rotmin < rotmax, "Minimum value of rotation must be smaller than maximun value !"
101
+ height = random.randint(imageHeight//3, imageHeight-1)
102
+ width = random.randint(imageWidth//3, imageWidth-1)
103
+ edge_num = random.randint(6, 8)
104
+ ratio = random.randint(6, 8)/10
105
+ region = get_random_shape(
106
+ edge_num=edge_num, ratio=ratio, height=height, width=width)
107
+ region_width, region_height = region.size
108
+ # get random position
109
+ x, y = random.randint(
110
+ 0, imageHeight-region_height), random.randint(0, imageWidth-region_width)
111
+ velocity = get_random_velocity(max_speed=3)
112
+ m = Image.fromarray(np.zeros((imageHeight, imageWidth)).astype(np.uint8))
113
+ m.paste(region, (y, x, y+region.size[0], x+region.size[1]))
114
+ masks = [m.convert('L')]
115
+ # return fixed masks
116
+ if random.uniform(0, 1) > 0.5:
117
+ return masks*video_length # -> directly copy all the base masks
118
+ # return moving masks
119
+ for _ in range(video_length-1):
120
+ x, y, velocity = random_move_control_points(
121
+ x, y, imageHeight, imageWidth, velocity, region.size, maxLineAcceleration=(3, 0.5), maxInitSpeed=3)
122
+ m = Image.fromarray(
123
+ np.zeros((imageHeight, imageWidth)).astype(np.uint8))
124
+ ### add by kaidong, to simulate zoon-in, zoom-out and rotation
125
+ extra_transform = random.uniform(0, 1)
126
+ # zoom in and zoom out
127
+ if extra_transform > 0.75:
128
+ resize_coefficient = random.uniform(zoomin, zoomout)
129
+ region = region.resize((math.ceil(region_width * resize_coefficient), math.ceil(region_height * resize_coefficient)), Image.NEAREST)
130
+ m.paste(region, (y, x, y + region.size[0], x + region.size[1]))
131
+ region_width, region_height = region.size
132
+ # rotation
133
+ elif extra_transform > 0.5:
134
+ m.paste(region, (y, x, y + region.size[0], x + region.size[1]))
135
+ m = m.rotate(random.randint(rotmin, rotmax))
136
+ # region_width, region_height = region.size
137
+ ### end
138
+ else:
139
+ m.paste(region, (y, x, y+region.size[0], x+region.size[1]))
140
+ masks.append(m.convert('L'))
141
+ return masks
142
+
143
+
144
+ def get_random_shape(edge_num=9, ratio=0.7, width=432, height=240):
145
+ '''
146
+ There is the initial point and 3 points per cubic bezier curve.
147
+ Thus, the curve will only pass though n points, which will be the sharp edges.
148
+ The other 2 modify the shape of the bezier curve.
149
+ edge_num, Number of possibly sharp edges
150
+ points_num, number of points in the Path
151
+ ratio, (0, 1) magnitude of the perturbation from the unit circle,
152
+ '''
153
+ points_num = edge_num*3 + 1
154
+ angles = np.linspace(0, 2*np.pi, points_num)
155
+ codes = np.full(points_num, Path.CURVE4)
156
+ codes[0] = Path.MOVETO
157
+ # Using this instad of Path.CLOSEPOLY avoids an innecessary straight line
158
+ verts = np.stack((np.cos(angles), np.sin(angles))).T * \
159
+ (2*ratio*np.random.random(points_num)+1-ratio)[:, None]
160
+ verts[-1, :] = verts[0, :]
161
+ path = Path(verts, codes)
162
+ # draw paths into images
163
+ fig = plt.figure()
164
+ ax = fig.add_subplot(111)
165
+ patch = patches.PathPatch(path, facecolor='black', lw=2)
166
+ ax.add_patch(patch)
167
+ ax.set_xlim(np.min(verts)*1.1, np.max(verts)*1.1)
168
+ ax.set_ylim(np.min(verts)*1.1, np.max(verts)*1.1)
169
+ ax.axis('off') # removes the axis to leave only the shape
170
+ fig.canvas.draw()
171
+ # convert plt images into numpy images
172
+ data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
173
+ data = data.reshape((fig.canvas.get_width_height()[::-1] + (3,)))
174
+ plt.close(fig)
175
+ # postprocess
176
+ data = cv2.resize(data, (width, height))[:, :, 0]
177
+ data = (1 - np.array(data > 0).astype(np.uint8))*255
178
+ corrdinates = np.where(data > 0)
179
+ xmin, xmax, ymin, ymax = np.min(corrdinates[0]), np.max(
180
+ corrdinates[0]), np.min(corrdinates[1]), np.max(corrdinates[1])
181
+ region = Image.fromarray(data).crop((ymin, xmin, ymax, xmax))
182
+ return region
183
+
184
+
185
+ def random_accelerate(velocity, maxAcceleration, dist='uniform'):
186
+ speed, angle = velocity
187
+ d_speed, d_angle = maxAcceleration
188
+ if dist == 'uniform':
189
+ speed += np.random.uniform(-d_speed, d_speed)
190
+ angle += np.random.uniform(-d_angle, d_angle)
191
+ elif dist == 'guassian':
192
+ speed += np.random.normal(0, d_speed / 2)
193
+ angle += np.random.normal(0, d_angle / 2)
194
+ else:
195
+ raise NotImplementedError(
196
+ f'Distribution type {dist} is not supported.')
197
+ return (speed, angle)
198
+
199
+
200
+ def get_random_velocity(max_speed=3, dist='uniform'):
201
+ if dist == 'uniform':
202
+ speed = np.random.uniform(max_speed)
203
+ elif dist == 'guassian':
204
+ speed = np.abs(np.random.normal(0, max_speed / 2))
205
+ else:
206
+ raise NotImplementedError(
207
+ 'Distribution type {} is not supported.'.format(dist))
208
+ angle = np.random.uniform(0, 2 * np.pi)
209
+ return (speed, angle)
210
+
211
+
212
+ def random_move_control_points(X, Y, imageHeight, imageWidth, lineVelocity, region_size, maxLineAcceleration=(3, 0.5), maxInitSpeed=3):
213
+ region_width, region_height = region_size
214
+ speed, angle = lineVelocity
215
+ X += int(speed * np.cos(angle))
216
+ Y += int(speed * np.sin(angle))
217
+ lineVelocity = random_accelerate(
218
+ lineVelocity, maxLineAcceleration, dist='guassian')
219
+ if ((X > imageHeight - region_height) or (X < 0) or (Y > imageWidth - region_width) or (Y < 0)):
220
+ lineVelocity = get_random_velocity(maxInitSpeed, dist='guassian')
221
+ new_X = np.clip(X, 0, imageHeight - region_height)
222
+ new_Y = np.clip(Y, 0, imageWidth - region_width)
223
+ return new_X, new_Y, lineVelocity
224
+
225
+
226
+
227
+ # ##############################################
228
+ # ##############################################
229
+
230
+ if __name__ == '__main__':
231
+ import os
232
+ os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
233
+ trials = 10
234
+ for _ in range(trials):
235
+ video_length = 10
236
+ # The returned masks are either stationary (50%) or moving (50%)
237
+ masks = create_random_shape_with_random_motion(video_length, zoomin=0.9, zoomout=1.1, rotmin=1, rotmax=10, imageHeight=240, imageWidth=432)
238
+ i = 0
239
+
240
+ for m in masks:
241
+ cv2.imshow('mask', np.array(m))
242
+ cv2.waitKey(500)
243
+ # m.save('mask_{}.png'.format(i))
244
+ i += 1
FGT_codes/FGT/data/util/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .STTN_mask import create_random_shape_with_random_motion
2
+
3
+ import logging
4
+ logger = logging.getLogger('base')
5
+
6
+
7
+ def initialize_mask(videoLength, dataInfo):
8
+ from .MaskModel import RandomMask
9
+ from .MaskModel import MidRandomMask
10
+ from .MaskModel import MatrixMask
11
+ from .MaskModel import FreeFormMask
12
+ from .MaskModel import StationaryMask
13
+
14
+ return {'random': RandomMask(videoLength, dataInfo),
15
+ 'mid': MidRandomMask(videoLength, dataInfo),
16
+ 'matrix': MatrixMask(videoLength, dataInfo),
17
+ 'free': FreeFormMask(videoLength, dataInfo),
18
+ 'stationary': StationaryMask(videoLength, dataInfo)
19
+ }
20
+
21
+
22
+ def create_mask(maskClass, form):
23
+ if form == 'mix':
24
+ from random import randint
25
+ candidates = list(maskClass.keys())
26
+ candidate_index = randint(0, len(candidates) - 1)
27
+ return maskClass[candidates[candidate_index]]()
28
+ return maskClass[form]()
FGT_codes/FGT/data/util/flow_utils/__init__.py ADDED
File without changes
FGT_codes/FGT/data/util/flow_utils/flow_reversal.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def flow_reversal(flow):
5
+ """
6
+ flow: shape [b, c, h, w]
7
+ return: backward flow in corresponding to the forward flow
8
+ The formula is borrowed from Quadratic Video Interpolation (4)
9
+ """
10
+ b, c, h, w = flow.shape
11
+ y = flow[:, 0:1, :, :]
12
+ x = flow[:, 1:2, :, :] # [b, 1, h, w]
13
+
14
+ x = x.repeat(1, c, 1, 1)
15
+ y = y.repeat(1, c, 1, 1)
16
+
17
+ # get the four points of the square (x1, y1), (x1, y2), (x2, y1), (x2, y2)
18
+ x1 = torch.floor(x)
19
+ x2 = x1 + 1
20
+ y1 = torch.floor(y)
21
+ y2 = y1 + 1
22
+
23
+ # get gaussian weights
24
+ w11, w12, w21, w22 = get_gaussian_weights(x, y, x1, x2, y1, y2)
25
+
26
+ # calculate the weight maps for each optical flows
27
+ flow11, o11 = sample_one(flow, x1, y1, w11)
28
+ flow12, o12 = sample_one(flow, x1, y2, w12)
29
+ flow21, o21 = sample_one(flow, x2, y1, w21)
30
+ flow22, o22 = sample_one(flow, x2, y2, w22)
31
+
32
+ # fuse all the reversed flows based on equation (4)
33
+ flow_o = flow11 + flow12 + flow21 + flow22
34
+ o = o11 + o12 + o21 + o22
35
+
36
+ flow_o = -flow_o
37
+ flow_o[o > 0] = flow_o[o > 0] / o[o > 0]
38
+
39
+ return flow_o
40
+
41
+
42
+ def get_gaussian_weights(x, y, x1, x2, y1, y2):
43
+ sigma = 1
44
+ w11 = torch.exp(-((x - x1) ** 2 + (y - y1) ** 2) / (sigma ** 2))
45
+ w12 = torch.exp(-((x - x1) ** 2 + (y - y2) ** 2) / (sigma ** 2))
46
+ w21 = torch.exp(-((x - x2) ** 2 + (y - y1) ** 2) / (sigma ** 2))
47
+ w22 = torch.exp(-((x - x2) ** 2 + (y - y2) ** 2) / (sigma ** 2))
48
+ return w11, w12, w21, w22
49
+
50
+
51
+ def sample_one(flow, shiftx, shifty, weight):
52
+ b, c, h, w = flow.shape
53
+ flat_shiftx = shiftx.view(-1) # [h * w]
54
+ flat_shifty = shifty.view(-1) # [h * w]
55
+ flat_basex = torch.arange(0, h, requires_grad=False).view(-1, 1).long().repeat(b, c, 1, w).view(-1) # [h * w]
56
+ flat_basey = torch.arange(0, w, requires_grad=False).view(-1, 1).long().repeat(b, c, h, 1).view(-1) # [h * w]
57
+ flat_weight = weight.reshape(-1) # [h * w]
58
+ flat_flow = flow.reshape(-1)
59
+
60
+ idxn = torch.arange(0, b, requires_grad=False).view(b, 1, 1, 1).long().repeat(1, c, h, w).view(-1)
61
+ idxc = torch.arange(0, c, requires_grad=False).view(1, c, 1, 1).long().repeat(b, 1, h, w).view(-1)
62
+ idxx = flat_shiftx.long() + flat_basex # size [-1]
63
+ idxy = flat_shifty.long() + flat_basey # size [-1]
64
+
65
+ # record the shifted pixels inside the image boundaries
66
+ mask = idxx.ge(0) & idxx.lt(h) & idxy.ge(0) & idxy.lt(w)
67
+
68
+ # mask off points out of boundaries
69
+ ids = idxn * c * h * w + idxc * h * w + idxx * w + idxy
70
+ ids_mask = torch.masked_select(ids, mask).clone()
71
+
72
+ # put the value into corresponding regions
73
+ flow_warp = torch.zeros([b * c * h * w])
74
+ flow_warp.put_(ids_mask, torch.masked_select(flat_flow * flat_weight, mask), accumulate=True)
75
+ one_warp = torch.zeros([b * c * h * w])
76
+ one_warp.put_(ids_mask, torch.masked_select(flat_weight, mask), accumulate=True)
77
+ return flow_warp.view(b, c, h, w), one_warp.view(b, c, h, w)
FGT_codes/FGT/data/util/flow_utils/region_fill.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ from scipy import sparse
4
+ from scipy.sparse.linalg import spsolve
5
+
6
+
7
+ # Laplacian filling
8
+ def regionfill(I, mask, factor=1.0):
9
+ if np.count_nonzero(mask) == 0:
10
+ return I.copy()
11
+ resize_mask = cv2.resize(
12
+ mask.astype(float), (0, 0), fx=factor, fy=factor) > 0
13
+ resize_I = cv2.resize(I.astype(float), (0, 0), fx=factor, fy=factor)
14
+ maskPerimeter = findBoundaryPixels(resize_mask)
15
+ regionfillLaplace(resize_I, resize_mask, maskPerimeter)
16
+ resize_I = cv2.resize(resize_I, (I.shape[1], I.shape[0]))
17
+ resize_I[mask == 0] = I[mask == 0]
18
+ return resize_I
19
+
20
+
21
+ def findBoundaryPixels(mask):
22
+ kernel = cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3))
23
+ maskDilated = cv2.dilate(mask.astype(float), kernel)
24
+ return (maskDilated > 0) & (mask == 0)
25
+
26
+
27
+ def regionfillLaplace(I, mask, maskPerimeter):
28
+ height, width = I.shape
29
+ rightSide = formRightSide(I, maskPerimeter)
30
+
31
+ # Location of mask pixels
32
+ maskIdx = np.where(mask)
33
+
34
+ # Only keep values for pixels that are in the mask
35
+ rightSide = rightSide[maskIdx]
36
+
37
+ # Number the mask pixels in a grid matrix
38
+ grid = -np.ones((height, width))
39
+ grid[maskIdx] = range(0, maskIdx[0].size)
40
+ # Pad with zeros to avoid "index out of bounds" errors in the for loop
41
+ grid = padMatrix(grid)
42
+ gridIdx = np.where(grid >= 0)
43
+
44
+ # Form the connectivity matrix D=sparse(i,j,s)
45
+ # Connect each mask pixel to itself
46
+ i = np.arange(0, maskIdx[0].size)
47
+ j = np.arange(0, maskIdx[0].size)
48
+ # The coefficient is the number of neighbors over which we average
49
+ numNeighbors = computeNumberOfNeighbors(height, width)
50
+ s = numNeighbors[maskIdx]
51
+ # Now connect the N,E,S,W neighbors if they exist
52
+ for direction in ((-1, 0), (0, 1), (1, 0), (0, -1)):
53
+ # Possible neighbors in the current direction
54
+ neighbors = grid[gridIdx[0] + direction[0], gridIdx[1] + direction[1]]
55
+ # ConDnect mask points to neighbors with -1's
56
+ index = (neighbors >= 0)
57
+ i = np.concatenate((i, grid[gridIdx[0][index], gridIdx[1][index]]))
58
+ j = np.concatenate((j, neighbors[index]))
59
+ s = np.concatenate((s, -np.ones(np.count_nonzero(index))))
60
+
61
+ D = sparse.coo_matrix((s, (i.astype(int), j.astype(int)))).tocsr()
62
+ sol = spsolve(D, rightSide)
63
+ I[maskIdx] = sol
64
+ return I
65
+
66
+
67
+ def formRightSide(I, maskPerimeter):
68
+ height, width = I.shape
69
+ perimeterValues = np.zeros((height, width))
70
+ perimeterValues[maskPerimeter] = I[maskPerimeter]
71
+ rightSide = np.zeros((height, width))
72
+
73
+ rightSide[1:height - 1, 1:width - 1] = (
74
+ perimeterValues[0:height - 2, 1:width - 1] +
75
+ perimeterValues[2:height, 1:width - 1] +
76
+ perimeterValues[1:height - 1, 0:width - 2] +
77
+ perimeterValues[1:height - 1, 2:width])
78
+
79
+ rightSide[1:height - 1, 0] = (
80
+ perimeterValues[0:height - 2, 0] + perimeterValues[2:height, 0] +
81
+ perimeterValues[1:height - 1, 1])
82
+
83
+ rightSide[1:height - 1, width - 1] = (
84
+ perimeterValues[0:height - 2, width - 1] +
85
+ perimeterValues[2:height, width - 1] +
86
+ perimeterValues[1:height - 1, width - 2])
87
+
88
+ rightSide[0, 1:width - 1] = (
89
+ perimeterValues[1, 1:width - 1] + perimeterValues[0, 0:width - 2] +
90
+ perimeterValues[0, 2:width])
91
+
92
+ rightSide[height - 1, 1:width - 1] = (
93
+ perimeterValues[height - 2, 1:width - 1] +
94
+ perimeterValues[height - 1, 0:width - 2] +
95
+ perimeterValues[height - 1, 2:width])
96
+
97
+ rightSide[0, 0] = perimeterValues[0, 1] + perimeterValues[1, 0]
98
+ rightSide[0, width - 1] = (
99
+ perimeterValues[0, width - 2] + perimeterValues[1, width - 1])
100
+ rightSide[height - 1, 0] = (
101
+ perimeterValues[height - 2, 0] + perimeterValues[height - 1, 1])
102
+ rightSide[height - 1, width - 1] = (perimeterValues[height - 2, width - 1] +
103
+ perimeterValues[height - 1, width - 2])
104
+ return rightSide
105
+
106
+
107
+ def computeNumberOfNeighbors(height, width):
108
+ # Initialize
109
+ numNeighbors = np.zeros((height, width))
110
+ # Interior pixels have 4 neighbors
111
+ numNeighbors[1:height - 1, 1:width - 1] = 4
112
+ # Border pixels have 3 neighbors
113
+ numNeighbors[1:height - 1, (0, width - 1)] = 3
114
+ numNeighbors[(0, height - 1), 1:width - 1] = 3
115
+ # Corner pixels have 2 neighbors
116
+ numNeighbors[(0, 0, height - 1, height - 1), (0, width - 1, 0,
117
+ width - 1)] = 2
118
+ return numNeighbors
119
+
120
+
121
+ def padMatrix(grid):
122
+ height, width = grid.shape
123
+ gridPadded = -np.ones((height + 2, width + 2))
124
+ gridPadded[1:height + 1, 1:width + 1] = grid
125
+ gridPadded = gridPadded.astype(grid.dtype)
126
+ return gridPadded
127
+
128
+
129
+ if __name__ == '__main__':
130
+ import time
131
+ x = np.linspace(0, 255, 500)
132
+ xv, _ = np.meshgrid(x, x)
133
+ image = ((xv + np.transpose(xv)) / 2.0).astype(int)
134
+ mask = np.zeros((500, 500))
135
+ mask[100:259, 100:259] = 1
136
+ mask = (mask > 0)
137
+ image[mask] = 0
138
+ st = time.time()
139
+ inpaint = regionfill(image, mask, 0.5).astype(np.uint8)
140
+ print(time.time() - st)
141
+ cv2.imshow('img', np.concatenate((image.astype(np.uint8), inpaint)))
142
+ cv2.waitKey()
FGT_codes/FGT/data/util/freeform_masks.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import shutil
4
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) # NOQA
5
+
6
+ import numpy as np
7
+ import argparse
8
+ from PIL import Image
9
+
10
+ from .mask_generators import get_video_masks_by_moving_random_stroke, get_masked_ratio
11
+ from .util import make_dirs, make_dir_under_root, get_everything_under
12
+ from .readers import MaskReader
13
+
14
+ def parse_args():
15
+ parser = argparse.ArgumentParser()
16
+ parser.add_argument(
17
+ '-od', '--output_dir',
18
+ type=str,
19
+ help="Output directory name"
20
+ )
21
+ parser.add_argument(
22
+ '-im',
23
+ '--image_masks', action='store_true',
24
+ help="Set this if you want to generate independent masks in one directory."
25
+ )
26
+ parser.add_argument(
27
+ '-vl', '--video_len',
28
+ type=int,
29
+ help="Maximum video length (i.e. #mask)"
30
+ )
31
+ parser.add_argument(
32
+ '-ns', '--num_stroke',
33
+ type=int,
34
+ help="Number of stroke in one mask"
35
+ )
36
+ parser.add_argument(
37
+ '-nsb', '--num_stroke_bound',
38
+ type=int,
39
+ nargs=2,
40
+ help="Upper/lower bound of number of stroke in one mask"
41
+ )
42
+ parser.add_argument(
43
+ '-n',
44
+ type=int,
45
+ help="Number of mask to generate"
46
+ )
47
+ parser.add_argument(
48
+ '-sp',
49
+ '--stroke_preset',
50
+ type=str,
51
+ default='rand_curve',
52
+ help="Preset of the stroke parameters"
53
+ )
54
+ parser.add_argument(
55
+ '-iw',
56
+ '--image_width',
57
+ type=int,
58
+ default=320
59
+ )
60
+ parser.add_argument(
61
+ '-ih',
62
+ '--image_height',
63
+ type=int,
64
+ default=180
65
+ )
66
+ parser.add_argument(
67
+ '--cluster_by_area',
68
+ action='store_true'
69
+ )
70
+ parser.add_argument(
71
+ '--leave_boarder_unmasked',
72
+ type=int,
73
+ help='Set this to a number, then a copy of the mask where the mask of boarder is erased.'
74
+ )
75
+ parser.add_argument(
76
+ '--redo_without_generation',
77
+ action='store_true',
78
+ help='Set this, and the script will skip the generation and redo the left tasks'
79
+ '(uncluster -> erase boarder -> re-cluster)'
80
+ )
81
+ args = parser.parse_args()
82
+ return args
83
+
84
+
85
+ def get_stroke_preset(stroke_preset):
86
+ if stroke_preset == 'object_like':
87
+ return {
88
+ "nVertexBound": [5, 30],
89
+ "maxHeadSpeed": 15,
90
+ "maxHeadAcceleration": (10, 1.5),
91
+ "brushWidthBound": (20, 50),
92
+ "nMovePointRatio": 0.5,
93
+ "maxPiontMove": 10,
94
+ "maxLineAcceleration": (5, 0.5),
95
+ "boarderGap": None,
96
+ "maxInitSpeed": 10,
97
+ }
98
+ elif stroke_preset == 'object_like_middle':
99
+ return {
100
+ "nVertexBound": [5, 15],
101
+ "maxHeadSpeed": 8,
102
+ "maxHeadAcceleration": (4, 1.5),
103
+ "brushWidthBound": (20, 50),
104
+ "nMovePointRatio": 0.5,
105
+ "maxPiontMove": 5,
106
+ "maxLineAcceleration": (5, 0.5),
107
+ "boarderGap": None,
108
+ "maxInitSpeed": 10,
109
+ }
110
+ elif stroke_preset == 'object_like_small':
111
+ return {
112
+ "nVertexBound": [5, 20],
113
+ "maxHeadSpeed": 7,
114
+ "maxHeadAcceleration": (3.5, 1.5),
115
+ "brushWidthBound": (10, 30),
116
+ "nMovePointRatio": 0.5,
117
+ "maxPiontMove": 5,
118
+ "maxLineAcceleration": (3, 0.5),
119
+ "boarderGap": None,
120
+ "maxInitSpeed": 4,
121
+ }
122
+ elif stroke_preset == 'rand_curve':
123
+ return {
124
+ "nVertexBound": [10, 30],
125
+ "maxHeadSpeed": 20,
126
+ "maxHeadAcceleration": (15, 0.5),
127
+ "brushWidthBound": (3, 10),
128
+ "nMovePointRatio": 0.5,
129
+ "maxPiontMove": 3,
130
+ "maxLineAcceleration": (5, 0.5),
131
+ "boarderGap": None,
132
+ "maxInitSpeed": 6
133
+ }
134
+ elif stroke_preset == 'rand_curve_small':
135
+ return {
136
+ "nVertexBound": [6, 22],
137
+ "maxHeadSpeed": 12,
138
+ "maxHeadAcceleration": (8, 0.5),
139
+ "brushWidthBound": (2.5, 5),
140
+ "nMovePointRatio": 0.5,
141
+ "maxPiontMove": 1.5,
142
+ "maxLineAcceleration": (3, 0.5),
143
+ "boarderGap": None,
144
+ "maxInitSpeed": 3
145
+ }
146
+ else:
147
+ raise NotImplementedError(f'The stroke presetting "{stroke_preset}" does not exist.')
148
+
149
+
150
+ def copy_masks_without_boarder(root_dir, args):
151
+ def erase_mask_boarder(mask, gap):
152
+ pix = np.asarray(mask).astype('uint8') * 255
153
+ pix[:gap, :] = 255
154
+ pix[-gap:, :] = 255
155
+ pix[:, :gap] = 255
156
+ pix[:, -gap:] = 255
157
+ return Image.fromarray(pix).convert('1')
158
+
159
+ wo_boarder_dir = root_dir + '_noBoarder'
160
+ shutil.copytree(root_dir, wo_boarder_dir)
161
+
162
+ for i, filename in enumerate(get_everything_under(wo_boarder_dir)):
163
+ if args.image_masks:
164
+ mask = Image.open(filename)
165
+ mask_wo_boarder = erase_mask_boarder(mask, args.leave_boarder_unmasked)
166
+ mask_wo_boarder.save(filename)
167
+ else:
168
+ # filename is a diretory containing multiple mask files
169
+ for f in get_everything_under(filename, pattern='*.png'):
170
+ mask = Image.open(f)
171
+ mask_wo_boarder = erase_mask_boarder(mask, args.leave_boarder_unmasked)
172
+ mask_wo_boarder.save(f)
173
+
174
+ return wo_boarder_dir
175
+
176
+
177
+ def cluster_by_masked_area(root_dir, args):
178
+ clustered_dir = root_dir + '_clustered'
179
+ make_dirs(clustered_dir)
180
+ radius = 5
181
+
182
+ # all masks with ratio in x +- radius will be stored in sub-directory x
183
+ clustered_centors = np.arange(radius, 100, radius * 2)
184
+ clustered_subdirs = []
185
+ for c in clustered_centors:
186
+ # make sub-directories for each ratio range
187
+ clustered_subdirs.append(make_dir_under_root(clustered_dir, str(c)))
188
+
189
+ for i, filename in enumerate(get_everything_under(root_dir)):
190
+ if args.image_masks:
191
+ ratio = get_masked_ratio(Image.open(filename))
192
+ else:
193
+ # filename is a diretory containing multiple mask files
194
+ ratio = np.mean([
195
+ get_masked_ratio(Image.open(f))
196
+ for f in get_everything_under(filename, pattern='*.png')
197
+ ])
198
+
199
+ # find the nearest centor
200
+ for i, c in enumerate(clustered_centors):
201
+ if c - radius <= ratio * 100 <= c + radius:
202
+ shutil.move(filename, clustered_subdirs[i])
203
+ break
204
+
205
+ shutil.rmtree(root_dir)
206
+ os.rename(clustered_dir, root_dir)
207
+
208
+
209
+ def decide_nStroke(args):
210
+ if args.num_stroke is not None:
211
+ return args.num_stroke
212
+ elif args.num_stroke_bound is not None:
213
+ return np.random.randint(args.num_stroke_bound[0], args.num_stroke_bound[1])
214
+ else:
215
+ raise ValueError('One of "-ns" or "-nsb" is needed')
216
+
217
+
218
+ def main(args):
219
+ preset = get_stroke_preset(args.stroke_preset)
220
+ make_dirs(args.output_dir)
221
+
222
+ if args.redo_without_generation:
223
+ assert(len(get_everything_under(args.output_dir)) > 0)
224
+ # put back clustered masks
225
+ for clustered_subdir in get_everything_under(args.output_dir):
226
+ if not os.path.isdir(clustered_subdir):
227
+ continue
228
+ for f in get_everything_under(clustered_subdir):
229
+ shutil.move(f, args.output_dir)
230
+ os.rmdir(clustered_subdir)
231
+
232
+ else:
233
+ if args.image_masks:
234
+ for i in range(args.n):
235
+ nStroke = decide_nStroke(args)
236
+ mask = get_video_masks_by_moving_random_stroke(
237
+ video_len=1, imageWidth=args.image_width, imageHeight=args.image_height,
238
+ nStroke=nStroke, **preset
239
+ )[0]
240
+ mask.save(os.path.join(args.output_dir, f'{i:07d}.png'))
241
+
242
+ else:
243
+ for i in range(args.n):
244
+ mask_dir = make_dir_under_root(args.output_dir, f'{i:05d}')
245
+ mask_reader = MaskReader(mask_dir, read=False)
246
+
247
+ nStroke = decide_nStroke(args)
248
+ masks = get_video_masks_by_moving_random_stroke(
249
+ imageWidth=args.image_width, imageHeight=args.image_height,
250
+ video_len=args.video_len, nStroke=nStroke, **preset)
251
+
252
+ mask_reader.set_files(masks)
253
+ mask_reader.save_files(output_dir=mask_reader.dir_name)
254
+
255
+ if args.leave_boarder_unmasked is not None:
256
+ dir_leave_boarder = copy_masks_without_boarder(args.output_dir, args)
257
+ if args.cluster_by_area:
258
+ cluster_by_masked_area(dir_leave_boarder, args)
259
+
260
+ if args.cluster_by_area:
261
+ cluster_by_masked_area(args.output_dir, args)
262
+
263
+
264
+ if __name__ == "__main__":
265
+ args = parse_args()
266
+ main(args)
FGT_codes/FGT/data/util/mask_generators.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+ from PIL import Image, ImageDraw
4
+
5
+
6
+ def get_video_masks_by_moving_random_stroke(
7
+ video_len, imageWidth=320, imageHeight=180, nStroke=5,
8
+ nVertexBound=[10, 30], maxHeadSpeed=15, maxHeadAcceleration=(15, 0.5),
9
+ brushWidthBound=(5, 20), boarderGap=None, nMovePointRatio=0.5, maxPiontMove=10,
10
+ maxLineAcceleration=5, maxInitSpeed=5
11
+ ):
12
+ '''
13
+ Get video masks by random strokes which move randomly between each
14
+ frame, including the whole stroke and its control points
15
+
16
+ Parameters
17
+ ----------
18
+ imageWidth: Image width
19
+ imageHeight: Image height
20
+ nStroke: Number of drawed lines
21
+ nVertexBound: Lower/upper bound of number of control points for each line
22
+ maxHeadSpeed: Max head speed when creating control points
23
+ maxHeadAcceleration: Max acceleration applying on the current head point (
24
+ a head point and its velosity decides the next point)
25
+ brushWidthBound (min, max): Bound of width for each stroke
26
+ boarderGap: The minimum gap between image boarder and drawed lines
27
+ nMovePointRatio: The ratio of control points to move for next frames
28
+ maxPiontMove: The magnitude of movement for control points for next frames
29
+ maxLineAcceleration: The magnitude of acceleration for the whole line
30
+
31
+ Examples
32
+ ----------
33
+ object_like_setting = {
34
+ "nVertexBound": [5, 20],
35
+ "maxHeadSpeed": 15,
36
+ "maxHeadAcceleration": (15, 3.14),
37
+ "brushWidthBound": (30, 50),
38
+ "nMovePointRatio": 0.5,
39
+ "maxPiontMove": 10,
40
+ "maxLineAcceleration": (5, 0.5),
41
+ "boarderGap": 20,
42
+ "maxInitSpeed": 10,
43
+ }
44
+ rand_curve_setting = {
45
+ "nVertexBound": [10, 30],
46
+ "maxHeadSpeed": 20,
47
+ "maxHeadAcceleration": (15, 0.5),
48
+ "brushWidthBound": (3, 10),
49
+ "nMovePointRatio": 0.5,
50
+ "maxPiontMove": 3,
51
+ "maxLineAcceleration": (5, 0.5),
52
+ "boarderGap": 20,
53
+ "maxInitSpeed": 6
54
+ }
55
+ get_video_masks_by_moving_random_stroke(video_len=5, nStroke=3, **object_like_setting)
56
+ '''
57
+ assert(video_len >= 1)
58
+
59
+ # Initilize a set of control points to draw the first mask
60
+ mask = Image.new(mode='1', size=(imageWidth, imageHeight), color=1)
61
+ control_points_set = []
62
+ for i in range(nStroke):
63
+ brushWidth = np.random.randint(brushWidthBound[0], brushWidthBound[1])
64
+ Xs, Ys, velocity = get_random_stroke_control_points(
65
+ imageWidth=imageWidth, imageHeight=imageHeight,
66
+ nVertexBound=nVertexBound, maxHeadSpeed=maxHeadSpeed,
67
+ maxHeadAcceleration=maxHeadAcceleration, boarderGap=boarderGap,
68
+ maxInitSpeed=maxInitSpeed
69
+ )
70
+ control_points_set.append((Xs, Ys, velocity, brushWidth))
71
+ draw_mask_by_control_points(mask, Xs, Ys, brushWidth, fill=0)
72
+
73
+ # Generate the following masks by randomly move strokes and their control points
74
+ masks = [mask]
75
+ for i in range(video_len - 1):
76
+ mask = Image.new(mode='1', size=(imageWidth, imageHeight), color=1)
77
+ for j in range(len(control_points_set)):
78
+ Xs, Ys, velocity, brushWidth = control_points_set[j]
79
+ new_Xs, new_Ys = random_move_control_points(
80
+ Xs, Ys, velocity, nMovePointRatio, maxPiontMove,
81
+ maxLineAcceleration, boarderGap
82
+ )
83
+ control_points_set[j] = (new_Xs, new_Ys, velocity, brushWidth)
84
+ for Xs, Ys, velocity, brushWidth in control_points_set:
85
+ draw_mask_by_control_points(mask, Xs, Ys, brushWidth, fill=0)
86
+ masks.append(mask)
87
+
88
+ return masks
89
+
90
+
91
+ def random_accelerate(velocity, maxAcceleration, dist='uniform'):
92
+ speed, angle = velocity
93
+ d_speed, d_angle = maxAcceleration
94
+
95
+ if dist == 'uniform':
96
+ speed += np.random.uniform(-d_speed, d_speed)
97
+ angle += np.random.uniform(-d_angle, d_angle)
98
+ elif dist == 'guassian':
99
+ speed += np.random.normal(0, d_speed / 2)
100
+ angle += np.random.normal(0, d_angle / 2)
101
+ else:
102
+ raise NotImplementedError(f'Distribution type {dist} is not supported.')
103
+
104
+ return (speed, angle)
105
+
106
+
107
+ def random_move_control_points(Xs, Ys, lineVelocity, nMovePointRatio, maxPiontMove, maxLineAcceleration, boarderGap=15):
108
+ new_Xs = Xs.copy()
109
+ new_Ys = Ys.copy()
110
+
111
+ # move the whole line and accelerate
112
+ speed, angle = lineVelocity
113
+ new_Xs += int(speed * np.cos(angle))
114
+ new_Ys += int(speed * np.sin(angle))
115
+ lineVelocity = random_accelerate(lineVelocity, maxLineAcceleration, dist='guassian')
116
+
117
+ # choose points to move
118
+ chosen = np.arange(len(Xs))
119
+ np.random.shuffle(chosen)
120
+ chosen = chosen[:int(len(Xs) * nMovePointRatio)]
121
+ for i in chosen:
122
+ new_Xs[i] += np.random.randint(-maxPiontMove, maxPiontMove)
123
+ new_Ys[i] += np.random.randint(-maxPiontMove, maxPiontMove)
124
+ return new_Xs, new_Ys
125
+
126
+
127
+ def get_random_stroke_control_points(
128
+ imageWidth, imageHeight,
129
+ nVertexBound=(10, 30), maxHeadSpeed=10, maxHeadAcceleration=(5, 0.5), boarderGap=20,
130
+ maxInitSpeed=10
131
+ ):
132
+ '''
133
+ Implementation the free-form training masks generating algorithm
134
+ proposed by JIAHUI YU et al. in "Free-Form Image Inpainting with Gated Convolution"
135
+ '''
136
+ startX = np.random.randint(imageWidth)
137
+ startY = np.random.randint(imageHeight)
138
+ Xs = [startX]
139
+ Ys = [startY]
140
+
141
+ numVertex = np.random.randint(nVertexBound[0], nVertexBound[1])
142
+
143
+ angle = np.random.uniform(0, 2 * np.pi)
144
+ speed = np.random.uniform(0, maxHeadSpeed)
145
+
146
+ for i in range(numVertex):
147
+ speed, angle = random_accelerate((speed, angle), maxHeadAcceleration)
148
+ speed = np.clip(speed, 0, maxHeadSpeed)
149
+
150
+ nextX = startX + speed * np.sin(angle)
151
+ nextY = startY + speed * np.cos(angle)
152
+
153
+ if boarderGap is not None:
154
+ nextX = np.clip(nextX, boarderGap, imageWidth - boarderGap)
155
+ nextY = np.clip(nextY, boarderGap, imageHeight - boarderGap)
156
+
157
+ startX, startY = nextX, nextY
158
+ Xs.append(nextX)
159
+ Ys.append(nextY)
160
+
161
+ velocity = get_random_velocity(maxInitSpeed, dist='guassian')
162
+
163
+ return np.array(Xs), np.array(Ys), velocity
164
+
165
+
166
+ def get_random_velocity(max_speed, dist='uniform'):
167
+ if dist == 'uniform':
168
+ speed = np.random.uniform(max_speed)
169
+ elif dist == 'guassian':
170
+ speed = np.abs(np.random.normal(0, max_speed / 2))
171
+ else:
172
+ raise NotImplementedError(f'Distribution type {dist} is not supported.')
173
+
174
+ angle = np.random.uniform(0, 2 * np.pi)
175
+ return (speed, angle)
176
+
177
+
178
+ def draw_mask_by_control_points(mask, Xs, Ys, brushWidth, fill=255):
179
+ radius = brushWidth // 2 - 1
180
+ for i in range(1, len(Xs)):
181
+ draw = ImageDraw.Draw(mask)
182
+ startX, startY = Xs[i - 1], Ys[i - 1]
183
+ nextX, nextY = Xs[i], Ys[i]
184
+ draw.line((startX, startY) + (nextX, nextY), fill=fill, width=brushWidth)
185
+ for x, y in zip(Xs, Ys):
186
+ draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill=fill)
187
+ return mask
188
+
189
+
190
+ # modified from https://github.com/naoto0804/pytorch-inpainting-with-partial-conv/blob/master/generate_data.py
191
+ def get_random_walk_mask(imageWidth=320, imageHeight=180, length=None):
192
+ action_list = [[0, 1], [0, -1], [1, 0], [-1, 0]]
193
+ canvas = np.zeros((imageHeight, imageWidth)).astype("i")
194
+ if length is None:
195
+ length = imageWidth * imageHeight
196
+ x = random.randint(0, imageHeight - 1)
197
+ y = random.randint(0, imageWidth - 1)
198
+ x_list = []
199
+ y_list = []
200
+ for i in range(length):
201
+ r = random.randint(0, len(action_list) - 1)
202
+ x = np.clip(x + action_list[r][0], a_min=0, a_max=imageHeight - 1)
203
+ y = np.clip(y + action_list[r][1], a_min=0, a_max=imageWidth - 1)
204
+ x_list.append(x)
205
+ y_list.append(y)
206
+ canvas[np.array(x_list), np.array(y_list)] = 1
207
+ return Image.fromarray(canvas * 255).convert('1')
208
+
209
+
210
+ def get_masked_ratio(mask):
211
+ """
212
+ Calculate the masked ratio.
213
+ mask: Expected a binary PIL image, where 0 and 1 represent
214
+ masked(invalid) and valid pixel values.
215
+ """
216
+ hist = mask.histogram()
217
+ return hist[0] / np.prod(mask.size)
FGT_codes/FGT/data/util/readers.py ADDED
@@ -0,0 +1,527 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # NOQA
4
+ import argparse
5
+ from math import ceil
6
+ from glob import glob
7
+
8
+ import numpy as np
9
+ import cv2
10
+ from PIL import Image, ImageDraw, ImageOps, ImageFont
11
+
12
+ from utils.logging_config import logger
13
+ from utils.util import make_dirs, bbox_offset
14
+
15
+
16
+ DEFAULT_FPS = 6
17
+ MAX_LENGTH = 60
18
+
19
+
20
+ def parse_args():
21
+ parser = argparse.ArgumentParser()
22
+ parser.add_argument(
23
+ '-fps', '--fps',
24
+ type=int, default=DEFAULT_FPS,
25
+ help="Output video FPS"
26
+ )
27
+ parser.add_argument(
28
+ '-v', '--video_dir',
29
+ type=str,
30
+ help="Video directory name"
31
+ )
32
+ parser.add_argument(
33
+ '-vs', '--video_dirs',
34
+ nargs='+',
35
+ type=str,
36
+ help="Video directory names"
37
+ )
38
+ parser.add_argument(
39
+ '-v2', '--video_dir2',
40
+ type=str,
41
+ help="Video directory name"
42
+ )
43
+ parser.add_argument(
44
+ '-sd', '--segms_dir',
45
+ type=str,
46
+ help="Segmentation directory name"
47
+ )
48
+ parser.add_argument(
49
+ '-fgd', '--fg_dir',
50
+ type=str,
51
+ help="Foreground directory name"
52
+ )
53
+ parser.add_argument(
54
+ '-fgfd', '--fg_frames_dir',
55
+ type=str,
56
+ help="Foreground frames directory name"
57
+ )
58
+ parser.add_argument(
59
+ '-fgsd', '--fg_segms_dir',
60
+ type=str,
61
+ help="Foreground segmentations directory name"
62
+ )
63
+ parser.add_argument(
64
+ '-syfd', '--syn_frames_dir',
65
+ type=str,
66
+ help="Synthesized frames directory name"
67
+ )
68
+ parser.add_argument(
69
+ '-bgfd', '--bg_frames_dir',
70
+ type=str,
71
+ help="Background frames directory name"
72
+ )
73
+ parser.add_argument(
74
+ '-rt', '--reader_type',
75
+ type=str,
76
+ help="Type of reader"
77
+ )
78
+ parser.add_argument(
79
+ '-od', '--output_dir',
80
+ type=str,
81
+ help="Output directory name"
82
+ )
83
+ parser.add_argument(
84
+ '-o', '--output_filename',
85
+ type=str, required=True,
86
+ help="Output output filename"
87
+ )
88
+ args = parser.parse_args()
89
+ return args
90
+
91
+
92
+ class Reader:
93
+ def __init__(self, dir_name, read=True, max_length=None, sample_period=1):
94
+ self.dir_name = dir_name
95
+ self.count = 0
96
+ self.max_length = max_length
97
+ self.filenames = []
98
+ self.sample_period = sample_period
99
+ if read:
100
+ if os.path.exists(dir_name):
101
+ # self.filenames = read_filenames_from_dir(dir_name, self.__class__.__name__)
102
+ # ^^^^^ yield None when reading some videos of face forensics data
103
+ # (related to 'Too many levels of symbolic links'?)
104
+
105
+ self.filenames = sorted(glob(os.path.join(dir_name, '*')))
106
+ self.filenames = [f for f in self.filenames if os.path.isfile(f)]
107
+ self.filenames = self.filenames[::sample_period][:max_length]
108
+ self.files = self.read_files(self.filenames)
109
+ else:
110
+ self.files = []
111
+ logger.warning(f"Directory {dir_name} not exists!")
112
+ else:
113
+ self.files = []
114
+ self.current_index = 0
115
+
116
+ def append(self, file_):
117
+ self.files.append(file_)
118
+
119
+ def set_files(self, files):
120
+ self.files = files
121
+
122
+ def read_files(self, filenames):
123
+ assert type(filenames) == list, f'filenames is not a list; dirname: {self.dir_name}'
124
+ filenames.sort()
125
+ frames = []
126
+ for filename in filenames:
127
+ file_ = self.read_file(filename)
128
+ frames.append(file_)
129
+ return frames
130
+
131
+ def save_files(self, output_dir=None):
132
+ make_dirs(output_dir)
133
+ logger.info(f"Saving {self.__class__.__name__} files to {output_dir}")
134
+ for i, file_ in enumerate(self.files):
135
+ self._save_file(output_dir, i, file_)
136
+
137
+ def _save_file(self, output_dir, i, file_):
138
+ raise NotImplementedError("This is an abstract function")
139
+
140
+ def read_file(self, filename):
141
+ raise NotImplementedError("This is an abstract function")
142
+
143
+ def __iter__(self):
144
+ return self
145
+
146
+ def __next__(self):
147
+ if self.current_index < len(self.files):
148
+ file_ = self.files[self.current_index]
149
+ self.current_index += 1
150
+ return file_
151
+ else:
152
+ self.current_index = 0
153
+ raise StopIteration
154
+
155
+ def __getitem__(self, key):
156
+ return self.files[key]
157
+
158
+ def __len__(self):
159
+ return len(self.files)
160
+
161
+
162
+ class FrameReader(Reader):
163
+ def __init__(
164
+ self, dir_name, resize=None, read=True, max_length=MAX_LENGTH,
165
+ scale=1, sample_period=1
166
+ ):
167
+ self.resize = resize
168
+ self.scale = scale
169
+ self.sample_period = sample_period
170
+ super().__init__(dir_name, read, max_length, sample_period)
171
+
172
+ def read_file(self, filename):
173
+ origin_frame = Image.open(filename)
174
+ size = self.resize if self.resize is not None else origin_frame.size
175
+ origin_frame_resized = origin_frame.resize(
176
+ (int(size[0] * self.scale), int(size[1] * self.scale))
177
+ )
178
+ return origin_frame_resized
179
+
180
+ def _save_file(self, output_dir, i, file_):
181
+ if len(self.filenames) == len(self.files):
182
+ name = sorted(self.filenames)[i].split('/')[-1]
183
+ else:
184
+ name = f"frame_{i:04}.png"
185
+ filename = os.path.join(
186
+ output_dir, name
187
+ )
188
+ file_.save(filename, "PNG")
189
+
190
+ def write_files_to_video(self, output_filename, fps=DEFAULT_FPS, frame_num_when_repeat_list=[1]):
191
+ logger.info(
192
+ f"Writeing frames to video {output_filename} with FPS={fps}")
193
+ video_writer = cv2.VideoWriter(
194
+ output_filename,
195
+ cv2.VideoWriter_fourcc(*"MJPG"),
196
+ fps,
197
+ self.files[0].size
198
+ )
199
+ for frame_num_when_repeat in frame_num_when_repeat_list:
200
+ for frame in self.files:
201
+ frame = frame.convert("RGB")
202
+ frame_cv = np.array(frame)
203
+ frame_cv = cv2.cvtColor(frame_cv, cv2.COLOR_RGB2BGR)
204
+ for i in range(frame_num_when_repeat):
205
+ video_writer.write(frame_cv)
206
+ video_writer.release()
207
+
208
+
209
+ class SynthesizedFrameReader(FrameReader):
210
+ def __init__(
211
+ self, bg_frames_dir, fg_frames_dir,
212
+ fg_segms_dir, segm_bbox_mask_dir, fg_dir, dir_name,
213
+ bboxes_list_dir,
214
+ fg_scale=0.7, fg_location=(48, 27), mask_only=False
215
+ ):
216
+ self.bg_reader = FrameReader(bg_frames_dir)
217
+ self.size = self.bg_reader[0].size
218
+ # TODO: add different location and change scale to var
219
+ self.fg_reader = ForegroundReader(
220
+ fg_frames_dir, fg_segms_dir, fg_dir,
221
+ resize=self.size,
222
+ scale=fg_scale
223
+ )
224
+ self.fg_location = fg_location
225
+ # self.masks = self.fg_reader.masks
226
+ # self.bbox_masks = self.fg_reader.bbox_masks
227
+ super().__init__(dir_name, read=False)
228
+ self.files = self.synthesize_frames(
229
+ self.bg_reader, self.fg_reader, mask_only)
230
+ self.bbox_masks = MaskGenerator(
231
+ segm_bbox_mask_dir, self.size, self.get_bboxeses()
232
+ )
233
+ self.bboxes_list_dir = bboxes_list_dir
234
+ self.bboxes_list = self.get_bboxeses()
235
+ self.save_bboxes()
236
+
237
+ def save_bboxes(self):
238
+ make_dirs(self.bboxes_list_dir)
239
+ logger.info(f"Saving bboxes to {self.bboxes_list_dir}")
240
+ for i, bboxes in enumerate(self.bboxes_list):
241
+ save_path = os.path.join(self.bboxes_list_dir, f"bboxes_{i:04}.txt")
242
+ if len(bboxes) > 0:
243
+ np.savetxt(save_path, bboxes[0], fmt='%4u')
244
+
245
+ def get_bboxeses(self):
246
+ bboxeses = self.fg_reader.segms.bboxeses
247
+ new_bboxeses = []
248
+ for bboxes in bboxeses:
249
+ new_bboxes = []
250
+ for bbox in bboxes:
251
+ offset_bbox = bbox_offset(bbox, self.fg_location)
252
+ new_bboxes.append(offset_bbox)
253
+ new_bboxeses.append(new_bboxes)
254
+ return new_bboxeses
255
+
256
+ def synthesize_frames(self, bg_reader, fg_reader, mask_only=False):
257
+ logger.info(
258
+ f"Synthesizing {bg_reader.dir_name} and {fg_reader.dir_name}"
259
+ )
260
+ synthesized_frames = []
261
+ for i, bg in enumerate(bg_reader):
262
+ if i == len(fg_reader):
263
+ break
264
+ fg = fg_reader[i]
265
+ mask = fg_reader.get_mask(i)
266
+ synthesized_frame = bg.copy()
267
+ if mask_only:
268
+ synthesized_frame.paste(mask, self.fg_location, mask)
269
+ else:
270
+ synthesized_frame.paste(fg, self.fg_location, mask)
271
+ synthesized_frames.append(synthesized_frame)
272
+ return synthesized_frames
273
+
274
+
275
+ class WarpedFrameReader(FrameReader):
276
+ def __init__(self, dir_name, i, ks):
277
+ self.i = i
278
+ self.ks = ks
279
+ super().__init__(dir_name)
280
+
281
+ def _save_file(self, output_dir, i, file_):
282
+ filename = os.path.join(
283
+ output_dir,
284
+ f"warped_frame_{self.i:04}_k{self.ks[i]:02}.png"
285
+ )
286
+ file_.save(filename)
287
+
288
+
289
+ class SegmentationReader(FrameReader):
290
+ def __init__(
291
+ self, dir_name,
292
+ resize=None, scale=1
293
+ ):
294
+ super().__init__(
295
+ dir_name, resize=resize, scale=scale
296
+ )
297
+
298
+ def read_file(self, filename):
299
+ origin_frame = Image.open(filename)
300
+ mask = ImageOps.invert(origin_frame.convert("L"))
301
+ mask = mask.point(lambda x: 0 if x < 255 else 255, '1')
302
+ size = self.resize if self.resize is not None else origin_frame.size
303
+ mask_resized = mask.resize(
304
+ (int(size[0] * self.scale), int(size[1] * self.scale))
305
+ )
306
+ return mask_resized
307
+
308
+
309
+ class MaskReader(Reader):
310
+ def __init__(self, dir_name, read=True):
311
+ super().__init__(dir_name, read=read)
312
+
313
+ def read_file(self, filename):
314
+ mask = Image.open(filename)
315
+ return mask
316
+
317
+ def _save_file(self, output_dir, i, file_):
318
+ filename = os.path.join(
319
+ output_dir,
320
+ f"mask_{i:04}.png"
321
+ )
322
+ file_.save(filename)
323
+
324
+ def get_bboxes(self, i):
325
+ # TODO: save bbox instead of looking for one
326
+ mask = self.files[i]
327
+ mask = ImageOps.invert(mask.convert("L")).convert("1")
328
+ mask = np.array(mask)
329
+ image, contours, hier = cv2.findContours(
330
+ mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
331
+ bboxes = []
332
+ for c in contours:
333
+ # get the bounding rect
334
+ x, y, w, h = cv2.boundingRect(c)
335
+ bbox = ((x, y), (x + w - 1, y + h - 1))
336
+ bboxes.append(bbox)
337
+ return bboxes
338
+
339
+ def get_bbox(self, i):
340
+ # TODO: save bbox instead of looking for one
341
+ mask = self.files[i]
342
+ mask = ImageOps.invert(mask.convert("L"))
343
+ mask = np.array(mask)
344
+ image, contours, hier = cv2.findContours(
345
+ mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
346
+ for c in contours:
347
+ # get the bounding rect
348
+ x, y, w, h = cv2.boundingRect(c)
349
+ bbox = ((x, y), (x + w - 1, y + h - 1))
350
+ return bbox
351
+
352
+
353
+ class MaskGenerator(Reader):
354
+ def __init__(
355
+ self, mask_output_dir, size, bboxeses, save_masks=True
356
+ ):
357
+ self.bboxeses = bboxeses
358
+ self.size = size
359
+ super().__init__(mask_output_dir, read=False)
360
+ self.files = self.generate_masks()
361
+ if save_masks:
362
+ make_dirs(mask_output_dir)
363
+ self.save_files(mask_output_dir)
364
+
365
+ def _save_file(self, output_dir, i, file_):
366
+ filename = os.path.join(
367
+ output_dir,
368
+ f"mask_{i:04}.png"
369
+ )
370
+ file_.save(filename)
371
+
372
+ def get_bboxes(self, i):
373
+ return self.bboxeses[i]
374
+
375
+ def generate_masks(self):
376
+ masks = []
377
+ for i in range(len(self.bboxeses)):
378
+ mask = self.generate_mask(i)
379
+ masks.append(mask)
380
+ return masks
381
+
382
+ def generate_mask(self, i):
383
+ bboxes = self.bboxeses[i]
384
+ mask = Image.new("1", self.size, 1)
385
+ draw = ImageDraw.Draw(mask)
386
+ for bbox in bboxes:
387
+ draw.rectangle(
388
+ bbox, fill=0
389
+ )
390
+ return mask
391
+
392
+
393
+ class ForegroundReader(FrameReader):
394
+ def __init__(
395
+ self, frames_dir, segms_dir, dir_name,
396
+ resize=None, scale=1
397
+ ):
398
+ self.frames_dir = frames_dir
399
+ self.segms_dir = segms_dir
400
+ self.frames = FrameReader(
401
+ frames_dir,
402
+ resize=resize, scale=scale
403
+ )
404
+ self.segms = SegmentationReader(
405
+ segms_dir, resize=resize, scale=scale
406
+ )
407
+ super().__init__(dir_name, read=False)
408
+ self.masks = self.segms.masks
409
+ # self.bbox_masks = self.segms.bbox_masks
410
+ self.files = self.generate_fg_frames(self.frames, self.segms)
411
+
412
+ def get_mask(self, i):
413
+ return self.masks[i]
414
+
415
+ def generate_fg_frames(self, frames, segms):
416
+ logger.info(
417
+ f"Generating fg frames from {self.frames_dir} and {self.segms_dir}"
418
+ )
419
+ fg_frames = []
420
+ for i, frame in enumerate(frames):
421
+ mask = self.masks[i]
422
+ fg_frame = Image.new("RGB", frame.size, (0, 0, 0))
423
+ fg_frame.paste(
424
+ frame, (0, 0),
425
+ mask
426
+ )
427
+ fg_frames.append(fg_frame)
428
+ return fg_frames
429
+
430
+
431
+ class CompareFramesReader(FrameReader):
432
+ def __init__(self, dir_names, col=2, names=[], mask_dir=None):
433
+ self.videos = []
434
+ for dir_name in dir_names:
435
+ # If a method fails on this video, use None to indicate the situation
436
+ try:
437
+ self.videos.append(FrameReader(dir_name))
438
+ except AssertionError:
439
+ self.videos.append(None)
440
+ if mask_dir is not None:
441
+ self.masks = MaskReader(mask_dir)
442
+ self.names = names
443
+ self.files = self.combine_videos(self.videos, col)
444
+
445
+ def combine_videos(self, videos, col=2, edge_offset=35, h_start_offset=35):
446
+ combined_frames = []
447
+ w, h = videos[0][0].size
448
+ # Prevent the first method fails and have a "None" as its video
449
+ i = 0
450
+ while videos[i] is None:
451
+ i += 1
452
+ length = len(videos[i])
453
+ video_num = len(videos)
454
+ row = ceil(video_num / col)
455
+ for frame_idx in range(length):
456
+ width = col * w + (col - 1) * edge_offset
457
+ height = row * h + (row - 1) * edge_offset + h_start_offset
458
+ combined_frame = Image.new("RGBA", (width, height))
459
+ draw = ImageDraw.Draw(combined_frame)
460
+ for i, video in enumerate(videos):
461
+ # Give the failed method a black output
462
+ if video is None or frame_idx >= len(video):
463
+ failed = True
464
+ frame = Image.new("RGBA", (w, h))
465
+ else:
466
+ frame = video[frame_idx].convert("RGBA")
467
+ failed = False
468
+
469
+ f_x = (i % col) * (w + edge_offset)
470
+ f_y = (i // col) * (h + edge_offset) + h_start_offset
471
+ combined_frame.paste(frame, (f_x, f_y))
472
+
473
+ # Draw name
474
+ font = ImageFont.truetype("DejaVuSans.ttf", 12)
475
+ # font = ImageFont.truetype("DejaVuSans-Bold.ttf", 13)
476
+ # font = ImageFont.truetype("timesbd.ttf", 14)
477
+ name = self.names[i] if not failed else f'{self.names[i]} (failed)'
478
+ draw.text(
479
+ (f_x + 10, f_y - 20),
480
+ name, (255, 255, 255), font=font
481
+ )
482
+
483
+ combined_frames.append(combined_frame)
484
+ return combined_frames
485
+
486
+
487
+ class BoundingBoxesListReader(Reader):
488
+ def __init__(
489
+ self, dir_name, resize=None, read=True, max_length=MAX_LENGTH,
490
+ scale=1
491
+ ):
492
+ self.resize = resize
493
+ self.scale = scale
494
+ super().__init__(dir_name, read, max_length)
495
+
496
+ def read_file(self, filename):
497
+ bboxes = np.loadtxt(filename, dtype=int)
498
+ bboxes = [bboxes.tolist()]
499
+ return bboxes
500
+
501
+
502
+ def save_frames_to_dir(frames, dirname):
503
+ reader = FrameReader(dirname, read=False)
504
+ reader.set_files(frames)
505
+ reader.save_files(dirname)
506
+
507
+
508
+ if __name__ == "__main__":
509
+ args = parse_args()
510
+ if args.reader_type is None:
511
+ reader = FrameReader(args.video_dir)
512
+ elif args.reader_type == 'fg':
513
+ reader = ForegroundReader(
514
+ args.video_dir, args.segms_dir, args.fg_dir)
515
+ elif args.reader_type == 'sy':
516
+ reader = SynthesizedFrameReader(
517
+ args.bg_frames_dir, args.fg_frames_dir,
518
+ args.fg_segms_dir, args.fg_dir, args.syn_frames_dir
519
+ )
520
+ elif args.reader_type == 'com':
521
+ reader = CompareFramesReader(
522
+ args.video_dirs
523
+ )
524
+ reader.write_files_to_video(
525
+ os.path.join(args.output_dir, args.output_filename),
526
+ fps=args.fps
527
+ )
FGT_codes/FGT/data/util/util.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import shutil
4
+ from glob import glob
5
+
6
+ import numpy as np
7
+ from PIL import Image
8
+
9
+ from utils.logging_config import logger
10
+
11
+
12
+ def parse_args():
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument(
15
+ '-v', '--video_dir',
16
+ type=str,
17
+ help="Video directory name"
18
+ )
19
+ parser.add_argument(
20
+ '-fl', '--flow_dir',
21
+ type=str,
22
+ help="Optical flow ground truth directory name"
23
+ )
24
+ parser.add_argument(
25
+ '-od', '--output_dir',
26
+ type=str,
27
+ help="Output directory name"
28
+ )
29
+ parser.add_argument(
30
+ '-o', '--output_filename',
31
+ type=str,
32
+ help="Output output filename"
33
+ )
34
+ args = parser.parse_args()
35
+ return args
36
+
37
+
38
+ def make_dirs(dir_name):
39
+ if not os.path.exists(dir_name):
40
+ os.makedirs(dir_name)
41
+ logger.info(f"Directory {dir_name} made")
42
+
43
+
44
+ ensure_dir = make_dirs
45
+
46
+
47
+ def make_dir_under_root(root_dir, name):
48
+ full_dir_name = os.path.join(root_dir, name)
49
+ make_dirs(full_dir_name)
50
+ return full_dir_name
51
+
52
+
53
+ def rm_dirs(dir_name, ignore_errors=False):
54
+ if os.path.exists(dir_name):
55
+ shutil.rmtree(dir_name, ignore_errors)
56
+ logger.info(f"Directory {dir_name} removed")
57
+
58
+
59
+ def read_dirnames_under_root(root_dir, skip_list=[]):
60
+ dirnames = [
61
+ name for i, name in enumerate(sorted(os.listdir(root_dir)))
62
+ if (os.path.isdir(os.path.join(root_dir, name))
63
+ and name not in skip_list
64
+ and i not in skip_list)
65
+ ]
66
+ logger.info(f"Reading directories under {root_dir}, exclude {skip_list}, num: {len(dirnames)}")
67
+ return dirnames
68
+
69
+
70
+ def bbox_offset(bbox, location):
71
+ x0, y0 = location
72
+ (x1, y1), (x2, y2) = bbox
73
+ return ((x1 + x0, y1 + y0), (x2 + x0, y2 + y0))
74
+
75
+
76
+ def cover2_bbox(bbox1, bbox2):
77
+ x1 = min(bbox1[0][0], bbox2[0][0])
78
+ y1 = min(bbox1[0][1], bbox2[0][1])
79
+ x2 = max(bbox1[1][0], bbox2[1][0])
80
+ y2 = max(bbox1[1][1], bbox2[1][1])
81
+ return ((x1, y1), (x2, y2))
82
+
83
+
84
+ def extend_r_bbox(bbox, w, h, r):
85
+ (x1, y1), (x2, y2) = bbox
86
+ x1 = max(x1 - r, 0)
87
+ x2 = min(x2 + r, w)
88
+ y1 = max(y1 - r, 0)
89
+ y2 = min(y2 + r, h)
90
+ return ((x1, y1), (x2, y2))
91
+
92
+
93
+ def mean_squared_error(A, B):
94
+ return np.square(np.subtract(A, B)).mean()
95
+
96
+
97
+ def bboxes_to_mask(size, bboxes):
98
+ mask = Image.new("L", size, 255)
99
+ mask = np.array(mask)
100
+ for bbox in bboxes:
101
+ try:
102
+ (x1, y1), (x2, y2) = bbox
103
+ except Exception:
104
+ (x1, y1, x2, y2) = bbox
105
+
106
+ mask[y1:y2, x1:x2] = 0
107
+ mask = Image.fromarray(mask.astype("uint8"))
108
+ return mask
109
+
110
+
111
+ def get_extended_from_box(img_size, box, patch_size):
112
+ def _decide_patch_num(box_width, patch_size):
113
+ num = np.ceil(box_width / patch_size).astype(np.int)
114
+ if (num * patch_size - box_width) < (patch_size // 2):
115
+ num += 1
116
+ return num
117
+
118
+ x1, y1 = box[0]
119
+ x2, y2 = box[1]
120
+ new_box = (x1, y1, x2 - x1, y2 - y1)
121
+ box_x_start, box_y_start, box_x_size, box_y_size = new_box
122
+
123
+ patchN_x = _decide_patch_num(box_x_size, patch_size)
124
+ patchN_y = _decide_patch_num(box_y_size, patch_size)
125
+
126
+ extend_x = (patch_size * patchN_x - box_x_size) // 2
127
+ extend_y = (patch_size * patchN_y - box_y_size) // 2
128
+ img_x_size = img_size[0]
129
+ img_y_size = img_size[1]
130
+
131
+ x_start = max(0, box_x_start - extend_x)
132
+ x_end = min(box_x_start - extend_x + patchN_x * patch_size, img_x_size)
133
+
134
+ y_start = max(0, box_y_start - extend_y)
135
+ y_end = min(box_y_start - extend_y + patchN_y * patch_size, img_y_size)
136
+ x_start, y_start, x_end, y_end = int(x_start), int(y_start), int(x_end), int(y_end)
137
+ extented_box = ((x_start, y_start), (x_end, y_end))
138
+ return extented_box
139
+
140
+
141
+ # code modified from https://github.com/WonwoongCho/Generative-Inpainting-pytorch/blob/master/util.py
142
+ def spatial_discounting_mask(mask_width, mask_height, discounting_gamma):
143
+ """Generate spatial discounting mask constant.
144
+ Spatial discounting mask is first introduced in publication:
145
+ Generative Image Inpainting with Contextual Attention, Yu et al.
146
+ Returns:
147
+ np.array: spatial discounting mask
148
+ """
149
+ gamma = discounting_gamma
150
+ mask_values = np.ones((mask_width, mask_height), dtype=np.float32)
151
+ for i in range(mask_width):
152
+ for j in range(mask_height):
153
+ mask_values[i, j] = max(
154
+ gamma**min(i, mask_width - i),
155
+ gamma**min(j, mask_height - j))
156
+
157
+ return mask_values
158
+
159
+
160
+ def bboxes_to_discounting_loss_mask(img_size, bboxes, discounting_gamma=0.99):
161
+ mask = np.zeros(img_size, dtype=np.float32) + 0.5
162
+ for bbox in bboxes:
163
+ try:
164
+ (x1, y1), (x2, y2) = bbox
165
+ except Exception:
166
+ (x1, y1, x2, y2) = bbox
167
+ mask_width, mask_height = y2 - y1, x2 - x1
168
+ mask[y1:y2, x1:x2] = spatial_discounting_mask(mask_width, mask_height, discounting_gamma)
169
+ return mask
170
+
171
+
172
+ def find_proper_window(image_size, bbox_point):
173
+ '''
174
+ parameters:
175
+ image_size(2-tuple): (height, width)
176
+ bbox_point(2-2-tuple): (first_point, last_point)
177
+ return values:
178
+ window left-up point, (2-tuple)
179
+ window right-bottom point, (2-tuple)
180
+ '''
181
+ bbox_height = bbox_point[1][0] - bbox_point[0][0]
182
+ bbox_width = bbox_point[1][1] - bbox_point[0][1]
183
+
184
+ window_size = min(
185
+ max(bbox_height, bbox_width) * 2,
186
+ image_size[0], image_size[1]
187
+ )
188
+ # Limit min window size due to the requirement of VGG16
189
+ window_size = max(window_size, 32)
190
+
191
+ horizontal_span = window_size - (bbox_point[1][1] - bbox_point[0][1])
192
+ vertical_span = window_size - (bbox_point[1][0] - bbox_point[0][0])
193
+
194
+ top_bound, bottom_bound = bbox_point[0][0] - \
195
+ vertical_span // 2, bbox_point[1][0] + vertical_span // 2
196
+ left_bound, right_bound = bbox_point[0][1] - \
197
+ horizontal_span // 2, bbox_point[1][1] + horizontal_span // 2
198
+
199
+ if left_bound < 0:
200
+ right_bound += 0 - left_bound
201
+ left_bound += 0 - left_bound
202
+ elif right_bound > image_size[1]:
203
+ left_bound -= right_bound - image_size[1]
204
+ right_bound -= right_bound - image_size[1]
205
+
206
+ if top_bound < 0:
207
+ bottom_bound += 0 - top_bound
208
+ top_bound += 0 - top_bound
209
+ elif bottom_bound > image_size[0]:
210
+ top_bound -= bottom_bound - image_size[0]
211
+ bottom_bound -= bottom_bound - image_size[0]
212
+
213
+ return (top_bound, left_bound), (bottom_bound, right_bound)
214
+
215
+
216
+ def drawrect(drawcontext, xy, outline=None, width=0, partial=None):
217
+ (x1, y1), (x2, y2) = xy
218
+ if partial is None:
219
+ points = (x1, y1), (x2, y1), (x2, y2), (x1, y2), (x1, y1)
220
+ drawcontext.line(points, fill=outline, width=width)
221
+ else:
222
+ drawcontext.line([(x1, y1), (x1, y1 + partial)], fill=outline, width=width)
223
+ drawcontext.line([(x1 + partial, y1), (x1, y1)], fill=outline, width=width)
224
+
225
+ drawcontext.line([(x2, y1), (x2, y1 + partial)], fill=outline, width=width)
226
+ drawcontext.line([(x2, y1), (x2 - partial, y1)], fill=outline, width=width)
227
+
228
+ drawcontext.line([(x1, y2), (x1 + partial, y2)], fill=outline, width=width)
229
+ drawcontext.line([(x1, y2), (x1, y2 - partial)], fill=outline, width=width)
230
+
231
+ drawcontext.line([(x2 - partial, y2), (x2, y2)], fill=outline, width=width)
232
+ drawcontext.line([(x2, y2), (x2, y2 - partial)], fill=outline, width=width)
233
+
234
+
235
+ def get_everything_under(root_dir, pattern='*', only_dirs=False, only_files=False):
236
+ assert not(only_dirs and only_files), 'You will get nothnig '\
237
+ 'when "only_dirs" and "only_files" are both set to True'
238
+ everything = sorted(glob(os.path.join(root_dir, pattern)))
239
+ if only_dirs:
240
+ everything = [f for f in everything if os.path.isdir(f)]
241
+ if only_files:
242
+ everything = [f for f in everything if os.path.isfile(f)]
243
+
244
+ return everything
245
+
246
+
247
+ def read_filenames_from_dir(dir_name, reader, max_length=None):
248
+ logger.debug(
249
+ f"{reader} reading files from {dir_name}")
250
+ filenames = []
251
+ for root, dirs, files in os.walk(dir_name):
252
+ assert len(dirs) == 0, f"There are direcories: {dirs} in {root}"
253
+ assert len(files) != 0, f"There are no files in {root}"
254
+ filenames = [os.path.join(root, name) for name in sorted(files)]
255
+ for name in filenames:
256
+ logger.debug(name)
257
+ if max_length is not None:
258
+ return filenames[:max_length]
259
+ return filenames
FGT_codes/FGT/data/util/utils.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import numpy as np
3
+ import cv2
4
+
5
+ def random_bbox(img_height, img_width, vertical_margin, horizontal_margin, mask_height, mask_width):
6
+ maxt = img_height - vertical_margin - mask_height
7
+ maxl = img_width - horizontal_margin - mask_width
8
+
9
+ t = random.randint(vertical_margin, maxt)
10
+ l = random.randint(horizontal_margin, maxl)
11
+ h = random.randint(mask_height // 2, mask_height)
12
+ w = random.randint(mask_width // 2, mask_width)
13
+ return (t, l, h, w) # 产生随机块状box,这个box后面会发展成为mask
14
+
15
+
16
+ def mid_bbox_mask(img_height, img_width, mask_height, mask_width):
17
+ def npmask(bbox, height, width):
18
+ mask = np.zeros((height, width, 1), np.float32)
19
+ mask[bbox[0]: bbox[0] + bbox[2], bbox[1]: bbox[1] + bbox[3], :] = 255.
20
+ return mask
21
+
22
+ bbox = (img_height * 3 // 8, img_width * 3 // 8, mask_height, mask_width)
23
+ mask = npmask(bbox, img_height, img_width)
24
+
25
+ return mask
26
+
27
+
28
+ def bbox2mask(img_height, img_width, max_delta_height, max_delta_width, bbox):
29
+ """Generate mask tensor from bbox.
30
+
31
+ Args:
32
+ bbox: configuration tuple, (top, left, height, width)
33
+ config: Config should have configuration including IMG_SHAPES,
34
+ MAX_DELTA_HEIGHT, MAX_DELTA_WIDTH.
35
+
36
+ Returns:
37
+ tf.Tensor: output with shape [B, 1, H, W]
38
+
39
+ """
40
+
41
+ def npmask(bbox, height, width, delta_h, delta_w):
42
+ mask = np.zeros((height, width, 1), np.float32)
43
+ h = np.random.randint(delta_h // 2 + 1) # 防止有0产生
44
+ w = np.random.randint(delta_w // 2 + 1)
45
+ mask[bbox[0] + h: bbox[0] + bbox[2] - h, bbox[1] + w: bbox[1] + bbox[3] - w, :] = 255. # height_true = height - 2 * h, width_true = width - 2 * w
46
+ return mask
47
+
48
+ mask = npmask(bbox, img_height, img_width,
49
+ max_delta_height,
50
+ max_delta_width)
51
+
52
+ return mask
53
+
54
+
55
+ def matrix2bbox(img_height, img_width, mask_height, mask_width, row, column):
56
+ """Generate masks with a matrix form
57
+ @param img_height
58
+ @param img_width
59
+ @param mask_height
60
+ @param mask_width
61
+ @param row: number of blocks in row
62
+ @param column: number of blocks in column
63
+ @return mbbox: multiple bboxes in (y, h, h, w) manner
64
+ """
65
+ assert img_height - column * mask_height > img_height // 2, "Too many masks across a column"
66
+ assert img_width - row * mask_width > img_width // 2, "Too many masks across a row"
67
+
68
+ interval_height = (img_height - column * mask_height) // (column + 1)
69
+ interval_width = (img_width - row * mask_width) // (row + 1)
70
+
71
+ mbbox = []
72
+ for i in range(row):
73
+ for j in range(column):
74
+ y = interval_height * (j+1) + j * mask_height
75
+ x = interval_width * (i+1) + i * mask_width
76
+ mbbox.append((y, x, mask_height, mask_width))
77
+ return mbbox
78
+
79
+
80
+ def mbbox2masks(img_height, img_width, mbbox):
81
+
82
+ def npmask(mbbox, height, width):
83
+ mask = np.zeros((height, width, 1), np.float32)
84
+ for bbox in mbbox:
85
+ mask[bbox[0]: bbox[0] + bbox[2], bbox[1]: bbox[1] + bbox[3], :] = 255. # height_true = height - 2 * h, width_true = width - 2 * w
86
+ return mask
87
+
88
+ mask = npmask(mbbox, img_height, img_width)
89
+
90
+ return mask
91
+
92
+
93
+ def draw_line(mask, startX, startY, angle, length, brushWidth):
94
+ """assume the size of mask is (H,W,1)
95
+ """
96
+ assert len(mask.shape) == 2 or mask.shape[2] == 1, "The channel of mask doesn't fit the opencv format"
97
+ offsetX = int(np.round(length * np.cos(angle)))
98
+ offsetY = int(np.round(length * np.sin(angle)))
99
+ endX = startX + offsetX
100
+ endY = startY + offsetY
101
+ if endX > mask.shape[1]:
102
+ endX = mask.shape[1]
103
+ if endY > mask.shape[0]:
104
+ endY = mask.shape[0]
105
+ mask_processed = cv2.line(mask, (startX, startY), (endX, endY), 255, brushWidth)
106
+ return mask_processed, endX, endY
107
+
108
+
109
+ def draw_circle(mask, circle_x, circle_y, brushWidth):
110
+ radius = brushWidth // 2
111
+ assert len(mask.shape) == 2 or mask.shape[2] == 1, "The channel of mask doesn't fit the opencv format"
112
+ mask_processed = cv2.circle(mask, (circle_x, circle_y), radius, 255)
113
+ return mask_processed
114
+
115
+
116
+ def freeFormMask(img_height, img_width, maxVertex, maxLength, maxBrushWidth, maxAngle):
117
+ mask = np.zeros((img_height, img_width))
118
+ numVertex = random.randint(1, maxVertex)
119
+ startX = random.randint(10, img_width)
120
+ startY = random.randint(10, img_height)
121
+ brushWidth = random.randint(10, maxBrushWidth)
122
+ for i in range(numVertex):
123
+ angle = random.uniform(0, maxAngle)
124
+ if i % 2 == 0:
125
+ angle = 2 * np.pi - angle
126
+ length = random.randint(10, maxLength)
127
+ mask, endX, endY = draw_line(mask, startX, startY, angle, length, brushWidth)
128
+ startX = startX + int(length * np.sin(angle))
129
+ startY = startY + int(length * np.cos(angle))
130
+ mask = draw_circle(mask, endX, endY, brushWidth)
131
+
132
+ if random.random() < 0.5:
133
+ mask = np.fliplr(mask)
134
+ if random.random() < 0.5:
135
+ mask = np.flipud(mask)
136
+
137
+ if len(mask.shape) == 2:
138
+ mask = mask[:, :, np.newaxis]
139
+
140
+ return mask
141
+
142
+
143
+ if __name__ == "__main__":
144
+ # for stationary mask generation
145
+ # stationary_mask_generator(240, 480, 50, 120)
146
+
147
+ # for free-form mask generation
148
+ # mask = freeFormMask(240, 480, 30, 50, 20, np.pi)
149
+ # cv2.imwrite('mask.png', mask)
150
+
151
+ # for matrix mask generation
152
+ # img_height, img_width = 240, 480
153
+ # masks = matrix2bbox(240, 480, 20, 20, 5, 4)
154
+ # matrixMask = mbbox2masks(img_height, img_width, masks)
155
+ # cv2.imwrite('matrixMask.png', matrixMask)
156
+ pass
157
+
158
+
FGT_codes/FGT/flowCheckPoint/config.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PASSMASK: 1
2
+ cnum: 48
3
+ conv_type: vanilla
4
+ flow_interval: 1
5
+ in_channel: 3
6
+ init_weights: 1
7
+ num_flows: 1
8
+ resBlocks: 1
9
+ use_bias: 1
10
+ use_residual: 1
11
+ model: lafc_single
FGT_codes/FGT/flowCheckPoint/lafc_single.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0fa168e8b711852c458594cddf4262afdb81e096253197a802a29b4dec9d6d12
3
+ size 11547053
FGT_codes/FGT/inputs.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+
4
+ def args_parser():
5
+ parser = argparse.ArgumentParser(description="General top layer trainer")
6
+ parser.add_argument("--opt", type=str, default="config/train.yaml", help="Path to optional configuration file")
7
+ parser.add_argument('--model', type=str, default='model',
8
+ help='Model block name, in the `model` directory')
9
+ parser.add_argument('--name', type=str, default='FGT_train', help='Experiment name')
10
+ parser.add_argument('--outputdir', type=str, default='/myData/ret/experiments', help='Output dir to save results')
11
+ parser.add_argument('--datadir', type=str, default='/myData/', metavar='PATH')
12
+ parser.add_argument('--datasetName_train', type=str, default='train_dataset_frames_diffusedFlows',
13
+ help='The file name of the train dataset, in `data` directory')
14
+ parser.add_argument('--network', type=str, default='network',
15
+ help='The network file which defines the training process, in the `network` directory')
16
+ parser.add_argument('--finetune', type=int, default=0, help='Whether to fine tune trained models')
17
+ # parser.add_argument('--checkPoint', type=str, default='', help='checkpoint path for continue training')
18
+ parser.add_argument('--gen_state', type=str, default='', help='Checkpoint of the generator')
19
+ parser.add_argument('--dis_state', type=str, default='', help='Checkpoint of the discriminator')
20
+ parser.add_argument('--opt_state', type=str, default='', help='Checkpoint of the options')
21
+ parser.add_argument('--record_iter', type=int, default=16, help='How many iters to print an item of log')
22
+ parser.add_argument('--flow_checkPoint', type=str, default='flowCheckPoint/',
23
+ help='The path for flow model filling')
24
+ parser.add_argument('--dataMode', type=str, default='resize', choices=['resize', 'crop'])
25
+
26
+ # data related parameters
27
+ parser.add_argument('--flow2rgb', type=int, default=1, help='Whether to transform flows from raw data to rgb')
28
+ parser.add_argument('--flow_direction', type=str, default='for', choices=['for', 'back', 'bi'],
29
+ help='Which GT flow should be chosen for guidance')
30
+ parser.add_argument('--num_frames', type=int, default=5, help='How many frames are chosen for frame completion')
31
+ parser.add_argument('--sample', type=str, default='random', choices=['random', 'seq'],
32
+ help='Choose the sample method for training in each iterations')
33
+ parser.add_argument('--max_val', type=float, default=0.01, help='The maximal value to quantize the optical flows')
34
+
35
+ # model related parameters
36
+ parser.add_argument('--res_h', type=int, default=240, help='The height of the frame resolution')
37
+ parser.add_argument('--res_w', type=int, default=432, help='The width of the frame resolution')
38
+ parser.add_argument('--in_channel', type=int, default=4, help='The input channel of the frame branch')
39
+ parser.add_argument('--cnum', type=int, default=64, help='The initial channel number of the frame branch')
40
+ parser.add_argument('--flow_inChannel', type=int, default=2, help='The input channel of the flow branch')
41
+ parser.add_argument('--flow_cnum', type=int, default=64, help='The initial channel dimension of the flow branch')
42
+ parser.add_argument('--dist_cnum', type=int, default=32, help='The initial channel num in the discriminator')
43
+ parser.add_argument('--frame_hidden', type=int, default=512,
44
+ help='The channel / patch dimension in the frame branch')
45
+ parser.add_argument('--flow_hidden', type=int, default=256, help='The channel / patch dimension in the flow branch')
46
+ parser.add_argument('--PASSMASK', type=int, default=1,
47
+ help='1 -> concat the mask with the corrupted optical flows to fill the flow')
48
+ parser.add_argument('--numBlocks', type=int, default=8, help='How many transformer blocks do we need to stack')
49
+ parser.add_argument('--kernel_size_w', type=int, default=7, help='The width of the kernel for extracting patches')
50
+ parser.add_argument('--kernel_size_h', type=int, default=7, help='The height of the kernel for extracting patches')
51
+ parser.add_argument('--stride_h', type=int, default=3, help='The height of the stride')
52
+ parser.add_argument('--stride_w', type=int, default=3, help='The width of the stride')
53
+ parser.add_argument('--pad_h', type=int, default=3, help='The height of the padding')
54
+ parser.add_argument('--pad_w', type=int, default=3, help='The width of the padding')
55
+ parser.add_argument('--num_head', type=int, default=4, help='The head number for the multihead attention')
56
+ parser.add_argument('--conv_type', type=str, choices=['vanilla', 'gated', 'partial'], default='vanilla',
57
+ help='Which kind of conv to use')
58
+ parser.add_argument('--norm', type=str, default='None', choices=['None', 'BN', 'SN', 'IN'],
59
+ help='The normalization method for the conv blocks')
60
+ parser.add_argument('--use_bias', type=int, default=1, help='If 1, use bias in the convolution blocks')
61
+ parser.add_argument('--ape', type=int, default=1, help='If ape = 1, use absolute positional embedding')
62
+ parser.add_argument('--pos_mode', type=str, default='single', choices=['single', 'dual'],
63
+ help='If pos_mode = dual, add positional embedding to flow patches')
64
+ parser.add_argument('--mlp_ratio', type=int, default=40, help='The mlp dilation rate for the feed forward layers')
65
+ parser.add_argument('--drop', type=int, default=0, help='The dropout rate, 0 by default')
66
+ parser.add_argument('--init_weights', type=int, default=1, help='If 1, initialize the network, 1 by default')
67
+
68
+ # loss related parameters
69
+ parser.add_argument('--L1M', type=float, default=1, help='The weight of L1 loss in the masked area')
70
+ parser.add_argument('--L1V', type=float, default=1, help='The weight of L1 loss in the valid area')
71
+ parser.add_argument('--adv', type=float, default=0.01, help='The weight of adversarial loss')
72
+
73
+ # spatial and temporal related parameters
74
+ parser.add_argument('--tw', type=int, default=2, help='The number of temporal group in the temporal transformer')
75
+ parser.add_argument('--sw', type=int, default=8,
76
+ help='The number of spatial window size in the spatial transformer')
77
+ parser.add_argument('--gd', type=int, default=4, help='Global downsample rate for spatial transformer')
78
+
79
+ parser.add_argument('--ref_length', type=int, default=10, help='The sample interval during inference')
80
+ parser.add_argument('--use_valid', action='store_true')
81
+
82
+ args = parser.parse_args()
83
+ return args
FGT_codes/FGT/metrics/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from skimage.metrics import peak_signal_noise_ratio as psnr
3
+ from skimage.metrics import structural_similarity as ssim
4
+ import os
5
+
6
+ os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
7
+
8
+
9
+ def calculate_metrics(results, gts):
10
+ B, H, W, C = results.shape
11
+ psnr_values, ssim_values, L1errors, L2errors = [], [], [], []
12
+ for i in range(B):
13
+ result = results[i]
14
+ gt = gts[i]
15
+ result_img = result
16
+ gt_img = gt
17
+ residual = result - gt
18
+ L1error = np.mean(np.abs(residual))
19
+ L2error = np.sum(residual ** 2) ** 0.5 / (H * W * C)
20
+ psnr_value = psnr(result_img, gt_img)
21
+ ssim_value = ssim(result_img, gt_img, multichannel=True)
22
+ L1errors.append(L1error)
23
+ L2errors.append(L2error)
24
+ psnr_values.append(psnr_value)
25
+ ssim_values.append(ssim_value)
26
+ L1_value = np.mean(L1errors)
27
+ L2_value = np.mean(L2errors)
28
+ psnr_value = np.mean(psnr_values)
29
+ ssim_value = np.mean(ssim_values)
30
+
31
+ return {'l1': L1_value, 'l2': L2_value, 'psnr': psnr_value, 'ssim': ssim_value}
FGT_codes/FGT/metrics/psnr.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy
2
+ import math
3
+
4
+
5
+ def psnr(img1, img2):
6
+ mse = numpy.mean( (img1 - img2) ** 2 )
7
+ if mse == 0:
8
+ return 100
9
+ PIXEL_MAX = 255.0
10
+ return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
FGT_codes/FGT/metrics/ssim.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+
5
+ def calculate_ssim(img1, img2):
6
+ C1 = (0.01 * 255)**2
7
+ C2 = (0.03 * 255)**2
8
+
9
+ img1 = img1.astype(np.float64)
10
+ img2 = img2.astype(np.float64)
11
+ kernel = cv2.getGaussianKernel(11, 1.5)
12
+ window = np.outer(kernel, kernel.transpose())
13
+
14
+ mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
15
+ mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
16
+ mu1_sq = mu1**2
17
+ mu2_sq = mu2**2
18
+ mu1_mu2 = mu1 * mu2
19
+ sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
20
+ sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
21
+ sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
22
+
23
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
24
+ (sigma1_sq + sigma2_sq + C2))
25
+ return ssim_map.mean()
26
+
27
+
28
+ def ssim(img1, img2):
29
+ '''calculate SSIM
30
+ the same outputs as MATLAB's
31
+ img1, img2: [0, 255]
32
+ '''
33
+ if not img1.shape == img2.shape:
34
+ raise ValueError('Input images must have the same dimensions.')
35
+ if img1.ndim == 2:
36
+ return calculate_ssim(img1, img2)
37
+ elif img1.ndim == 3:
38
+ if img1.shape[2] == 3:
39
+ ssims = []
40
+ for i in range(3):
41
+ ssims.append(calculate_ssim(img1[:, :, i], img2[:, :, i]))
42
+ return np.array(ssims).mean()
43
+ elif img1.shape[2] == 1:
44
+ return calculate_ssim(np.squeeze(img1), np.squeeze(img2))
45
+ else:
46
+ raise ValueError('Wrong input image dimensions.')
FGT_codes/FGT/models/BaseNetwork.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .utils.network_blocks_2d import *
2
+
3
+
4
+ class BaseNetwork(nn.Module):
5
+ def __init__(self, conv_type):
6
+ super(BaseNetwork, self).__init__()
7
+ self.conv_type = conv_type
8
+ if conv_type == 'gated':
9
+ self.ConvBlock = GatedConv
10
+ self.DeconvBlock = GatedDeconv
11
+ if conv_type == 'partial':
12
+ self.ConvBlock = PartialConv
13
+ self.DeconvBlock = PartialDeconv
14
+ if conv_type == 'vanilla':
15
+ self.ConvBlock = VanillaConv
16
+ self.DeconvBlock = VanillaDeconv
17
+ self.ConvBlock2d = self.ConvBlock
18
+ self.DeconvBlock2d = self.DeconvBlock
19
+
20
+ def init_weights(self, init_type='normal', gain=0.02):
21
+ '''
22
+ initialize network's weights
23
+ init_type: normal | xavier | kaiming | orthogonal
24
+ https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39
25
+ '''
26
+
27
+ def init_func(m):
28
+ classname = m.__class__.__name__
29
+ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
30
+ if init_type == 'normal':
31
+ nn.init.normal_(m.weight.data, 0.0, gain)
32
+ elif init_type == 'xavier':
33
+ nn.init.xavier_normal_(m.weight.data, gain=gain)
34
+ elif init_type == 'kaiming':
35
+ nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
36
+ elif init_type == 'orthogonal':
37
+ nn.init.orthogonal_(m.weight.data, gain=gain)
38
+
39
+ if hasattr(m, 'bias') and m.bias is not None:
40
+ nn.init.constant_(m.bias.data, 0.0)
41
+
42
+ elif classname.find('BatchNorm2d') != -1:
43
+ nn.init.normal_(m.weight.data, 1.0, gain)
44
+ nn.init.constant_(m.bias.data, 0.0)
45
+
46
+ self.apply(init_func)
FGT_codes/FGT/models/__init__.py ADDED
File without changes
FGT_codes/FGT/models/__pycache__/BaseNetwork.cpython-39.pyc ADDED
Binary file (1.97 kB). View file
 
FGT_codes/FGT/models/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (163 Bytes). View file
 
FGT_codes/FGT/models/__pycache__/model.cpython-39.pyc ADDED
Binary file (10.3 kB). View file
 
FGT_codes/FGT/models/lafc_single.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+ import functools
5
+ from .BaseNetwork import BaseNetwork
6
+ from models.utils.reconstructionLayers import make_layer, ResidualBlock_noBN
7
+
8
+
9
+ class Model(nn.Module):
10
+ def __init__(self, config):
11
+ super(Model, self).__init__()
12
+ self.net = P3DNet(config['num_flows'], config['cnum'], config['in_channel'], config['PASSMASK'],
13
+ config['use_residual'],
14
+ config['resBlocks'], config['use_bias'], config['conv_type'], config['init_weights'])
15
+
16
+ def forward(self, flows, masks, edges=None):
17
+ ret = self.net(flows, masks, edges)
18
+ return ret
19
+
20
+
21
+ class P3DNet(BaseNetwork):
22
+ def __init__(self, num_flows, num_feats, in_channels, passmask, use_residual, res_blocks,
23
+ use_bias, conv_type, init_weights):
24
+ super().__init__(conv_type)
25
+ self.passmask = passmask
26
+ self.encoder2 = nn.Sequential(
27
+ nn.ReplicationPad2d(2),
28
+ self.ConvBlock2d(in_channels, num_feats, kernel_size=5, stride=1, padding=0, bias=use_bias, norm=None),
29
+ self.ConvBlock2d(num_feats, num_feats * 2, kernel_size=3, stride=2, padding=1, bias=use_bias, norm=None)
30
+ )
31
+ self.encoder4 = nn.Sequential(
32
+ self.ConvBlock2d(num_feats * 2, num_feats * 2, kernel_size=3, stride=1, padding=1, bias=use_bias,
33
+ norm=None),
34
+ self.ConvBlock2d(num_feats * 2, num_feats * 4, kernel_size=3, stride=2, padding=1, bias=use_bias, norm=None)
35
+ )
36
+ residualBlock = functools.partial(ResidualBlock_noBN, nf=num_feats * 4)
37
+ self.res_blocks = make_layer(residualBlock, res_blocks)
38
+ self.resNums = res_blocks
39
+ # dilation convolution to enlarge the receptive field
40
+ self.middle = nn.Sequential(
41
+ self.ConvBlock2d(num_feats * 4, num_feats * 4, kernel_size=3, stride=1, padding=8, bias=use_bias,
42
+ dilation=8, norm=None),
43
+ self.ConvBlock2d(num_feats * 4, num_feats * 4, kernel_size=3, stride=1, padding=4, bias=use_bias,
44
+ dilation=4, norm=None),
45
+ self.ConvBlock2d(num_feats * 4, num_feats * 4, kernel_size=3, stride=1, padding=2, bias=use_bias,
46
+ dilation=2, norm=None),
47
+ self.ConvBlock2d(num_feats * 4, num_feats * 4, kernel_size=3, stride=1, padding=1, bias=use_bias,
48
+ dilation=1, norm=None),
49
+ )
50
+ self.decoder2 = nn.Sequential(
51
+ self.DeconvBlock2d(num_feats * 8, num_feats * 2, kernel_size=3, stride=1, padding=1, bias=use_bias,
52
+ norm=None),
53
+ self.ConvBlock2d(num_feats * 2, num_feats * 2, kernel_size=3, stride=1, padding=1, bias=use_bias,
54
+ norm=None),
55
+ self.ConvBlock2d(num_feats * 2, num_feats * 2, kernel_size=3, stride=1, padding=1, bias=use_bias,
56
+ norm=None)
57
+ )
58
+ self.decoder = nn.Sequential(
59
+ self.DeconvBlock2d(num_feats * 4, num_feats, kernel_size=3, stride=1, padding=1, bias=use_bias,
60
+ norm=None),
61
+ self.ConvBlock2d(num_feats, num_feats // 2, kernel_size=3, stride=1, padding=1, bias=use_bias,
62
+ norm=None),
63
+ self.ConvBlock2d(num_feats // 2, 2, kernel_size=3, stride=1, padding=1, bias=use_bias,
64
+ norm=None)
65
+ )
66
+ self.edgeDetector = EdgeDetection(conv_type)
67
+ if init_weights:
68
+ self.init_weights()
69
+
70
+ def forward(self, flows, masks, edges=None):
71
+ if self.passmask:
72
+ inputs = torch.cat((flows, masks), dim=1)
73
+ else:
74
+ inputs = flows
75
+ if edges is not None:
76
+ inputs = torch.cat((inputs, edges), dim=1)
77
+ e2 = self.encoder2(inputs)
78
+ e4 = self.encoder4(e2)
79
+ if self.resNums > 0:
80
+ e4_res = self.res_blocks(e4)
81
+ else:
82
+ e4_res = e4
83
+ c_e4_filled = self.middle(e4_res)
84
+ c_e4 = torch.cat((c_e4_filled, e4), dim=1)
85
+ c_e2Post = self.decoder2(c_e4)
86
+ c_e2 = torch.cat((c_e2Post, e2), dim=1)
87
+ output = self.decoder(c_e2)
88
+ edge = self.edgeDetector(output)
89
+ return output, edge
90
+
91
+
92
+ class EdgeDetection(BaseNetwork):
93
+ def __init__(self, conv_type, in_channels=2, out_channels=1, mid_channels=16):
94
+ super(EdgeDetection, self).__init__(conv_type)
95
+ self.projection = self.ConvBlock2d(in_channels=in_channels, out_channels=mid_channels, kernel_size=3, stride=1,
96
+ padding=1, norm=None)
97
+ self.mid_layer_1 = self.ConvBlock2d(in_channels=mid_channels, out_channels=mid_channels, kernel_size=3,
98
+ stride=1, padding=1, norm=None)
99
+ self.mid_layer_2 = self.ConvBlock2d(in_channels=mid_channels, out_channels=mid_channels, kernel_size=3,
100
+ stride=1, padding=1, activation=None, norm=None)
101
+ self.l_relu = nn.LeakyReLU()
102
+ self.out_layer = self.ConvBlock2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=1,
103
+ activation=None, norm=None)
104
+
105
+ def forward(self, flow):
106
+ flow = self.projection(flow)
107
+ edge = self.mid_layer_1(flow)
108
+ edge = self.mid_layer_2(edge)
109
+ edge = self.l_relu(flow + edge)
110
+ edge = self.out_layer(edge)
111
+ edge = torch.sigmoid(edge)
112
+ return edge
113
+
114
+
FGT_codes/FGT/models/model.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models.BaseNetwork import BaseNetwork
2
+ from models.transformer_base.ffn_base import FusionFeedForward
3
+ from models.transformer_base.attention_flow import SWMHSA_depthGlobalWindowConcatLN_qkFlow_reweightFlow
4
+ from models.transformer_base.attention_base import TMHSA
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from functools import reduce
9
+ import torch.nn.functional as F
10
+
11
+
12
+ class Model(nn.Module):
13
+ def __init__(self, config):
14
+ super(Model, self).__init__()
15
+ self.net = FGT(config['tw'], config['sw'], config['gd'], config['input_resolution'], config['in_channel'],
16
+ config['cnum'], config['flow_inChannel'], config['flow_cnum'], config['frame_hidden'],
17
+ config['flow_hidden'], config['PASSMASK'],
18
+ config['numBlocks'], config['kernel_size'], config['stride'], config['padding'],
19
+ config['num_head'], config['conv_type'], config['norm'],
20
+ config['use_bias'], config['ape'],
21
+ config['mlp_ratio'], config['drop'], config['init_weights'])
22
+
23
+ def forward(self, frames, flows, masks):
24
+ ret = self.net(frames, flows, masks)
25
+ return ret
26
+
27
+
28
+ class Encoder(nn.Module):
29
+ def __init__(self, in_channels):
30
+ super(Encoder, self).__init__()
31
+ self.group = [1, 2, 4, 8, 1]
32
+ self.layers = nn.ModuleList([
33
+ nn.Conv2d(in_channels, 64, kernel_size=3, stride=2, padding=1),
34
+ nn.LeakyReLU(0.2, inplace=True),
35
+ nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
36
+ nn.LeakyReLU(0.2, inplace=True),
37
+ nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
38
+ nn.LeakyReLU(0.2, inplace=True),
39
+ nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
40
+ nn.LeakyReLU(0.2, inplace=True),
41
+ nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1, groups=1),
42
+ nn.LeakyReLU(0.2, inplace=True),
43
+ nn.Conv2d(640, 512, kernel_size=3, stride=1, padding=1, groups=2),
44
+ nn.LeakyReLU(0.2, inplace=True),
45
+ nn.Conv2d(768, 384, kernel_size=3, stride=1, padding=1, groups=4),
46
+ nn.LeakyReLU(0.2, inplace=True),
47
+ nn.Conv2d(640, 256, kernel_size=3, stride=1, padding=1, groups=8),
48
+ nn.LeakyReLU(0.2, inplace=True),
49
+ nn.Conv2d(512, 128, kernel_size=3, stride=1, padding=1, groups=1),
50
+ nn.LeakyReLU(0.2, inplace=True)
51
+ ])
52
+
53
+ def forward(self, x):
54
+ bt, c, h, w = x.size()
55
+ h, w = h // 4, w // 4
56
+ out = x
57
+ for i, layer in enumerate(self.layers):
58
+ if i == 8:
59
+ x0 = out
60
+ if i > 8 and i % 2 == 0:
61
+ g = self.group[(i - 8) // 2]
62
+ x = x0.view(bt, g, -1, h, w)
63
+ o = out.view(bt, g, -1, h, w)
64
+ out = torch.cat([x, o], 2).view(bt, -1, h, w)
65
+ out = layer(out)
66
+ return out
67
+
68
+
69
+ class AddPosEmb(nn.Module):
70
+ def __init__(self, h, w, in_channels, out_channels):
71
+ super(AddPosEmb, self).__init__()
72
+ self.proj = nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=True, groups=out_channels)
73
+ self.h, self.w = h, w
74
+
75
+ def forward(self, x, h=0, w=0):
76
+ B, N, C = x.shape
77
+ if h == 0 and w == 0:
78
+ assert N == self.h * self.w, 'Wrong input size'
79
+ else:
80
+ assert N == h * w, 'Wrong input size during inference'
81
+ feat_token = x
82
+ if h == 0 and w == 0:
83
+ cnn_feat = feat_token.transpose(1, 2).view(B, C, self.h, self.w)
84
+ else:
85
+ cnn_feat = feat_token.transpose(1, 2).view(B, C, h, w)
86
+ x = self.proj(cnn_feat) + cnn_feat
87
+ x = x.flatten(2).transpose(1, 2)
88
+ return x
89
+
90
+
91
+ class Vec2Patch(nn.Module):
92
+ def __init__(self, channel, hidden, output_size, kernel_size, stride, padding):
93
+ super(Vec2Patch, self).__init__()
94
+ self.relu = nn.LeakyReLU(0.2, inplace=True)
95
+ c_out = reduce((lambda x, y: x * y), kernel_size) * channel
96
+ self.embedding = nn.Linear(hidden, c_out)
97
+ self.restore = nn.Fold(output_size=output_size, kernel_size=kernel_size, stride=stride, padding=padding)
98
+ self.kernel_size = kernel_size
99
+ self.stride = stride
100
+ self.padding = padding
101
+
102
+ def forward(self, x, output_h=0, output_w=0):
103
+ feat = self.embedding(x)
104
+ feat = feat.permute(0, 2, 1)
105
+ if output_h != 0 or output_w != 0:
106
+ feat = F.fold(feat, output_size=(output_h, output_w), kernel_size=self.kernel_size, stride=self.stride,
107
+ padding=self.padding)
108
+ else:
109
+ feat = self.restore(feat)
110
+ return feat
111
+
112
+
113
+ class TemporalTransformer(nn.Module):
114
+ def __init__(self, token_size, frame_hidden, num_heads, t_groupSize, mlp_ratio, dropout, n_vecs,
115
+ t2t_params):
116
+ super(TemporalTransformer, self).__init__()
117
+ self.attention = TMHSA(token_size=token_size, group_size=t_groupSize, d_model=frame_hidden, head=num_heads,
118
+ p=dropout)
119
+ self.ffn = FusionFeedForward(frame_hidden, mlp_ratio, n_vecs, t2t_params, p=dropout)
120
+ self.norm1 = nn.LayerNorm(frame_hidden)
121
+ self.norm2 = nn.LayerNorm(frame_hidden)
122
+ self.dropout = nn.Dropout(p=dropout)
123
+
124
+ def forward(self, x, t, h, w, output_size):
125
+ token_size = h * w
126
+ s = self.norm1(x)
127
+ x = x + self.dropout(self.attention(s, t, h, w))
128
+ y = self.norm2(x)
129
+ x = x + self.ffn(y, token_size, output_size[0], output_size[1])
130
+ return x
131
+
132
+
133
+ class SpatialTransformer(nn.Module):
134
+ def __init__(self, token_size, frame_hidden, flow_hidden, num_heads, s_windowSize, g_downSize, mlp_ratio,
135
+ dropout, n_vecs, t2t_params):
136
+ super(SpatialTransformer, self).__init__()
137
+ self.attention = SWMHSA_depthGlobalWindowConcatLN_qkFlow_reweightFlow(token_size=token_size, window_size=s_windowSize,
138
+ kernel_size=g_downSize, d_model=frame_hidden,
139
+ flow_dModel=flow_hidden, head=num_heads, p=dropout)
140
+ self.ffn = FusionFeedForward(frame_hidden, mlp_ratio, n_vecs, t2t_params, p=dropout)
141
+ self.norm = nn.LayerNorm(frame_hidden)
142
+ self.dropout = nn.Dropout(p=dropout)
143
+
144
+ def forward(self, x, f, t, h, w, output_size):
145
+ token_size = h * w
146
+ x = x + self.dropout(self.attention(x, f, t, h, w))
147
+ y = self.norm(x)
148
+ x = x + self.ffn(y, token_size, output_size[0], output_size[1])
149
+ return x
150
+
151
+
152
+ class TransformerBlock(nn.Module):
153
+ def __init__(self, token_size, frame_hidden, flow_hidden, num_heads, t_groupSize, s_windowSize, g_downSize,
154
+ mlp_ratio,
155
+ dropout, n_vecs,
156
+ t2t_params):
157
+ super(TransformerBlock, self).__init__()
158
+ self.t_transformer = TemporalTransformer(token_size=token_size, frame_hidden=frame_hidden, num_heads=num_heads,
159
+ t_groupSize=t_groupSize, mlp_ratio=mlp_ratio,
160
+ dropout=dropout, n_vecs=n_vecs,
161
+ t2t_params=t2t_params) # temporal multi-head self attention
162
+ self.s_transformer = SpatialTransformer(token_size=token_size, frame_hidden=frame_hidden,
163
+ flow_hidden=flow_hidden, num_heads=num_heads, s_windowSize=s_windowSize,
164
+ g_downSize=g_downSize, mlp_ratio=mlp_ratio,
165
+ dropout=dropout, n_vecs=n_vecs, t2t_params=t2t_params)
166
+
167
+ def forward(self, inputs):
168
+ x, f, t = inputs['x'], inputs['f'], inputs['t']
169
+ h, w = inputs['h'], inputs['w']
170
+ output_size = inputs['output_size']
171
+ x = self.t_transformer(x, t, h, w, output_size)
172
+ x = self.s_transformer(x, f, t, h, w, output_size)
173
+ return {'x': x, 'f': f, 't': t, 'h': h, 'w': w, 'output_size': output_size}
174
+
175
+
176
+ class Decoder(BaseNetwork):
177
+ def __init__(self, conv_type, in_channels, out_channels, use_bias, norm=None):
178
+ super(Decoder, self).__init__(conv_type)
179
+ self.layer1 = self.DeconvBlock(in_channels, in_channels, kernel_size=3, padding=1, norm=norm,
180
+ bias=use_bias)
181
+ self.layer2 = self.ConvBlock(in_channels, in_channels // 2, kernel_size=3, stride=1, padding=1, norm=norm,
182
+ bias=use_bias)
183
+ self.layer3 = self.DeconvBlock(in_channels // 2, in_channels // 2, kernel_size=3, padding=1, norm=norm,
184
+ bias=use_bias)
185
+ self.final = self.ConvBlock(in_channels // 2, out_channels, kernel_size=3, stride=1, padding=1, norm=norm,
186
+ bias=use_bias, activation=None)
187
+
188
+ def forward(self, features):
189
+ feat1 = self.layer1(features)
190
+ feat2 = self.layer2(feat1)
191
+ feat3 = self.layer3(feat2)
192
+ output = self.final(feat3)
193
+ return output
194
+
195
+
196
+ class FGT(BaseNetwork):
197
+ def __init__(self, t_groupSize, s_windowSize, g_downSize, input_resolution, in_channels, cnum, flow_inChannel,
198
+ flow_cnum,
199
+ frame_hidden, flow_hidden, passmask, numBlocks, kernel_size, stride, padding, num_heads, conv_type,
200
+ norm, use_bias, ape, mlp_ratio=4, drop=0, init_weights=True):
201
+ super(FGT, self).__init__(conv_type)
202
+ self.in_channels = in_channels
203
+ self.passmask = passmask
204
+ self.ape = ape
205
+ self.frame_endoder = Encoder(in_channels)
206
+ self.flow_encoder = nn.Sequential(
207
+ nn.ReplicationPad2d(2),
208
+ self.ConvBlock(flow_inChannel, flow_cnum, kernel_size=5, stride=1, padding=0, bias=use_bias, norm=norm),
209
+ self.ConvBlock(flow_cnum, flow_cnum * 2, kernel_size=3, stride=2, padding=1, bias=use_bias, norm=norm),
210
+ self.ConvBlock(flow_cnum * 2, flow_cnum * 2, kernel_size=3, stride=1, padding=1, bias=use_bias, norm=norm),
211
+ self.ConvBlock(flow_cnum * 2, flow_cnum * 2, kernel_size=3, stride=2, padding=1, bias=use_bias, norm=norm)
212
+ )
213
+ # patch to vector operation
214
+ self.patch2vec = nn.Conv2d(cnum * 2, frame_hidden, kernel_size=kernel_size, stride=stride, padding=padding)
215
+ self.f_patch2vec = nn.Conv2d(flow_cnum * 2, flow_hidden, kernel_size=kernel_size, stride=stride,
216
+ padding=padding)
217
+ # initialize transformer blocks for frame completion
218
+ n_vecs = 1
219
+ token_size = []
220
+ output_shape = (input_resolution[0] // 4, input_resolution[1] // 4)
221
+ for i, d in enumerate(kernel_size):
222
+ token_nums = int((output_shape[i] + 2 * padding[i] - kernel_size[i]) / stride[i] + 1)
223
+ n_vecs *= token_nums
224
+ token_size.append(token_nums)
225
+ # Add positional embedding to the encode features
226
+ if self.ape:
227
+ self.add_pos_emb = AddPosEmb(token_size[0], token_size[1], frame_hidden, frame_hidden)
228
+ self.token_size = token_size
229
+ # initialize transformer blocks
230
+ blocks = []
231
+ t2t_params = {'kernel_size': kernel_size, 'stride': stride, 'padding': padding, 'output_size': output_shape}
232
+ for i in range(numBlocks // 2 - 1):
233
+ layer = TransformerBlock(token_size, frame_hidden, flow_hidden, num_heads, t_groupSize, s_windowSize,
234
+ g_downSize, mlp_ratio, drop, n_vecs, t2t_params)
235
+ blocks.append(layer)
236
+ self.first_t_transformer = TemporalTransformer(token_size, frame_hidden, num_heads, t_groupSize, mlp_ratio,
237
+ drop, n_vecs, t2t_params)
238
+ self.first_s_transformer = SpatialTransformer(token_size, frame_hidden, flow_hidden, num_heads, s_windowSize,
239
+ g_downSize, mlp_ratio, drop, n_vecs, t2t_params)
240
+ self.transformer = nn.Sequential(*blocks)
241
+ # vector to patch operation
242
+ self.vec2patch = Vec2Patch(cnum * 2, frame_hidden, output_shape, kernel_size, stride, padding)
243
+ # decoder
244
+ self.decoder = Decoder(conv_type, cnum * 2, 3, use_bias, norm)
245
+
246
+ if init_weights:
247
+ self.init_weights()
248
+
249
+ def forward(self, masked_frames, flows, masks):
250
+ b, t, c, h, w = masked_frames.shape
251
+ cf = flows.shape[2]
252
+ output_shape = (h // 4, w // 4)
253
+ if self.passmask:
254
+ inputs = torch.cat((masked_frames, masks), dim=2)
255
+ else:
256
+ inputs = masked_frames
257
+ inputs = inputs.view(b * t, self.in_channels, h, w)
258
+ flows = flows.view(b * t, cf, h, w)
259
+ enc_feats = self.frame_endoder(inputs)
260
+ flow_feats = self.flow_encoder(flows)
261
+ trans_feat = self.patch2vec(enc_feats)
262
+ flow_patches = self.f_patch2vec(flow_feats)
263
+ _, c, h, w = trans_feat.shape
264
+ cf = flow_patches.shape[1]
265
+ if h != self.token_size[0] or w != self.token_size[1]:
266
+ new_h, new_w = h, w
267
+ else:
268
+ new_h, new_w = 0, 0
269
+ output_shape = (0, 0)
270
+ trans_feat = trans_feat.view(b * t, c, -1).permute(0, 2, 1)
271
+ flow_patches = flow_patches.view(b * t, cf, -1).permute(0, 2, 1)
272
+ trans_feat = self.first_t_transformer(trans_feat, t, new_h, new_w, output_shape)
273
+ trans_feat = self.add_pos_emb(trans_feat, new_h, new_w)
274
+ trans_feat = self.first_s_transformer(trans_feat, flow_patches, t, new_h, new_w, output_shape)
275
+ inputs_trans_feat = {'x': trans_feat, 'f': flow_patches, 't': t, 'h': new_h, 'w': new_w,
276
+ 'output_size': output_shape}
277
+ trans_feat = self.transformer(inputs_trans_feat)['x']
278
+ trans_feat = self.vec2patch(trans_feat, output_shape[0], output_shape[1])
279
+ enc_feats = enc_feats + trans_feat
280
+
281
+ output = self.decoder(enc_feats)
282
+ output = torch.tanh(output)
283
+ return output
284
+
FGT_codes/FGT/models/temporal_patch_gan.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # temporal patch GAN to maintain the temporal consecutive of the flows
2
+ import torch
3
+ import torch.nn as nn
4
+ from .BaseNetwork import BaseNetwork
5
+
6
+
7
+ class Discriminator(BaseNetwork):
8
+ def __init__(self, in_channels, conv_type, dist_cnum, use_sigmoid=False, use_spectral_norm=True, init_weights=True):
9
+ """
10
+
11
+ Args:
12
+ in_channels: The input channels of the discriminator
13
+ use_sigmoid: Whether to use sigmoid for the base network (true for the nsgan)
14
+ use_spectral_norm: The usage of the spectral norm: always be true for the stability of GAN
15
+ init_weights: always be True
16
+ """
17
+ super(Discriminator, self).__init__(conv_type)
18
+ self.use_sigmoid = use_sigmoid
19
+ nf = dist_cnum
20
+
21
+ self.conv = nn.Sequential(
22
+ spectral_norm(
23
+ nn.Conv3d(in_channels=in_channels, out_channels=nf * 1, kernel_size=(3, 5, 5), stride=(1, 2, 2),
24
+ padding=(1, 2, 2),
25
+ bias=not use_spectral_norm), use_spectral_norm),
26
+ nn.LeakyReLU(0.2, inplace=True),
27
+ spectral_norm(
28
+ nn.Conv3d(in_channels=nf * 1, out_channels=nf * 2, kernel_size=(3, 5, 5), stride=(1, 2, 2),
29
+ padding=(1, 2, 2),
30
+ bias=not use_spectral_norm), use_spectral_norm),
31
+ nn.LeakyReLU(0.2, inplace=True),
32
+ spectral_norm(
33
+ nn.Conv3d(in_channels=nf * 2, out_channels=nf * 4, kernel_size=(3, 5, 5), stride=(1, 2, 2),
34
+ padding=(1, 2, 2),
35
+ bias=not use_spectral_norm), use_spectral_norm),
36
+ nn.LeakyReLU(0.2, inplace=True),
37
+ spectral_norm(
38
+ nn.Conv3d(in_channels=nf * 4, out_channels=nf * 4, kernel_size=(3, 5, 5), stride=(1, 2, 2),
39
+ padding=(1, 2, 2),
40
+ bias=not use_spectral_norm), use_spectral_norm),
41
+ nn.LeakyReLU(0.2, inplace=True),
42
+ spectral_norm(
43
+ nn.Conv3d(in_channels=nf * 4, out_channels=nf * 4, kernel_size=(3, 5, 5), stride=(1, 2, 2),
44
+ padding=(1, 2, 2),
45
+ bias=not use_spectral_norm), use_spectral_norm),
46
+ nn.LeakyReLU(0.2, inplace=True),
47
+ nn.Conv3d(in_channels=nf * 4, out_channels=nf * 4, kernel_size=(3, 5, 5), stride=(1, 2, 2),
48
+ padding=(1, 2, 2))
49
+ )
50
+
51
+ if init_weights:
52
+ self.init_weights()
53
+
54
+ def forward(self, xs, t):
55
+ """
56
+
57
+ Args:
58
+ xs: Input feature, with shape of [bt, c, h, w]
59
+
60
+ Returns: The discriminative map from the GAN
61
+
62
+ """
63
+ bt, c, h, w = xs.shape
64
+ b = bt // t
65
+ xs = xs.view(b, t, c, h, w).permute(0, 2, 1, 3, 4).contiguous()
66
+ feat = self.conv(xs)
67
+ if self.use_sigmoid:
68
+ feat = torch.sigmoid(feat)
69
+ out = torch.transpose(feat, 1, 2) # [b, t, c, h, w]
70
+ return out
71
+
72
+
73
+ def spectral_norm(module, mode=True):
74
+ if mode:
75
+ return nn.utils.spectral_norm(module)
76
+ return module
FGT_codes/FGT/models/transformer_base/__init__.py ADDED
File without changes
FGT_codes/FGT/models/transformer_base/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (180 Bytes). View file
 
FGT_codes/FGT/models/transformer_base/__pycache__/attention_base.cpython-39.pyc ADDED
Binary file (4.1 kB). View file
 
FGT_codes/FGT/models/transformer_base/__pycache__/attention_flow.cpython-39.pyc ADDED
Binary file (5.51 kB). View file
 
FGT_codes/FGT/models/transformer_base/__pycache__/ffn_base.cpython-39.pyc ADDED
Binary file (4.11 kB). View file
 
FGT_codes/FGT/models/transformer_base/attention_base.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class Attention(nn.Module):
8
+ """
9
+ Compute 'Scaled Dot Product Attention
10
+ """
11
+
12
+ def __init__(self, p=0.1):
13
+ super(Attention, self).__init__()
14
+ self.dropout = nn.Dropout(p=p)
15
+
16
+ def forward(self, query, key, value):
17
+ scores = torch.matmul(query, key.transpose(-2, -1)
18
+ ) / math.sqrt(query.size(-1))
19
+ p_attn = F.softmax(scores, dim=-1)
20
+ p_attn = self.dropout(p_attn)
21
+ p_val = torch.matmul(p_attn, value)
22
+ return p_val, p_attn
23
+
24
+
25
+ class TMHSA(nn.Module):
26
+ def __init__(self, token_size, group_size, d_model, head, p=0.1):
27
+ super(TMHSA, self).__init__()
28
+ self.h, self.w = token_size
29
+ self.group_size = group_size # 这里的group size表示可分的组
30
+ self.wh, self.ww = math.ceil(self.h / self.group_size), math.ceil(self.w / self.group_size)
31
+ self.pad_r = (self.ww - self.w % self.ww) % self.ww
32
+ self.pad_b = (self.wh - self.h % self.wh) % self.wh
33
+ self.new_h, self.new_w = self.h + self.pad_b, self.w + self.pad_r # 只在右侧和下侧进行padding,另一侧不padding,实现起来更加容易
34
+ self.window_h, self.window_w = self.new_h // self.group_size, self.new_w // self.group_size # 这里面的group表示的是窗口大小,而window_size表示的是group大小(与spatial的定义不同)
35
+ self.d_model = d_model
36
+ self.p = p
37
+ self.query_embedding = nn.Linear(d_model, d_model)
38
+ self.key_embedding = nn.Linear(d_model, d_model)
39
+ self.value_embedding = nn.Linear(d_model, d_model)
40
+ self.output_linear = nn.Linear(d_model, d_model)
41
+ self.attention = Attention(p=p)
42
+ self.head = head
43
+
44
+ def inference(self, x, t, h, w):
45
+ # calculate the attention related parameters
46
+ wh, ww = math.ceil(h / self.group_size), math.ceil(w / self.group_size)
47
+ pad_r = (ww - w % ww) % ww
48
+ pad_b = (wh - h % wh) % wh
49
+ new_h, new_w = h + pad_b, w + pad_r
50
+ window_h, window_w = new_h // self.group_size, new_w // self.group_size
51
+ bt, n, c = x.shape
52
+ b = bt // t
53
+ c_h = c // self.head
54
+ x = x.view(bt, h, w, c)
55
+ if pad_r > 0 or pad_b > 0:
56
+ x = F.pad(x,
57
+ (0, 0, 0, pad_r, 0, pad_b)) # channel, channel, left, right, top, bottom -> [bt, new_h, new_w, c]
58
+ query = self.query_embedding(x)
59
+ key = self.key_embedding(x)
60
+ value = self.value_embedding(x)
61
+ query = query.view(b, t, self.group_size, window_h, self.group_size, window_w, self.head, c_h)
62
+ query = query.permute(0, 2, 4, 6, 1, 3, 5, 7).reshape(b, self.group_size * self.group_size, self.head, -1, c_h)
63
+ key = key.view(b, t, self.group_size, window_h, self.group_size, window_w, self.head, c_h)
64
+ key = key.permute(0, 2, 4, 6, 1, 3, 5, 7).reshape(b, self.group_size * self.group_size, self.head, -1, c_h)
65
+ value = value.view(b, t, self.group_size, window_h, self.group_size, window_w, self.head, c_h)
66
+ value = value.permute(0, 2, 4, 6, 1, 3, 5, 7).reshape(b, self.group_size * self.group_size, self.head, -1, c_h)
67
+ att, _ = self.attention(query, key, value)
68
+ att = att.view(b, self.group_size, self.group_size, self.head, t, window_h, window_w, c_h)
69
+ att = att.permute(0, 4, 1, 5, 2, 6, 3, 7).contiguous().view(bt, new_h, new_w, c)
70
+ if pad_b > 0 or pad_r > 0:
71
+ att = att[:, :h, :w, :]
72
+ att = att.reshape(bt, n, c)
73
+ output = self.output_linear(att)
74
+ return output
75
+
76
+ def forward(self, x, t, h=0, w=0):
77
+ bt, n, c = x.shape
78
+ if h == 0 and w == 0:
79
+ assert n == self.h * self.w, 'Wrong input shape: {} with token: h->{}, w->{}'.format(x.shape, self.h,
80
+ self.w)
81
+ else:
82
+ assert n == h * w, 'Wrong input shape: {} with token: h->{}, w->{}'.format(x.shape, h, w)
83
+ return self.inference(x, t, h, w)
84
+ b = bt // t
85
+ c_h = c // self.head
86
+ x = x.view(bt, self.h, self.w, c)
87
+ if self.pad_r > 0 or self.pad_b > 0:
88
+ x = F.pad(x, (
89
+ 0, 0, 0, self.pad_r, 0, self.pad_b)) # channel, channel, left, right, top, bottom -> [bt, new_h, new_w, c]
90
+ query = self.query_embedding(x)
91
+ key = self.key_embedding(x)
92
+ value = self.value_embedding(x)
93
+ query = query.view(b, t, self.group_size, self.window_h, self.group_size, self.window_w, self.head, c_h)
94
+ query = query.permute(0, 2, 4, 6, 1, 3, 5, 7).reshape(b, self.group_size * self.group_size, self.head, -1, c_h)
95
+ key = key.view(b, t, self.group_size, self.window_h, self.group_size, self.window_w, self.head, c_h)
96
+ key = key.permute(0, 2, 4, 6, 1, 3, 5, 7).reshape(b, self.group_size * self.group_size, self.head, -1, c_h)
97
+ value = value.view(b, t, self.group_size, self.window_h, self.group_size, self.window_w, self.head, c_h)
98
+ value = value.permute(0, 2, 4, 6, 1, 3, 5, 7).reshape(b, self.group_size * self.group_size, self.head, -1, c_h)
99
+ att, _ = self.attention(query, key, value)
100
+ att = att.view(b, self.group_size, self.group_size, self.head, t, self.window_h, self.window_w, c_h)
101
+ att = att.permute(0, 4, 1, 5, 2, 6, 3, 7).contiguous().view(bt, self.new_h, self.new_w, c)
102
+ if self.pad_b > 0 or self.pad_r > 0:
103
+ att = att[:, :self.h, :self.w, :]
104
+ att = att.reshape(bt, n, c)
105
+ output = self.output_linear(att)
106
+ return output
FGT_codes/FGT/models/transformer_base/attention_flow.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class Attention(nn.Module):
8
+ """
9
+ Compute 'Scaled Dot Product Attention
10
+ """
11
+
12
+ def __init__(self, p=0.1):
13
+ super(Attention, self).__init__()
14
+ self.dropout = nn.Dropout(p=p)
15
+
16
+ def forward(self, query, key, value):
17
+ scores = torch.matmul(query, key.transpose(-2, -1)
18
+ ) / math.sqrt(query.size(-1))
19
+ p_attn = F.softmax(scores, dim=-1)
20
+ p_attn = self.dropout(p_attn)
21
+ p_val = torch.matmul(p_attn, value)
22
+ return p_val, p_attn
23
+
24
+
25
+ class SWMHSA_depthGlobalWindowConcatLN_qkFlow_reweightFlow(nn.Module):
26
+ def __init__(self, token_size, window_size, kernel_size, d_model, flow_dModel, head, p=0.1):
27
+ super(SWMHSA_depthGlobalWindowConcatLN_qkFlow_reweightFlow, self).__init__()
28
+ self.h, self.w = token_size
29
+ self.head = head
30
+ self.window_size = window_size
31
+ self.d_model = d_model
32
+ self.flow_dModel = flow_dModel
33
+ in_channels = d_model + flow_dModel
34
+ self.query_embedding = nn.Linear(in_channels, d_model)
35
+ self.key_embedding = nn.Linear(in_channels, d_model)
36
+ self.value_embedding = nn.Linear(d_model, d_model)
37
+ self.output_linear = nn.Linear(d_model, d_model)
38
+ self.attention = Attention(p)
39
+ self.pad_l = self.pad_t = 0
40
+ self.pad_r = (self.window_size - self.w % self.window_size) % self.window_size
41
+ self.pad_b = (self.window_size - self.h % self.window_size) % self.window_size
42
+ self.new_h, self.new_w = self.h + self.pad_b, self.w + self.pad_r
43
+ self.group_h, self.group_w = self.new_h // self.window_size, self.new_w // self.window_size
44
+ self.global_extract_v = nn.Conv2d(d_model, d_model, kernel_size=kernel_size, stride=kernel_size, padding=0,
45
+ groups=d_model)
46
+ self.global_extract_k = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=kernel_size,
47
+ padding=0,
48
+ groups=in_channels)
49
+ self.q_norm = nn.LayerNorm(d_model + flow_dModel)
50
+ self.k_norm = nn.LayerNorm(d_model + flow_dModel)
51
+ self.v_norm = nn.LayerNorm(d_model)
52
+ self.reweightFlow = nn.Sequential(
53
+ nn.Linear(in_channels, flow_dModel),
54
+ nn.Sigmoid()
55
+ )
56
+
57
+ def inference(self, x, f, h, w):
58
+ pad_r = (self.window_size - w % self.window_size) % self.window_size
59
+ pad_b = (self.window_size - h % self.window_size) % self.window_size
60
+ new_h, new_w = h + pad_b, w + pad_r
61
+ group_h, group_w = new_h // self.window_size, new_w // self.window_size
62
+ bt, n, c = x.shape
63
+ cf = f.shape[2]
64
+ x = x.view(bt, h, w, c)
65
+ f = f.view(bt, h, w, cf)
66
+ if pad_r > 0 or pad_b > 0:
67
+ x = F.pad(x, (0, 0, self.pad_l, pad_r, self.pad_t, pad_b))
68
+ f = F.pad(f, (0, 0, self.pad_l, pad_r, self.pad_t, pad_b))
69
+ y = x.permute(0, 3, 1, 2)
70
+ xf = torch.cat((x, f), dim=-1)
71
+ flow_weights = self.reweightFlow(xf)
72
+ f = f * flow_weights
73
+ qk = torch.cat((x, f), dim=-1) # [b, h, w, c]
74
+ qk_c = qk.shape[-1]
75
+ # generate q
76
+ q = qk.reshape(bt, group_h, self.window_size, group_w, self.window_size, qk_c).transpose(2, 3)
77
+ q = q.reshape(bt, group_h * group_w, self.window_size * self.window_size, qk_c)
78
+ # generate k
79
+ ky = qk.permute(0, 3, 1, 2) # [b, c, h, w]
80
+ k_global = self.global_extract_k(ky)
81
+ k_global = k_global.permute(0, 2, 3, 1).reshape(bt, -1, qk_c).unsqueeze(1).repeat(1, group_h * group_w, 1, 1)
82
+ k = torch.cat((q, k_global), dim=2)
83
+ # norm q and k
84
+ q = self.q_norm(q)
85
+ k = self.k_norm(k)
86
+ # generate v
87
+ global_tokens = self.global_extract_v(y) # [bt, c, h', w']
88
+ global_tokens = global_tokens.permute(0, 2, 3, 1).reshape(bt, -1, c).unsqueeze(1).repeat(1,
89
+ group_h * group_w,
90
+ 1,
91
+ 1) # [bt, gh * gw, h'*w', c]
92
+ x = x.reshape(bt, group_h, self.window_size, group_w, self.window_size, c).transpose(2,
93
+ 3) # [bt, gh, gw, ws, ws, c]
94
+ x = x.reshape(bt, group_h * group_w, self.window_size * self.window_size, c) # [bt, gh * gw, ws^2, c]
95
+ v = torch.cat((x, global_tokens), dim=2)
96
+ v = self.v_norm(v)
97
+ query = self.query_embedding(q) # [bt, self.group_h, self.group_w, self.window_size, self.window_size, c]
98
+ key = self.key_embedding(k)
99
+ value = self.value_embedding(v)
100
+ query = query.reshape(bt, group_h * group_w, self.window_size * self.window_size, self.head,
101
+ c // self.head).permute(0, 1, 3, 2, 4)
102
+ key = key.reshape(bt, group_h * group_w, -1, self.head,
103
+ c // self.head).permute(0, 1, 3, 2, 4)
104
+ value = value.reshape(bt, group_h * group_w, -1, self.head,
105
+ c // self.head).permute(0, 1, 3, 2, 4)
106
+ attn, _ = self.attention(query, key, value)
107
+ x = attn.transpose(2, 3).reshape(bt, group_h, group_w, self.window_size, self.window_size, c)
108
+ x = x.transpose(2, 3).reshape(bt, group_h * self.window_size, group_w * self.window_size, c)
109
+ if pad_r > 0 or pad_b > 0:
110
+ x = x[:, :h, :w, :].contiguous()
111
+ x = x.reshape(bt, n, c)
112
+ output = self.output_linear(x)
113
+ return output
114
+
115
+ def forward(self, x, f, t, h=0, w=0):
116
+ if h != 0 or w != 0:
117
+ return self.inference(x, f, h, w)
118
+ bt, n, c = x.shape
119
+ cf = f.shape[2]
120
+ x = x.view(bt, self.h, self.w, c)
121
+ f = f.view(bt, self.h, self.w, cf)
122
+ if self.pad_r > 0 or self.pad_b > 0:
123
+ x = F.pad(x, (0, 0, self.pad_l, self.pad_r, self.pad_t, self.pad_b))
124
+ f = F.pad(f, (0, 0, self.pad_l, self.pad_r, self.pad_t, self.pad_b)) # [bt, cf, h, w]
125
+ y = x.permute(0, 3, 1, 2)
126
+ xf = torch.cat((x, f), dim=-1)
127
+ weights = self.reweightFlow(xf)
128
+ f = f * weights
129
+ qk = torch.cat((x, f), dim=-1) # [b, h, w, c]
130
+ qk_c = qk.shape[-1]
131
+ # generate q
132
+ q = qk.reshape(bt, self.group_h, self.window_size, self.group_w, self.window_size, qk_c).transpose(2, 3)
133
+ q = q.reshape(bt, self.group_h * self.group_w, self.window_size * self.window_size, qk_c)
134
+ # generate k
135
+ ky = qk.permute(0, 3, 1, 2) # [b, c, h, w]
136
+ k_global = self.global_extract_k(ky) # [b, qk_c, h, w]
137
+ k_global = k_global.permute(0, 2, 3, 1).reshape(bt, -1, qk_c).unsqueeze(1).repeat(1,
138
+ self.group_h * self.group_w,
139
+ 1, 1)
140
+ k = torch.cat((q, k_global), dim=2)
141
+ # norm q and k
142
+ q = self.q_norm(q)
143
+ k = self.k_norm(k)
144
+ # generate v
145
+ global_tokens = self.global_extract_v(y) # [bt, c, h', w']
146
+ global_tokens = global_tokens.permute(0, 2, 3, 1).reshape(bt, -1, c).unsqueeze(1).repeat(1,
147
+ self.group_h * self.group_w,
148
+ 1,
149
+ 1) # [bt, gh * gw, h'*w', c]
150
+ x = x.reshape(bt, self.group_h, self.window_size, self.group_w, self.window_size, c).transpose(2,
151
+ 3) # [bt, gh, gw, ws, ws, c]
152
+ x = x.reshape(bt, self.group_h * self.group_w, self.window_size * self.window_size, c) # [bt, gh * gw, ws^2, c]
153
+ v = torch.cat((x, global_tokens), dim=2)
154
+ v = self.v_norm(v)
155
+ query = self.query_embedding(q) # [bt, self.group_h, self.group_w, self.window_size, self.window_size, c]
156
+ key = self.key_embedding(k)
157
+ value = self.value_embedding(v)
158
+ query = query.reshape(bt, self.group_h * self.group_w, self.window_size * self.window_size, self.head,
159
+ c // self.head).permute(0, 1, 3, 2, 4)
160
+ key = key.reshape(bt, self.group_h * self.group_w, -1, self.head,
161
+ c // self.head).permute(0, 1, 3, 2, 4)
162
+ value = value.reshape(bt, self.group_h * self.group_w, -1, self.head,
163
+ c // self.head).permute(0, 1, 3, 2, 4)
164
+ attn, _ = self.attention(query, key, value)
165
+ x = attn.transpose(2, 3).reshape(bt, self.group_h, self.group_w, self.window_size, self.window_size, c)
166
+ x = x.transpose(2, 3).reshape(bt, self.group_h * self.window_size, self.group_w * self.window_size, c)
167
+ if self.pad_r > 0 or self.pad_b > 0:
168
+ x = x[:, :self.h, :self.w, :].contiguous()
169
+ x = x.reshape(bt, n, c)
170
+ output = self.output_linear(x)
171
+ return output
FGT_codes/FGT/models/transformer_base/ffn_base.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from functools import reduce
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from functools import partial
6
+
7
+
8
+ class FeedForward(nn.Module):
9
+ def __init__(self, frame_hidden, mlp_ratio, n_vecs, t2t_params, p):
10
+ """
11
+
12
+ Args:
13
+ frame_hidden: hidden size of frame features
14
+ mlp_ratio: mlp ratio in the middle layer of the transformers
15
+ n_vecs: number of vectors in the transformer
16
+ t2t_params: dictionary -> {'kernel_size': kernel_size, 'stride': stride, 'padding': padding, 'output_size': output_shape}
17
+ p: dropout rate, 0 by default
18
+ """
19
+ super(FeedForward, self).__init__()
20
+ self.conv = nn.Sequential(
21
+ nn.Linear(frame_hidden, frame_hidden * mlp_ratio),
22
+ nn.ReLU(inplace=True),
23
+ nn.Dropout(p),
24
+ nn.Linear(frame_hidden * mlp_ratio, frame_hidden),
25
+ nn.Dropout(p)
26
+ )
27
+
28
+ def forward(self, x, n_vecs=0, output_h=0, output_w=0):
29
+ x = self.conv(x)
30
+ return x
31
+
32
+
33
+ class FusionFeedForward(nn.Module):
34
+ def __init__(self, frame_hidden, mlp_ratio, n_vecs, t2t_params, p):
35
+ super(FusionFeedForward, self).__init__()
36
+ self.kernel_shape = reduce((lambda x, y: x * y), t2t_params['kernel_size'])
37
+ self.t2t_params = t2t_params
38
+ hidden_size = self.kernel_shape * mlp_ratio
39
+ self.conv1 = nn.Linear(frame_hidden, hidden_size)
40
+ self.conv2 = nn.Sequential(
41
+ nn.ReLU(inplace=True),
42
+ nn.Dropout(p),
43
+ nn.Linear(hidden_size, frame_hidden),
44
+ nn.Dropout(p)
45
+ )
46
+ assert t2t_params is not None and n_vecs is not None
47
+ tp = t2t_params.copy()
48
+ self.fold = nn.Fold(**tp)
49
+ del tp['output_size']
50
+ self.unfold = nn.Unfold(**tp)
51
+ self.n_vecs = n_vecs
52
+
53
+ def forward(self, x, n_vecs=0, output_h=0, output_w=0):
54
+ x = self.conv1(x)
55
+ b, n, c = x.size()
56
+ if n_vecs != 0:
57
+ normalizer = x.new_ones(b, n, self.kernel_shape).view(-1, n_vecs, self.kernel_shape).permute(0, 2, 1)
58
+ x = self.unfold(F.fold(x.view(-1, n_vecs, c).permute(0, 2, 1), output_size=(output_h, output_w),
59
+ kernel_size=self.t2t_params['kernel_size'], stride=self.t2t_params['stride'],
60
+ padding=self.t2t_params['padding']) / F.fold(normalizer,
61
+ output_size=(output_h, output_w),
62
+ kernel_size=self.t2t_params[
63
+ 'kernel_size'],
64
+ stride=self.t2t_params['stride'],
65
+ padding=self.t2t_params[
66
+ 'padding'])).permute(0,
67
+ 2,
68
+ 1).contiguous().view(
69
+ b, n, c)
70
+ else:
71
+ normalizer = x.new_ones(b, n, self.kernel_shape).view(-1, self.n_vecs, self.kernel_shape).permute(0, 2, 1)
72
+ x = self.unfold(self.fold(x.view(-1, self.n_vecs, c).permute(0, 2, 1)) / self.fold(normalizer)).permute(0,
73
+ 2,
74
+ 1).contiguous().view(
75
+ b, n, c)
76
+ x = self.conv2(x)
77
+ return x
78
+
79
+
80
+ class ResidualBlock_noBN(nn.Module):
81
+ """Residual block w/o BN
82
+ ---Conv-ReLU-Conv-+-
83
+ |________________|
84
+ """
85
+
86
+ def __init__(self, nf=64):
87
+ super(ResidualBlock_noBN, self).__init__()
88
+ self.conv1 = nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1, bias=True)
89
+ self.conv2 = nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1, bias=True)
90
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
91
+
92
+ def forward(self, x):
93
+ """
94
+
95
+ Args:
96
+ x: with shape of [b, c, t, h, w]
97
+
98
+ Returns: processed features with shape [b, c, t, h, w]
99
+
100
+ """
101
+ identity = x
102
+ out = self.lrelu(self.conv1(x))
103
+ out = self.conv2(out)
104
+ out = identity + out
105
+ # Remove ReLU at the end of the residual block
106
+ # http://torch.ch/blog/2016/02/04/resnets.html
107
+ return out
108
+
109
+
110
+ def make_layer(block, n_layers):
111
+ layers = []
112
+ for _ in range(n_layers):
113
+ layers.append(block())
114
+ return nn.Sequential(*layers)
FGT_codes/FGT/models/utils/RAFT/utils/__init__.py ADDED
File without changes
FGT_codes/FGT/models/utils/RAFT/utils/utils.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+ from scipy import interpolate
5
+
6
+
7
+ class InputPadder:
8
+ """ Pads images such that dimensions are divisible by 8 """
9
+ def __init__(self, dims, mode='sintel'):
10
+ self.ht, self.wd = dims[-2:]
11
+ pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
12
+ pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
13
+ if mode == 'sintel':
14
+ self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
15
+ else:
16
+ self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
17
+
18
+ def pad(self, *inputs):
19
+ return [F.pad(x, self._pad, mode='replicate') for x in inputs]
20
+
21
+ def unpad(self,x):
22
+ ht, wd = x.shape[-2:]
23
+ c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
24
+ return x[..., c[0]:c[1], c[2]:c[3]]
25
+
26
+ def forward_interpolate(flow):
27
+ flow = flow.detach().cpu().numpy()
28
+ dx, dy = flow[0], flow[1]
29
+
30
+ ht, wd = dx.shape
31
+ x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
32
+
33
+ x1 = x0 + dx
34
+ y1 = y0 + dy
35
+
36
+ x1 = x1.reshape(-1)
37
+ y1 = y1.reshape(-1)
38
+ dx = dx.reshape(-1)
39
+ dy = dy.reshape(-1)
40
+
41
+ valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
42
+ x1 = x1[valid]
43
+ y1 = y1[valid]
44
+ dx = dx[valid]
45
+ dy = dy[valid]
46
+
47
+ flow_x = interpolate.griddata(
48
+ (x1, y1), dx, (x0, y0), method='nearest', fill_value=0)
49
+
50
+ flow_y = interpolate.griddata(
51
+ (x1, y1), dy, (x0, y0), method='nearest', fill_value=0)
52
+
53
+ flow = np.stack([flow_x, flow_y], axis=0)
54
+ return torch.from_numpy(flow).float()
55
+
56
+
57
+ def bilinear_sampler(img, coords, mode='bilinear', mask=False):
58
+ """ Wrapper for grid_sample, uses pixel coordinates """
59
+ H, W = img.shape[-2:]
60
+ xgrid, ygrid = coords.split([1,1], dim=-1)
61
+ xgrid = 2*xgrid/(W-1) - 1
62
+ ygrid = 2*ygrid/(H-1) - 1
63
+
64
+ grid = torch.cat([xgrid, ygrid], dim=-1)
65
+ img = F.grid_sample(img, grid, align_corners=True)
66
+
67
+ if mask:
68
+ mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
69
+ return img, mask.float()
70
+
71
+ return img
72
+
73
+
74
+ def coords_grid(batch, ht, wd):
75
+ coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
76
+ coords = torch.stack(coords[::-1], dim=0).float()
77
+ return coords[None].repeat(batch, 1, 1, 1)
78
+
79
+
80
+ def upflow8(flow, mode='bilinear'):
81
+ new_size = (8 * flow.shape[2], 8 * flow.shape[3])
82
+ return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
FGT_codes/FGT/models/utils/__init__.py ADDED
File without changes
FGT_codes/FGT/models/utils/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (169 Bytes). View file
 
FGT_codes/FGT/models/utils/__pycache__/network_blocks_2d.cpython-39.pyc ADDED
Binary file (5.41 kB). View file