multimodalart HF staff commited on
Commit
2e8ed6c
1 Parent(s): ac07ed8
Files changed (1) hide show
  1. app.py +28 -4
app.py CHANGED
@@ -86,17 +86,41 @@ def train(*inputs):
86
  lr_warmup_steps=0,
87
  max_train_steps=Training_Steps,
88
  num_class_images=200
89
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  elif(inputs[-4] == "object"):
91
  class_data_dir = None
92
  elif(inputs[-4] == "style"):
93
  class_data_dir = None
94
 
95
- args = argparse.Namespace(
96
  image_captions_filename = True,
97
  train_text_encoder = True,
98
  stop_text_encoder_training = stptxt,
99
- save_n_steps = 0
100
  dump_only_text_encoder = True,
101
  pretrained_model_name_or_path = "./stable-diffusion-v1-5",
102
  instance_data_dir="instance_images",
@@ -114,7 +138,7 @@ def train(*inputs):
114
  lr_warmup_steps = 0,
115
  max_train_steps=Training_Steps,
116
  )
117
- run_training(args)
118
  os.rmdir('instance_images')
119
  with gr.Blocks(css=css) as demo:
120
  with gr.Box():
 
86
  lr_warmup_steps=0,
87
  max_train_steps=Training_Steps,
88
  num_class_images=200
89
+ )
90
+ args_unet = argparse.Namespace(
91
+ image_captions_filename = True,
92
+ train_only_unet=True,
93
+ Session_dir="output_model",
94
+ save_starting_step=0,
95
+ save_n_steps=0,
96
+ pretrained_model_name_or_path="./stable-diffusion-v1-5",
97
+ instance_data_dir="instance_images",
98
+ output_dir="output_model",
99
+ instance_prompt="",
100
+ seed=42,
101
+ resolution=512,
102
+ mixed_precision="fp16",
103
+ train_batch_size=1,
104
+ gradient_accumulation_steps=1,
105
+ gradient_checkpointing=False,
106
+ use_8bit_adam=True,
107
+ learning_rate=2e-6,
108
+ lr_scheduler="polynomial",
109
+ lr_warmup_steps=0,
110
+ max_train_steps=Training_Steps
111
+ )
112
+ run_training(args_txt_encoder)
113
+ run_training(args_unet)
114
  elif(inputs[-4] == "object"):
115
  class_data_dir = None
116
  elif(inputs[-4] == "style"):
117
  class_data_dir = None
118
 
119
+ args_general = argparse.Namespace(
120
  image_captions_filename = True,
121
  train_text_encoder = True,
122
  stop_text_encoder_training = stptxt,
123
+ save_n_steps = 0,
124
  dump_only_text_encoder = True,
125
  pretrained_model_name_or_path = "./stable-diffusion-v1-5",
126
  instance_data_dir="instance_images",
 
138
  lr_warmup_steps = 0,
139
  max_train_steps=Training_Steps,
140
  )
141
+ run_training(args_general)
142
  os.rmdir('instance_images')
143
  with gr.Blocks(css=css) as demo:
144
  with gr.Box():