kyleleey commited on
Commit
288ffdd
·
1 Parent(s): 0352152

remove unused parts

Browse files
Files changed (1) hide show
  1. video3d/model_ddp.py +4 -85
video3d/model_ddp.py CHANGED
@@ -1170,99 +1170,18 @@ class Unsup3DDDP:
1170
 
1171
  # mask distribution
1172
  self.enable_mask_distribution = cfgs.get('enable_mask_distribution', False)
 
1173
  self.random_mask_law = cfgs.get('random_mask_law', 'batch_swap_noy') # batch_swap, batch_swap_noy, # random_azimuth # random_all
1174
  self.mask_distribution_path = cfgs.get('mask_distribution_path', None)
1175
- if self.enable_mask_distribution and (self.mask_distribution_path is not None):
1176
- self.class_mask_distribution = {}
1177
- for category in os.listdir(self.mask_distribution_path):
1178
- # Here we assume the category names are identical
1179
- distribution_file = osp.join(self.mask_distribution_path, category, "raw_mask_distribution.npy")
1180
- distribution = np.load(distribution_file)
1181
- self.class_mask_distribution.update(
1182
- {
1183
- category: distribution # [256, 256]
1184
- }
1185
- )
1186
- self.mask_distribution_loss_weight = cfgs.get("mask_distribution_loss_weight", 0.1)
1187
- self.mask_distribution_loss_freq = cfgs.get("mask_distribution_loss_freq", 1)
1188
-
1189
- self.mask_distribution_average = cfgs.get("mask_distribution_average", False)
1190
-
1191
- else:
1192
- self.enable_mask_distribution = False
1193
 
1194
  self.enable_clip = cfgs.get('enable_clip', False)
1195
  self.enable_clip = False
1196
- # if self.enable_clip:
1197
- # self.clip_model, _ = clip.load('ViT-B/32', self.device)
1198
- # self.clip_model = self.clip_model.eval().requires_grad_(False)
1199
- # self.clip_mean = [0.48145466, 0.4578275, 0.40821073]
1200
- # self.clip_std = [0.26862954, 0.26130258, 0.27577711]
1201
- # self.clip_reso = 224
1202
- # self.clip_render_size = 64
1203
- # self.enable_clip_text = cfgs.get('enable_clip_text', False)
1204
- # if self.enable_clip_text:
1205
- # self.clip_text_feature = {}
1206
- # for category_name in ['bear', 'elephant', 'horse', 'sheep', 'cow', 'zebra', 'giraffe']:
1207
- # text_input = clip.tokenize(['A photo of ' + category_name]).to(self.device)
1208
- # text_feature = self.clip_model.encode_text(text_input).detach() # [1, 512]
1209
- # self.clip_text_feature.update({category_name: text_feature})
1210
 
1211
  self.enable_disc = cfgs.get('enable_disc', False)
1212
- if self.enable_disc:
1213
- self.mask_discriminator_iter = cfgs.get('mask_discriminator_iter', [0, 0])
1214
- # this module is not in netInstance or netPrior
1215
-
1216
- self.mask_disc_feat_condition = cfgs.get('mask_disc_feat_condition', False)
1217
- if self.mask_disc_feat_condition:
1218
- self.mask_disc = discriminator_architecture.DCDiscriminator(in_dim=(cfgs.get('dim_of_classes', 128) + 1)).to(self.device)
1219
- else:
1220
- self.mask_disc = discriminator_architecture.DCDiscriminator(in_dim=(len(list(self.netPrior.category_id_map.keys())) + 1)).to(self.device)
1221
-
1222
- self.disc_gt = cfgs.get('disc_gt', True)
1223
- self.disc_iv = cfgs.get('disc_iv', False) # whether to use input view render in disc loss
1224
- self.disc_iv_label = cfgs.get('disc_iv_label', 'Fake')
1225
- self.disc_reg_mul = cfgs.get('disc_reg_mul', 10.)
1226
-
1227
- self.record_mask_gt = None
1228
- self.record_mask_iv = None
1229
- self.record_mask_rv = None
1230
- self.discriminator_loss = 0.
1231
- self.discriminator_loss_weight = cfgs.get('discriminator_loss_weight', 0.1)
1232
 
1233
- # the local texture for fine-tune process stage
1234
- if (self.cfgs.get('texture_way', None) is not None) or self.cfgs.get('gan_tex', False):
1235
- if self.cfgs.get('gan_tex', False):
1236
- self.few_shot_gan_tex = True
1237
- self.few_shot_gan_tex_reso = self.cfgs.get('few_shot_gan_tex_reso', 64) # used to render novel view, will upsample to out_image_size ASAP
1238
- self.few_shot_gan_tex_patch = self.cfgs.get('few_shot_gan_tex_patch', 0) # used to sample patch size on out_image_size image
1239
- if self.few_shot_gan_tex_patch > 0:
1240
- self.few_shot_gan_tex_patch_max = self.cfgs.get('few_shot_gan_tex_patch_max', 128)
1241
- assert self.few_shot_gan_tex_patch_max > self.few_shot_gan_tex_patch
1242
- self.few_shot_gan_tex_patch_num = self.cfgs.get('few_shot_gan_tex_patch_num', 1)
1243
- self.discriminator_texture = discriminator_architecture.DCDiscriminator(in_dim=3, img_size=self.few_shot_gan_tex_patch).to(self.device)
1244
- else:
1245
- self.discriminator_texture = discriminator_architecture.DCDiscriminator(in_dim=3, img_size=self.out_image_size).to(self.device)
1246
-
1247
- self.few_shot_gan_tex_real = self.cfgs.get('few_shot_gan_tex_real', 'gt')
1248
- self.few_shot_gan_tex_fake = self.cfgs.get('few_shot_gan_tex_fake', 'rv')
1249
- else:
1250
- self.few_shot_gan_tex = False
1251
-
1252
- if self.cfgs.get('clip_tex', False):
1253
- self.few_shot_clip_tex = True
1254
- self.clip_model, _ = clip.load('ViT-B/32', self.device)
1255
- self.clip_model = self.clip_model.eval().requires_grad_(False)
1256
- self.clip_mean = [0.48145466, 0.4578275, 0.40821073]
1257
- self.clip_std = [0.26862954, 0.26130258, 0.27577711]
1258
- self.clip_reso = 224
1259
- self.enable_clip_text = False
1260
- else:
1261
- self.few_shot_clip_tex = False
1262
-
1263
- else:
1264
- self.few_shot_gan_tex = False
1265
- self.few_shot_clip_tex = False
1266
 
1267
  self.enable_sds = cfgs.get('enable_sds', False)
1268
  self.enable_vsd = cfgs.get('enable_vsd', False)
 
1170
 
1171
  # mask distribution
1172
  self.enable_mask_distribution = cfgs.get('enable_mask_distribution', False)
1173
+ self.enable_mask_distribution = False
1174
  self.random_mask_law = cfgs.get('random_mask_law', 'batch_swap_noy') # batch_swap, batch_swap_noy, # random_azimuth # random_all
1175
  self.mask_distribution_path = cfgs.get('mask_distribution_path', None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1176
 
1177
  self.enable_clip = cfgs.get('enable_clip', False)
1178
  self.enable_clip = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1179
 
1180
  self.enable_disc = cfgs.get('enable_disc', False)
1181
+ self.enable_disc = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1182
 
1183
+ self.few_shot_gan_tex = False
1184
+ self.few_shot_clip_tex = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1185
 
1186
  self.enable_sds = cfgs.get('enable_sds', False)
1187
  self.enable_vsd = cfgs.get('enable_vsd', False)