mkshing commited on
Commit
47390c8
1 Parent(s): e2a20af

apply v0.2.0

Browse files
Files changed (6) hide show
  1. app_inference.py +1 -1
  2. app_training.py +1 -2
  3. inference.py +10 -2
  4. requirements.txt +1 -1
  5. train_svdiff.py +108 -64
  6. trainer.py +2 -0
app_inference.py CHANGED
@@ -12,7 +12,7 @@ from utils import find_exp_dirs
12
 
13
  SAMPLE_MODEL_IDS = [
14
  'svdiff-library/svdiff_dog_example',
15
- 'mshing/svdiff_kumamon_example',
16
  ]
17
 
18
 
 
12
 
13
  SAMPLE_MODEL_IDS = [
14
  'svdiff-library/svdiff_dog_example',
15
+ 'svdiff-library/svdiff_chair_example',
16
  ]
17
 
18
 
app_training.py CHANGED
@@ -64,7 +64,7 @@ def create_training_demo(trainer: Trainer,
64
  label='Resolution')
65
  num_training_steps = gr.Number(
66
  label='Number of Training Steps', value=1000, precision=0)
67
- learning_rate = gr.Number(label='Learning Rate', value=0.005)
68
  gradient_accumulation = gr.Number(
69
  label='Number of Gradient Accumulation',
70
  value=1,
@@ -91,7 +91,6 @@ def create_training_demo(trainer: Trainer,
91
  gr.Markdown('''
92
  - The base model must be a model that is compatible with [diffusers](https://github.com/huggingface/diffusers) library.
93
  - It takes a few minutes to download the base model first.
94
- - It will take about 8 minutes to train for 1000 steps with a T4 GPU.
95
  - You may want to try a small number of steps first, like 1, to see if everything works fine in your environment.
96
  - You can check the training status by pressing the "Open logs" button if you are running this on your Space.
97
  - You need to set the environment variable `WANDB_API_KEY` if you'd like to use [W&B](https://wandb.ai/site). See [W&B documentation](https://docs.wandb.ai/guides/track/advanced/environment-variables).
 
64
  label='Resolution')
65
  num_training_steps = gr.Number(
66
  label='Number of Training Steps', value=1000, precision=0)
67
+ learning_rate = gr.Number(label='Learning Rate', value=0.001)
68
  gradient_accumulation = gr.Number(
69
  label='Number of Gradient Accumulation',
70
  value=1,
 
91
  gr.Markdown('''
92
  - The base model must be a model that is compatible with [diffusers](https://github.com/huggingface/diffusers) library.
93
  - It takes a few minutes to download the base model first.
 
94
  - You may want to try a small number of steps first, like 1, to see if everything works fine in your environment.
95
  - You can check the training status by pressing the "Open logs" button if you are running this on your Space.
96
  - You need to set the environment variable `WANDB_API_KEY` if you'd like to use [W&B](https://wandb.ai/site). See [W&B documentation](https://docs.wandb.ai/guides/track/advanced/environment-variables).
inference.py CHANGED
@@ -8,7 +8,7 @@ import PIL.Image
8
  import torch
9
  from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
10
  from huggingface_hub import ModelCard
11
- from svdiff_pytorch import load_unet_for_svdiff, SCHEDULER_MAPPING, image_grid
12
 
13
 
14
 
@@ -58,18 +58,26 @@ class InferencePipeline:
58
  for module in unet.modules():
59
  if hasattr(module, "perform_svd"):
60
  module.perform_svd()
61
- unet = unet.to(self.device, dtype=torch.float16)
 
 
 
 
 
 
62
  if base_model_id != self.base_model_id:
63
  if self.device.type == 'cpu':
64
  pipe = DiffusionPipeline.from_pretrained(
65
  base_model_id,
66
  unet=unet,
 
67
  use_auth_token=self.hf_token
68
  )
69
  else:
70
  pipe = DiffusionPipeline.from_pretrained(
71
  base_model_id,
72
  unet=unet,
 
73
  torch_dtype=torch.float16,
74
  use_auth_token=self.hf_token
75
  )
 
8
  import torch
9
  from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
10
  from huggingface_hub import ModelCard
11
+ from svdiff_pytorch import load_unet_for_svdiff, load_text_encoder_for_svdiff, SCHEDULER_MAPPING, image_grid
12
 
13
 
14
 
 
58
  for module in unet.modules():
59
  if hasattr(module, "perform_svd"):
60
  module.perform_svd()
61
+ if self.device.type != 'cpu':
62
+ unet = unet.to(self.device, dtype=torch.float16)
63
+ text_encoder = load_text_encoder_for_svdiff(base_model_id, spectral_shifts_ckpt=model_id, subfolder="text_encoder")
64
+ if self.device.type != 'cpu':
65
+ text_encoder = text_encoder.to(self.device, dtype=torch.float16)
66
+ else:
67
+ text_encoder = text_encoder.to(self.device)
68
  if base_model_id != self.base_model_id:
69
  if self.device.type == 'cpu':
70
  pipe = DiffusionPipeline.from_pretrained(
71
  base_model_id,
72
  unet=unet,
73
+ text_encoder=text_encoder,
74
  use_auth_token=self.hf_token
75
  )
76
  else:
77
  pipe = DiffusionPipeline.from_pretrained(
78
  base_model_id,
79
  unet=unet,
80
+ text_encoder=text_encoder,
81
  torch_dtype=torch.float16,
82
  use_auth_token=self.hf_token
83
  )
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- svdiff-pytorch
2
  bitsandbytes==0.35.0
3
  python-slugify==7.0.0
4
  tomesd
 
1
+ svdiff-pytorch>=0.2.0
2
  bitsandbytes==0.35.0
3
  python-slugify==7.0.0
4
  tomesd
train_svdiff.py CHANGED
@@ -7,6 +7,7 @@ import warnings
7
  from pathlib import Path
8
  from typing import Optional
9
  from packaging import version
 
10
 
11
  import numpy as np
12
  import torch
@@ -22,7 +23,7 @@ from PIL import Image
22
  from torch.utils.data import Dataset
23
  from torchvision import transforms
24
  from tqdm.auto import tqdm
25
- from transformers import AutoTokenizer, PretrainedConfig
26
 
27
  import diffusers
28
  from diffusers import __version__
@@ -33,7 +34,7 @@ from diffusers import (
33
  StableDiffusionPipeline,
34
  DPMSolverMultistepScheduler,
35
  )
36
- from svdiff_pytorch import load_unet_for_svdiff, SCHEDULER_MAPPING
37
  from diffusers.loaders import AttnProcsLayers
38
  from diffusers.optimization import get_scheduler
39
  from diffusers.utils import check_min_version, is_wandb_available
@@ -78,26 +79,6 @@ These are SVDiff weights for {base_model}. The weights were trained on {prompt}
78
  f.write(yaml + model_card)
79
 
80
 
81
- def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
82
- text_encoder_config = PretrainedConfig.from_pretrained(
83
- pretrained_model_name_or_path,
84
- subfolder="text_encoder",
85
- revision=revision,
86
- )
87
- model_class = text_encoder_config.architectures[0]
88
-
89
- if model_class == "CLIPTextModel":
90
- from transformers import CLIPTextModel
91
-
92
- return CLIPTextModel
93
- elif model_class == "RobertaSeriesModelWithTransformation":
94
- from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
95
-
96
- return RobertaSeriesModelWithTransformation
97
- else:
98
- raise ValueError(f"{model_class} is not supported.")
99
-
100
-
101
  def parse_args(input_args=None):
102
  parser = argparse.ArgumentParser(description="Simple example of a training script.")
103
  parser.add_argument(
@@ -271,9 +252,15 @@ def parse_args(input_args=None):
271
  parser.add_argument(
272
  "--learning_rate",
273
  type=float,
274
- default=5e-4,
275
  help="Initial learning rate (after the potential warmup period) to use.",
276
  )
 
 
 
 
 
 
277
  parser.add_argument(
278
  "--scale_lr",
279
  action="store_true",
@@ -380,6 +367,11 @@ def parse_args(input_args=None):
380
  parser.add_argument(
381
  "--enable_token_merging", action="store_true", help="Whether or not to use tomesd on prior generation"
382
  )
 
 
 
 
 
383
  if input_args is not None:
384
  args = parser.parse_args(input_args)
385
  else:
@@ -594,6 +586,11 @@ def main(args):
594
  # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
595
  # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
596
  # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
 
 
 
 
 
597
  # Make one log on every process with the configuration for debugging.
598
  logging.basicConfig(
599
  format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -700,14 +697,14 @@ def main(args):
700
  use_fast=False,
701
  )
702
 
703
- # import correct text encoder class
704
- text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
705
-
706
  # Load scheduler and models
707
  noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
708
- text_encoder = text_encoder_cls.from_pretrained(
709
- args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
710
- )
 
 
 
711
  vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
712
  unet = load_unet_for_svdiff(args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=True)
713
 
@@ -716,26 +713,26 @@ def main(args):
716
  text_encoder.requires_grad_(False)
717
  unet.requires_grad_(False)
718
  optim_params = []
 
719
  for n, p in unet.named_parameters():
720
  if "delta" in n:
721
  p.requires_grad = True
722
- optim_params.append(p)
 
 
 
 
 
 
 
 
 
 
 
 
723
  total_params = sum(p.numel() for p in optim_params)
724
  print(f"Number of Trainable Parameters: {total_params * 1.e-6:.2f} M")
725
 
726
- # For mixed precision training we cast the text_encoder and vae weights to half-precision
727
- # as these models are only used for inference, keeping weights in full precision is not required.
728
- weight_dtype = torch.float32
729
- if accelerator.mixed_precision == "fp16":
730
- weight_dtype = torch.float16
731
- elif accelerator.mixed_precision == "bf16":
732
- weight_dtype = torch.bfloat16
733
-
734
- # Move unet, vae and text_encoder to device and cast to weight_dtype
735
- # unet.to(accelerator.device, dtype=weight_dtype)
736
- vae.to(accelerator.device, dtype=weight_dtype)
737
- text_encoder.to(accelerator.device, dtype=weight_dtype)
738
-
739
  if args.enable_xformers_memory_efficient_attention:
740
  if is_xformers_available():
741
  import xformers
@@ -751,12 +748,26 @@ def main(args):
751
 
752
  if args.gradient_checkpointing:
753
  unet.enable_gradient_checkpointing()
 
 
754
 
755
- if args.scale_lr:
756
- args.learning_rate = (
757
- args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
 
 
 
 
 
 
758
  )
759
 
 
 
 
 
 
 
760
  # Enable TF32 for faster training on Ampere GPUs,
761
  # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
762
  if args.allow_tf32:
@@ -782,7 +793,7 @@ def main(args):
782
 
783
  # Optimizer creation
784
  optimizer = optimizer_class(
785
- optim_params,
786
  lr=args.learning_rate,
787
  betas=(args.adam_beta1, args.adam_beta2),
788
  weight_decay=args.adam_weight_decay,
@@ -826,9 +837,29 @@ def main(args):
826
  )
827
 
828
  # Prepare everything with our `accelerator`.
829
- unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
830
- unet, optimizer, train_dataloader, lr_scheduler
831
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
832
 
833
  # We need to recalculate our total training steps as the size of the training dataloader may have changed.
834
  num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@@ -842,14 +873,27 @@ def main(args):
842
  if accelerator.is_main_process:
843
  accelerator.init_trackers("svdiff-pytorch", config=vars(args))
844
 
845
- def save_weights(step):
 
 
 
 
 
846
  # Create the pipeline using using the trained modules and save it.
847
  if accelerator.is_main_process:
848
- save_path = os.path.join(args.output_dir, f"checkpoint-{step}")
 
849
  os.makedirs(save_path, exist_ok=True)
850
- unet_model = accelerator.unwrap_model(unet, keep_fp32_wrapper=True)
851
- state_dict = {k: v for k, v in unet_model.state_dict().items() if "delta" in k}
 
852
  save_file(state_dict, os.path.join(save_path, "spectral_shifts.safetensors"))
 
 
 
 
 
 
853
  print(f"[*] Weights saved at {save_path}")
854
 
855
  # Train!
@@ -897,6 +941,8 @@ def main(args):
897
 
898
  for epoch in range(first_epoch, args.num_train_epochs):
899
  unet.train()
 
 
900
  for step, batch in enumerate(train_dataloader):
901
  # Skip steps until we reach the resumed step
902
  if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
@@ -952,7 +998,11 @@ def main(args):
952
 
953
  accelerator.backward(loss)
954
  if accelerator.sync_gradients:
955
- params_to_clip = unet.parameters()
 
 
 
 
956
  accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
957
  optimizer.step()
958
  lr_scheduler.step()
@@ -970,7 +1020,7 @@ def main(args):
970
  # accelerator.save_state(save_path)
971
  # logger.info(f"Saved state to {save_path}")
972
 
973
- logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
974
  progress_bar.set_postfix(**logs)
975
  accelerator.log(logs, step=global_step)
976
 
@@ -982,14 +1032,8 @@ def main(args):
982
  log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch)
983
 
984
  accelerator.wait_for_everyone()
985
- save_weights(global_step)
986
- # put the latest checkpoint to output-dir
987
- save_path = args.output_dir
988
- unet_model = accelerator.unwrap_model(unet, keep_fp32_wrapper=True)
989
- state_dict = {k: v for k, v in unet_model.state_dict().items() if "delta" in k}
990
- save_file(state_dict, os.path.join(save_path, "spectral_shifts.safetensors"))
991
- print(f"[*] Weights saved at {save_path}")
992
-
993
  if accelerator.is_main_process:
994
  if args.push_to_hub:
995
  save_model_card(
 
7
  from pathlib import Path
8
  from typing import Optional
9
  from packaging import version
10
+ import itertools
11
 
12
  import numpy as np
13
  import torch
 
23
  from torch.utils.data import Dataset
24
  from torchvision import transforms
25
  from tqdm.auto import tqdm
26
+ from transformers import CLIPTextModel, AutoTokenizer, PretrainedConfig
27
 
28
  import diffusers
29
  from diffusers import __version__
 
34
  StableDiffusionPipeline,
35
  DPMSolverMultistepScheduler,
36
  )
37
+ from svdiff_pytorch import load_unet_for_svdiff, load_text_encoder_for_svdiff, SCHEDULER_MAPPING
38
  from diffusers.loaders import AttnProcsLayers
39
  from diffusers.optimization import get_scheduler
40
  from diffusers.utils import check_min_version, is_wandb_available
 
79
  f.write(yaml + model_card)
80
 
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  def parse_args(input_args=None):
83
  parser = argparse.ArgumentParser(description="Simple example of a training script.")
84
  parser.add_argument(
 
252
  parser.add_argument(
253
  "--learning_rate",
254
  type=float,
255
+ default=1e-3,
256
  help="Initial learning rate (after the potential warmup period) to use.",
257
  )
258
+ parser.add_argument(
259
+ "--learning_rate_1d",
260
+ type=float,
261
+ default=1e-6,
262
+ help="Initial learning rate (after the potential warmup period) to use for 1-d weights",
263
+ )
264
  parser.add_argument(
265
  "--scale_lr",
266
  action="store_true",
 
367
  parser.add_argument(
368
  "--enable_token_merging", action="store_true", help="Whether or not to use tomesd on prior generation"
369
  )
370
+ parser.add_argument(
371
+ "--train_text_encoder",
372
+ action="store_true",
373
+ help="Whether to train spectral shifts of the text encoder. If set, the text encoder should be float32 precision.",
374
+ )
375
  if input_args is not None:
376
  args = parser.parse_args(input_args)
377
  else:
 
586
  # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
587
  # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
588
  # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
589
+ if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
590
+ raise ValueError(
591
+ "Gradient accumulation is not supported when training the text encoder in distributed training. "
592
+ "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
593
+ )
594
  # Make one log on every process with the configuration for debugging.
595
  logging.basicConfig(
596
  format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
 
697
  use_fast=False,
698
  )
699
 
 
 
 
700
  # Load scheduler and models
701
  noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
702
+ if args.train_text_encoder:
703
+ text_encoder = load_text_encoder_for_svdiff(args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision)
704
+ else:
705
+ text_encoder = CLIPTextModel.from_pretrained(
706
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
707
+ )
708
  vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
709
  unet = load_unet_for_svdiff(args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=True)
710
 
 
713
  text_encoder.requires_grad_(False)
714
  unet.requires_grad_(False)
715
  optim_params = []
716
+ optim_params_1d = []
717
  for n, p in unet.named_parameters():
718
  if "delta" in n:
719
  p.requires_grad = True
720
+ if "norm" in n:
721
+ optim_params_1d.append(p)
722
+ else:
723
+ optim_params.append(p)
724
+ if args.train_text_encoder:
725
+ for n, p in text_encoder.named_parameters():
726
+ if "delta" in n:
727
+ p.requires_grad = True
728
+ if "norm" in n:
729
+ optim_params_1d.append(p)
730
+ else:
731
+ optim_params.append(p)
732
+
733
  total_params = sum(p.numel() for p in optim_params)
734
  print(f"Number of Trainable Parameters: {total_params * 1.e-6:.2f} M")
735
 
 
 
 
 
 
 
 
 
 
 
 
 
 
736
  if args.enable_xformers_memory_efficient_attention:
737
  if is_xformers_available():
738
  import xformers
 
748
 
749
  if args.gradient_checkpointing:
750
  unet.enable_gradient_checkpointing()
751
+ if args.train_text_encoder:
752
+ text_encoder.gradient_checkpointing_enable()
753
 
754
+ # Check that all trainable models are in full precision
755
+ low_precision_error_string = (
756
+ "Please make sure to always have all model weights in full float32 precision when starting training - even if"
757
+ " doing mixed precision training. copy of the weights should still be float32."
758
+ )
759
+
760
+ if accelerator.unwrap_model(unet).dtype != torch.float32:
761
+ raise ValueError(
762
+ f"Unet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
763
  )
764
 
765
+ if args.train_text_encoder and accelerator.unwrap_model(text_encoder).dtype != torch.float32:
766
+ raise ValueError(
767
+ f"Text encoder loaded as datatype {accelerator.unwrap_model(text_encoder).dtype}."
768
+ f" {low_precision_error_string}"
769
+ )
770
+
771
  # Enable TF32 for faster training on Ampere GPUs,
772
  # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
773
  if args.allow_tf32:
 
793
 
794
  # Optimizer creation
795
  optimizer = optimizer_class(
796
+ [{"params": optim_params}, {"params": optim_params_1d, "lr": args.learning_rate_1d}],
797
  lr=args.learning_rate,
798
  betas=(args.adam_beta1, args.adam_beta2),
799
  weight_decay=args.adam_weight_decay,
 
837
  )
838
 
839
  # Prepare everything with our `accelerator`.
840
+ if args.train_text_encoder:
841
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
842
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler
843
+ )
844
+ else:
845
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
846
+ unet, optimizer, train_dataloader, lr_scheduler
847
+ )
848
+
849
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
850
+ # as these models are only used for inference, keeping weights in full precision is not required.
851
+ weight_dtype = torch.float32
852
+ if accelerator.mixed_precision == "fp16":
853
+ weight_dtype = torch.float16
854
+ elif accelerator.mixed_precision == "bf16":
855
+ weight_dtype = torch.bfloat16
856
+
857
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
858
+ # unet.to(accelerator.device, dtype=weight_dtype)
859
+ vae.to(accelerator.device, dtype=weight_dtype)
860
+ if not args.train_text_encoder:
861
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
862
+
863
 
864
  # We need to recalculate our total training steps as the size of the training dataloader may have changed.
865
  num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
 
873
  if accelerator.is_main_process:
874
  accelerator.init_trackers("svdiff-pytorch", config=vars(args))
875
 
876
+ # cache keys to save
877
+ state_dict_keys = [k for k in accelerator.unwrap_model(unet).state_dict().keys() if "delta" in k]
878
+ if args.train_text_encoder:
879
+ state_dict_keys_te = [k for k in accelerator.unwrap_model(text_encoder).state_dict().keys() if "delta" in k]
880
+
881
+ def save_weights(step, save_path=None):
882
  # Create the pipeline using using the trained modules and save it.
883
  if accelerator.is_main_process:
884
+ if save_path is None:
885
+ save_path = os.path.join(args.output_dir, f"checkpoint-{step}")
886
  os.makedirs(save_path, exist_ok=True)
887
+ state_dict = accelerator.unwrap_model(unet, keep_fp32_wrapper=True).state_dict()
888
+ # state_dict = {k: v for k, v in unet_model.state_dict().items() if "delta" in k}
889
+ state_dict = {k: state_dict[k] for k in state_dict_keys}
890
  save_file(state_dict, os.path.join(save_path, "spectral_shifts.safetensors"))
891
+ if args.train_text_encoder:
892
+ state_dict = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True).state_dict()
893
+ # state_dict = {k: v for k, v in unet_model.state_dict().items() if "delta" in k}
894
+ state_dict = {k: state_dict[k] for k in state_dict_keys_te}
895
+ save_file(state_dict, os.path.join(save_path, "spectral_shifts_te.safetensors"))
896
+
897
  print(f"[*] Weights saved at {save_path}")
898
 
899
  # Train!
 
941
 
942
  for epoch in range(first_epoch, args.num_train_epochs):
943
  unet.train()
944
+ if args.train_text_encoder:
945
+ text_encoder.train()
946
  for step, batch in enumerate(train_dataloader):
947
  # Skip steps until we reach the resumed step
948
  if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
 
998
 
999
  accelerator.backward(loss)
1000
  if accelerator.sync_gradients:
1001
+ params_to_clip = (
1002
+ itertools.chain(unet.parameters(), text_encoder.parameters())
1003
+ if args.train_text_encoder
1004
+ else unet.parameters()
1005
+ )
1006
  accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1007
  optimizer.step()
1008
  lr_scheduler.step()
 
1020
  # accelerator.save_state(save_path)
1021
  # logger.info(f"Saved state to {save_path}")
1022
 
1023
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "lr_1d": lr_scheduler.get_last_lr()[1]}
1024
  progress_bar.set_postfix(**logs)
1025
  accelerator.log(logs, step=global_step)
1026
 
 
1032
  log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch)
1033
 
1034
  accelerator.wait_for_everyone()
1035
+ # put the last checkpoint to output-dir
1036
+ save_weights(global_step, save_path=args.output_dir)
 
 
 
 
 
 
1037
  if accelerator.is_main_process:
1038
  if args.push_to_hub:
1039
  save_model_card(
trainer.py CHANGED
@@ -124,6 +124,8 @@ class Trainer:
124
  --train_batch_size=1 \
125
  --gradient_accumulation_steps={gradient_accumulation} \
126
  --learning_rate={learning_rate} \
 
 
127
  --lr_scheduler=constant \
128
  --lr_warmup_steps=0 \
129
  --max_train_steps={n_steps} \
 
124
  --train_batch_size=1 \
125
  --gradient_accumulation_steps={gradient_accumulation} \
126
  --learning_rate={learning_rate} \
127
+ --learning_rate_1d=1e-6 \
128
+ --train_text_encoder \
129
  --lr_scheduler=constant \
130
  --lr_warmup_steps=0 \
131
  --max_train_steps={n_steps} \