Nadine Rueegg commited on
Commit
dc81e88
1 Parent(s): 5fb530c

(1) adjust licence, (2) save obj instead of glb, (3)u upload newely trained model with regularizers on tail

Browse files
LICENSE CHANGED
@@ -12,6 +12,8 @@ ETH Zurich
12
  Any copyright or patent right is owned by and proprietary material of the
13
 
14
  Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (hereinafter “MPG”; MPI and MPG hereinafter collectively “Max-Planck”)
 
 
15
 
16
  hereinafter the “Licensor”.
17
 
@@ -20,7 +22,7 @@ Licensor grants you (Licensee) personally a single-user, non-exclusive, non-tran
20
 
21
  To install the Data & Software on computers owned, leased or otherwise controlled by you and/or your organization;
22
  To use the Data & Software for the sole purpose of performing non-commercial scientific research, non-commercial education, or non-commercial artistic projects;
23
- Any other use, in particular any use for commercial, pornographic, military, or surveillance, purposes is prohibited. This includes, without limitation, incorporation in a commercial product, use in a commercial service, or production of other artifacts for commercial purposes. The Data & Software may not be used to create fake, libelous, misleading, or defamatory content of any kind excluding analyses in peer-reviewed scientific research. The Data & Software may not be reproduced, modified and/or made available in any form to any third party without Max-Planck’s prior written permission.
24
 
25
  The Data & Software may not be used for pornographic purposes or to generate pornographic material whether commercial or not. This license also prohibits the use of the Software to train methods/algorithms/neural networks/etc. for commercial, pornographic, military, surveillance, or defamatory use of any kind. By downloading the Data & Software, you agree not to reverse engineer it.
26
 
 
12
  Any copyright or patent right is owned by and proprietary material of the
13
 
14
  Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (hereinafter “MPG”; MPI and MPG hereinafter collectively “Max-Planck”)
15
+ and
16
+ ETH Zurich
17
 
18
  hereinafter the “Licensor”.
19
 
 
22
 
23
  To install the Data & Software on computers owned, leased or otherwise controlled by you and/or your organization;
24
  To use the Data & Software for the sole purpose of performing non-commercial scientific research, non-commercial education, or non-commercial artistic projects;
25
+ Any other use, in particular any use for commercial, pornographic, military, or surveillance, purposes is prohibited. This includes, without limitation, incorporation in a commercial product, use in a commercial service, or production of other artifacts for commercial purposes. The Data & Software may not be used to create fake, libelous, misleading, or defamatory content of any kind excluding analyses in peer-reviewed scientific research. The Data & Software may not be reproduced, modified and/or made available in any form to any third party without Max-Planck’s / ETH's prior written permission.
26
 
27
  The Data & Software may not be used for pornographic purposes or to generate pornographic material whether commercial or not. This license also prohibits the use of the Software to train methods/algorithms/neural networks/etc. for commercial, pornographic, military, surveillance, or defamatory use of any kind. By downloading the Data & Software, you agree not to reverse engineer it.
28
 
checkpoint/cvpr23_dm39dnnv3barcv2b_refwithgcpervertisflat0morestanding0_forrelease_v0b/checkpoint.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b2401264e10213a41d6797c23feb7734d01b45ed3a19591230ac9636a38b5fef
3
+ size 639465445
checkpoint/cvpr23_dm39dnnv3barcv2b_refwithgcpervertisflat0morestanding0_forrelease_v0b/refinement_loss_weights_withgc_withvertexwise_addnonflat_forrelease_v0b.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+ {
5
+ "keyp_ref": 0.2,
6
+ "silh_ref": 50.0,
7
+ "pose_legs_side": 1.0,
8
+ "pose_legs_tors": 1.0,
9
+ "pose_tail_side": 0.1,
10
+ "pose_tail_tors": 1.0,
11
+ "pose_spine_side": 0.0,
12
+ "pose_spine_tors": 0.0,
13
+ "reg_trans": 0.0,
14
+ "reg_flength": 0.0,
15
+ "reg_pose": 0.0,
16
+ "gc_plane": 5.0,
17
+ "gc_blowplane": 5.0,
18
+ "gc_vertexwise": 10.0,
19
+ "gc_isflat": 0.5
20
+ }
checkpoint/cvpr23_dm39dnnv3barcv2b_refwithgcpervertisflat0morestanding0_forrelease_v0b/train_withref.py ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # [barc2] python scripts/train.py --workers 12 --checkpoint project22_no3dcgloss_smaldogsilvia_v0 --loss-weight-path barc_loss_weights_no3dcgloss.json --config barc_cfg_train.yaml start --model-file-hg hg_ksp_fromnewanipose_stanext_v0/checkpoint.pth.tar --model-file-3d barc_normflow_pret/checkpoint.pth.tar
3
+ # [barc3] python scripts/train.py --workers 12 --checkpoint project22_no3dcgloss_smaldognadine_v0 --loss-weight-path barc_loss_weights_no3dcgloss.json --config barc_cfg_train.yaml start --model-file-hg hg_ksp_fromnewanipose_stanext_v0/checkpoint.pth.tar --model-file-3d barc_normflow_pret/checkpoint.pth.tar
4
+
5
+ # python scripts/train_withref.py --workers 12 --checkpoint project22_no3dcgloss_smaldognadine_v4_ref_v0 --loss-weight-path barc_loss_weights_no3dcgloss.json --config refinement_cfg_train.yaml continue --model-file-complete project22_no3dcgloss_smaldognadine_v4/checkpoint.pth.tar --new-optimizer 1
6
+ # python scripts/train_withref.py --workers 12 --checkpoint project22_no3dcgloss_smaldognadine_v4_refadd_v0 --loss-weight-path barc_loss_weights_no3dcgloss.json --config refinement_cfg_train.yaml continue --model-file-complete project22_no3dcgloss_smaldognadine_v4/checkpoint.pth.tar --new-optimizer 1
7
+
8
+
9
+
10
+ print('start ...')
11
+ import numpy as np
12
+ import random
13
+ import torch
14
+ import argparse
15
+ import os
16
+ import json
17
+ import torch
18
+ import torch.backends.cudnn
19
+ from torch.nn import DataParallel
20
+ from torch.optim.rmsprop import RMSprop
21
+ from torch.utils.data import DataLoader
22
+ from tqdm import trange, tqdm
23
+ from collections import OrderedDict
24
+ from itertools import chain
25
+ import shutil
26
+
27
+ # set random seeds (we have never changed those and there is probably one missing)
28
+ torch.manual_seed(52)
29
+ np.random.seed(435)
30
+ random.seed(643)
31
+
32
+ import sys
33
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../', 'src'))
34
+ # from combined_model.train_main_image_to_3d_withbreedrel import do_training_epoch, do_validation_epoch
35
+ from combined_model.train_main_image_to_3d_wbr_withref import do_training_epoch, do_validation_epoch
36
+ # from combined_model.model_shape_v7 import ModelImageTo3d_withshape_withproj
37
+ # from combined_model.model_shape_v7_withref import ModelImageTo3d_withshape_withproj
38
+ from combined_model.model_shape_v7_withref_withgraphcnn import ModelImageTo3d_withshape_withproj
39
+
40
+ from combined_model.loss_image_to_3d_withbreedrel import Loss
41
+ from combined_model.loss_image_to_3d_refinement import LossRef
42
+ from stacked_hourglass.utils.misc import save_checkpoint, adjust_learning_rate
43
+ from stacked_hourglass.datasets.samplers.custom_pair_samplers import CustomPairBatchSampler
44
+ from stacked_hourglass.datasets.samplers.custom_gc_sampler import CustomGCSampler
45
+ from stacked_hourglass.datasets.samplers.custom_gc_sampler_noclasses import CustomGCSamplerNoCLass
46
+ from configs.barc_cfg_defaults import get_cfg_defaults, update_cfg_global_with_yaml, get_cfg_global_updated
47
+
48
+
49
+
50
+ class PrintLog():
51
+ def __init__(self, out_file):
52
+ self.out_file = out_file
53
+ # self._print_to_file('------------------------------------------------------')
54
+ def clean_file(self):
55
+ # this function deletes all content of that file
56
+ with open(self.out_file,'w') as file:
57
+ pass
58
+ def _print_to_file(self, *args, **kwargs):
59
+ with open(self.out_file,'a') as file:
60
+ print(*args, **kwargs, file=file)
61
+ def print(self, *args, **kwargs):
62
+ print(*args, **kwargs)
63
+ self._print_to_file(*args, **kwargs)
64
+ def print_log_only(self, *args, **kwargs):
65
+ self._print_to_file(*args, **kwargs)
66
+
67
+
68
+ def main(args):
69
+
70
+ # load all configs and weights
71
+ # step 1: load default configs
72
+ # step 2: load updates from .yaml file
73
+ # step 3: load training weights
74
+ path_config = os.path.join(get_cfg_defaults().barc_dir, 'src', 'configs', args.config)
75
+ update_cfg_global_with_yaml(path_config)
76
+ cfg = get_cfg_global_updated()
77
+ with open(os.path.join(os.path.dirname(__file__), '../', 'src', 'configs', args.loss_weight_path), 'r') as f:
78
+ weight_dict = json.load(f)
79
+ with open(os.path.join(os.path.dirname(__file__), '../', 'src', 'configs', args.loss_weight_ref_path), 'r') as f:
80
+ weight_dict_ref = json.load(f)
81
+ # Select the hardware device to use for training.
82
+ if torch.cuda.is_available() and cfg.device=='cuda':
83
+ device = torch.device('cuda', torch.cuda.current_device())
84
+ torch.backends.cudnn.benchmark = True
85
+ else:
86
+ device = torch.device('cpu')
87
+
88
+ # import data loader
89
+ if cfg.data.DATASET == 'stanext24_easy':
90
+ from stacked_hourglass.datasets.stanext24_easy import StanExtEasy as StanExt
91
+ elif cfg.data.DATASET == 'stanext24':
92
+ from stacked_hourglass.datasets.stanext24 import StanExt
93
+ elif cfg.data.DATASET == 'stanext24_withgc':
94
+ from stacked_hourglass.datasets.stanext24_withgc import StanExtGC as StanExt ###################
95
+ elif cfg.data.DATASET == 'stanext24_withgc_big':
96
+ from stacked_hourglass.datasets.stanext24_withgc_v2 import StanExtGC as StanExt
97
+ elif cfg.data.DATASET == 'stanext24_withgc_cs0':
98
+ from stacked_hourglass.datasets.stanext24_withgc_v2 import StanExtGC as StanExt
99
+ # -> same dataset as in stanext24_withgc_big, but different training sampler
100
+ elif cfg.data.DATASET == 'stanext24_withgc_csaddnonflat':
101
+ from stacked_hourglass.datasets.stanext24_withgc_v2 import StanExtGC as StanExt
102
+ elif cfg.data.DATASET == 'stanext24_withgc_csaddnonflatmorestanding':
103
+ from stacked_hourglass.datasets.stanext24_withgc_v2 import StanExtGC as StanExt
104
+ elif cfg.data.DATASET == 'stanext24_withgc_noclasses':
105
+ from stacked_hourglass.datasets.stanext24_withgc_v2 import StanExtGC as StanExt
106
+ else:
107
+ raise NotImplementedError
108
+
109
+ # Disable gradient calculations by default.
110
+ torch.set_grad_enabled(False)
111
+
112
+ # create checkpoint dir
113
+ path_checkpoint = os.path.join(cfg.paths.ROOT_CHECKPOINT_PATH, args.checkpoint)
114
+ os.makedirs(path_checkpoint, exist_ok=True)
115
+
116
+ # copy the python train file
117
+ in_train_file = os.path.abspath(__file__)
118
+ out_train_file_dir = os.path.join(path_checkpoint)
119
+ shutil.copy2(in_train_file, out_train_file_dir)
120
+ shutil.copy2(os.path.join(os.path.dirname(__file__), '../', 'src', 'configs', args.loss_weight_ref_path), path_checkpoint)
121
+
122
+ # create printlog
123
+ pl = PrintLog(out_file=path_checkpoint + '/partial_log.txt')
124
+ pl.print('------------------------------------------------------')
125
+
126
+ # print some information
127
+ pl.print('dataset: ' + cfg.data.DATASET)
128
+ pl.print('structure_pose_net: ' + cfg.params.STRUCTURE_POSE_NET)
129
+ pl.print('refinement network type: ' + cfg.params.REF_NET_TYPE)
130
+ pl.print('refinement network detach shape: ' + str(cfg.params.REF_DETACH_SHAPE))
131
+ pl.print('graphcnn_type: ' + cfg.params.GRAPHCNN_TYPE)
132
+ pl.print('isflat_type: ' + cfg.params.ISFLAT_TYPE)
133
+ pl.print('shaperef_type: ' + cfg.params.SHAPEREF_TYPE)
134
+ pl.print('smal_model_type: ' + cfg.smal.SMAL_MODEL_TYPE)
135
+ pl.print('train_parts: ' + cfg.optim.TRAIN_PARTS)
136
+
137
+ # load model
138
+ if weight_dict['partseg'] > 0:
139
+ render_partseg = True
140
+ else:
141
+ render_partseg = False
142
+ model = ModelImageTo3d_withshape_withproj(
143
+ smal_model_type=cfg.smal.SMAL_MODEL_TYPE, smal_keyp_conf=cfg.smal.SMAL_KEYP_CONF, \
144
+ num_stage_comb=cfg.params.NUM_STAGE_COMB, num_stage_heads=cfg.params.NUM_STAGE_HEADS, \
145
+ num_stage_heads_pose=cfg.params.NUM_STAGE_HEADS_POSE, trans_sep=cfg.params.TRANS_SEP, \
146
+ arch=cfg.params.ARCH, n_joints=cfg.params.N_JOINTS, n_classes=cfg.params.N_CLASSES, \
147
+ n_keyp=cfg.params.N_KEYP, n_bones=cfg.params.N_BONES, n_betas=cfg.params.N_BETAS, n_betas_limbs=cfg.params.N_BETAS_LIMBS, \
148
+ n_breeds=cfg.params.N_BREEDS, n_z=cfg.params.N_Z, image_size=cfg.params.IMG_SIZE, \
149
+ silh_no_tail=cfg.params.SILH_NO_TAIL, thr_keyp_sc=cfg.params.KP_THRESHOLD, add_z_to_3d_input=cfg.params.ADD_Z_TO_3D_INPUT,
150
+ n_segbps=cfg.params.N_SEGBPS, add_segbps_to_3d_input=cfg.params.ADD_SEGBPS_TO_3D_INPUT, add_partseg=cfg.params.ADD_PARTSEG, n_partseg=cfg.params.N_PARTSEG, \
151
+ fix_flength=cfg.params.FIX_FLENGTH, render_partseg=render_partseg, structure_z_to_betas=cfg.params.STRUCTURE_Z_TO_B, \
152
+ structure_pose_net=cfg.params.STRUCTURE_POSE_NET, nf_version=cfg.params.NF_VERSION, ref_net_type=cfg.params.REF_NET_TYPE, \
153
+ ref_detach_shape=cfg.params.REF_DETACH_SHAPE, graphcnn_type=cfg.params.GRAPHCNN_TYPE, isflat_type=cfg.params.ISFLAT_TYPE, shaperef_type=cfg.params.SHAPEREF_TYPE)
154
+ model = model.to(device)
155
+
156
+ # define parameters that should be optimized
157
+ if cfg.optim.TRAIN_PARTS == 'all_with_shapedirs': # do not use this option!
158
+ params = chain(model.breed_model.parameters(), \
159
+ model.model_3d.parameters(), \
160
+ model.model_learnable_shapedirs.parameters())
161
+ elif cfg.optim.TRAIN_PARTS == 'all_without_shapedirs':
162
+ params = chain(model.breed_model.parameters(), \
163
+ model.model_3d.parameters())
164
+ elif cfg.optim.TRAIN_PARTS == 'model3donly_noshape_noshapedirs':
165
+ params = chain(model.model_3d.parameters())
166
+ elif cfg.optim.TRAIN_PARTS == 'all_noresnetclass_without_shapedirs':
167
+ params = chain(model.breed_model.linear_breeds.parameters(), \
168
+ model.model_3d.parameters())
169
+ elif cfg.optim.TRAIN_PARTS == 'breed_model':
170
+ params = chain(model.breed_model.parameters())
171
+ elif cfg.optim.TRAIN_PARTS == 'flength_trans_betas_only':
172
+ params = chain(model.model_3d.output_info_linear_models[1].parameters(), \
173
+ model.model_3d.output_info_linear_models[2].parameters(), \
174
+ model.model_3d.output_info_linear_models[3].parameters(), \
175
+ model.breed_model.linear_betas.parameters())
176
+ elif cfg.optim.TRAIN_PARTS == 'all_without_shapedirs_with_refinement':
177
+ params = chain(model.breed_model.parameters(), \
178
+ model.model_3d.parameters(), \
179
+ model.refinement_model.parameters())
180
+ elif cfg.optim.TRAIN_PARTS == 'refinement_model':
181
+ params = chain(model.refinement_model.parameters())
182
+ elif cfg.optim.TRAIN_PARTS == 'refinement_model_and_shape':
183
+ params = chain(model.refinement_model.parameters(), \
184
+ model.breed_model.parameters())
185
+ else:
186
+ raise NotImplementedError
187
+
188
+ # create optimizer
189
+ optimizer = RMSprop(params, lr=cfg.optim.LR, momentum=cfg.optim.MOMENTUM, weight_decay=cfg.optim.WEIGHT_DECAY)
190
+ start_epoch = 0
191
+ best_acc = 0
192
+
193
+ # load pretrained model or parts of the model
194
+ if args.command == "start":
195
+ path_model_file_hg = os.path.join(cfg.paths.ROOT_CHECKPOINT_PATH, args.model_file_hg)
196
+ path_model_file_shape = os.path.join(cfg.paths.ROOT_CHECKPOINT_PATH, args.model_file_shape)
197
+ path_model_file_3d = os.path.join(cfg.paths.ROOT_CHECKPOINT_PATH, args.model_file_3d)
198
+ # (1) load pretrained shape model
199
+ # -> usually we do not work with a pretrained model here
200
+ if os.path.isfile(path_model_file_shape):
201
+ pl.print('Loading model weights for shape network from a separate file: {}'.format(path_model_file_shape))
202
+ checkpoint_shape = torch.load(path_model_file_shape)
203
+ state_dict_shape = checkpoint_shape['state_dict']
204
+ # model.load_state_dict(state_dict_complete, strict=False)
205
+ # --- Problem: there is the last layer which predicts betas and we might change the numbers of betas
206
+ # NEW: allow to load the model even if the number of betas is different
207
+ model_dict = model.state_dict()
208
+ # i) filter out unnecessary keys and remove weights for layers that have changed shapes (smal.shapedirs, resnet18.fc.weight, ...)
209
+ state_dict_shape_new = OrderedDict()
210
+ for k, v in state_dict_shape.items():
211
+ if k in model_dict:
212
+ if v.shape==model_dict[k].shape:
213
+ state_dict_shape_new[k] = v
214
+ else:
215
+ state_dict_shape_new[k] = model_dict[k]
216
+ # ii) overwrite entries in the existing state dict
217
+ model_dict.update(state_dict_shape_new)
218
+ # iii) load the new state dict
219
+ model.load_state_dict(model_dict)
220
+ # (2) load pretrained 3d network
221
+ # -> we recommend to load a pretrained model
222
+ if os.path.isfile(path_model_file_3d):
223
+ assert os.path.isfile(path_model_file_3d)
224
+ pl.print('Loading model weights (2d-to-3d) from file: {}'.format(path_model_file_3d))
225
+ checkpoint_3d = torch.load(path_model_file_3d)
226
+ state_dict_3d = checkpoint_3d['state_dict']
227
+ model.load_state_dict(state_dict_3d, strict=False)
228
+ else:
229
+ pl.print('no model (2d-to-3d) loaded')
230
+ # (3) initialize weights for stacked hourglass
231
+ # -> the stacked hourglass needs to be pretrained
232
+ assert os.path.isfile(path_model_file_hg)
233
+ pl.print('Loading model weights (stacked hourglass) from file: {}'.format(path_model_file_hg))
234
+ checkpoint = torch.load(path_model_file_hg)
235
+ state_dict = checkpoint['state_dict']
236
+ if sorted(state_dict.keys())[0].startswith('module.'):
237
+ new_state_dict = OrderedDict()
238
+ for k, v in state_dict.items():
239
+ name = k[7:] # remove 'module.' of dataparallel
240
+ new_state_dict[name]=v
241
+ state_dict = new_state_dict
242
+ model.stacked_hourglass.load_state_dict(state_dict)
243
+ elif args.command == "continue":
244
+ path_model_file_complete = os.path.join(cfg.paths.ROOT_CHECKPOINT_PATH, args.model_file_complete)
245
+ pl.print('Loading complete model weights from file: {}'.format(path_model_file_complete))
246
+ checkpoint = torch.load(path_model_file_complete)
247
+ model.load_state_dict(checkpoint['state_dict'], strict=False)
248
+ if args.new_optimizer == 0:
249
+ pl.print('load optimizer state')
250
+ start_epoch = checkpoint['epoch']
251
+ best_acc = checkpoint['best_acc']
252
+ optimizer.load_state_dict(checkpoint['optimizer'])
253
+ else:
254
+ pl.print('do not load optimizer state')
255
+
256
+
257
+
258
+ # load loss module
259
+ loss_module = Loss(smal_model_type=cfg.smal.SMAL_MODEL_TYPE, data_info=StanExt.DATA_INFO, nf_version=cfg.params.NF_VERSION).to(device)
260
+ loss_module_ref = LossRef(smal_model_type=cfg.smal.SMAL_MODEL_TYPE, data_info=StanExt.DATA_INFO, nf_version=cfg.params.NF_VERSION).to(device)
261
+
262
+ # print weight_dict
263
+ pl.print("weight_dict: ")
264
+ pl.print(weight_dict)
265
+ pl.print("weight_dict_ref: ")
266
+ pl.print(weight_dict_ref)
267
+
268
+
269
+ if cfg.data.DATASET in ['stanext24_withgc', 'stanext24_withgc_big']:
270
+ # NEW for ground contact
271
+ pl.print("WARNING: we use a data sampler with ground contact that is not fully ready!")
272
+ pl.print('use a very standard data loader that is not suitable for breed losses!')
273
+ dataset_mode='complete_with_gc'
274
+ train_dataset = StanExt(image_path=None, is_train=True, dataset_mode=dataset_mode, V12=cfg.data.V12, val_opt=cfg.data.VAL_OPT, shorten_dataset_to=cfg.data.SHORTEN_VAL_DATASET_TO)
275
+ train_loader = DataLoader(
276
+ train_dataset,
277
+ batch_size=cfg.optim.BATCH_SIZE, shuffle=True,
278
+ num_workers=args.workers, pin_memory=True,
279
+ drop_last=True)
280
+
281
+ val_dataset = StanExt(image_path=None, is_train=False, dataset_mode=dataset_mode, V12=cfg.data.V12, val_opt=cfg.data.VAL_OPT)
282
+ val_loader = DataLoader(
283
+ val_dataset,
284
+ batch_size=cfg.optim.BATCH_SIZE, shuffle=False,
285
+ num_workers=args.workers, pin_memory=True,
286
+ drop_last=True) # drop last, need to check that!!
287
+ elif cfg.data.DATASET in ['stanext24_withgc_cs0', 'stanext24_withgc_csaddnonflat', 'stanext24_withgc_csaddnonflatmorestanding']: # cs0: custom sampler 0
288
+ dataset_mode='complete_with_gc'
289
+ if cfg.data.DATASET == 'stanext24_withgc_cs0':
290
+ add_nonflat = False
291
+ more_standing = False
292
+ assert cfg.optim.BATCH_SIZE == 12
293
+ pl.print('use CustomGCSampler without nonflat images')
294
+ elif cfg.data.DATASET == 'stanext24_withgc_csaddnonflat':
295
+ add_nonflat = True
296
+ more_standing = False
297
+ pl.print('use CustomGCSampler (with 12 flat and with 2 nonflat images)')
298
+ assert cfg.optim.BATCH_SIZE == 14
299
+ else: # stanext24_withgc_csaddnonflatmorestanding
300
+ add_nonflat = True
301
+ more_standing = True
302
+ pl.print('use CustomGCSampler (with 12 flat and with 2 nonflat images, more standing poses)')
303
+ assert cfg.optim.BATCH_SIZE == 14
304
+ train_dataset = StanExt(image_path=None, is_train=True, dataset_mode=dataset_mode, V12=cfg.data.V12, val_opt=cfg.data.VAL_OPT, add_nonflat=add_nonflat)
305
+ data_sampler_info_gc = train_dataset.get_data_sampler_info_gc()
306
+ batch_sampler = CustomGCSampler
307
+ train_custom_batch_sampler = batch_sampler(data_sampler_info_gc=data_sampler_info_gc, batch_size=cfg.optim.BATCH_SIZE, add_nonflat=add_nonflat, more_standing=more_standing)
308
+ train_loader = DataLoader(
309
+ train_dataset,
310
+ batch_sampler=train_custom_batch_sampler,
311
+ num_workers=args.workers, pin_memory=True)
312
+ val_dataset = StanExt(image_path=None, is_train=False, dataset_mode=dataset_mode, V12=cfg.data.V12, val_opt=cfg.data.VAL_OPT)
313
+ val_loader = DataLoader(
314
+ val_dataset,
315
+ batch_size=cfg.optim.BATCH_SIZE, shuffle=False,
316
+ num_workers=args.workers, pin_memory=True,
317
+ drop_last=True) # drop last, need to check that!!
318
+ elif cfg.data.DATASET == 'stanext24_withgc_noclasses':
319
+ dataset_mode='complete_with_gc'
320
+ add_nonflat = True
321
+ assert cfg.optim.BATCH_SIZE == 14
322
+ pl.print('use CustomGCSamplerNoCLass (with nonflat images)')
323
+ train_dataset = StanExt(image_path=None, is_train=True, dataset_mode=dataset_mode, V12=cfg.data.V12, val_opt=cfg.data.VAL_OPT, add_nonflat=add_nonflat)
324
+ data_sampler_info_gc = train_dataset.get_data_sampler_info_gc()
325
+ batch_sampler = CustomGCSamplerNoCLass
326
+ train_custom_batch_sampler = batch_sampler(data_sampler_info_gc=data_sampler_info_gc, batch_size=cfg.optim.BATCH_SIZE, add_nonflat=add_nonflat)
327
+ train_loader = DataLoader(
328
+ train_dataset,
329
+ batch_sampler=train_custom_batch_sampler,
330
+ num_workers=args.workers, pin_memory=True)
331
+ val_dataset = StanExt(image_path=None, is_train=False, dataset_mode=dataset_mode, V12=cfg.data.V12, val_opt=cfg.data.VAL_OPT)
332
+ val_loader = DataLoader(
333
+ val_dataset,
334
+ batch_size=cfg.optim.BATCH_SIZE, shuffle=False,
335
+ num_workers=args.workers, pin_memory=True,
336
+ drop_last=True) # drop last, need to check that!!
337
+
338
+
339
+ else:
340
+
341
+ dataset_mode='complete'
342
+
343
+ # load data sampler
344
+ if ('0' in weight_dict['breed_options']) or ('1' in weight_dict['breed_options']) or ('2' in weight_dict['breed_options']):
345
+ # remark: you will not need this data loader, it was only relevant for some of our experiments related to clades
346
+ batch_sampler = CustomBatchSampler
347
+ pl.print('use CustomBatchSampler')
348
+ else:
349
+ # this sampler will always load two dogs of the same breed right after each other
350
+ batch_sampler = CustomPairBatchSampler
351
+ pl.print('use CustomPairBatchSampler')
352
+
353
+ # load dataset (train and {test or val})
354
+ train_dataset = StanExt(image_path=None, is_train=True, dataset_mode=dataset_mode, V12=cfg.data.V12, val_opt=cfg.data.VAL_OPT)
355
+ data_sampler_info = train_dataset.get_data_sampler_info()
356
+ train_custom_batch_sampler = batch_sampler(data_sampler_info=data_sampler_info, batch_size=cfg.optim.BATCH_SIZE)
357
+ train_loader = DataLoader(
358
+ train_dataset,
359
+ batch_sampler=train_custom_batch_sampler,
360
+ num_workers=args.workers, pin_memory=True)
361
+
362
+ if cfg.data.VAL_METRICS == 'no_loss':
363
+ # this is the option that we choose normally
364
+ # here we load val/test images using a standard sampler
365
+ # using a standard sampler at test time is better, but it prevents us from evaluating all the loss functions used at training time
366
+ # -> with this option here we calculate iou and pck for the val/test batches
367
+ val_dataset = StanExt(image_path=None, is_train=False, dataset_mode=dataset_mode, V12=cfg.data.V12, val_opt=cfg.data.VAL_OPT, shorten_dataset_to=cfg.data.SHORTEN_VAL_DATASET_TO)
368
+ val_loader = DataLoader(
369
+ val_dataset,
370
+ batch_size=cfg.optim.BATCH_SIZE, shuffle=False,
371
+ num_workers=args.workers, pin_memory=True)
372
+ else:
373
+ # this is an option we might choose for debugging purposes
374
+ # here we load val/test images using our custom sampler for pairs of dogs of the same breed
375
+ val_dataset = StanExt(image_path=None, is_train=False, dataset_mode=dataset_mode, V12=cfg.data.V12, val_opt=cfg.data.VAL_OPT)
376
+ data_sampler_info = val_dataset.get_data_sampler_info()
377
+ val_custom_batch_sampler = batch_sampler(data_sampler_info=data_sampler_info, batch_size=cfg.optim.BATCH_SIZE, drop_last=True)
378
+ val_loader = DataLoader(
379
+ val_dataset,
380
+ batch_sampler=val_custom_batch_sampler,
381
+ num_workers=args.workers, pin_memory=True)
382
+
383
+
384
+
385
+
386
+
387
+
388
+ # save results one time before starting
389
+ '''
390
+ save_imgs_path = None # '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/debugging/rubbish/'
391
+ valid_string, valid_acc = do_validation_epoch(val_loader, model, loss_module, loss_module_ref, device,
392
+ StanExt.DATA_INFO,
393
+ weight_dict=weight_dict,
394
+ weight_dict_ref=weight_dict_ref,
395
+ acc_joints=StanExt.ACC_JOINTS,
396
+ metrics=cfg.data.VAL_METRICS,
397
+ save_imgs_path=save_imgs_path)
398
+ predictions = np.zeros((1,1))
399
+ valid_loss = - valid_acc
400
+ # print metrics
401
+ epoch = 0
402
+ tqdm.write(' | VAL: ' + valid_string)
403
+
404
+ # remember best acc (acc is actually iou) and save checkpoint
405
+ is_best = valid_acc > best_acc
406
+ best_acc = max(valid_acc, best_acc)
407
+ save_checkpoint({
408
+ 'epoch': epoch + 1,
409
+ 'arch': cfg.params.ARCH,
410
+ 'state_dict': model.state_dict(),
411
+ 'best_acc': best_acc,
412
+ 'optimizer' : optimizer.state_dict(),
413
+ }, predictions, is_best, checkpoint=path_checkpoint, snapshot=args.snapshot)
414
+ '''
415
+
416
+
417
+
418
+
419
+ # train and eval
420
+ lr = cfg.optim.LR
421
+ pl.print('initial learning rate: ' + str(lr))
422
+ for epoch in trange(0, cfg.optim.EPOCHS, desc='Overall', ascii=True):
423
+ lr = adjust_learning_rate(optimizer, epoch, lr, cfg.optim.SCHEDULE, cfg.optim.GAMMA)
424
+ if epoch >= start_epoch:
425
+ # train for one epoch
426
+ train_string, train_acc = do_training_epoch(train_loader, model, loss_module, loss_module_ref, device,
427
+ StanExt.DATA_INFO,
428
+ optimizer,
429
+ weight_dict=weight_dict,
430
+ weight_dict_ref=weight_dict_ref,
431
+ acc_joints=StanExt.ACC_JOINTS)
432
+ # evaluate on validation set
433
+ save_imgs_path = None # '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/debugging/rubbish/'
434
+ valid_string, valid_acc = do_validation_epoch(val_loader, model, loss_module, loss_module_ref, device,
435
+ StanExt.DATA_INFO,
436
+ weight_dict=weight_dict,
437
+ weight_dict_ref=weight_dict_ref,
438
+ acc_joints=StanExt.ACC_JOINTS,
439
+ metrics=cfg.data.VAL_METRICS,
440
+ save_imgs_path=save_imgs_path)
441
+ predictions = np.zeros((1,1))
442
+ train_loss = - train_acc
443
+ valid_loss = - valid_acc
444
+ # print metrics
445
+ tqdm.write(f'[{epoch + 1:3d}/{cfg.optim.EPOCHS:3d}] lr={lr:0.2e}' + ' | TRAIN: ' + train_string + ' | VAL: ' + valid_string)
446
+ pl.print_log_only(f'[{epoch + 1:3d}/{cfg.optim.EPOCHS:3d}] lr={lr:0.2e}' + ' | TRAIN: ' + train_string + ' | VAL: ' + valid_string)
447
+
448
+ # remember best acc (acc is actually iou) and save checkpoint
449
+ is_best = valid_acc > best_acc
450
+ best_acc = max(valid_acc, best_acc)
451
+ save_checkpoint({
452
+ 'epoch': epoch + 1,
453
+ 'arch': cfg.params.ARCH,
454
+ 'state_dict': model.state_dict(),
455
+ 'best_acc': best_acc,
456
+ 'optimizer' : optimizer.state_dict(),
457
+ }, predictions, is_best, checkpoint=path_checkpoint, snapshot=args.snapshot)
458
+
459
+
460
+ if __name__ == '__main__':
461
+
462
+ # use as follows:
463
+ # python scripts/train_image_to_3d_withshape_withbreedrel.py --workers 12 --checkpoint=barc_new_v2 start --model-file-hg dogs_hg8_ksp_24_sev12_v3/model_best.pth.tar --model-file-3d Normflow_CVPR_set8_v3k2_v1/checkpoint.pth.tar
464
+
465
+ parser = argparse.ArgumentParser(description='Train a image-to-3d model.')
466
+
467
+ # arguments that we have no matter if we start a new training run or if we load the full network where training is somewhere in the middle
468
+ parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH',
469
+ help='path to save checkpoint (default: checkpoint)')
470
+ parser.add_argument('-cg', '--config', default='barc_cfg_train.yaml', type=str, metavar='PATH',
471
+ help='name of config file (default: barc_cfg_train.yaml within src/configs folder)')
472
+ parser.add_argument('-lw', '--loss-weight-path', default='barc_loss_weights.json', type=str, metavar='PATH',
473
+ help='name of json file which contains the loss weights')
474
+ parser.add_argument('-lwr', '--loss-weight-ref-path', default='refinement_loss_weights.json', type=str, metavar='PATH',
475
+ help='name of json file which contains the loss weights for the refinement network')
476
+ parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
477
+ help='number of data loading workers (default: 4)')
478
+ parser.add_argument('--snapshot', default=0, type=int,
479
+ help='save models for every #snapshot epochs (default: 0)')
480
+
481
+ # argument that decides if we continue a training run (loading full network) or start from scratch (using only pretrained parts)
482
+ subparsers = parser.add_subparsers(dest="command") # parser.add_subparsers(help="subparsers")
483
+ parser_start = subparsers.add_parser('start') # start training
484
+ parser_continue = subparsers.add_parser('continue') # continue training
485
+
486
+ # arguments that we only have if we start a new training run
487
+ # remark: some parts can / need to be pretrained (stacked hourglass, 3d network)
488
+ parser_start.add_argument('--model-file-hg', default='', type=str, metavar='PATH',
489
+ help='path to saved model weights (stacked hour glass)')
490
+ parser_start.add_argument('--model-file-3d', default='', type=str, metavar='PATH',
491
+ help='path to saved model weights (2d-to-3d model)')
492
+ parser_start.add_argument('--model-file-shape', default='', type=str, metavar='PATH',
493
+ help='path to saved model weights (resnet, shape branch)')
494
+
495
+ # arguments that we only have if we continue training the full network
496
+ parser_continue.add_argument('--model-file-complete', default='', type=str, metavar='PATH',
497
+ help='path to saved model weights (full model)')
498
+ parser_continue.add_argument('--new-optimizer', default=0, type=int,
499
+ help='should we restart the optimizer? 0:no, 1: yes (default: 0)')
500
+ main(parser.parse_args())
501
+
502
+
scripts/gradio_demo.py CHANGED
@@ -152,7 +152,7 @@ def run_bbox_inference(input_image):
152
  out_path = os.path.join(cfg.paths.ROOT_OUT_PATH, 'gradio_examples', 'test2.png')
153
  img, bbox = detect_object(model=model_bbox, img_path_or_img=input_image, confidence=0.5)
154
  fig = plt.figure() # plt.figure(figsize=(20,30))
155
- plt.imsave(out_path, img)
156
  return img, bbox
157
 
158
 
@@ -406,7 +406,7 @@ def run_bite_inference(input_image, bbox=None, apply_ttopt=True):
406
  # save mesh
407
  my_mesh_tri = trimesh.Trimesh(vertices=smal_verts[0, ...].detach().cpu().numpy(), faces=faces_prep[0, ...].detach().cpu().numpy(), process=False, maintain_order=True)
408
  my_mesh_tri.visual.vertex_colors = vert_colors
409
- my_mesh_tri.export(root_out_path + name + '_res_e000' + '.obj')
410
 
411
  else:
412
 
@@ -465,11 +465,13 @@ def run_bite_inference(input_image, bbox=None, apply_ttopt=True):
465
  pred_keyp = pred_keyp_raw[:, :24, :]
466
 
467
  # save silhouette reprojection visualization
 
468
  if i==0:
469
  img_silh = Image.fromarray(np.uint8(255*pred_silh_images[0, 0, :, :].detach().cpu().numpy())).convert('RGB')
470
  img_silh.save(root_out_path_details + name + '_silh_ainit.png')
471
  my_mesh_tri = trimesh.Trimesh(vertices=smal_verts[0, ...].detach().cpu().numpy(), faces=faces_prep[0, ...].detach().cpu().numpy(), process=False, maintain_order=True)
472
  my_mesh_tri.export(root_out_path_details + name + '_res_ainit.obj')
 
473
 
474
  # silhouette loss
475
  diff_silh = torch.abs(pred_silh_images[0, 0, :, :] - target_hg_silh)
@@ -533,7 +535,7 @@ def run_bite_inference(input_image, bbox=None, apply_ttopt=True):
533
  loop.set_description(f"Body Fitting = {total_loss.item():.3f}")
534
 
535
  # save the result three times (0, 150, 300)
536
- if i % 150 == 0:
537
  # save silhouette image
538
  img_silh = Image.fromarray(np.uint8(255*pred_silh_images[0, 0, :, :].detach().cpu().numpy())).convert('RGB')
539
  img_silh.save(root_out_path_details + name + '_silh_e' + format(i, '03d') + '.png')
@@ -547,14 +549,14 @@ def run_bite_inference(input_image, bbox=None, apply_ttopt=True):
547
  pred_tex_max = np.max(pred_tex, axis=2)
548
  im_masked[pred_tex_max<0.01, :] = input_image_np[pred_tex_max<0.01, :]
549
  out_path = root_out_path + name + '_comp_pred_e' + format(i, '03d') + '.png'
550
- plt.imsave(out_path, im_masked)
551
  # save mesh
552
  my_mesh_tri = trimesh.Trimesh(vertices=smal_verts[0, ...].detach().cpu().numpy(), faces=faces_prep[0, ...].detach().cpu().numpy(), process=False, maintain_order=True)
553
  my_mesh_tri.visual.vertex_colors = vert_colors
554
- my_mesh_tri.export(root_out_path + name + '_res_e' + format(i, '03d') + '.obj')
555
  # save focal length (together with the mesh this is enough to create an overlay in blender)
556
- out_file_flength = root_out_path_details + name + '_flength_e' + format(i, '03d') # + '.npz'
557
- np.save(out_file_flength, optimed_camera_flength.detach().cpu().numpy())
558
  current_i += 1
559
 
560
  # prepare output mesh
@@ -564,8 +566,8 @@ def run_bite_inference(input_image, bbox=None, apply_ttopt=True):
564
  [0, 0, 1, 1],
565
  [0, 0, 0, 1]])
566
  result_path = os.path.join(save_imgs_path, test_name_list[0] + '_z')
567
- mesh.export(file_obj=result_path + '.glb')
568
- result_gltf = result_path + '.glb'
569
  return result_gltf
570
 
571
 
 
152
  out_path = os.path.join(cfg.paths.ROOT_OUT_PATH, 'gradio_examples', 'test2.png')
153
  img, bbox = detect_object(model=model_bbox, img_path_or_img=input_image, confidence=0.5)
154
  fig = plt.figure() # plt.figure(figsize=(20,30))
155
+ # plt.imsave(out_path, img)
156
  return img, bbox
157
 
158
 
 
406
  # save mesh
407
  my_mesh_tri = trimesh.Trimesh(vertices=smal_verts[0, ...].detach().cpu().numpy(), faces=faces_prep[0, ...].detach().cpu().numpy(), process=False, maintain_order=True)
408
  my_mesh_tri.visual.vertex_colors = vert_colors
409
+ # my_mesh_tri.export(root_out_path + name + '_res_e000' + '.obj')
410
 
411
  else:
412
 
 
465
  pred_keyp = pred_keyp_raw[:, :24, :]
466
 
467
  # save silhouette reprojection visualization
468
+ """
469
  if i==0:
470
  img_silh = Image.fromarray(np.uint8(255*pred_silh_images[0, 0, :, :].detach().cpu().numpy())).convert('RGB')
471
  img_silh.save(root_out_path_details + name + '_silh_ainit.png')
472
  my_mesh_tri = trimesh.Trimesh(vertices=smal_verts[0, ...].detach().cpu().numpy(), faces=faces_prep[0, ...].detach().cpu().numpy(), process=False, maintain_order=True)
473
  my_mesh_tri.export(root_out_path_details + name + '_res_ainit.obj')
474
+ """
475
 
476
  # silhouette loss
477
  diff_silh = torch.abs(pred_silh_images[0, 0, :, :] - target_hg_silh)
 
535
  loop.set_description(f"Body Fitting = {total_loss.item():.3f}")
536
 
537
  # save the result three times (0, 150, 300)
538
+ if i == 300: # if i % 150 == 0:
539
  # save silhouette image
540
  img_silh = Image.fromarray(np.uint8(255*pred_silh_images[0, 0, :, :].detach().cpu().numpy())).convert('RGB')
541
  img_silh.save(root_out_path_details + name + '_silh_e' + format(i, '03d') + '.png')
 
549
  pred_tex_max = np.max(pred_tex, axis=2)
550
  im_masked[pred_tex_max<0.01, :] = input_image_np[pred_tex_max<0.01, :]
551
  out_path = root_out_path + name + '_comp_pred_e' + format(i, '03d') + '.png'
552
+ # plt.imsave(out_path, im_masked)
553
  # save mesh
554
  my_mesh_tri = trimesh.Trimesh(vertices=smal_verts[0, ...].detach().cpu().numpy(), faces=faces_prep[0, ...].detach().cpu().numpy(), process=False, maintain_order=True)
555
  my_mesh_tri.visual.vertex_colors = vert_colors
556
+ # my_mesh_tri.export(root_out_path + name + '_res_e' + format(i, '03d') + '.obj')
557
  # save focal length (together with the mesh this is enough to create an overlay in blender)
558
+ # out_file_flength = root_out_path_details + name + '_flength_e' + format(i, '03d') # + '.npz'
559
+ # np.save(out_file_flength, optimed_camera_flength.detach().cpu().numpy())
560
  current_i += 1
561
 
562
  # prepare output mesh
 
566
  [0, 0, 1, 1],
567
  [0, 0, 0, 1]])
568
  result_path = os.path.join(save_imgs_path, test_name_list[0] + '_z')
569
+ mesh.export(file_obj=result_path + '.obj')
570
+ result_gltf = result_path + '.obj'
571
  return result_gltf
572
 
573