Spaces:
Runtime error
Runtime error
maduvantha
commited on
Commit
•
eddf80e
1
Parent(s):
e6adef9
Upload 168 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +6 -0
- __pycache__/animate.cpython-310.pyc +0 -0
- __pycache__/app.cpython-310.pyc +0 -0
- __pycache__/augmentation.cpython-310.pyc +0 -0
- __pycache__/demo.cpython-310.pyc +0 -0
- __pycache__/frames_dataset.cpython-310.pyc +0 -0
- __pycache__/logger.cpython-310.pyc +0 -0
- __pycache__/some.cpython-310.pyc +0 -0
- config/bair-256.yaml +82 -0
- config/fashion-256.yaml +77 -0
- config/mgif-256.yaml +84 -0
- config/nemo-256.yaml +76 -0
- config/taichi-256.yaml +157 -0
- config/taichi-adv-256.yaml +150 -0
- config/vox-256.yaml +83 -0
- config/vox-adv-256.yaml +84 -0
- data/bair256.csv +51 -0
- data/taichi-loading/README.md +18 -0
- data/taichi-loading/load_videos.py +113 -0
- data/taichi-loading/taichi-metadata.csv +0 -0
- data/taichi256.csv +51 -0
- modules/__pycache__/dense_motion.cpython-310.pyc +0 -0
- modules/__pycache__/generator.cpython-310.pyc +0 -0
- modules/__pycache__/keypoint_detector.cpython-310.pyc +0 -0
- modules/__pycache__/util.cpython-310.pyc +0 -0
- modules/dense_motion.py +113 -0
- modules/discriminator.py +95 -0
- modules/generator.py +97 -0
- modules/keypoint_detector.py +75 -0
- modules/model.py +259 -0
- modules/util.py +245 -0
- share/doc/networkx-3.0/LICENSE.txt +37 -0
- share/doc/networkx-3.0/examples/3d_drawing/README.txt +2 -0
- share/doc/networkx-3.0/examples/3d_drawing/__pycache__/mayavi2_spring.cpython-310.pyc +0 -0
- share/doc/networkx-3.0/examples/3d_drawing/__pycache__/plot_basic.cpython-310.pyc +0 -0
- share/doc/networkx-3.0/examples/3d_drawing/mayavi2_spring.py +43 -0
- share/doc/networkx-3.0/examples/3d_drawing/plot_basic.py +51 -0
- share/doc/networkx-3.0/examples/README.txt +8 -0
- share/doc/networkx-3.0/examples/algorithms/README.txt +2 -0
- share/doc/networkx-3.0/examples/algorithms/WormNet.v3.benchmark.txt +0 -0
- share/doc/networkx-3.0/examples/algorithms/__pycache__/plot_beam_search.cpython-310.pyc +0 -0
- share/doc/networkx-3.0/examples/algorithms/__pycache__/plot_betweenness_centrality.cpython-310.pyc +0 -0
- share/doc/networkx-3.0/examples/algorithms/__pycache__/plot_blockmodel.cpython-310.pyc +0 -0
- share/doc/networkx-3.0/examples/algorithms/__pycache__/plot_circuits.cpython-310.pyc +0 -0
- share/doc/networkx-3.0/examples/algorithms/__pycache__/plot_davis_club.cpython-310.pyc +0 -0
- share/doc/networkx-3.0/examples/algorithms/__pycache__/plot_dedensification.cpython-310.pyc +0 -0
- share/doc/networkx-3.0/examples/algorithms/__pycache__/plot_iterated_dynamical_systems.cpython-310.pyc +0 -0
- share/doc/networkx-3.0/examples/algorithms/__pycache__/plot_krackhardt_centrality.cpython-310.pyc +0 -0
- share/doc/networkx-3.0/examples/algorithms/__pycache__/plot_parallel_betweenness.cpython-310.pyc +0 -0
- share/doc/networkx-3.0/examples/algorithms/__pycache__/plot_rcm.cpython-310.pyc +0 -0
.gitattributes
CHANGED
@@ -32,3 +32,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
sup-mat/absolute-demo.gif filter=lfs diff=lfs merge=lfs -text
|
36 |
+
sup-mat/face-swap.gif filter=lfs diff=lfs merge=lfs -text
|
37 |
+
sup-mat/fashion-teaser.gif filter=lfs diff=lfs merge=lfs -text
|
38 |
+
sup-mat/mgif-teaser.gif filter=lfs diff=lfs merge=lfs -text
|
39 |
+
sup-mat/relative-demo.gif filter=lfs diff=lfs merge=lfs -text
|
40 |
+
sup-mat/vox-teaser.gif filter=lfs diff=lfs merge=lfs -text
|
__pycache__/animate.cpython-310.pyc
ADDED
Binary file (3.1 kB). View file
|
|
__pycache__/app.cpython-310.pyc
ADDED
Binary file (820 Bytes). View file
|
|
__pycache__/augmentation.cpython-310.pyc
ADDED
Binary file (11.2 kB). View file
|
|
__pycache__/demo.cpython-310.pyc
ADDED
Binary file (5.34 kB). View file
|
|
__pycache__/frames_dataset.cpython-310.pyc
ADDED
Binary file (6.69 kB). View file
|
|
__pycache__/logger.cpython-310.pyc
ADDED
Binary file (7.4 kB). View file
|
|
__pycache__/some.cpython-310.pyc
ADDED
Binary file (2.14 kB). View file
|
|
config/bair-256.yaml
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset_params:
|
2 |
+
root_dir: data/bair
|
3 |
+
frame_shape: [256, 256, 3]
|
4 |
+
id_sampling: False
|
5 |
+
augmentation_params:
|
6 |
+
flip_param:
|
7 |
+
horizontal_flip: True
|
8 |
+
time_flip: True
|
9 |
+
jitter_param:
|
10 |
+
brightness: 0.1
|
11 |
+
contrast: 0.1
|
12 |
+
saturation: 0.1
|
13 |
+
hue: 0.1
|
14 |
+
|
15 |
+
|
16 |
+
model_params:
|
17 |
+
common_params:
|
18 |
+
num_kp: 10
|
19 |
+
num_channels: 3
|
20 |
+
estimate_jacobian: True
|
21 |
+
kp_detector_params:
|
22 |
+
temperature: 0.1
|
23 |
+
block_expansion: 32
|
24 |
+
max_features: 1024
|
25 |
+
scale_factor: 0.25
|
26 |
+
num_blocks: 5
|
27 |
+
generator_params:
|
28 |
+
block_expansion: 64
|
29 |
+
max_features: 512
|
30 |
+
num_down_blocks: 2
|
31 |
+
num_bottleneck_blocks: 6
|
32 |
+
estimate_occlusion_map: True
|
33 |
+
dense_motion_params:
|
34 |
+
block_expansion: 64
|
35 |
+
max_features: 1024
|
36 |
+
num_blocks: 5
|
37 |
+
scale_factor: 0.25
|
38 |
+
discriminator_params:
|
39 |
+
scales: [1]
|
40 |
+
block_expansion: 32
|
41 |
+
max_features: 512
|
42 |
+
num_blocks: 4
|
43 |
+
sn: True
|
44 |
+
|
45 |
+
train_params:
|
46 |
+
num_epochs: 20
|
47 |
+
num_repeats: 1
|
48 |
+
epoch_milestones: [12, 18]
|
49 |
+
lr_generator: 2.0e-4
|
50 |
+
lr_discriminator: 2.0e-4
|
51 |
+
lr_kp_detector: 2.0e-4
|
52 |
+
batch_size: 36
|
53 |
+
scales: [1, 0.5, 0.25, 0.125]
|
54 |
+
checkpoint_freq: 10
|
55 |
+
transform_params:
|
56 |
+
sigma_affine: 0.05
|
57 |
+
sigma_tps: 0.005
|
58 |
+
points_tps: 5
|
59 |
+
loss_weights:
|
60 |
+
generator_gan: 1
|
61 |
+
discriminator_gan: 1
|
62 |
+
feature_matching: [10, 10, 10, 10]
|
63 |
+
perceptual: [10, 10, 10, 10, 10]
|
64 |
+
equivariance_value: 10
|
65 |
+
equivariance_jacobian: 10
|
66 |
+
|
67 |
+
reconstruction_params:
|
68 |
+
num_videos: 1000
|
69 |
+
format: '.mp4'
|
70 |
+
|
71 |
+
animate_params:
|
72 |
+
num_pairs: 50
|
73 |
+
format: '.mp4'
|
74 |
+
normalization_params:
|
75 |
+
adapt_movement_scale: False
|
76 |
+
use_relative_movement: True
|
77 |
+
use_relative_jacobian: True
|
78 |
+
|
79 |
+
visualizer_params:
|
80 |
+
kp_size: 5
|
81 |
+
draw_border: True
|
82 |
+
colormap: 'gist_rainbow'
|
config/fashion-256.yaml
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset_params:
|
2 |
+
root_dir: data/fashion-png
|
3 |
+
frame_shape: [256, 256, 3]
|
4 |
+
id_sampling: False
|
5 |
+
augmentation_params:
|
6 |
+
flip_param:
|
7 |
+
horizontal_flip: True
|
8 |
+
time_flip: True
|
9 |
+
jitter_param:
|
10 |
+
hue: 0.1
|
11 |
+
|
12 |
+
model_params:
|
13 |
+
common_params:
|
14 |
+
num_kp: 10
|
15 |
+
num_channels: 3
|
16 |
+
estimate_jacobian: True
|
17 |
+
kp_detector_params:
|
18 |
+
temperature: 0.1
|
19 |
+
block_expansion: 32
|
20 |
+
max_features: 1024
|
21 |
+
scale_factor: 0.25
|
22 |
+
num_blocks: 5
|
23 |
+
generator_params:
|
24 |
+
block_expansion: 64
|
25 |
+
max_features: 512
|
26 |
+
num_down_blocks: 2
|
27 |
+
num_bottleneck_blocks: 6
|
28 |
+
estimate_occlusion_map: True
|
29 |
+
dense_motion_params:
|
30 |
+
block_expansion: 64
|
31 |
+
max_features: 1024
|
32 |
+
num_blocks: 5
|
33 |
+
scale_factor: 0.25
|
34 |
+
discriminator_params:
|
35 |
+
scales: [1]
|
36 |
+
block_expansion: 32
|
37 |
+
max_features: 512
|
38 |
+
num_blocks: 4
|
39 |
+
|
40 |
+
train_params:
|
41 |
+
num_epochs: 100
|
42 |
+
num_repeats: 50
|
43 |
+
epoch_milestones: [60, 90]
|
44 |
+
lr_generator: 2.0e-4
|
45 |
+
lr_discriminator: 2.0e-4
|
46 |
+
lr_kp_detector: 2.0e-4
|
47 |
+
batch_size: 27
|
48 |
+
scales: [1, 0.5, 0.25, 0.125]
|
49 |
+
checkpoint_freq: 50
|
50 |
+
transform_params:
|
51 |
+
sigma_affine: 0.05
|
52 |
+
sigma_tps: 0.005
|
53 |
+
points_tps: 5
|
54 |
+
loss_weights:
|
55 |
+
generator_gan: 1
|
56 |
+
discriminator_gan: 1
|
57 |
+
feature_matching: [10, 10, 10, 10]
|
58 |
+
perceptual: [10, 10, 10, 10, 10]
|
59 |
+
equivariance_value: 10
|
60 |
+
equivariance_jacobian: 10
|
61 |
+
|
62 |
+
reconstruction_params:
|
63 |
+
num_videos: 1000
|
64 |
+
format: '.mp4'
|
65 |
+
|
66 |
+
animate_params:
|
67 |
+
num_pairs: 50
|
68 |
+
format: '.mp4'
|
69 |
+
normalization_params:
|
70 |
+
adapt_movement_scale: False
|
71 |
+
use_relative_movement: True
|
72 |
+
use_relative_jacobian: True
|
73 |
+
|
74 |
+
visualizer_params:
|
75 |
+
kp_size: 5
|
76 |
+
draw_border: True
|
77 |
+
colormap: 'gist_rainbow'
|
config/mgif-256.yaml
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset_params:
|
2 |
+
root_dir: data/moving-gif
|
3 |
+
frame_shape: [256, 256, 3]
|
4 |
+
id_sampling: False
|
5 |
+
augmentation_params:
|
6 |
+
flip_param:
|
7 |
+
horizontal_flip: True
|
8 |
+
time_flip: True
|
9 |
+
crop_param:
|
10 |
+
size: [256, 256]
|
11 |
+
resize_param:
|
12 |
+
ratio: [0.9, 1.1]
|
13 |
+
jitter_param:
|
14 |
+
hue: 0.5
|
15 |
+
|
16 |
+
model_params:
|
17 |
+
common_params:
|
18 |
+
num_kp: 10
|
19 |
+
num_channels: 3
|
20 |
+
estimate_jacobian: True
|
21 |
+
kp_detector_params:
|
22 |
+
temperature: 0.1
|
23 |
+
block_expansion: 32
|
24 |
+
max_features: 1024
|
25 |
+
scale_factor: 0.25
|
26 |
+
num_blocks: 5
|
27 |
+
single_jacobian_map: True
|
28 |
+
generator_params:
|
29 |
+
block_expansion: 64
|
30 |
+
max_features: 512
|
31 |
+
num_down_blocks: 2
|
32 |
+
num_bottleneck_blocks: 6
|
33 |
+
estimate_occlusion_map: True
|
34 |
+
dense_motion_params:
|
35 |
+
block_expansion: 64
|
36 |
+
max_features: 1024
|
37 |
+
num_blocks: 5
|
38 |
+
scale_factor: 0.25
|
39 |
+
discriminator_params:
|
40 |
+
scales: [1]
|
41 |
+
block_expansion: 32
|
42 |
+
max_features: 512
|
43 |
+
num_blocks: 4
|
44 |
+
sn: True
|
45 |
+
|
46 |
+
train_params:
|
47 |
+
num_epochs: 100
|
48 |
+
num_repeats: 25
|
49 |
+
epoch_milestones: [60, 90]
|
50 |
+
lr_generator: 2.0e-4
|
51 |
+
lr_discriminator: 2.0e-4
|
52 |
+
lr_kp_detector: 2.0e-4
|
53 |
+
|
54 |
+
batch_size: 36
|
55 |
+
scales: [1, 0.5, 0.25, 0.125]
|
56 |
+
checkpoint_freq: 100
|
57 |
+
transform_params:
|
58 |
+
sigma_affine: 0.05
|
59 |
+
sigma_tps: 0.005
|
60 |
+
points_tps: 5
|
61 |
+
loss_weights:
|
62 |
+
generator_gan: 1
|
63 |
+
discriminator_gan: 1
|
64 |
+
feature_matching: [10, 10, 10, 10]
|
65 |
+
perceptual: [10, 10, 10, 10, 10]
|
66 |
+
equivariance_value: 10
|
67 |
+
equivariance_jacobian: 10
|
68 |
+
|
69 |
+
reconstruction_params:
|
70 |
+
num_videos: 1000
|
71 |
+
format: '.mp4'
|
72 |
+
|
73 |
+
animate_params:
|
74 |
+
num_pairs: 50
|
75 |
+
format: '.mp4'
|
76 |
+
normalization_params:
|
77 |
+
adapt_movement_scale: False
|
78 |
+
use_relative_movement: True
|
79 |
+
use_relative_jacobian: True
|
80 |
+
|
81 |
+
visualizer_params:
|
82 |
+
kp_size: 5
|
83 |
+
draw_border: True
|
84 |
+
colormap: 'gist_rainbow'
|
config/nemo-256.yaml
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset_params:
|
2 |
+
root_dir: data/nemo-png
|
3 |
+
frame_shape: [256, 256, 3]
|
4 |
+
id_sampling: False
|
5 |
+
augmentation_params:
|
6 |
+
flip_param:
|
7 |
+
horizontal_flip: True
|
8 |
+
time_flip: True
|
9 |
+
|
10 |
+
model_params:
|
11 |
+
common_params:
|
12 |
+
num_kp: 10
|
13 |
+
num_channels: 3
|
14 |
+
estimate_jacobian: True
|
15 |
+
kp_detector_params:
|
16 |
+
temperature: 0.1
|
17 |
+
block_expansion: 32
|
18 |
+
max_features: 1024
|
19 |
+
scale_factor: 0.25
|
20 |
+
num_blocks: 5
|
21 |
+
generator_params:
|
22 |
+
block_expansion: 64
|
23 |
+
max_features: 512
|
24 |
+
num_down_blocks: 2
|
25 |
+
num_bottleneck_blocks: 6
|
26 |
+
estimate_occlusion_map: True
|
27 |
+
dense_motion_params:
|
28 |
+
block_expansion: 64
|
29 |
+
max_features: 1024
|
30 |
+
num_blocks: 5
|
31 |
+
scale_factor: 0.25
|
32 |
+
discriminator_params:
|
33 |
+
scales: [1]
|
34 |
+
block_expansion: 32
|
35 |
+
max_features: 512
|
36 |
+
num_blocks: 4
|
37 |
+
sn: True
|
38 |
+
|
39 |
+
train_params:
|
40 |
+
num_epochs: 100
|
41 |
+
num_repeats: 8
|
42 |
+
epoch_milestones: [60, 90]
|
43 |
+
lr_generator: 2.0e-4
|
44 |
+
lr_discriminator: 2.0e-4
|
45 |
+
lr_kp_detector: 2.0e-4
|
46 |
+
batch_size: 36
|
47 |
+
scales: [1, 0.5, 0.25, 0.125]
|
48 |
+
checkpoint_freq: 50
|
49 |
+
transform_params:
|
50 |
+
sigma_affine: 0.05
|
51 |
+
sigma_tps: 0.005
|
52 |
+
points_tps: 5
|
53 |
+
loss_weights:
|
54 |
+
generator_gan: 1
|
55 |
+
discriminator_gan: 1
|
56 |
+
feature_matching: [10, 10, 10, 10]
|
57 |
+
perceptual: [10, 10, 10, 10, 10]
|
58 |
+
equivariance_value: 10
|
59 |
+
equivariance_jacobian: 10
|
60 |
+
|
61 |
+
reconstruction_params:
|
62 |
+
num_videos: 1000
|
63 |
+
format: '.mp4'
|
64 |
+
|
65 |
+
animate_params:
|
66 |
+
num_pairs: 50
|
67 |
+
format: '.mp4'
|
68 |
+
normalization_params:
|
69 |
+
adapt_movement_scale: False
|
70 |
+
use_relative_movement: True
|
71 |
+
use_relative_jacobian: True
|
72 |
+
|
73 |
+
visualizer_params:
|
74 |
+
kp_size: 5
|
75 |
+
draw_border: True
|
76 |
+
colormap: 'gist_rainbow'
|
config/taichi-256.yaml
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Dataset parameters
|
2 |
+
# Each dataset should contain 2 folders train and test
|
3 |
+
# Each video can be represented as:
|
4 |
+
# - an image of concatenated frames
|
5 |
+
# - '.mp4' or '.gif'
|
6 |
+
# - folder with all frames from a specific video
|
7 |
+
# In case of Taichi. Same (youtube) video can be splitted in many parts (chunks). Each part has a following
|
8 |
+
# format (id)#other#info.mp4. For example '12335#adsbf.mp4' has an id 12335. In case of TaiChi id stands for youtube
|
9 |
+
# video id.
|
10 |
+
dataset_params:
|
11 |
+
# Path to data, data can be stored in several formats: .mp4 or .gif videos, stacked .png images or folders with frames.
|
12 |
+
root_dir: data/taichi-png
|
13 |
+
# Image shape, needed for staked .png format.
|
14 |
+
frame_shape: [256, 256, 3]
|
15 |
+
# In case of TaiChi single video can be splitted in many chunks, or the maybe several videos for single person.
|
16 |
+
# In this case epoch can be a pass over different videos (if id_sampling=True) or over different chunks (if id_sampling=False)
|
17 |
+
# If the name of the video '12335#adsbf.mp4' the id is assumed to be 12335
|
18 |
+
id_sampling: True
|
19 |
+
# List with pairs for animation, None for random pairs
|
20 |
+
pairs_list: data/taichi256.csv
|
21 |
+
# Augmentation parameters see augmentation.py for all posible augmentations
|
22 |
+
augmentation_params:
|
23 |
+
flip_param:
|
24 |
+
horizontal_flip: True
|
25 |
+
time_flip: True
|
26 |
+
jitter_param:
|
27 |
+
brightness: 0.1
|
28 |
+
contrast: 0.1
|
29 |
+
saturation: 0.1
|
30 |
+
hue: 0.1
|
31 |
+
|
32 |
+
# Defines model architecture
|
33 |
+
model_params:
|
34 |
+
common_params:
|
35 |
+
# Number of keypoint
|
36 |
+
num_kp: 10
|
37 |
+
# Number of channels per image
|
38 |
+
num_channels: 3
|
39 |
+
# Using first or zero order model
|
40 |
+
estimate_jacobian: True
|
41 |
+
kp_detector_params:
|
42 |
+
# Softmax temperature for keypoint heatmaps
|
43 |
+
temperature: 0.1
|
44 |
+
# Number of features mutliplier
|
45 |
+
block_expansion: 32
|
46 |
+
# Maximum allowed number of features
|
47 |
+
max_features: 1024
|
48 |
+
# Number of block in Unet. Can be increased or decreased depending or resolution.
|
49 |
+
num_blocks: 5
|
50 |
+
# Keypioint is predicted on smaller images for better performance,
|
51 |
+
# scale_factor=0.25 means that 256x256 image will be resized to 64x64
|
52 |
+
scale_factor: 0.25
|
53 |
+
generator_params:
|
54 |
+
# Number of features mutliplier
|
55 |
+
block_expansion: 64
|
56 |
+
# Maximum allowed number of features
|
57 |
+
max_features: 512
|
58 |
+
# Number of downsampling blocks in Jonson architecture.
|
59 |
+
# Can be increased or decreased depending or resolution.
|
60 |
+
num_down_blocks: 2
|
61 |
+
# Number of ResBlocks in Jonson architecture.
|
62 |
+
num_bottleneck_blocks: 6
|
63 |
+
# Use occlusion map or not
|
64 |
+
estimate_occlusion_map: True
|
65 |
+
|
66 |
+
dense_motion_params:
|
67 |
+
# Number of features mutliplier
|
68 |
+
block_expansion: 64
|
69 |
+
# Maximum allowed number of features
|
70 |
+
max_features: 1024
|
71 |
+
# Number of block in Unet. Can be increased or decreased depending or resolution.
|
72 |
+
num_blocks: 5
|
73 |
+
# Dense motion is predicted on smaller images for better performance,
|
74 |
+
# scale_factor=0.25 means that 256x256 image will be resized to 64x64
|
75 |
+
scale_factor: 0.25
|
76 |
+
discriminator_params:
|
77 |
+
# Discriminator can be multiscale, if you want 2 discriminator on original
|
78 |
+
# resolution and half of the original, specify scales: [1, 0.5]
|
79 |
+
scales: [1]
|
80 |
+
# Number of features mutliplier
|
81 |
+
block_expansion: 32
|
82 |
+
# Maximum allowed number of features
|
83 |
+
max_features: 512
|
84 |
+
# Number of blocks. Can be increased or decreased depending or resolution.
|
85 |
+
num_blocks: 4
|
86 |
+
|
87 |
+
# Parameters of training
|
88 |
+
train_params:
|
89 |
+
# Number of training epochs
|
90 |
+
num_epochs: 100
|
91 |
+
# For better i/o performance when number of videos is small number of epochs can be multiplied by this number.
|
92 |
+
# Thus effectivlly with num_repeats=100 each epoch is 100 times larger.
|
93 |
+
num_repeats: 150
|
94 |
+
# Drop learning rate by 10 times after this epochs
|
95 |
+
epoch_milestones: [60, 90]
|
96 |
+
# Initial learing rate for all modules
|
97 |
+
lr_generator: 2.0e-4
|
98 |
+
lr_discriminator: 2.0e-4
|
99 |
+
lr_kp_detector: 2.0e-4
|
100 |
+
batch_size: 30
|
101 |
+
# Scales for perceptual pyramide loss. If scales = [1, 0.5, 0.25, 0.125] and image resolution is 256x256,
|
102 |
+
# than the loss will be computer on resolutions 256x256, 128x128, 64x64, 32x32.
|
103 |
+
scales: [1, 0.5, 0.25, 0.125]
|
104 |
+
# Save checkpoint this frequently. If checkpoint_freq=50, checkpoint will be saved every 50 epochs.
|
105 |
+
checkpoint_freq: 50
|
106 |
+
# Parameters of transform for equivariance loss
|
107 |
+
transform_params:
|
108 |
+
# Sigma for affine part
|
109 |
+
sigma_affine: 0.05
|
110 |
+
# Sigma for deformation part
|
111 |
+
sigma_tps: 0.005
|
112 |
+
# Number of point in the deformation grid
|
113 |
+
points_tps: 5
|
114 |
+
loss_weights:
|
115 |
+
# Weight for LSGAN loss in generator, 0 for no adversarial loss.
|
116 |
+
generator_gan: 0
|
117 |
+
# Weight for LSGAN loss in discriminator
|
118 |
+
discriminator_gan: 1
|
119 |
+
# Weights for feature matching loss, the number should be the same as number of blocks in discriminator.
|
120 |
+
feature_matching: [10, 10, 10, 10]
|
121 |
+
# Weights for perceptual loss.
|
122 |
+
perceptual: [10, 10, 10, 10, 10]
|
123 |
+
# Weights for value equivariance.
|
124 |
+
equivariance_value: 10
|
125 |
+
# Weights for jacobian equivariance.
|
126 |
+
equivariance_jacobian: 10
|
127 |
+
|
128 |
+
# Parameters of reconstruction
|
129 |
+
reconstruction_params:
|
130 |
+
# Maximum number of videos for reconstruction
|
131 |
+
num_videos: 1000
|
132 |
+
# Format for visualization, note that results will be also stored in staked .png.
|
133 |
+
format: '.mp4'
|
134 |
+
|
135 |
+
# Parameters of animation
|
136 |
+
animate_params:
|
137 |
+
# Maximum number of pairs for animation, the pairs will be either taken from pairs_list or random.
|
138 |
+
num_pairs: 50
|
139 |
+
# Format for visualization, note that results will be also stored in staked .png.
|
140 |
+
format: '.mp4'
|
141 |
+
# Normalization of diriving keypoints
|
142 |
+
normalization_params:
|
143 |
+
# Increase or decrease relative movement scale depending on the size of the object
|
144 |
+
adapt_movement_scale: False
|
145 |
+
# Apply only relative displacement of the keypoint
|
146 |
+
use_relative_movement: True
|
147 |
+
# Apply only relative change in jacobian
|
148 |
+
use_relative_jacobian: True
|
149 |
+
|
150 |
+
# Visualization parameters
|
151 |
+
visualizer_params:
|
152 |
+
# Draw keypoints of this size, increase or decrease depending on resolution
|
153 |
+
kp_size: 5
|
154 |
+
# Draw white border around images
|
155 |
+
draw_border: True
|
156 |
+
# Color map for keypoints
|
157 |
+
colormap: 'gist_rainbow'
|
config/taichi-adv-256.yaml
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Dataset parameters
|
2 |
+
dataset_params:
|
3 |
+
# Path to data, data can be stored in several formats: .mp4 or .gif videos, stacked .png images or folders with frames.
|
4 |
+
root_dir: data/taichi-png
|
5 |
+
# Image shape, needed for staked .png format.
|
6 |
+
frame_shape: [256, 256, 3]
|
7 |
+
# In case of TaiChi single video can be splitted in many chunks, or the maybe several videos for single person.
|
8 |
+
# In this case epoch can be a pass over different videos (if id_sampling=True) or over different chunks (if id_sampling=False)
|
9 |
+
# If the name of the video '12335#adsbf.mp4' the id is assumed to be 12335
|
10 |
+
id_sampling: True
|
11 |
+
# List with pairs for animation, None for random pairs
|
12 |
+
pairs_list: data/taichi256.csv
|
13 |
+
# Augmentation parameters see augmentation.py for all posible augmentations
|
14 |
+
augmentation_params:
|
15 |
+
flip_param:
|
16 |
+
horizontal_flip: True
|
17 |
+
time_flip: True
|
18 |
+
jitter_param:
|
19 |
+
brightness: 0.1
|
20 |
+
contrast: 0.1
|
21 |
+
saturation: 0.1
|
22 |
+
hue: 0.1
|
23 |
+
|
24 |
+
# Defines model architecture
|
25 |
+
model_params:
|
26 |
+
common_params:
|
27 |
+
# Number of keypoint
|
28 |
+
num_kp: 10
|
29 |
+
# Number of channels per image
|
30 |
+
num_channels: 3
|
31 |
+
# Using first or zero order model
|
32 |
+
estimate_jacobian: True
|
33 |
+
kp_detector_params:
|
34 |
+
# Softmax temperature for keypoint heatmaps
|
35 |
+
temperature: 0.1
|
36 |
+
# Number of features mutliplier
|
37 |
+
block_expansion: 32
|
38 |
+
# Maximum allowed number of features
|
39 |
+
max_features: 1024
|
40 |
+
# Number of block in Unet. Can be increased or decreased depending or resolution.
|
41 |
+
num_blocks: 5
|
42 |
+
# Keypioint is predicted on smaller images for better performance,
|
43 |
+
# scale_factor=0.25 means that 256x256 image will be resized to 64x64
|
44 |
+
scale_factor: 0.25
|
45 |
+
generator_params:
|
46 |
+
# Number of features mutliplier
|
47 |
+
block_expansion: 64
|
48 |
+
# Maximum allowed number of features
|
49 |
+
max_features: 512
|
50 |
+
# Number of downsampling blocks in Jonson architecture.
|
51 |
+
# Can be increased or decreased depending or resolution.
|
52 |
+
num_down_blocks: 2
|
53 |
+
# Number of ResBlocks in Jonson architecture.
|
54 |
+
num_bottleneck_blocks: 6
|
55 |
+
# Use occlusion map or not
|
56 |
+
estimate_occlusion_map: True
|
57 |
+
|
58 |
+
dense_motion_params:
|
59 |
+
# Number of features mutliplier
|
60 |
+
block_expansion: 64
|
61 |
+
# Maximum allowed number of features
|
62 |
+
max_features: 1024
|
63 |
+
# Number of block in Unet. Can be increased or decreased depending or resolution.
|
64 |
+
num_blocks: 5
|
65 |
+
# Dense motion is predicted on smaller images for better performance,
|
66 |
+
# scale_factor=0.25 means that 256x256 image will be resized to 64x64
|
67 |
+
scale_factor: 0.25
|
68 |
+
discriminator_params:
|
69 |
+
# Discriminator can be multiscale, if you want 2 discriminator on original
|
70 |
+
# resolution and half of the original, specify scales: [1, 0.5]
|
71 |
+
scales: [1]
|
72 |
+
# Number of features mutliplier
|
73 |
+
block_expansion: 32
|
74 |
+
# Maximum allowed number of features
|
75 |
+
max_features: 512
|
76 |
+
# Number of blocks. Can be increased or decreased depending or resolution.
|
77 |
+
num_blocks: 4
|
78 |
+
use_kp: True
|
79 |
+
|
80 |
+
# Parameters of training
|
81 |
+
train_params:
|
82 |
+
# Number of training epochs
|
83 |
+
num_epochs: 150
|
84 |
+
# For better i/o performance when number of videos is small number of epochs can be multiplied by this number.
|
85 |
+
# Thus effectivlly with num_repeats=100 each epoch is 100 times larger.
|
86 |
+
num_repeats: 150
|
87 |
+
# Drop learning rate by 10 times after this epochs
|
88 |
+
epoch_milestones: []
|
89 |
+
# Initial learing rate for all modules
|
90 |
+
lr_generator: 2.0e-4
|
91 |
+
lr_discriminator: 2.0e-4
|
92 |
+
lr_kp_detector: 0
|
93 |
+
batch_size: 27
|
94 |
+
# Scales for perceptual pyramide loss. If scales = [1, 0.5, 0.25, 0.125] and image resolution is 256x256,
|
95 |
+
# than the loss will be computer on resolutions 256x256, 128x128, 64x64, 32x32.
|
96 |
+
scales: [1, 0.5, 0.25, 0.125]
|
97 |
+
# Save checkpoint this frequently. If checkpoint_freq=50, checkpoint will be saved every 50 epochs.
|
98 |
+
checkpoint_freq: 50
|
99 |
+
# Parameters of transform for equivariance loss
|
100 |
+
transform_params:
|
101 |
+
# Sigma for affine part
|
102 |
+
sigma_affine: 0.05
|
103 |
+
# Sigma for deformation part
|
104 |
+
sigma_tps: 0.005
|
105 |
+
# Number of point in the deformation grid
|
106 |
+
points_tps: 5
|
107 |
+
loss_weights:
|
108 |
+
# Weight for LSGAN loss in generator
|
109 |
+
generator_gan: 1
|
110 |
+
# Weight for LSGAN loss in discriminator
|
111 |
+
discriminator_gan: 1
|
112 |
+
# Weights for feature matching loss, the number should be the same as number of blocks in discriminator.
|
113 |
+
feature_matching: [10, 10, 10, 10]
|
114 |
+
# Weights for perceptual loss.
|
115 |
+
perceptual: [10, 10, 10, 10, 10]
|
116 |
+
# Weights for value equivariance.
|
117 |
+
equivariance_value: 10
|
118 |
+
# Weights for jacobian equivariance.
|
119 |
+
equivariance_jacobian: 10
|
120 |
+
|
121 |
+
# Parameters of reconstruction
|
122 |
+
reconstruction_params:
|
123 |
+
# Maximum number of videos for reconstruction
|
124 |
+
num_videos: 1000
|
125 |
+
# Format for visualization, note that results will be also stored in staked .png.
|
126 |
+
format: '.mp4'
|
127 |
+
|
128 |
+
# Parameters of animation
|
129 |
+
animate_params:
|
130 |
+
# Maximum number of pairs for animation, the pairs will be either taken from pairs_list or random.
|
131 |
+
num_pairs: 50
|
132 |
+
# Format for visualization, note that results will be also stored in staked .png.
|
133 |
+
format: '.mp4'
|
134 |
+
# Normalization of diriving keypoints
|
135 |
+
normalization_params:
|
136 |
+
# Increase or decrease relative movement scale depending on the size of the object
|
137 |
+
adapt_movement_scale: False
|
138 |
+
# Apply only relative displacement of the keypoint
|
139 |
+
use_relative_movement: True
|
140 |
+
# Apply only relative change in jacobian
|
141 |
+
use_relative_jacobian: True
|
142 |
+
|
143 |
+
# Visualization parameters
|
144 |
+
visualizer_params:
|
145 |
+
# Draw keypoints of this size, increase or decrease depending on resolution
|
146 |
+
kp_size: 5
|
147 |
+
# Draw white border around images
|
148 |
+
draw_border: True
|
149 |
+
# Color map for keypoints
|
150 |
+
colormap: 'gist_rainbow'
|
config/vox-256.yaml
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset_params:
|
2 |
+
root_dir: data/vox-png
|
3 |
+
frame_shape: [256, 256, 3]
|
4 |
+
id_sampling: True
|
5 |
+
pairs_list: data/vox256.csv
|
6 |
+
augmentation_params:
|
7 |
+
flip_param:
|
8 |
+
horizontal_flip: True
|
9 |
+
time_flip: True
|
10 |
+
jitter_param:
|
11 |
+
brightness: 0.1
|
12 |
+
contrast: 0.1
|
13 |
+
saturation: 0.1
|
14 |
+
hue: 0.1
|
15 |
+
|
16 |
+
|
17 |
+
model_params:
|
18 |
+
common_params:
|
19 |
+
num_kp: 10
|
20 |
+
num_channels: 3
|
21 |
+
estimate_jacobian: True
|
22 |
+
kp_detector_params:
|
23 |
+
temperature: 0.1
|
24 |
+
block_expansion: 32
|
25 |
+
max_features: 1024
|
26 |
+
scale_factor: 0.25
|
27 |
+
num_blocks: 5
|
28 |
+
generator_params:
|
29 |
+
block_expansion: 64
|
30 |
+
max_features: 512
|
31 |
+
num_down_blocks: 2
|
32 |
+
num_bottleneck_blocks: 6
|
33 |
+
estimate_occlusion_map: True
|
34 |
+
dense_motion_params:
|
35 |
+
block_expansion: 64
|
36 |
+
max_features: 1024
|
37 |
+
num_blocks: 5
|
38 |
+
scale_factor: 0.25
|
39 |
+
discriminator_params:
|
40 |
+
scales: [1]
|
41 |
+
block_expansion: 32
|
42 |
+
max_features: 512
|
43 |
+
num_blocks: 4
|
44 |
+
sn: True
|
45 |
+
|
46 |
+
train_params:
|
47 |
+
num_epochs: 100
|
48 |
+
num_repeats: 75
|
49 |
+
epoch_milestones: [60, 90]
|
50 |
+
lr_generator: 2.0e-4
|
51 |
+
lr_discriminator: 2.0e-4
|
52 |
+
lr_kp_detector: 2.0e-4
|
53 |
+
batch_size: 40
|
54 |
+
scales: [1, 0.5, 0.25, 0.125]
|
55 |
+
checkpoint_freq: 50
|
56 |
+
transform_params:
|
57 |
+
sigma_affine: 0.05
|
58 |
+
sigma_tps: 0.005
|
59 |
+
points_tps: 5
|
60 |
+
loss_weights:
|
61 |
+
generator_gan: 0
|
62 |
+
discriminator_gan: 1
|
63 |
+
feature_matching: [10, 10, 10, 10]
|
64 |
+
perceptual: [10, 10, 10, 10, 10]
|
65 |
+
equivariance_value: 10
|
66 |
+
equivariance_jacobian: 10
|
67 |
+
|
68 |
+
reconstruction_params:
|
69 |
+
num_videos: 1000
|
70 |
+
format: '.mp4'
|
71 |
+
|
72 |
+
animate_params:
|
73 |
+
num_pairs: 50
|
74 |
+
format: '.mp4'
|
75 |
+
normalization_params:
|
76 |
+
adapt_movement_scale: False
|
77 |
+
use_relative_movement: True
|
78 |
+
use_relative_jacobian: True
|
79 |
+
|
80 |
+
visualizer_params:
|
81 |
+
kp_size: 5
|
82 |
+
draw_border: True
|
83 |
+
colormap: 'gist_rainbow'
|
config/vox-adv-256.yaml
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset_params:
|
2 |
+
root_dir: data/vox-png
|
3 |
+
frame_shape: [256, 256, 3]
|
4 |
+
id_sampling: True
|
5 |
+
pairs_list: data/vox256.csv
|
6 |
+
augmentation_params:
|
7 |
+
flip_param:
|
8 |
+
horizontal_flip: True
|
9 |
+
time_flip: True
|
10 |
+
jitter_param:
|
11 |
+
brightness: 0.1
|
12 |
+
contrast: 0.1
|
13 |
+
saturation: 0.1
|
14 |
+
hue: 0.1
|
15 |
+
|
16 |
+
|
17 |
+
model_params:
|
18 |
+
common_params:
|
19 |
+
num_kp: 10
|
20 |
+
num_channels: 3
|
21 |
+
estimate_jacobian: True
|
22 |
+
kp_detector_params:
|
23 |
+
temperature: 0.1
|
24 |
+
block_expansion: 32
|
25 |
+
max_features: 1024
|
26 |
+
scale_factor: 0.25
|
27 |
+
num_blocks: 5
|
28 |
+
generator_params:
|
29 |
+
block_expansion: 64
|
30 |
+
max_features: 512
|
31 |
+
num_down_blocks: 2
|
32 |
+
num_bottleneck_blocks: 6
|
33 |
+
estimate_occlusion_map: True
|
34 |
+
dense_motion_params:
|
35 |
+
block_expansion: 64
|
36 |
+
max_features: 1024
|
37 |
+
num_blocks: 5
|
38 |
+
scale_factor: 0.25
|
39 |
+
discriminator_params:
|
40 |
+
scales: [1]
|
41 |
+
block_expansion: 32
|
42 |
+
max_features: 512
|
43 |
+
num_blocks: 4
|
44 |
+
use_kp: True
|
45 |
+
|
46 |
+
|
47 |
+
train_params:
|
48 |
+
num_epochs: 150
|
49 |
+
num_repeats: 75
|
50 |
+
epoch_milestones: []
|
51 |
+
lr_generator: 2.0e-4
|
52 |
+
lr_discriminator: 2.0e-4
|
53 |
+
lr_kp_detector: 2.0e-4
|
54 |
+
batch_size: 36
|
55 |
+
scales: [1, 0.5, 0.25, 0.125]
|
56 |
+
checkpoint_freq: 50
|
57 |
+
transform_params:
|
58 |
+
sigma_affine: 0.05
|
59 |
+
sigma_tps: 0.005
|
60 |
+
points_tps: 5
|
61 |
+
loss_weights:
|
62 |
+
generator_gan: 1
|
63 |
+
discriminator_gan: 1
|
64 |
+
feature_matching: [10, 10, 10, 10]
|
65 |
+
perceptual: [10, 10, 10, 10, 10]
|
66 |
+
equivariance_value: 10
|
67 |
+
equivariance_jacobian: 10
|
68 |
+
|
69 |
+
reconstruction_params:
|
70 |
+
num_videos: 1000
|
71 |
+
format: '.mp4'
|
72 |
+
|
73 |
+
animate_params:
|
74 |
+
num_pairs: 50
|
75 |
+
format: '.mp4'
|
76 |
+
normalization_params:
|
77 |
+
adapt_movement_scale: False
|
78 |
+
use_relative_movement: True
|
79 |
+
use_relative_jacobian: True
|
80 |
+
|
81 |
+
visualizer_params:
|
82 |
+
kp_size: 5
|
83 |
+
draw_border: True
|
84 |
+
colormap: 'gist_rainbow'
|
data/bair256.csv
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
distance,source,driving,frame
|
2 |
+
0,000054.mp4,000048.mp4,0
|
3 |
+
0,000050.mp4,000063.mp4,0
|
4 |
+
0,000073.mp4,000007.mp4,0
|
5 |
+
0,000021.mp4,000010.mp4,0
|
6 |
+
0,000084.mp4,000046.mp4,0
|
7 |
+
0,000031.mp4,000102.mp4,0
|
8 |
+
0,000029.mp4,000111.mp4,0
|
9 |
+
0,000090.mp4,000112.mp4,0
|
10 |
+
0,000039.mp4,000010.mp4,0
|
11 |
+
0,000008.mp4,000069.mp4,0
|
12 |
+
0,000068.mp4,000076.mp4,0
|
13 |
+
0,000051.mp4,000052.mp4,0
|
14 |
+
0,000022.mp4,000098.mp4,0
|
15 |
+
0,000096.mp4,000032.mp4,0
|
16 |
+
0,000032.mp4,000099.mp4,0
|
17 |
+
0,000006.mp4,000053.mp4,0
|
18 |
+
0,000098.mp4,000020.mp4,0
|
19 |
+
0,000029.mp4,000066.mp4,0
|
20 |
+
0,000022.mp4,000007.mp4,0
|
21 |
+
0,000027.mp4,000065.mp4,0
|
22 |
+
0,000026.mp4,000059.mp4,0
|
23 |
+
0,000015.mp4,000112.mp4,0
|
24 |
+
0,000086.mp4,000123.mp4,0
|
25 |
+
0,000103.mp4,000052.mp4,0
|
26 |
+
0,000123.mp4,000103.mp4,0
|
27 |
+
0,000051.mp4,000005.mp4,0
|
28 |
+
0,000062.mp4,000125.mp4,0
|
29 |
+
0,000126.mp4,000111.mp4,0
|
30 |
+
0,000066.mp4,000090.mp4,0
|
31 |
+
0,000075.mp4,000106.mp4,0
|
32 |
+
0,000020.mp4,000010.mp4,0
|
33 |
+
0,000076.mp4,000028.mp4,0
|
34 |
+
0,000062.mp4,000002.mp4,0
|
35 |
+
0,000095.mp4,000127.mp4,0
|
36 |
+
0,000113.mp4,000072.mp4,0
|
37 |
+
0,000027.mp4,000104.mp4,0
|
38 |
+
0,000054.mp4,000124.mp4,0
|
39 |
+
0,000019.mp4,000089.mp4,0
|
40 |
+
0,000052.mp4,000072.mp4,0
|
41 |
+
0,000108.mp4,000033.mp4,0
|
42 |
+
0,000044.mp4,000118.mp4,0
|
43 |
+
0,000029.mp4,000086.mp4,0
|
44 |
+
0,000068.mp4,000066.mp4,0
|
45 |
+
0,000014.mp4,000036.mp4,0
|
46 |
+
0,000053.mp4,000071.mp4,0
|
47 |
+
0,000022.mp4,000094.mp4,0
|
48 |
+
0,000000.mp4,000121.mp4,0
|
49 |
+
0,000071.mp4,000079.mp4,0
|
50 |
+
0,000127.mp4,000005.mp4,0
|
51 |
+
0,000085.mp4,000023.mp4,0
|
data/taichi-loading/README.md
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# TaiChi dataset
|
2 |
+
|
3 |
+
The scripst for loading the TaiChi dataset.
|
4 |
+
|
5 |
+
We provide only the id of the corresponding video and the bounding box. Following script will download videos from youtube and crop them according to the provided bounding boxes.
|
6 |
+
|
7 |
+
1) Load youtube-dl:
|
8 |
+
```
|
9 |
+
wget https://yt-dl.org/downloads/latest/youtube-dl -O youtube-dl
|
10 |
+
chmod a+rx youtube-dl
|
11 |
+
```
|
12 |
+
|
13 |
+
2) Run script to download videos, there are 2 formats that can be used for storing videos one is .mp4 and another is folder with .png images. While .png images occupy significantly more space, the format is loss-less and have better i/o performance when training.
|
14 |
+
|
15 |
+
```
|
16 |
+
python load_videos.py --metadata taichi-metadata.csv --format .mp4 --out_folder taichi --workers 8
|
17 |
+
```
|
18 |
+
select number of workers based on number of cpu avaliable. Note .png format take aproximatly 80GB.
|
data/taichi-loading/load_videos.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pandas as pd
|
3 |
+
import imageio
|
4 |
+
import os
|
5 |
+
import subprocess
|
6 |
+
from multiprocessing import Pool
|
7 |
+
from itertools import cycle
|
8 |
+
import warnings
|
9 |
+
import glob
|
10 |
+
import time
|
11 |
+
from tqdm import tqdm
|
12 |
+
from argparse import ArgumentParser
|
13 |
+
from skimage import img_as_ubyte
|
14 |
+
from skimage.transform import resize
|
15 |
+
warnings.filterwarnings("ignore")
|
16 |
+
|
17 |
+
DEVNULL = open(os.devnull, 'wb')
|
18 |
+
|
19 |
+
|
20 |
+
def save(path, frames, format):
|
21 |
+
if format == '.mp4':
|
22 |
+
imageio.mimsave(path, frames)
|
23 |
+
elif format == '.png':
|
24 |
+
if os.path.exists(path):
|
25 |
+
print ("Warning: skiping video %s" % os.path.basename(path))
|
26 |
+
return
|
27 |
+
else:
|
28 |
+
os.makedirs(path)
|
29 |
+
for j, frame in enumerate(frames):
|
30 |
+
imageio.imsave(os.path.join(path, str(j).zfill(7) + '.png'), frames[j])
|
31 |
+
else:
|
32 |
+
print ("Unknown format %s" % format)
|
33 |
+
exit()
|
34 |
+
|
35 |
+
|
36 |
+
def download(video_id, args):
|
37 |
+
video_path = os.path.join(args.video_folder, video_id + ".mp4")
|
38 |
+
subprocess.call([args.youtube, '-f', "''best/mp4''", '--write-auto-sub', '--write-sub',
|
39 |
+
'--sub-lang', 'en', '--skip-unavailable-fragments',
|
40 |
+
"https://www.youtube.com/watch?v=" + video_id, "--output",
|
41 |
+
video_path], stdout=DEVNULL, stderr=DEVNULL)
|
42 |
+
return video_path
|
43 |
+
|
44 |
+
|
45 |
+
def run(data):
|
46 |
+
video_id, args = data
|
47 |
+
if not os.path.exists(os.path.join(args.video_folder, video_id.split('#')[0] + '.mp4')):
|
48 |
+
download(video_id.split('#')[0], args)
|
49 |
+
|
50 |
+
if not os.path.exists(os.path.join(args.video_folder, video_id.split('#')[0] + '.mp4')):
|
51 |
+
print ('Can not load video %s, broken link' % video_id.split('#')[0])
|
52 |
+
return
|
53 |
+
reader = imageio.get_reader(os.path.join(args.video_folder, video_id.split('#')[0] + '.mp4'))
|
54 |
+
fps = reader.get_meta_data()['fps']
|
55 |
+
|
56 |
+
df = pd.read_csv(args.metadata)
|
57 |
+
df = df[df['video_id'] == video_id]
|
58 |
+
|
59 |
+
all_chunks_dict = [{'start': df['start'].iloc[j], 'end': df['end'].iloc[j],
|
60 |
+
'bbox': list(map(int, df['bbox'].iloc[j].split('-'))), 'frames':[]} for j in range(df.shape[0])]
|
61 |
+
ref_fps = df['fps'].iloc[0]
|
62 |
+
ref_height = df['height'].iloc[0]
|
63 |
+
ref_width = df['width'].iloc[0]
|
64 |
+
partition = df['partition'].iloc[0]
|
65 |
+
try:
|
66 |
+
for i, frame in enumerate(reader):
|
67 |
+
for entry in all_chunks_dict:
|
68 |
+
if (i * ref_fps >= entry['start'] * fps) and (i * ref_fps < entry['end'] * fps):
|
69 |
+
left, top, right, bot = entry['bbox']
|
70 |
+
left = int(left / (ref_width / frame.shape[1]))
|
71 |
+
top = int(top / (ref_height / frame.shape[0]))
|
72 |
+
right = int(right / (ref_width / frame.shape[1]))
|
73 |
+
bot = int(bot / (ref_height / frame.shape[0]))
|
74 |
+
crop = frame[top:bot, left:right]
|
75 |
+
if args.image_shape is not None:
|
76 |
+
crop = img_as_ubyte(resize(crop, args.image_shape, anti_aliasing=True))
|
77 |
+
entry['frames'].append(crop)
|
78 |
+
except imageio.core.format.CannotReadFrameError:
|
79 |
+
None
|
80 |
+
|
81 |
+
for entry in all_chunks_dict:
|
82 |
+
first_part = '#'.join(video_id.split('#')[::-1])
|
83 |
+
path = first_part + '#' + str(entry['start']).zfill(6) + '#' + str(entry['end']).zfill(6) + '.mp4'
|
84 |
+
save(os.path.join(args.out_folder, partition, path), entry['frames'], args.format)
|
85 |
+
|
86 |
+
|
87 |
+
if __name__ == "__main__":
|
88 |
+
parser = ArgumentParser()
|
89 |
+
parser.add_argument("--video_folder", default='youtube-taichi', help='Path to youtube videos')
|
90 |
+
parser.add_argument("--metadata", default='taichi-metadata-new.csv', help='Path to metadata')
|
91 |
+
parser.add_argument("--out_folder", default='taichi-png', help='Path to output')
|
92 |
+
parser.add_argument("--format", default='.png', help='Storing format')
|
93 |
+
parser.add_argument("--workers", default=1, type=int, help='Number of workers')
|
94 |
+
parser.add_argument("--youtube", default='./youtube-dl', help='Path to youtube-dl')
|
95 |
+
|
96 |
+
parser.add_argument("--image_shape", default=(256, 256), type=lambda x: tuple(map(int, x.split(','))),
|
97 |
+
help="Image shape, None for no resize")
|
98 |
+
|
99 |
+
args = parser.parse_args()
|
100 |
+
if not os.path.exists(args.video_folder):
|
101 |
+
os.makedirs(args.video_folder)
|
102 |
+
if not os.path.exists(args.out_folder):
|
103 |
+
os.makedirs(args.out_folder)
|
104 |
+
for partition in ['test', 'train']:
|
105 |
+
if not os.path.exists(os.path.join(args.out_folder, partition)):
|
106 |
+
os.makedirs(os.path.join(args.out_folder, partition))
|
107 |
+
|
108 |
+
df = pd.read_csv(args.metadata)
|
109 |
+
video_ids = set(df['video_id'])
|
110 |
+
pool = Pool(processes=args.workers)
|
111 |
+
args_list = cycle([args])
|
112 |
+
for chunks_data in tqdm(pool.imap_unordered(run, zip(video_ids, args_list))):
|
113 |
+
None
|
data/taichi-loading/taichi-metadata.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/taichi256.csv
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
distance,source,driving,frame
|
2 |
+
3.54437869822485,ab28GAufK8o#000261#000596.mp4,aDyyTMUBoLE#000164#000351.mp4,0
|
3 |
+
2.8639053254437887,DMEaUoA8EPE#000028#000354.mp4,0Q914by5A98#010440#010764.mp4,0
|
4 |
+
2.153846153846153,L82WHgYRq6I#000021#000479.mp4,0Q914by5A98#010440#010764.mp4,0
|
5 |
+
2.8994082840236666,oNkBx4CZuEg#000000#001024.mp4,DMEaUoA8EPE#000028#000354.mp4,0
|
6 |
+
3.3905325443786998,ab28GAufK8o#000261#000596.mp4,uEqWZ9S_-Lw#000089#000581.mp4,0
|
7 |
+
3.266272189349112,0Q914by5A98#010440#010764.mp4,ab28GAufK8o#000261#000596.mp4,0
|
8 |
+
2.7514792899408294,WlDYrq8K6nk#008186#008512.mp4,OiblkvkAHWM#014331#014459.mp4,0
|
9 |
+
3.0177514792899407,oNkBx4CZuEg#001024#002048.mp4,aDyyTMUBoLE#000375#000518.mp4,0
|
10 |
+
3.4792899408284064,aDyyTMUBoLE#000164#000351.mp4,w2awOCDRtrc#001729#002009.mp4,0
|
11 |
+
2.769230769230769,oNkBx4CZuEg#000000#001024.mp4,L82WHgYRq6I#000021#000479.mp4,0
|
12 |
+
3.8047337278106514,ab28GAufK8o#000261#000596.mp4,w2awOCDRtrc#001729#002009.mp4,0
|
13 |
+
3.4260355029585763,w2awOCDRtrc#001729#002009.mp4,oNkBx4CZuEg#000000#001024.mp4,0
|
14 |
+
3.313609467455621,DMEaUoA8EPE#000028#000354.mp4,WlDYrq8K6nk#005943#006135.mp4,0
|
15 |
+
3.8402366863905333,oNkBx4CZuEg#001024#002048.mp4,ab28GAufK8o#000261#000596.mp4,0
|
16 |
+
3.3254437869822504,aDyyTMUBoLE#000164#000351.mp4,oNkBx4CZuEg#000000#001024.mp4,0
|
17 |
+
1.2485207100591724,0Q914by5A98#010440#010764.mp4,aDyyTMUBoLE#000164#000351.mp4,0
|
18 |
+
3.804733727810652,OiblkvkAHWM#006251#006533.mp4,aDyyTMUBoLE#000375#000518.mp4,0
|
19 |
+
3.662721893491124,uEqWZ9S_-Lw#000089#000581.mp4,DMEaUoA8EPE#000028#000354.mp4,0
|
20 |
+
3.230769230769233,A3ZmT97hAWU#000095#000678.mp4,ab28GAufK8o#000261#000596.mp4,0
|
21 |
+
3.3668639053254434,w81Tr0Dp1K8#015329#015485.mp4,WlDYrq8K6nk#008186#008512.mp4,0
|
22 |
+
3.313609467455621,WlDYrq8K6nk#005943#006135.mp4,DMEaUoA8EPE#000028#000354.mp4,0
|
23 |
+
2.7514792899408294,OiblkvkAHWM#014331#014459.mp4,WlDYrq8K6nk#008186#008512.mp4,0
|
24 |
+
1.964497041420118,L82WHgYRq6I#000021#000479.mp4,DMEaUoA8EPE#000028#000354.mp4,0
|
25 |
+
3.78698224852071,FBuF0xOal9M#046824#047542.mp4,lCb5w6n8kPs#011879#012014.mp4,0
|
26 |
+
3.92307692307692,ab28GAufK8o#000261#000596.mp4,L82WHgYRq6I#000021#000479.mp4,0
|
27 |
+
3.8402366863905333,ab28GAufK8o#000261#000596.mp4,oNkBx4CZuEg#001024#002048.mp4,0
|
28 |
+
3.828402366863905,ab28GAufK8o#000261#000596.mp4,OiblkvkAHWM#006251#006533.mp4,0
|
29 |
+
2.041420118343196,L82WHgYRq6I#000021#000479.mp4,aDyyTMUBoLE#000164#000351.mp4,0
|
30 |
+
3.2485207100591724,0Q914by5A98#010440#010764.mp4,w2awOCDRtrc#001729#002009.mp4,0
|
31 |
+
3.2485207100591746,oNkBx4CZuEg#000000#001024.mp4,0Q914by5A98#010440#010764.mp4,0
|
32 |
+
1.964497041420118,DMEaUoA8EPE#000028#000354.mp4,L82WHgYRq6I#000021#000479.mp4,0
|
33 |
+
3.5266272189349115,kgvcI9oe3NI#001578#001763.mp4,lCb5w6n8kPs#004451#004631.mp4,0
|
34 |
+
3.005917159763317,A3ZmT97hAWU#000095#000678.mp4,0Q914by5A98#010440#010764.mp4,0
|
35 |
+
3.230769230769233,ab28GAufK8o#000261#000596.mp4,A3ZmT97hAWU#000095#000678.mp4,0
|
36 |
+
3.5266272189349115,lCb5w6n8kPs#004451#004631.mp4,kgvcI9oe3NI#001578#001763.mp4,0
|
37 |
+
2.769230769230769,L82WHgYRq6I#000021#000479.mp4,oNkBx4CZuEg#000000#001024.mp4,0
|
38 |
+
3.165680473372782,WlDYrq8K6nk#005943#006135.mp4,w81Tr0Dp1K8#001375#001516.mp4,0
|
39 |
+
2.8994082840236666,DMEaUoA8EPE#000028#000354.mp4,oNkBx4CZuEg#000000#001024.mp4,0
|
40 |
+
2.4556213017751523,0Q914by5A98#010440#010764.mp4,mndSqTrxpts#000000#000175.mp4,0
|
41 |
+
2.201183431952659,A3ZmT97hAWU#000095#000678.mp4,VMSqvTE90hk#007168#007312.mp4,0
|
42 |
+
3.8047337278106514,w2awOCDRtrc#001729#002009.mp4,ab28GAufK8o#000261#000596.mp4,0
|
43 |
+
3.769230769230769,uEqWZ9S_-Lw#000089#000581.mp4,0Q914by5A98#010440#010764.mp4,0
|
44 |
+
3.6568047337278102,A3ZmT97hAWU#000095#000678.mp4,aDyyTMUBoLE#000164#000351.mp4,0
|
45 |
+
3.7869822485207107,uEqWZ9S_-Lw#000089#000581.mp4,L82WHgYRq6I#000021#000479.mp4,0
|
46 |
+
3.78698224852071,lCb5w6n8kPs#011879#012014.mp4,FBuF0xOal9M#046824#047542.mp4,0
|
47 |
+
3.591715976331361,nAQEOC1Z10M#020177#020600.mp4,w81Tr0Dp1K8#004036#004218.mp4,0
|
48 |
+
3.8757396449704156,uEqWZ9S_-Lw#000089#000581.mp4,aDyyTMUBoLE#000164#000351.mp4,0
|
49 |
+
2.45562130177515,aDyyTMUBoLE#000164#000351.mp4,DMEaUoA8EPE#000028#000354.mp4,0
|
50 |
+
3.5502958579881647,uEqWZ9S_-Lw#000089#000581.mp4,OiblkvkAHWM#006251#006533.mp4,0
|
51 |
+
3.7928994082840224,aDyyTMUBoLE#000375#000518.mp4,ab28GAufK8o#000261#000596.mp4,0
|
modules/__pycache__/dense_motion.cpython-310.pyc
ADDED
Binary file (3.87 kB). View file
|
|
modules/__pycache__/generator.cpython-310.pyc
ADDED
Binary file (3.07 kB). View file
|
|
modules/__pycache__/keypoint_detector.cpython-310.pyc
ADDED
Binary file (2.51 kB). View file
|
|
modules/__pycache__/util.cpython-310.pyc
ADDED
Binary file (7.56 kB). View file
|
|
modules/dense_motion.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import torch
|
4 |
+
from modules.util import Hourglass, AntiAliasInterpolation2d, make_coordinate_grid, kp2gaussian
|
5 |
+
|
6 |
+
|
7 |
+
class DenseMotionNetwork(nn.Module):
|
8 |
+
"""
|
9 |
+
Module that predicting a dense motion from sparse motion representation given by kp_source and kp_driving
|
10 |
+
"""
|
11 |
+
|
12 |
+
def __init__(self, block_expansion, num_blocks, max_features, num_kp, num_channels, estimate_occlusion_map=False,
|
13 |
+
scale_factor=1, kp_variance=0.01):
|
14 |
+
super(DenseMotionNetwork, self).__init__()
|
15 |
+
self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp + 1) * (num_channels + 1),
|
16 |
+
max_features=max_features, num_blocks=num_blocks)
|
17 |
+
|
18 |
+
self.mask = nn.Conv2d(self.hourglass.out_filters, num_kp + 1, kernel_size=(7, 7), padding=(3, 3))
|
19 |
+
|
20 |
+
if estimate_occlusion_map:
|
21 |
+
self.occlusion = nn.Conv2d(self.hourglass.out_filters, 1, kernel_size=(7, 7), padding=(3, 3))
|
22 |
+
else:
|
23 |
+
self.occlusion = None
|
24 |
+
|
25 |
+
self.num_kp = num_kp
|
26 |
+
self.scale_factor = scale_factor
|
27 |
+
self.kp_variance = kp_variance
|
28 |
+
|
29 |
+
if self.scale_factor != 1:
|
30 |
+
self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor)
|
31 |
+
|
32 |
+
def create_heatmap_representations(self, source_image, kp_driving, kp_source):
|
33 |
+
"""
|
34 |
+
Eq 6. in the paper H_k(z)
|
35 |
+
"""
|
36 |
+
spatial_size = source_image.shape[2:]
|
37 |
+
gaussian_driving = kp2gaussian(kp_driving, spatial_size=spatial_size, kp_variance=self.kp_variance)
|
38 |
+
gaussian_source = kp2gaussian(kp_source, spatial_size=spatial_size, kp_variance=self.kp_variance)
|
39 |
+
heatmap = gaussian_driving - gaussian_source
|
40 |
+
|
41 |
+
#adding background feature
|
42 |
+
zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1]).type(heatmap.type())
|
43 |
+
heatmap = torch.cat([zeros, heatmap], dim=1)
|
44 |
+
heatmap = heatmap.unsqueeze(2)
|
45 |
+
return heatmap
|
46 |
+
|
47 |
+
def create_sparse_motions(self, source_image, kp_driving, kp_source):
|
48 |
+
"""
|
49 |
+
Eq 4. in the paper T_{s<-d}(z)
|
50 |
+
"""
|
51 |
+
bs, _, h, w = source_image.shape
|
52 |
+
identity_grid = make_coordinate_grid((h, w), type=kp_source['value'].type())
|
53 |
+
identity_grid = identity_grid.view(1, 1, h, w, 2)
|
54 |
+
coordinate_grid = identity_grid - kp_driving['value'].view(bs, self.num_kp, 1, 1, 2)
|
55 |
+
if 'jacobian' in kp_driving:
|
56 |
+
jacobian = torch.matmul(kp_source['jacobian'], torch.inverse(kp_driving['jacobian']))
|
57 |
+
jacobian = jacobian.unsqueeze(-3).unsqueeze(-3)
|
58 |
+
jacobian = jacobian.repeat(1, 1, h, w, 1, 1)
|
59 |
+
coordinate_grid = torch.matmul(jacobian, coordinate_grid.unsqueeze(-1))
|
60 |
+
coordinate_grid = coordinate_grid.squeeze(-1)
|
61 |
+
|
62 |
+
driving_to_source = coordinate_grid + kp_source['value'].view(bs, self.num_kp, 1, 1, 2)
|
63 |
+
|
64 |
+
#adding background feature
|
65 |
+
identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1)
|
66 |
+
sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1)
|
67 |
+
return sparse_motions
|
68 |
+
|
69 |
+
def create_deformed_source_image(self, source_image, sparse_motions):
|
70 |
+
"""
|
71 |
+
Eq 7. in the paper \hat{T}_{s<-d}(z)
|
72 |
+
"""
|
73 |
+
bs, _, h, w = source_image.shape
|
74 |
+
source_repeat = source_image.unsqueeze(1).unsqueeze(1).repeat(1, self.num_kp + 1, 1, 1, 1, 1)
|
75 |
+
source_repeat = source_repeat.view(bs * (self.num_kp + 1), -1, h, w)
|
76 |
+
sparse_motions = sparse_motions.view((bs * (self.num_kp + 1), h, w, -1))
|
77 |
+
sparse_deformed = F.grid_sample(source_repeat, sparse_motions)
|
78 |
+
sparse_deformed = sparse_deformed.view((bs, self.num_kp + 1, -1, h, w))
|
79 |
+
return sparse_deformed
|
80 |
+
|
81 |
+
def forward(self, source_image, kp_driving, kp_source):
|
82 |
+
if self.scale_factor != 1:
|
83 |
+
source_image = self.down(source_image)
|
84 |
+
|
85 |
+
bs, _, h, w = source_image.shape
|
86 |
+
|
87 |
+
out_dict = dict()
|
88 |
+
heatmap_representation = self.create_heatmap_representations(source_image, kp_driving, kp_source)
|
89 |
+
sparse_motion = self.create_sparse_motions(source_image, kp_driving, kp_source)
|
90 |
+
deformed_source = self.create_deformed_source_image(source_image, sparse_motion)
|
91 |
+
out_dict['sparse_deformed'] = deformed_source
|
92 |
+
|
93 |
+
input = torch.cat([heatmap_representation, deformed_source], dim=2)
|
94 |
+
input = input.view(bs, -1, h, w)
|
95 |
+
|
96 |
+
prediction = self.hourglass(input)
|
97 |
+
|
98 |
+
mask = self.mask(prediction)
|
99 |
+
mask = F.softmax(mask, dim=1)
|
100 |
+
out_dict['mask'] = mask
|
101 |
+
mask = mask.unsqueeze(2)
|
102 |
+
sparse_motion = sparse_motion.permute(0, 1, 4, 2, 3)
|
103 |
+
deformation = (sparse_motion * mask).sum(dim=1)
|
104 |
+
deformation = deformation.permute(0, 2, 3, 1)
|
105 |
+
|
106 |
+
out_dict['deformation'] = deformation
|
107 |
+
|
108 |
+
# Sec. 3.2 in the paper
|
109 |
+
if self.occlusion:
|
110 |
+
occlusion_map = torch.sigmoid(self.occlusion(prediction))
|
111 |
+
out_dict['occlusion_map'] = occlusion_map
|
112 |
+
|
113 |
+
return out_dict
|
modules/discriminator.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from modules.util import kp2gaussian
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
class DownBlock2d(nn.Module):
|
8 |
+
"""
|
9 |
+
Simple block for processing video (encoder).
|
10 |
+
"""
|
11 |
+
|
12 |
+
def __init__(self, in_features, out_features, norm=False, kernel_size=4, pool=False, sn=False):
|
13 |
+
super(DownBlock2d, self).__init__()
|
14 |
+
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size)
|
15 |
+
|
16 |
+
if sn:
|
17 |
+
self.conv = nn.utils.spectral_norm(self.conv)
|
18 |
+
|
19 |
+
if norm:
|
20 |
+
self.norm = nn.InstanceNorm2d(out_features, affine=True)
|
21 |
+
else:
|
22 |
+
self.norm = None
|
23 |
+
self.pool = pool
|
24 |
+
|
25 |
+
def forward(self, x):
|
26 |
+
out = x
|
27 |
+
out = self.conv(out)
|
28 |
+
if self.norm:
|
29 |
+
out = self.norm(out)
|
30 |
+
out = F.leaky_relu(out, 0.2)
|
31 |
+
if self.pool:
|
32 |
+
out = F.avg_pool2d(out, (2, 2))
|
33 |
+
return out
|
34 |
+
|
35 |
+
|
36 |
+
class Discriminator(nn.Module):
|
37 |
+
"""
|
38 |
+
Discriminator similar to Pix2Pix
|
39 |
+
"""
|
40 |
+
|
41 |
+
def __init__(self, num_channels=3, block_expansion=64, num_blocks=4, max_features=512,
|
42 |
+
sn=False, use_kp=False, num_kp=10, kp_variance=0.01, **kwargs):
|
43 |
+
super(Discriminator, self).__init__()
|
44 |
+
|
45 |
+
down_blocks = []
|
46 |
+
for i in range(num_blocks):
|
47 |
+
down_blocks.append(
|
48 |
+
DownBlock2d(num_channels + num_kp * use_kp if i == 0 else min(max_features, block_expansion * (2 ** i)),
|
49 |
+
min(max_features, block_expansion * (2 ** (i + 1))),
|
50 |
+
norm=(i != 0), kernel_size=4, pool=(i != num_blocks - 1), sn=sn))
|
51 |
+
|
52 |
+
self.down_blocks = nn.ModuleList(down_blocks)
|
53 |
+
self.conv = nn.Conv2d(self.down_blocks[-1].conv.out_channels, out_channels=1, kernel_size=1)
|
54 |
+
if sn:
|
55 |
+
self.conv = nn.utils.spectral_norm(self.conv)
|
56 |
+
self.use_kp = use_kp
|
57 |
+
self.kp_variance = kp_variance
|
58 |
+
|
59 |
+
def forward(self, x, kp=None):
|
60 |
+
feature_maps = []
|
61 |
+
out = x
|
62 |
+
if self.use_kp:
|
63 |
+
heatmap = kp2gaussian(kp, x.shape[2:], self.kp_variance)
|
64 |
+
out = torch.cat([out, heatmap], dim=1)
|
65 |
+
|
66 |
+
for down_block in self.down_blocks:
|
67 |
+
feature_maps.append(down_block(out))
|
68 |
+
out = feature_maps[-1]
|
69 |
+
prediction_map = self.conv(out)
|
70 |
+
|
71 |
+
return feature_maps, prediction_map
|
72 |
+
|
73 |
+
|
74 |
+
class MultiScaleDiscriminator(nn.Module):
|
75 |
+
"""
|
76 |
+
Multi-scale (scale) discriminator
|
77 |
+
"""
|
78 |
+
|
79 |
+
def __init__(self, scales=(), **kwargs):
|
80 |
+
super(MultiScaleDiscriminator, self).__init__()
|
81 |
+
self.scales = scales
|
82 |
+
discs = {}
|
83 |
+
for scale in scales:
|
84 |
+
discs[str(scale).replace('.', '-')] = Discriminator(**kwargs)
|
85 |
+
self.discs = nn.ModuleDict(discs)
|
86 |
+
|
87 |
+
def forward(self, x, kp=None):
|
88 |
+
out_dict = {}
|
89 |
+
for scale, disc in self.discs.items():
|
90 |
+
scale = str(scale).replace('-', '.')
|
91 |
+
key = 'prediction_' + scale
|
92 |
+
feature_maps, prediction_map = disc(x[key], kp)
|
93 |
+
out_dict['feature_maps_' + scale] = feature_maps
|
94 |
+
out_dict['prediction_map_' + scale] = prediction_map
|
95 |
+
return out_dict
|
modules/generator.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from modules.util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d
|
5 |
+
from modules.dense_motion import DenseMotionNetwork
|
6 |
+
|
7 |
+
|
8 |
+
class OcclusionAwareGenerator(nn.Module):
|
9 |
+
"""
|
10 |
+
Generator that given source image and and keypoints try to transform image according to movement trajectories
|
11 |
+
induced by keypoints. Generator follows Johnson architecture.
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __init__(self, num_channels, num_kp, block_expansion, max_features, num_down_blocks,
|
15 |
+
num_bottleneck_blocks, estimate_occlusion_map=False, dense_motion_params=None, estimate_jacobian=False):
|
16 |
+
super(OcclusionAwareGenerator, self).__init__()
|
17 |
+
|
18 |
+
if dense_motion_params is not None:
|
19 |
+
self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, num_channels=num_channels,
|
20 |
+
estimate_occlusion_map=estimate_occlusion_map,
|
21 |
+
**dense_motion_params)
|
22 |
+
else:
|
23 |
+
self.dense_motion_network = None
|
24 |
+
|
25 |
+
self.first = SameBlock2d(num_channels, block_expansion, kernel_size=(7, 7), padding=(3, 3))
|
26 |
+
|
27 |
+
down_blocks = []
|
28 |
+
for i in range(num_down_blocks):
|
29 |
+
in_features = min(max_features, block_expansion * (2 ** i))
|
30 |
+
out_features = min(max_features, block_expansion * (2 ** (i + 1)))
|
31 |
+
down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
|
32 |
+
self.down_blocks = nn.ModuleList(down_blocks)
|
33 |
+
|
34 |
+
up_blocks = []
|
35 |
+
for i in range(num_down_blocks):
|
36 |
+
in_features = min(max_features, block_expansion * (2 ** (num_down_blocks - i)))
|
37 |
+
out_features = min(max_features, block_expansion * (2 ** (num_down_blocks - i - 1)))
|
38 |
+
up_blocks.append(UpBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
|
39 |
+
self.up_blocks = nn.ModuleList(up_blocks)
|
40 |
+
|
41 |
+
self.bottleneck = torch.nn.Sequential()
|
42 |
+
in_features = min(max_features, block_expansion * (2 ** num_down_blocks))
|
43 |
+
for i in range(num_bottleneck_blocks):
|
44 |
+
self.bottleneck.add_module('r' + str(i), ResBlock2d(in_features, kernel_size=(3, 3), padding=(1, 1)))
|
45 |
+
|
46 |
+
self.final = nn.Conv2d(block_expansion, num_channels, kernel_size=(7, 7), padding=(3, 3))
|
47 |
+
self.estimate_occlusion_map = estimate_occlusion_map
|
48 |
+
self.num_channels = num_channels
|
49 |
+
|
50 |
+
def deform_input(self, inp, deformation):
|
51 |
+
_, h_old, w_old, _ = deformation.shape
|
52 |
+
_, _, h, w = inp.shape
|
53 |
+
if h_old != h or w_old != w:
|
54 |
+
deformation = deformation.permute(0, 3, 1, 2)
|
55 |
+
deformation = F.interpolate(deformation, size=(h, w), mode='bilinear')
|
56 |
+
deformation = deformation.permute(0, 2, 3, 1)
|
57 |
+
return F.grid_sample(inp, deformation)
|
58 |
+
|
59 |
+
def forward(self, source_image, kp_driving, kp_source):
|
60 |
+
# Encoding (downsampling) part
|
61 |
+
out = self.first(source_image)
|
62 |
+
for i in range(len(self.down_blocks)):
|
63 |
+
out = self.down_blocks[i](out)
|
64 |
+
|
65 |
+
# Transforming feature representation according to deformation and occlusion
|
66 |
+
output_dict = {}
|
67 |
+
if self.dense_motion_network is not None:
|
68 |
+
dense_motion = self.dense_motion_network(source_image=source_image, kp_driving=kp_driving,
|
69 |
+
kp_source=kp_source)
|
70 |
+
output_dict['mask'] = dense_motion['mask']
|
71 |
+
output_dict['sparse_deformed'] = dense_motion['sparse_deformed']
|
72 |
+
|
73 |
+
if 'occlusion_map' in dense_motion:
|
74 |
+
occlusion_map = dense_motion['occlusion_map']
|
75 |
+
output_dict['occlusion_map'] = occlusion_map
|
76 |
+
else:
|
77 |
+
occlusion_map = None
|
78 |
+
deformation = dense_motion['deformation']
|
79 |
+
out = self.deform_input(out, deformation)
|
80 |
+
|
81 |
+
if occlusion_map is not None:
|
82 |
+
if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]:
|
83 |
+
occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear')
|
84 |
+
out = out * occlusion_map
|
85 |
+
|
86 |
+
output_dict["deformed"] = self.deform_input(source_image, deformation)
|
87 |
+
|
88 |
+
# Decoding part
|
89 |
+
out = self.bottleneck(out)
|
90 |
+
for i in range(len(self.up_blocks)):
|
91 |
+
out = self.up_blocks[i](out)
|
92 |
+
out = self.final(out)
|
93 |
+
out = F.sigmoid(out)
|
94 |
+
|
95 |
+
output_dict["prediction"] = out
|
96 |
+
|
97 |
+
return output_dict
|
modules/keypoint_detector.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from modules.util import Hourglass, make_coordinate_grid, AntiAliasInterpolation2d
|
5 |
+
|
6 |
+
|
7 |
+
class KPDetector(nn.Module):
|
8 |
+
"""
|
9 |
+
Detecting a keypoints. Return keypoint position and jacobian near each keypoint.
|
10 |
+
"""
|
11 |
+
|
12 |
+
def __init__(self, block_expansion, num_kp, num_channels, max_features,
|
13 |
+
num_blocks, temperature, estimate_jacobian=False, scale_factor=1,
|
14 |
+
single_jacobian_map=False, pad=0):
|
15 |
+
super(KPDetector, self).__init__()
|
16 |
+
|
17 |
+
self.predictor = Hourglass(block_expansion, in_features=num_channels,
|
18 |
+
max_features=max_features, num_blocks=num_blocks)
|
19 |
+
|
20 |
+
self.kp = nn.Conv2d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=(7, 7),
|
21 |
+
padding=pad)
|
22 |
+
|
23 |
+
if estimate_jacobian:
|
24 |
+
self.num_jacobian_maps = 1 if single_jacobian_map else num_kp
|
25 |
+
self.jacobian = nn.Conv2d(in_channels=self.predictor.out_filters,
|
26 |
+
out_channels=4 * self.num_jacobian_maps, kernel_size=(7, 7), padding=pad)
|
27 |
+
self.jacobian.weight.data.zero_()
|
28 |
+
self.jacobian.bias.data.copy_(torch.tensor([1, 0, 0, 1] * self.num_jacobian_maps, dtype=torch.float))
|
29 |
+
else:
|
30 |
+
self.jacobian = None
|
31 |
+
|
32 |
+
self.temperature = temperature
|
33 |
+
self.scale_factor = scale_factor
|
34 |
+
if self.scale_factor != 1:
|
35 |
+
self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor)
|
36 |
+
|
37 |
+
def gaussian2kp(self, heatmap):
|
38 |
+
"""
|
39 |
+
Extract the mean and from a heatmap
|
40 |
+
"""
|
41 |
+
shape = heatmap.shape
|
42 |
+
heatmap = heatmap.unsqueeze(-1)
|
43 |
+
grid = make_coordinate_grid(shape[2:], heatmap.type()).unsqueeze_(0).unsqueeze_(0)
|
44 |
+
value = (heatmap * grid).sum(dim=(2, 3))
|
45 |
+
kp = {'value': value}
|
46 |
+
|
47 |
+
return kp
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
if self.scale_factor != 1:
|
51 |
+
x = self.down(x)
|
52 |
+
|
53 |
+
feature_map = self.predictor(x)
|
54 |
+
prediction = self.kp(feature_map)
|
55 |
+
|
56 |
+
final_shape = prediction.shape
|
57 |
+
heatmap = prediction.view(final_shape[0], final_shape[1], -1)
|
58 |
+
heatmap = F.softmax(heatmap / self.temperature, dim=2)
|
59 |
+
heatmap = heatmap.view(*final_shape)
|
60 |
+
|
61 |
+
out = self.gaussian2kp(heatmap)
|
62 |
+
|
63 |
+
if self.jacobian is not None:
|
64 |
+
jacobian_map = self.jacobian(feature_map)
|
65 |
+
jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 4, final_shape[2],
|
66 |
+
final_shape[3])
|
67 |
+
heatmap = heatmap.unsqueeze(2)
|
68 |
+
|
69 |
+
jacobian = heatmap * jacobian_map
|
70 |
+
jacobian = jacobian.view(final_shape[0], final_shape[1], 4, -1)
|
71 |
+
jacobian = jacobian.sum(dim=-1)
|
72 |
+
jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 2, 2)
|
73 |
+
out['jacobian'] = jacobian
|
74 |
+
|
75 |
+
return out
|
modules/model.py
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from modules.util import AntiAliasInterpolation2d, make_coordinate_grid
|
5 |
+
from torchvision import models
|
6 |
+
import numpy as np
|
7 |
+
from torch.autograd import grad
|
8 |
+
|
9 |
+
|
10 |
+
class Vgg19(torch.nn.Module):
|
11 |
+
"""
|
12 |
+
Vgg19 network for perceptual loss. See Sec 3.3.
|
13 |
+
"""
|
14 |
+
def __init__(self, requires_grad=False):
|
15 |
+
super(Vgg19, self).__init__()
|
16 |
+
vgg_pretrained_features = models.vgg19(pretrained=True).features
|
17 |
+
self.slice1 = torch.nn.Sequential()
|
18 |
+
self.slice2 = torch.nn.Sequential()
|
19 |
+
self.slice3 = torch.nn.Sequential()
|
20 |
+
self.slice4 = torch.nn.Sequential()
|
21 |
+
self.slice5 = torch.nn.Sequential()
|
22 |
+
for x in range(2):
|
23 |
+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
24 |
+
for x in range(2, 7):
|
25 |
+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
26 |
+
for x in range(7, 12):
|
27 |
+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
28 |
+
for x in range(12, 21):
|
29 |
+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
30 |
+
for x in range(21, 30):
|
31 |
+
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
32 |
+
|
33 |
+
self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))),
|
34 |
+
requires_grad=False)
|
35 |
+
self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))),
|
36 |
+
requires_grad=False)
|
37 |
+
|
38 |
+
if not requires_grad:
|
39 |
+
for param in self.parameters():
|
40 |
+
param.requires_grad = False
|
41 |
+
|
42 |
+
def forward(self, X):
|
43 |
+
X = (X - self.mean) / self.std
|
44 |
+
h_relu1 = self.slice1(X)
|
45 |
+
h_relu2 = self.slice2(h_relu1)
|
46 |
+
h_relu3 = self.slice3(h_relu2)
|
47 |
+
h_relu4 = self.slice4(h_relu3)
|
48 |
+
h_relu5 = self.slice5(h_relu4)
|
49 |
+
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
|
50 |
+
return out
|
51 |
+
|
52 |
+
|
53 |
+
class ImagePyramide(torch.nn.Module):
|
54 |
+
"""
|
55 |
+
Create image pyramide for computing pyramide perceptual loss. See Sec 3.3
|
56 |
+
"""
|
57 |
+
def __init__(self, scales, num_channels):
|
58 |
+
super(ImagePyramide, self).__init__()
|
59 |
+
downs = {}
|
60 |
+
for scale in scales:
|
61 |
+
downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale)
|
62 |
+
self.downs = nn.ModuleDict(downs)
|
63 |
+
|
64 |
+
def forward(self, x):
|
65 |
+
out_dict = {}
|
66 |
+
for scale, down_module in self.downs.items():
|
67 |
+
out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x)
|
68 |
+
return out_dict
|
69 |
+
|
70 |
+
|
71 |
+
class Transform:
|
72 |
+
"""
|
73 |
+
Random tps transformation for equivariance constraints. See Sec 3.3
|
74 |
+
"""
|
75 |
+
def __init__(self, bs, **kwargs):
|
76 |
+
noise = torch.normal(mean=0, std=kwargs['sigma_affine'] * torch.ones([bs, 2, 3]))
|
77 |
+
self.theta = noise + torch.eye(2, 3).view(1, 2, 3)
|
78 |
+
self.bs = bs
|
79 |
+
|
80 |
+
if ('sigma_tps' in kwargs) and ('points_tps' in kwargs):
|
81 |
+
self.tps = True
|
82 |
+
self.control_points = make_coordinate_grid((kwargs['points_tps'], kwargs['points_tps']), type=noise.type())
|
83 |
+
self.control_points = self.control_points.unsqueeze(0)
|
84 |
+
self.control_params = torch.normal(mean=0,
|
85 |
+
std=kwargs['sigma_tps'] * torch.ones([bs, 1, kwargs['points_tps'] ** 2]))
|
86 |
+
else:
|
87 |
+
self.tps = False
|
88 |
+
|
89 |
+
def transform_frame(self, frame):
|
90 |
+
grid = make_coordinate_grid(frame.shape[2:], type=frame.type()).unsqueeze(0)
|
91 |
+
grid = grid.view(1, frame.shape[2] * frame.shape[3], 2)
|
92 |
+
grid = self.warp_coordinates(grid).view(self.bs, frame.shape[2], frame.shape[3], 2)
|
93 |
+
return F.grid_sample(frame, grid, padding_mode="reflection")
|
94 |
+
|
95 |
+
def warp_coordinates(self, coordinates):
|
96 |
+
theta = self.theta.type(coordinates.type())
|
97 |
+
theta = theta.unsqueeze(1)
|
98 |
+
transformed = torch.matmul(theta[:, :, :, :2], coordinates.unsqueeze(-1)) + theta[:, :, :, 2:]
|
99 |
+
transformed = transformed.squeeze(-1)
|
100 |
+
|
101 |
+
if self.tps:
|
102 |
+
control_points = self.control_points.type(coordinates.type())
|
103 |
+
control_params = self.control_params.type(coordinates.type())
|
104 |
+
distances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2)
|
105 |
+
distances = torch.abs(distances).sum(-1)
|
106 |
+
|
107 |
+
result = distances ** 2
|
108 |
+
result = result * torch.log(distances + 1e-6)
|
109 |
+
result = result * control_params
|
110 |
+
result = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1)
|
111 |
+
transformed = transformed + result
|
112 |
+
|
113 |
+
return transformed
|
114 |
+
|
115 |
+
def jacobian(self, coordinates):
|
116 |
+
new_coordinates = self.warp_coordinates(coordinates)
|
117 |
+
grad_x = grad(new_coordinates[..., 0].sum(), coordinates, create_graph=True)
|
118 |
+
grad_y = grad(new_coordinates[..., 1].sum(), coordinates, create_graph=True)
|
119 |
+
jacobian = torch.cat([grad_x[0].unsqueeze(-2), grad_y[0].unsqueeze(-2)], dim=-2)
|
120 |
+
return jacobian
|
121 |
+
|
122 |
+
|
123 |
+
def detach_kp(kp):
|
124 |
+
return {key: value.detach() for key, value in kp.items()}
|
125 |
+
|
126 |
+
|
127 |
+
class GeneratorFullModel(torch.nn.Module):
|
128 |
+
"""
|
129 |
+
Merge all generator related updates into single model for better multi-gpu usage
|
130 |
+
"""
|
131 |
+
|
132 |
+
def __init__(self, kp_extractor, generator, discriminator, train_params):
|
133 |
+
super(GeneratorFullModel, self).__init__()
|
134 |
+
self.kp_extractor = kp_extractor
|
135 |
+
self.generator = generator
|
136 |
+
self.discriminator = discriminator
|
137 |
+
self.train_params = train_params
|
138 |
+
self.scales = train_params['scales']
|
139 |
+
self.disc_scales = self.discriminator.scales
|
140 |
+
self.pyramid = ImagePyramide(self.scales, generator.num_channels)
|
141 |
+
if torch.cuda.is_available():
|
142 |
+
self.pyramid = self.pyramid.cuda()
|
143 |
+
|
144 |
+
self.loss_weights = train_params['loss_weights']
|
145 |
+
|
146 |
+
if sum(self.loss_weights['perceptual']) != 0:
|
147 |
+
self.vgg = Vgg19()
|
148 |
+
if torch.cuda.is_available():
|
149 |
+
self.vgg = self.vgg.cuda()
|
150 |
+
|
151 |
+
def forward(self, x):
|
152 |
+
kp_source = self.kp_extractor(x['source'])
|
153 |
+
kp_driving = self.kp_extractor(x['driving'])
|
154 |
+
|
155 |
+
generated = self.generator(x['source'], kp_source=kp_source, kp_driving=kp_driving)
|
156 |
+
generated.update({'kp_source': kp_source, 'kp_driving': kp_driving})
|
157 |
+
|
158 |
+
loss_values = {}
|
159 |
+
|
160 |
+
pyramide_real = self.pyramid(x['driving'])
|
161 |
+
pyramide_generated = self.pyramid(generated['prediction'])
|
162 |
+
|
163 |
+
if sum(self.loss_weights['perceptual']) != 0:
|
164 |
+
value_total = 0
|
165 |
+
for scale in self.scales:
|
166 |
+
x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)])
|
167 |
+
y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)])
|
168 |
+
|
169 |
+
for i, weight in enumerate(self.loss_weights['perceptual']):
|
170 |
+
value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
|
171 |
+
value_total += self.loss_weights['perceptual'][i] * value
|
172 |
+
loss_values['perceptual'] = value_total
|
173 |
+
|
174 |
+
if self.loss_weights['generator_gan'] != 0:
|
175 |
+
discriminator_maps_generated = self.discriminator(pyramide_generated, kp=detach_kp(kp_driving))
|
176 |
+
discriminator_maps_real = self.discriminator(pyramide_real, kp=detach_kp(kp_driving))
|
177 |
+
value_total = 0
|
178 |
+
for scale in self.disc_scales:
|
179 |
+
key = 'prediction_map_%s' % scale
|
180 |
+
value = ((1 - discriminator_maps_generated[key]) ** 2).mean()
|
181 |
+
value_total += self.loss_weights['generator_gan'] * value
|
182 |
+
loss_values['gen_gan'] = value_total
|
183 |
+
|
184 |
+
if sum(self.loss_weights['feature_matching']) != 0:
|
185 |
+
value_total = 0
|
186 |
+
for scale in self.disc_scales:
|
187 |
+
key = 'feature_maps_%s' % scale
|
188 |
+
for i, (a, b) in enumerate(zip(discriminator_maps_real[key], discriminator_maps_generated[key])):
|
189 |
+
if self.loss_weights['feature_matching'][i] == 0:
|
190 |
+
continue
|
191 |
+
value = torch.abs(a - b).mean()
|
192 |
+
value_total += self.loss_weights['feature_matching'][i] * value
|
193 |
+
loss_values['feature_matching'] = value_total
|
194 |
+
|
195 |
+
if (self.loss_weights['equivariance_value'] + self.loss_weights['equivariance_jacobian']) != 0:
|
196 |
+
transform = Transform(x['driving'].shape[0], **self.train_params['transform_params'])
|
197 |
+
transformed_frame = transform.transform_frame(x['driving'])
|
198 |
+
transformed_kp = self.kp_extractor(transformed_frame)
|
199 |
+
|
200 |
+
generated['transformed_frame'] = transformed_frame
|
201 |
+
generated['transformed_kp'] = transformed_kp
|
202 |
+
|
203 |
+
## Value loss part
|
204 |
+
if self.loss_weights['equivariance_value'] != 0:
|
205 |
+
value = torch.abs(kp_driving['value'] - transform.warp_coordinates(transformed_kp['value'])).mean()
|
206 |
+
loss_values['equivariance_value'] = self.loss_weights['equivariance_value'] * value
|
207 |
+
|
208 |
+
## jacobian loss part
|
209 |
+
if self.loss_weights['equivariance_jacobian'] != 0:
|
210 |
+
jacobian_transformed = torch.matmul(transform.jacobian(transformed_kp['value']),
|
211 |
+
transformed_kp['jacobian'])
|
212 |
+
|
213 |
+
normed_driving = torch.inverse(kp_driving['jacobian'])
|
214 |
+
normed_transformed = jacobian_transformed
|
215 |
+
value = torch.matmul(normed_driving, normed_transformed)
|
216 |
+
|
217 |
+
eye = torch.eye(2).view(1, 1, 2, 2).type(value.type())
|
218 |
+
|
219 |
+
value = torch.abs(eye - value).mean()
|
220 |
+
loss_values['equivariance_jacobian'] = self.loss_weights['equivariance_jacobian'] * value
|
221 |
+
|
222 |
+
return loss_values, generated
|
223 |
+
|
224 |
+
|
225 |
+
class DiscriminatorFullModel(torch.nn.Module):
|
226 |
+
"""
|
227 |
+
Merge all discriminator related updates into single model for better multi-gpu usage
|
228 |
+
"""
|
229 |
+
|
230 |
+
def __init__(self, kp_extractor, generator, discriminator, train_params):
|
231 |
+
super(DiscriminatorFullModel, self).__init__()
|
232 |
+
self.kp_extractor = kp_extractor
|
233 |
+
self.generator = generator
|
234 |
+
self.discriminator = discriminator
|
235 |
+
self.train_params = train_params
|
236 |
+
self.scales = self.discriminator.scales
|
237 |
+
self.pyramid = ImagePyramide(self.scales, generator.num_channels)
|
238 |
+
if torch.cuda.is_available():
|
239 |
+
self.pyramid = self.pyramid.cuda()
|
240 |
+
|
241 |
+
self.loss_weights = train_params['loss_weights']
|
242 |
+
|
243 |
+
def forward(self, x, generated):
|
244 |
+
pyramide_real = self.pyramid(x['driving'])
|
245 |
+
pyramide_generated = self.pyramid(generated['prediction'].detach())
|
246 |
+
|
247 |
+
kp_driving = generated['kp_driving']
|
248 |
+
discriminator_maps_generated = self.discriminator(pyramide_generated, kp=detach_kp(kp_driving))
|
249 |
+
discriminator_maps_real = self.discriminator(pyramide_real, kp=detach_kp(kp_driving))
|
250 |
+
|
251 |
+
loss_values = {}
|
252 |
+
value_total = 0
|
253 |
+
for scale in self.scales:
|
254 |
+
key = 'prediction_map_%s' % scale
|
255 |
+
value = (1 - discriminator_maps_real[key]) ** 2 + discriminator_maps_generated[key] ** 2
|
256 |
+
value_total += self.loss_weights['discriminator_gan'] * value.mean()
|
257 |
+
loss_values['disc_gan'] = value_total
|
258 |
+
|
259 |
+
return loss_values
|
modules/util.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d
|
7 |
+
|
8 |
+
|
9 |
+
def kp2gaussian(kp, spatial_size, kp_variance):
|
10 |
+
"""
|
11 |
+
Transform a keypoint into gaussian like representation
|
12 |
+
"""
|
13 |
+
mean = kp['value']
|
14 |
+
|
15 |
+
coordinate_grid = make_coordinate_grid(spatial_size, mean.type())
|
16 |
+
number_of_leading_dimensions = len(mean.shape) - 1
|
17 |
+
shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape
|
18 |
+
coordinate_grid = coordinate_grid.view(*shape)
|
19 |
+
repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1)
|
20 |
+
coordinate_grid = coordinate_grid.repeat(*repeats)
|
21 |
+
|
22 |
+
# Preprocess kp shape
|
23 |
+
shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 2)
|
24 |
+
mean = mean.view(*shape)
|
25 |
+
|
26 |
+
mean_sub = (coordinate_grid - mean)
|
27 |
+
|
28 |
+
out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance)
|
29 |
+
|
30 |
+
return out
|
31 |
+
|
32 |
+
|
33 |
+
def make_coordinate_grid(spatial_size, type):
|
34 |
+
"""
|
35 |
+
Create a meshgrid [-1,1] x [-1,1] of given spatial_size.
|
36 |
+
"""
|
37 |
+
h, w = spatial_size
|
38 |
+
x = torch.arange(w).type(type)
|
39 |
+
y = torch.arange(h).type(type)
|
40 |
+
|
41 |
+
x = (2 * (x / (w - 1)) - 1)
|
42 |
+
y = (2 * (y / (h - 1)) - 1)
|
43 |
+
|
44 |
+
yy = y.view(-1, 1).repeat(1, w)
|
45 |
+
xx = x.view(1, -1).repeat(h, 1)
|
46 |
+
|
47 |
+
meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2)
|
48 |
+
|
49 |
+
return meshed
|
50 |
+
|
51 |
+
|
52 |
+
class ResBlock2d(nn.Module):
|
53 |
+
"""
|
54 |
+
Res block, preserve spatial resolution.
|
55 |
+
"""
|
56 |
+
|
57 |
+
def __init__(self, in_features, kernel_size, padding):
|
58 |
+
super(ResBlock2d, self).__init__()
|
59 |
+
self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
|
60 |
+
padding=padding)
|
61 |
+
self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
|
62 |
+
padding=padding)
|
63 |
+
self.norm1 = BatchNorm2d(in_features, affine=True)
|
64 |
+
self.norm2 = BatchNorm2d(in_features, affine=True)
|
65 |
+
|
66 |
+
def forward(self, x):
|
67 |
+
out = self.norm1(x)
|
68 |
+
out = F.relu(out)
|
69 |
+
out = self.conv1(out)
|
70 |
+
out = self.norm2(out)
|
71 |
+
out = F.relu(out)
|
72 |
+
out = self.conv2(out)
|
73 |
+
out += x
|
74 |
+
return out
|
75 |
+
|
76 |
+
|
77 |
+
class UpBlock2d(nn.Module):
|
78 |
+
"""
|
79 |
+
Upsampling block for use in decoder.
|
80 |
+
"""
|
81 |
+
|
82 |
+
def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
|
83 |
+
super(UpBlock2d, self).__init__()
|
84 |
+
|
85 |
+
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
|
86 |
+
padding=padding, groups=groups)
|
87 |
+
self.norm = BatchNorm2d(out_features, affine=True)
|
88 |
+
|
89 |
+
def forward(self, x):
|
90 |
+
out = F.interpolate(x, scale_factor=2)
|
91 |
+
out = self.conv(out)
|
92 |
+
out = self.norm(out)
|
93 |
+
out = F.relu(out)
|
94 |
+
return out
|
95 |
+
|
96 |
+
|
97 |
+
class DownBlock2d(nn.Module):
|
98 |
+
"""
|
99 |
+
Downsampling block for use in encoder.
|
100 |
+
"""
|
101 |
+
|
102 |
+
def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
|
103 |
+
super(DownBlock2d, self).__init__()
|
104 |
+
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
|
105 |
+
padding=padding, groups=groups)
|
106 |
+
self.norm = BatchNorm2d(out_features, affine=True)
|
107 |
+
self.pool = nn.AvgPool2d(kernel_size=(2, 2))
|
108 |
+
|
109 |
+
def forward(self, x):
|
110 |
+
out = self.conv(x)
|
111 |
+
out = self.norm(out)
|
112 |
+
out = F.relu(out)
|
113 |
+
out = self.pool(out)
|
114 |
+
return out
|
115 |
+
|
116 |
+
|
117 |
+
class SameBlock2d(nn.Module):
|
118 |
+
"""
|
119 |
+
Simple block, preserve spatial resolution.
|
120 |
+
"""
|
121 |
+
|
122 |
+
def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1):
|
123 |
+
super(SameBlock2d, self).__init__()
|
124 |
+
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features,
|
125 |
+
kernel_size=kernel_size, padding=padding, groups=groups)
|
126 |
+
self.norm = BatchNorm2d(out_features, affine=True)
|
127 |
+
|
128 |
+
def forward(self, x):
|
129 |
+
out = self.conv(x)
|
130 |
+
out = self.norm(out)
|
131 |
+
out = F.relu(out)
|
132 |
+
return out
|
133 |
+
|
134 |
+
|
135 |
+
class Encoder(nn.Module):
|
136 |
+
"""
|
137 |
+
Hourglass Encoder
|
138 |
+
"""
|
139 |
+
|
140 |
+
def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
|
141 |
+
super(Encoder, self).__init__()
|
142 |
+
|
143 |
+
down_blocks = []
|
144 |
+
for i in range(num_blocks):
|
145 |
+
down_blocks.append(DownBlock2d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)),
|
146 |
+
min(max_features, block_expansion * (2 ** (i + 1))),
|
147 |
+
kernel_size=3, padding=1))
|
148 |
+
self.down_blocks = nn.ModuleList(down_blocks)
|
149 |
+
|
150 |
+
def forward(self, x):
|
151 |
+
outs = [x]
|
152 |
+
for down_block in self.down_blocks:
|
153 |
+
outs.append(down_block(outs[-1]))
|
154 |
+
return outs
|
155 |
+
|
156 |
+
|
157 |
+
class Decoder(nn.Module):
|
158 |
+
"""
|
159 |
+
Hourglass Decoder
|
160 |
+
"""
|
161 |
+
|
162 |
+
def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
|
163 |
+
super(Decoder, self).__init__()
|
164 |
+
|
165 |
+
up_blocks = []
|
166 |
+
|
167 |
+
for i in range(num_blocks)[::-1]:
|
168 |
+
in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1)))
|
169 |
+
out_filters = min(max_features, block_expansion * (2 ** i))
|
170 |
+
up_blocks.append(UpBlock2d(in_filters, out_filters, kernel_size=3, padding=1))
|
171 |
+
|
172 |
+
self.up_blocks = nn.ModuleList(up_blocks)
|
173 |
+
self.out_filters = block_expansion + in_features
|
174 |
+
|
175 |
+
def forward(self, x):
|
176 |
+
out = x.pop()
|
177 |
+
for up_block in self.up_blocks:
|
178 |
+
out = up_block(out)
|
179 |
+
skip = x.pop()
|
180 |
+
out = torch.cat([out, skip], dim=1)
|
181 |
+
return out
|
182 |
+
|
183 |
+
|
184 |
+
class Hourglass(nn.Module):
|
185 |
+
"""
|
186 |
+
Hourglass architecture.
|
187 |
+
"""
|
188 |
+
|
189 |
+
def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
|
190 |
+
super(Hourglass, self).__init__()
|
191 |
+
self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features)
|
192 |
+
self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features)
|
193 |
+
self.out_filters = self.decoder.out_filters
|
194 |
+
|
195 |
+
def forward(self, x):
|
196 |
+
return self.decoder(self.encoder(x))
|
197 |
+
|
198 |
+
|
199 |
+
class AntiAliasInterpolation2d(nn.Module):
|
200 |
+
"""
|
201 |
+
Band-limited downsampling, for better preservation of the input signal.
|
202 |
+
"""
|
203 |
+
def __init__(self, channels, scale):
|
204 |
+
super(AntiAliasInterpolation2d, self).__init__()
|
205 |
+
sigma = (1 / scale - 1) / 2
|
206 |
+
kernel_size = 2 * round(sigma * 4) + 1
|
207 |
+
self.ka = kernel_size // 2
|
208 |
+
self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka
|
209 |
+
|
210 |
+
kernel_size = [kernel_size, kernel_size]
|
211 |
+
sigma = [sigma, sigma]
|
212 |
+
# The gaussian kernel is the product of the
|
213 |
+
# gaussian function of each dimension.
|
214 |
+
kernel = 1
|
215 |
+
meshgrids = torch.meshgrid(
|
216 |
+
[
|
217 |
+
torch.arange(size, dtype=torch.float32)
|
218 |
+
for size in kernel_size
|
219 |
+
]
|
220 |
+
)
|
221 |
+
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
|
222 |
+
mean = (size - 1) / 2
|
223 |
+
kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2))
|
224 |
+
|
225 |
+
# Make sure sum of values in gaussian kernel equals 1.
|
226 |
+
kernel = kernel / torch.sum(kernel)
|
227 |
+
# Reshape to depthwise convolutional weight
|
228 |
+
kernel = kernel.view(1, 1, *kernel.size())
|
229 |
+
kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
|
230 |
+
|
231 |
+
self.register_buffer('weight', kernel)
|
232 |
+
self.groups = channels
|
233 |
+
self.scale = scale
|
234 |
+
inv_scale = 1 / scale
|
235 |
+
self.int_inv_scale = int(inv_scale)
|
236 |
+
|
237 |
+
def forward(self, input):
|
238 |
+
if self.scale == 1.0:
|
239 |
+
return input
|
240 |
+
|
241 |
+
out = F.pad(input, (self.ka, self.kb, self.ka, self.kb))
|
242 |
+
out = F.conv2d(out, weight=self.weight, groups=self.groups)
|
243 |
+
out = out[:, :, ::self.int_inv_scale, ::self.int_inv_scale]
|
244 |
+
|
245 |
+
return out
|
share/doc/networkx-3.0/LICENSE.txt
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
NetworkX is distributed with the 3-clause BSD license.
|
2 |
+
|
3 |
+
::
|
4 |
+
|
5 |
+
Copyright (C) 2004-2023, NetworkX Developers
|
6 |
+
Aric Hagberg <hagberg@lanl.gov>
|
7 |
+
Dan Schult <dschult@colgate.edu>
|
8 |
+
Pieter Swart <swart@lanl.gov>
|
9 |
+
All rights reserved.
|
10 |
+
|
11 |
+
Redistribution and use in source and binary forms, with or without
|
12 |
+
modification, are permitted provided that the following conditions are
|
13 |
+
met:
|
14 |
+
|
15 |
+
* Redistributions of source code must retain the above copyright
|
16 |
+
notice, this list of conditions and the following disclaimer.
|
17 |
+
|
18 |
+
* Redistributions in binary form must reproduce the above
|
19 |
+
copyright notice, this list of conditions and the following
|
20 |
+
disclaimer in the documentation and/or other materials provided
|
21 |
+
with the distribution.
|
22 |
+
|
23 |
+
* Neither the name of the NetworkX Developers nor the names of its
|
24 |
+
contributors may be used to endorse or promote products derived
|
25 |
+
from this software without specific prior written permission.
|
26 |
+
|
27 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
28 |
+
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
29 |
+
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
30 |
+
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
31 |
+
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
32 |
+
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
33 |
+
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
34 |
+
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
35 |
+
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
36 |
+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
37 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
share/doc/networkx-3.0/examples/3d_drawing/README.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
3D Drawing
|
2 |
+
----------
|
share/doc/networkx-3.0/examples/3d_drawing/__pycache__/mayavi2_spring.cpython-310.pyc
ADDED
Binary file (1.1 kB). View file
|
|
share/doc/networkx-3.0/examples/3d_drawing/__pycache__/plot_basic.cpython-310.pyc
ADDED
Binary file (1.49 kB). View file
|
|
share/doc/networkx-3.0/examples/3d_drawing/mayavi2_spring.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
=======
|
3 |
+
Mayavi2
|
4 |
+
=======
|
5 |
+
|
6 |
+
"""
|
7 |
+
|
8 |
+
import networkx as nx
|
9 |
+
import numpy as np
|
10 |
+
from mayavi import mlab
|
11 |
+
|
12 |
+
# some graphs to try
|
13 |
+
# H=nx.krackhardt_kite_graph()
|
14 |
+
# H=nx.Graph();H.add_edge('a','b');H.add_edge('a','c');H.add_edge('a','d')
|
15 |
+
# H=nx.grid_2d_graph(4,5)
|
16 |
+
H = nx.cycle_graph(20)
|
17 |
+
|
18 |
+
# reorder nodes from 0,len(G)-1
|
19 |
+
G = nx.convert_node_labels_to_integers(H)
|
20 |
+
# 3d spring layout
|
21 |
+
pos = nx.spring_layout(G, dim=3, seed=1001)
|
22 |
+
# numpy array of x,y,z positions in sorted node order
|
23 |
+
xyz = np.array([pos[v] for v in sorted(G)])
|
24 |
+
# scalar colors
|
25 |
+
scalars = np.array(list(G.nodes())) + 5
|
26 |
+
|
27 |
+
mlab.figure()
|
28 |
+
|
29 |
+
pts = mlab.points3d(
|
30 |
+
xyz[:, 0],
|
31 |
+
xyz[:, 1],
|
32 |
+
xyz[:, 2],
|
33 |
+
scalars,
|
34 |
+
scale_factor=0.1,
|
35 |
+
scale_mode="none",
|
36 |
+
colormap="Blues",
|
37 |
+
resolution=20,
|
38 |
+
)
|
39 |
+
|
40 |
+
pts.mlab_source.dataset.lines = np.array(list(G.edges()))
|
41 |
+
tube = mlab.pipeline.tube(pts, tube_radius=0.01)
|
42 |
+
mlab.pipeline.surface(tube, color=(0.8, 0.8, 0.8))
|
43 |
+
mlab.orientation_axes()
|
share/doc/networkx-3.0/examples/3d_drawing/plot_basic.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
================
|
3 |
+
Basic matplotlib
|
4 |
+
================
|
5 |
+
|
6 |
+
A basic example of 3D Graph visualization using `mpl_toolkits.mplot_3d`.
|
7 |
+
|
8 |
+
"""
|
9 |
+
|
10 |
+
import networkx as nx
|
11 |
+
import numpy as np
|
12 |
+
import matplotlib.pyplot as plt
|
13 |
+
from mpl_toolkits.mplot3d import Axes3D
|
14 |
+
|
15 |
+
# The graph to visualize
|
16 |
+
G = nx.cycle_graph(20)
|
17 |
+
|
18 |
+
# 3d spring layout
|
19 |
+
pos = nx.spring_layout(G, dim=3, seed=779)
|
20 |
+
# Extract node and edge positions from the layout
|
21 |
+
node_xyz = np.array([pos[v] for v in sorted(G)])
|
22 |
+
edge_xyz = np.array([(pos[u], pos[v]) for u, v in G.edges()])
|
23 |
+
|
24 |
+
# Create the 3D figure
|
25 |
+
fig = plt.figure()
|
26 |
+
ax = fig.add_subplot(111, projection="3d")
|
27 |
+
|
28 |
+
# Plot the nodes - alpha is scaled by "depth" automatically
|
29 |
+
ax.scatter(*node_xyz.T, s=100, ec="w")
|
30 |
+
|
31 |
+
# Plot the edges
|
32 |
+
for vizedge in edge_xyz:
|
33 |
+
ax.plot(*vizedge.T, color="tab:gray")
|
34 |
+
|
35 |
+
|
36 |
+
def _format_axes(ax):
|
37 |
+
"""Visualization options for the 3D axes."""
|
38 |
+
# Turn gridlines off
|
39 |
+
ax.grid(False)
|
40 |
+
# Suppress tick labels
|
41 |
+
for dim in (ax.xaxis, ax.yaxis, ax.zaxis):
|
42 |
+
dim.set_ticks([])
|
43 |
+
# Set axes labels
|
44 |
+
ax.set_xlabel("x")
|
45 |
+
ax.set_ylabel("y")
|
46 |
+
ax.set_zlabel("z")
|
47 |
+
|
48 |
+
|
49 |
+
_format_axes(ax)
|
50 |
+
fig.tight_layout()
|
51 |
+
plt.show()
|
share/doc/networkx-3.0/examples/README.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.. _examples_gallery:
|
2 |
+
|
3 |
+
Gallery
|
4 |
+
=======
|
5 |
+
|
6 |
+
General-purpose and introductory examples for NetworkX.
|
7 |
+
The `tutorial <../tutorial.html>`_ introduces conventions and basic graph
|
8 |
+
manipulations.
|
share/doc/networkx-3.0/examples/algorithms/README.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
Algorithms
|
2 |
+
----------
|
share/doc/networkx-3.0/examples/algorithms/WormNet.v3.benchmark.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
share/doc/networkx-3.0/examples/algorithms/__pycache__/plot_beam_search.cpython-310.pyc
ADDED
Binary file (3.22 kB). View file
|
|
share/doc/networkx-3.0/examples/algorithms/__pycache__/plot_betweenness_centrality.cpython-310.pyc
ADDED
Binary file (2.4 kB). View file
|
|
share/doc/networkx-3.0/examples/algorithms/__pycache__/plot_blockmodel.cpython-310.pyc
ADDED
Binary file (2.75 kB). View file
|
|
share/doc/networkx-3.0/examples/algorithms/__pycache__/plot_circuits.cpython-310.pyc
ADDED
Binary file (2.47 kB). View file
|
|
share/doc/networkx-3.0/examples/algorithms/__pycache__/plot_davis_club.cpython-310.pyc
ADDED
Binary file (1.25 kB). View file
|
|
share/doc/networkx-3.0/examples/algorithms/__pycache__/plot_dedensification.cpython-310.pyc
ADDED
Binary file (2.44 kB). View file
|
|
share/doc/networkx-3.0/examples/algorithms/__pycache__/plot_iterated_dynamical_systems.cpython-310.pyc
ADDED
Binary file (6.55 kB). View file
|
|
share/doc/networkx-3.0/examples/algorithms/__pycache__/plot_krackhardt_centrality.cpython-310.pyc
ADDED
Binary file (915 Bytes). View file
|
|
share/doc/networkx-3.0/examples/algorithms/__pycache__/plot_parallel_betweenness.cpython-310.pyc
ADDED
Binary file (2.55 kB). View file
|
|
share/doc/networkx-3.0/examples/algorithms/__pycache__/plot_rcm.cpython-310.pyc
ADDED
Binary file (1.11 kB). View file
|
|