lopho commited on
Commit
b2f876f
β€’
1 Parent(s): 149cc2d

forgot about the nested package structure

Browse files
Files changed (23) hide show
  1. makeavid_sd/README.md +0 -1
  2. makeavid_sd/{makeavid_sd/__init__.py β†’ __init__.py} +0 -0
  3. makeavid_sd/{makeavid_sd/flax_impl β†’ flax_impl}/__init__.py +0 -0
  4. makeavid_sd/{makeavid_sd/flax_impl β†’ flax_impl}/dataset.py +0 -0
  5. makeavid_sd/{makeavid_sd/flax_impl β†’ flax_impl}/flax_attention_pseudo3d.py +0 -0
  6. makeavid_sd/{makeavid_sd/flax_impl β†’ flax_impl}/flax_embeddings.py +0 -0
  7. makeavid_sd/{makeavid_sd/flax_impl β†’ flax_impl}/flax_resnet_pseudo3d.py +0 -0
  8. makeavid_sd/{makeavid_sd/flax_impl β†’ flax_impl}/flax_trainer.py +0 -0
  9. makeavid_sd/{makeavid_sd/flax_impl β†’ flax_impl}/flax_unet_pseudo3d_blocks.py +0 -0
  10. makeavid_sd/{makeavid_sd/flax_impl β†’ flax_impl}/flax_unet_pseudo3d_condition.py +0 -0
  11. makeavid_sd/{makeavid_sd/flax_impl β†’ flax_impl}/train.py +0 -0
  12. makeavid_sd/{makeavid_sd/flax_impl β†’ flax_impl}/train.sh +0 -0
  13. makeavid_sd/{makeavid_sd/inference.py β†’ inference.py} +0 -0
  14. makeavid_sd/requirements.txt +0 -2
  15. makeavid_sd/setup.py +0 -11
  16. makeavid_sd/{makeavid_sd/torch_impl β†’ torch_impl}/__init__.py +0 -0
  17. makeavid_sd/{makeavid_sd/torch_impl β†’ torch_impl}/torch_attention_pseudo3d.py +0 -0
  18. makeavid_sd/{makeavid_sd/torch_impl β†’ torch_impl}/torch_cross_attention.py +0 -0
  19. makeavid_sd/{makeavid_sd/torch_impl β†’ torch_impl}/torch_embeddings.py +0 -0
  20. makeavid_sd/{makeavid_sd/torch_impl β†’ torch_impl}/torch_resnet_pseudo3d.py +0 -0
  21. makeavid_sd/{makeavid_sd/torch_impl β†’ torch_impl}/torch_unet_pseudo3d_blocks.py +0 -0
  22. makeavid_sd/{makeavid_sd/torch_impl β†’ torch_impl}/torch_unet_pseudo3d_condition.py +0 -0
  23. makeavid_sd/trainer_xla.py +0 -104
makeavid_sd/README.md DELETED
@@ -1 +0,0 @@
1
- # makeavid-sd-tpu
 
 
makeavid_sd/{makeavid_sd/__init__.py β†’ __init__.py} RENAMED
File without changes
makeavid_sd/{makeavid_sd/flax_impl β†’ flax_impl}/__init__.py RENAMED
File without changes
makeavid_sd/{makeavid_sd/flax_impl β†’ flax_impl}/dataset.py RENAMED
File without changes
makeavid_sd/{makeavid_sd/flax_impl β†’ flax_impl}/flax_attention_pseudo3d.py RENAMED
File without changes
makeavid_sd/{makeavid_sd/flax_impl β†’ flax_impl}/flax_embeddings.py RENAMED
File without changes
makeavid_sd/{makeavid_sd/flax_impl β†’ flax_impl}/flax_resnet_pseudo3d.py RENAMED
File without changes
makeavid_sd/{makeavid_sd/flax_impl β†’ flax_impl}/flax_trainer.py RENAMED
File without changes
makeavid_sd/{makeavid_sd/flax_impl β†’ flax_impl}/flax_unet_pseudo3d_blocks.py RENAMED
File without changes
makeavid_sd/{makeavid_sd/flax_impl β†’ flax_impl}/flax_unet_pseudo3d_condition.py RENAMED
File without changes
makeavid_sd/{makeavid_sd/flax_impl β†’ flax_impl}/train.py RENAMED
File without changes
makeavid_sd/{makeavid_sd/flax_impl β†’ flax_impl}/train.sh RENAMED
File without changes
makeavid_sd/{makeavid_sd/inference.py β†’ inference.py} RENAMED
File without changes
makeavid_sd/requirements.txt DELETED
@@ -1,2 +0,0 @@
1
- torch
2
- torch_xla
 
 
 
makeavid_sd/setup.py DELETED
@@ -1,11 +0,0 @@
1
- from setuptools import setup
2
- setup(
3
- name = 'makeavid_sd',
4
- version = '0.1.0',
5
- description = 'makeavid sd',
6
- author = 'Lopho',
7
- author_email = 'contact@lopho.org',
8
- platforms = ['any'],
9
- license = 'GNU Affero General Public License v3',
10
- url = 'http://github.com/lopho/makeavid-sd-tpu'
11
- )
 
 
 
 
 
 
 
 
 
 
 
 
makeavid_sd/{makeavid_sd/torch_impl β†’ torch_impl}/__init__.py RENAMED
File without changes
makeavid_sd/{makeavid_sd/torch_impl β†’ torch_impl}/torch_attention_pseudo3d.py RENAMED
File without changes
makeavid_sd/{makeavid_sd/torch_impl β†’ torch_impl}/torch_cross_attention.py RENAMED
File without changes
makeavid_sd/{makeavid_sd/torch_impl β†’ torch_impl}/torch_embeddings.py RENAMED
File without changes
makeavid_sd/{makeavid_sd/torch_impl β†’ torch_impl}/torch_resnet_pseudo3d.py RENAMED
File without changes
makeavid_sd/{makeavid_sd/torch_impl β†’ torch_impl}/torch_unet_pseudo3d_blocks.py RENAMED
File without changes
makeavid_sd/{makeavid_sd/torch_impl β†’ torch_impl}/torch_unet_pseudo3d_condition.py RENAMED
File without changes
makeavid_sd/trainer_xla.py DELETED
@@ -1,104 +0,0 @@
1
- import os
2
- os.environ['PJRT_DEVICE'] = 'TPU'
3
-
4
- from tqdm.auto import tqdm
5
- import torch
6
- from torch.utils.data import DataLoader
7
- from torch_xla.core import xla_model
8
- from diffusers import UNetPseudo3DConditionModel
9
- from dataset import load_dataset
10
-
11
-
12
- class TempoTrainerXLA:
13
- def __init__(self,
14
- pretrained: str = 'lxj616/make-a-stable-diffusion-video-timelapse',
15
- lr: float = 1e-4,
16
- dtype: torch.dtype = torch.float32,
17
- ) -> None:
18
- self.dtype = dtype
19
- self.device: torch.device = xla_model.xla_device(0)
20
- unet: UNetPseudo3DConditionModel = UNetPseudo3DConditionModel.from_pretrained(
21
- pretrained,
22
- subfolder = 'unet'
23
- ).to(dtype = dtype, memory_format = torch.contiguous_format)
24
- unfreeze_all: bool = False
25
- unet = unet.train()
26
- if not unfreeze_all:
27
- unet.requires_grad_(False)
28
- for name, param in unet.named_parameters():
29
- if 'temporal_conv' in name:
30
- param.requires_grad_(True)
31
- for block in [*unet.down_blocks, unet.mid_block, *unet.up_blocks]:
32
- if hasattr(block, 'attentions') and block.attentions is not None:
33
- for attn_block in block.attentions:
34
- for transformer_block in attn_block.transformer_blocks:
35
- transformer_block.requires_grad_(False)
36
- transformer_block.attn_temporal.requires_grad_(True)
37
- transformer_block.norm_temporal.requires_grad_(True)
38
- else:
39
- unet.requires_grad_(True)
40
- self.model: UNetPseudo3DConditionModel = unet.to(device = self.device)
41
- #self.model = torch.compile(self.model, backend = 'aot_torchxla_trace_once')
42
- self.params = lambda: filter(lambda p: p.requires_grad, self.model.parameters())
43
- self.optim: torch.optim.Optimizer = torch.optim.AdamW(self.params(), lr = lr)
44
- def lr_warmup(warmup_steps: int = 0):
45
- def lambda_lr(step: int) -> float:
46
- if step < warmup_steps:
47
- return step / warmup_steps
48
- else:
49
- return 1.0
50
- return lambda_lr
51
- self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optim, lr_lambda = lr_warmup(warmup_steps = 60), last_epoch = -1)
52
-
53
- @torch.no_grad()
54
- def train(self, dataloader: DataLoader, epochs: int = 1, log_every: int = 1, save_every: int = 1000) -> None:
55
- # 'latent_model_input'
56
- # 'encoder_hidden_states'
57
- # 'timesteps'
58
- # 'noise'
59
- global_step: int = 0
60
- for epoch in range(epochs):
61
- pbar = tqdm(dataloader, dynamic_ncols = True, smoothing = 0.01)
62
- for b in pbar:
63
- latent_model_input: torch.Tensor = b['latent_model_input'].to(device = self.device)
64
- encoder_hidden_states: torch.Tensor = b['encoder_hidden_states'].to(device = self.device)
65
- timesteps: torch.Tensor = b['timesteps'].to(device = self.device)
66
- noise: torch.Tensor = b['noise'].to(device = self.device)
67
- with torch.enable_grad():
68
- self.optim.zero_grad(set_to_none = True)
69
- y = self.model(latent_model_input, timesteps, encoder_hidden_states).sample
70
- loss = torch.nn.functional.mse_loss(noise, y)
71
- loss.backward()
72
- self.optim.step()
73
- self.scheduler.step()
74
- xla_model.mark_step()
75
- if global_step % log_every == 0:
76
- pbar.set_postfix({ 'loss': loss.detach().item(), 'epoch': epoch })
77
-
78
- def main():
79
- pretrained: str = 'lxj616/make-a-stable-diffusion-video-timelapse'
80
- dataset_path: str = './storage/dataset/tempofunk'
81
- dtype: torch.dtype = torch.bfloat16
82
- trainer = TempoTrainerXLA(
83
- pretrained = pretrained,
84
- lr = 1e-5,
85
- dtype = dtype
86
- )
87
- dataloader: DataLoader = load_dataset(
88
- dataset_path = dataset_path,
89
- pretrained = pretrained,
90
- batch_size = 1,
91
- num_frames = 10,
92
- num_workers = 1,
93
- dtype = dtype
94
- )
95
- trainer.train(
96
- dataloader = dataloader,
97
- epochs = 1000,
98
- log_every = 1,
99
- save_every = 1000
100
- )
101
-
102
- if __name__ == '__main__':
103
- main()
104
-