multimodalart HF staff commited on
Commit
d1c3953
1 Parent(s): 7d4fd85

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -34
app.py CHANGED
@@ -69,12 +69,13 @@ def train(*inputs):
69
  file_counter += 1
70
 
71
  uses_custom = inputs[-1]
 
72
  if(uses_custom):
73
  Training_Steps = int(inputs[-3])
74
  Train_text_encoder_for = int(inputs[-2])
75
  else:
76
  Training_Steps = file_counter*200
77
- if(inputs[-4] == "person"):
78
  class_data_dir = "mix"
79
  Train_text_encoder_for=100
80
  args_txt_encoder = argparse.Namespace(
@@ -124,41 +125,41 @@ def train(*inputs):
124
  )
125
  run_training(args_txt_encoder)
126
  run_training(args_unet)
127
- elif(inputs[-4] == "object"):
128
- Train_text_encoder_for=30
 
 
 
129
  class_data_dir = None
130
- elif(inputs[-4] == "style"):
131
- Train_text_encoder_for=15
132
- class_data_dir = None
133
-
134
- stptxt = int((Training_Steps*Train_text_encoder_for)/100)
135
- args_general = argparse.Namespace(
136
- image_captions_filename = True,
137
- train_text_encoder = True,
138
- stop_text_encoder_training = stptxt,
139
- save_n_steps = 0,
140
- dump_only_text_encoder = True,
141
- pretrained_model_name_or_path = model_to_load,
142
- instance_data_dir="instance_images",
143
- class_data_dir=class_data_dir,
144
- output_dir="output_model",
145
- instance_prompt="",
146
- seed=42,
147
- resolution=512,
148
- mixed_precision="fp16",
149
- train_batch_size=1,
150
- gradient_accumulation_steps=1,
151
- use_8bit_adam=True,
152
- learning_rate=2e-6,
153
- lr_scheduler="polynomial",
154
- lr_warmup_steps = 0,
155
- max_train_steps=Training_Steps,
156
- )
157
-
158
- run_training(args_general)
159
- os.rmdir('instance_images')
160
- shutil.make_archive("output_model.zip", 'zip', "output_model")
161
  return gr.update(visible=True, value="output_model.zip")
 
162
  with gr.Blocks(css=css) as demo:
163
  with gr.Box():
164
  # You can remove this part here for your local clone
 
69
  file_counter += 1
70
 
71
  uses_custom = inputs[-1]
72
+ type_of_thing = inputs[-4]
73
  if(uses_custom):
74
  Training_Steps = int(inputs[-3])
75
  Train_text_encoder_for = int(inputs[-2])
76
  else:
77
  Training_Steps = file_counter*200
78
+ if(type_of_thing == "person"):
79
  class_data_dir = "mix"
80
  Train_text_encoder_for=100
81
  args_txt_encoder = argparse.Namespace(
 
125
  )
126
  run_training(args_txt_encoder)
127
  run_training(args_unet)
128
+ elif(type_of_thing == "object" or type_of_thing == "style"):
129
+ if(type_of_thing == "object"):
130
+ Train_text_encoder_for=30
131
+ elif(type_of_thing == "style"):
132
+ Train_text_encoder_for=15
133
  class_data_dir = None
134
+ stptxt = int((Training_Steps*Train_text_encoder_for)/100)
135
+ args_general = argparse.Namespace(
136
+ image_captions_filename = True,
137
+ train_text_encoder = True,
138
+ stop_text_encoder_training = stptxt,
139
+ save_n_steps = 0,
140
+ pretrained_model_name_or_path = model_to_load,
141
+ instance_data_dir="instance_images",
142
+ class_data_dir=class_data_dir,
143
+ output_dir="output_model",
144
+ instance_prompt="",
145
+ seed=42,
146
+ resolution=512,
147
+ mixed_precision="fp16",
148
+ train_batch_size=1,
149
+ gradient_accumulation_steps=1,
150
+ use_8bit_adam=True,
151
+ learning_rate=2e-6,
152
+ lr_scheduler="polynomial",
153
+ lr_warmup_steps = 0,
154
+ max_train_steps=Training_Steps,
155
+ )
156
+ run_training(args_general)
157
+
158
+ shutil.rmtree('instance_images')
159
+ shutil.make_archive("output_model", 'zip', "output_model")
160
+ shutil.rmtree("output_model")
 
 
 
 
161
  return gr.update(visible=True, value="output_model.zip")
162
+
163
  with gr.Blocks(css=css) as demo:
164
  with gr.Box():
165
  # You can remove this part here for your local clone