Spaces:
Runtime error
Runtime error
import os | |
os.environ['PJRT_DEVICE'] = 'TPU' | |
from tqdm.auto import tqdm | |
import torch | |
from torch.utils.data import DataLoader | |
from torch_xla.core import xla_model | |
from diffusers import UNetPseudo3DConditionModel | |
from dataset import load_dataset | |
class TempoTrainerXLA: | |
def __init__(self, | |
pretrained: str = 'lxj616/make-a-stable-diffusion-video-timelapse', | |
lr: float = 1e-4, | |
dtype: torch.dtype = torch.float32, | |
) -> None: | |
self.dtype = dtype | |
self.device: torch.device = xla_model.xla_device(0) | |
unet: UNetPseudo3DConditionModel = UNetPseudo3DConditionModel.from_pretrained( | |
pretrained, | |
subfolder = 'unet' | |
).to(dtype = dtype, memory_format = torch.contiguous_format) | |
unfreeze_all: bool = False | |
unet = unet.train() | |
if not unfreeze_all: | |
unet.requires_grad_(False) | |
for name, param in unet.named_parameters(): | |
if 'temporal_conv' in name: | |
param.requires_grad_(True) | |
for block in [*unet.down_blocks, unet.mid_block, *unet.up_blocks]: | |
if hasattr(block, 'attentions') and block.attentions is not None: | |
for attn_block in block.attentions: | |
for transformer_block in attn_block.transformer_blocks: | |
transformer_block.requires_grad_(False) | |
transformer_block.attn_temporal.requires_grad_(True) | |
transformer_block.norm_temporal.requires_grad_(True) | |
else: | |
unet.requires_grad_(True) | |
self.model: UNetPseudo3DConditionModel = unet.to(device = self.device) | |
#self.model = torch.compile(self.model, backend = 'aot_torchxla_trace_once') | |
self.params = lambda: filter(lambda p: p.requires_grad, self.model.parameters()) | |
self.optim: torch.optim.Optimizer = torch.optim.AdamW(self.params(), lr = lr) | |
def lr_warmup(warmup_steps: int = 0): | |
def lambda_lr(step: int) -> float: | |
if step < warmup_steps: | |
return step / warmup_steps | |
else: | |
return 1.0 | |
return lambda_lr | |
self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optim, lr_lambda = lr_warmup(warmup_steps = 60), last_epoch = -1) | |
def train(self, dataloader: DataLoader, epochs: int = 1, log_every: int = 1, save_every: int = 1000) -> None: | |
# 'latent_model_input' | |
# 'encoder_hidden_states' | |
# 'timesteps' | |
# 'noise' | |
global_step: int = 0 | |
for epoch in range(epochs): | |
pbar = tqdm(dataloader, dynamic_ncols = True, smoothing = 0.01) | |
for b in pbar: | |
latent_model_input: torch.Tensor = b['latent_model_input'].to(device = self.device) | |
encoder_hidden_states: torch.Tensor = b['encoder_hidden_states'].to(device = self.device) | |
timesteps: torch.Tensor = b['timesteps'].to(device = self.device) | |
noise: torch.Tensor = b['noise'].to(device = self.device) | |
with torch.enable_grad(): | |
self.optim.zero_grad(set_to_none = True) | |
y = self.model(latent_model_input, timesteps, encoder_hidden_states).sample | |
loss = torch.nn.functional.mse_loss(noise, y) | |
loss.backward() | |
self.optim.step() | |
self.scheduler.step() | |
xla_model.mark_step() | |
if global_step % log_every == 0: | |
pbar.set_postfix({ 'loss': loss.detach().item(), 'epoch': epoch }) | |
def main(): | |
pretrained: str = 'lxj616/make-a-stable-diffusion-video-timelapse' | |
dataset_path: str = './storage/dataset/tempofunk' | |
dtype: torch.dtype = torch.bfloat16 | |
trainer = TempoTrainerXLA( | |
pretrained = pretrained, | |
lr = 1e-5, | |
dtype = dtype | |
) | |
dataloader: DataLoader = load_dataset( | |
dataset_path = dataset_path, | |
pretrained = pretrained, | |
batch_size = 1, | |
num_frames = 10, | |
num_workers = 1, | |
dtype = dtype | |
) | |
trainer.train( | |
dataloader = dataloader, | |
epochs = 1000, | |
log_every = 1, | |
save_every = 1000 | |
) | |
if __name__ == '__main__': | |
main() | |