maxin-cn commited on
Commit
ceb3f0c
1 Parent(s): 6d2e0dc

Update models/__init__.py

Browse files
Files changed (1) hide show
  1. models/__init__.py +1 -21
models/__init__.py CHANGED
@@ -2,8 +2,6 @@ import os
2
  import sys
3
  sys.path.append(os.path.split(sys.path[0])[0])
4
 
5
- from .dit import DiT_models
6
- from .uvit import UViT_models
7
  from .unet import UNet3DConditionModel
8
  from torch.optim.lr_scheduler import LambdaLR
9
 
@@ -28,25 +26,7 @@ def get_lr_scheduler(optimizer, name, **kwargs):
28
 
29
  def get_models(args):
30
 
31
- if 'DiT' in args.model:
32
- return DiT_models[args.model](
33
- input_size=args.latent_size,
34
- num_classes=args.num_classes,
35
- class_guided=args.class_guided,
36
- num_frames=args.num_frames,
37
- use_lora=args.use_lora,
38
- attention_mode=args.attention_mode
39
- )
40
- elif 'UViT' in args.model:
41
- return UViT_models[args.model](
42
- input_size=args.latent_size,
43
- num_classes=args.num_classes,
44
- class_guided=args.class_guided,
45
- num_frames=args.num_frames,
46
- use_lora=args.use_lora,
47
- attention_mode=args.attention_mode
48
- )
49
- elif 'TAV' in args.model:
50
  pretrained_model_path = args.pretrained_model_path
51
  return UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", use_concat=args.use_mask)
52
  else:
 
2
  import sys
3
  sys.path.append(os.path.split(sys.path[0])[0])
4
 
 
 
5
  from .unet import UNet3DConditionModel
6
  from torch.optim.lr_scheduler import LambdaLR
7
 
 
26
 
27
  def get_models(args):
28
 
29
+ if 'TAV' in args.model:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  pretrained_model_path = args.pretrained_model_path
31
  return UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", use_concat=args.use_mask)
32
  else: