abhishek HF staff commited on
Commit
327a449
·
1 Parent(s): 9f9c9d7

autotrain spacerunner

Browse files
Files changed (3) hide show
  1. app.py +125 -96
  2. requirements.autotrain +21 -0
  3. script.py +0 -0
app.py CHANGED
@@ -1,12 +1,14 @@
1
  import os
2
- is_spaces = True if os.environ.get('SPACE_ID') else False
3
 
4
- if(is_spaces):
 
 
5
  import spaces
6
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
7
  import sys
8
 
9
  from dotenv import load_dotenv
 
10
  load_dotenv()
11
 
12
  # Add the current working directory to the Python path
@@ -22,11 +24,14 @@ import json
22
  import yaml
23
  from slugify import slugify
24
  from transformers import AutoProcessor, AutoModelForCausalLM
25
- if(not is_spaces):
 
 
26
  from toolkit.job import get_job
27
 
28
  MAX_IMAGES = 150
29
 
 
30
  def load_captioning(uploaded_images, concept_sentence):
31
  updates = []
32
  if len(uploaded_images) <= 1:
@@ -34,11 +39,9 @@ def load_captioning(uploaded_images, concept_sentence):
34
  "Please upload at least 2 images to train your model (the ideal number with default settings is between 4-30)"
35
  )
36
  elif len(uploaded_images) > MAX_IMAGES:
37
- raise gr.Error(
38
- f"For now, only {MAX_IMAGES} or less images are allowed for training"
39
- )
40
  # Update for the captioning_area
41
- #for _ in range(3):
42
  updates.append(gr.update(visible=True))
43
  # Update visibility and image for each captioning row and image
44
  for i in range(1, MAX_IMAGES + 1):
@@ -50,23 +53,25 @@ def load_captioning(uploaded_images, concept_sentence):
50
 
51
  # Update for image component - display image if available, otherwise hide
52
  image_value = uploaded_images[i - 1] if visible else None
53
-
54
  updates.append(gr.update(value=image_value, visible=visible))
55
 
56
- #Update value of captioning area
57
  text_value = "[trigger]" if visible and concept_sentence else None
58
  updates.append(gr.update(value=text_value, visible=visible))
59
 
60
- #Update for the sample caption area
61
  updates.append(gr.update(visible=True))
62
  updates.append(gr.update(placeholder=f'A photo of {concept_sentence} holding a sign that reads "Hello friend"'))
63
- updates.append(gr.update(placeholder=f'A mountainous landscape in the style of {concept_sentence}'))
64
- updates.append(gr.update(placeholder=f'A {concept_sentence} in a mall'))
65
  return updates
66
 
67
- if(is_spaces):
 
68
  load_captioning = spaces.GPU()(load_captioning)
69
 
 
70
  def create_dataset(*inputs):
71
  print("Creating dataset")
72
  images = inputs[0]
@@ -74,56 +79,60 @@ def create_dataset(*inputs):
74
  if not os.path.exists(destination_folder):
75
  os.makedirs(destination_folder)
76
 
77
- jsonl_file_path = os.path.join(destination_folder, 'metadata.jsonl')
78
- with open(jsonl_file_path, 'a') as jsonl_file:
79
  for index, image in enumerate(images):
80
  new_image_path = shutil.copy(image, destination_folder)
81
-
82
  original_caption = inputs[index + 1]
83
  file_name = os.path.basename(new_image_path)
84
 
85
  data = {"file_name": file_name, "prompt": original_caption}
86
 
87
  jsonl_file.write(json.dumps(data) + "\n")
88
-
89
  return destination_folder
90
 
 
91
  def run_captioning(images, concept_sentence, *captions):
92
  device = "cuda" if torch.cuda.is_available() else "cpu"
93
  torch_dtype = torch.float16
94
- model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large", torch_dtype=torch_dtype, trust_remote_code=True).to(device)
 
 
95
  processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)
96
 
97
  captions = list(captions)
98
  for i, image_path in enumerate(images):
99
  print(captions[i])
100
  if isinstance(image_path, str): # If image is a file path
101
- image = Image.open(image_path).convert('RGB')
102
-
103
  prompt = "<DETAILED_CAPTION>"
104
  inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)
105
-
106
  generated_ids = model.generate(
107
- input_ids=inputs["input_ids"],
108
- pixel_values=inputs["pixel_values"],
109
- max_new_tokens=1024,
110
- num_beams=3
111
  )
112
-
113
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
114
- parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height))
115
- caption_text = parsed_answer['<DETAILED_CAPTION>'].replace("The image shows ", "")
116
- if(concept_sentence):
 
 
117
  caption_text = f"{caption_text} [trigger]"
118
  captions[i] = caption_text
119
-
120
-
121
  yield captions
122
  model.to("cpu")
123
  del model
124
  del processor
125
-
 
126
  def start_training(
 
 
127
  lora_name,
128
  concept_sentence,
129
  steps,
@@ -144,57 +153,76 @@ def start_training(
144
  config = yaml.safe_load(f)
145
 
146
  # Update the config with user inputs
147
- config['config']['name'] = slugged_lora_name
148
- config['config']['process'][0]['model']['low_vram'] = True
149
- config['config']['process'][0]['train']['skip_first_sample'] = True
150
- config['config']['process'][0]['train']['steps'] = int(steps)
151
- config['config']['process'][0]['train']['lr'] = float(lr)
152
- config['config']['process'][0]['network']['linear'] = int(rank)
153
- config['config']['process'][0]['network']['linear_alpha'] = int(rank)
154
- config['config']['process'][0]['datasets'][0]['folder_path'] = dataset_folder
155
- if(concept_sentence):
156
- config['config']['process'][0]['trigger_word'] = concept_sentence
157
- if(sample_1 or sample_2 or sample_2):
158
- config['config']['process'][0]['train']['disable_sampling'] = False
159
- config['config']['process'][0]['sample']["sample_every"] = steps
160
- config['config']['process'][0]['sample']['prompts'] = []
161
- if(sample_1):
162
- config['config']['process'][0]['sample']['prompts'].append(sample_1)
163
- if(sample_2):
164
- config['config']['process'][0]['sample']['prompts'].append(sample_2)
165
- if(sample_3):
166
- config['config']['process'][0]['sample']['prompts'].append(sample_3)
167
  else:
168
- config['config']['process'][0]['train']['disable_sampling'] = True
169
  # Save the updated config
170
  config_path = f"config/{slugged_lora_name}.yaml"
171
  with open(config_path, "w") as f:
172
  yaml.dump(config, f)
173
- if(is_spaces):
174
  print("Started training with spacerunner...")
175
- pass
176
- #do the spacerunner things here
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  else:
178
- #run the job locally
179
  job = get_job(config_path)
180
  job.run()
181
  job.cleanup()
182
 
183
  return f"Training completed successfully. Model saved as {slugged_lora_name}"
184
 
 
185
  theme = gr.themes.Monochrome(
186
  text_size=gr.themes.Size(lg="18px", md="15px", sm="13px", xl="22px", xs="12px", xxl="24px", xxs="9px"),
187
- font=[gr.themes.GoogleFont('Source Sans Pro'), 'ui-sans-serif', 'system-ui', 'sans-serif'],
188
  )
189
- css = '''
190
  #component-1{text-align:center}
191
  .main_ui_logged_out{opacity: 0.3; pointer-events: none}
192
  .tabitem{border: 0px}
193
- '''
 
194
 
195
  def swap_visibilty(profile: gr.OAuthProfile | None):
196
  print(profile)
197
- if(is_spaces):
198
  if profile is None:
199
  return gr.update(elem_classes=["main_ui_logged_out"])
200
  else:
@@ -202,19 +230,26 @@ def swap_visibilty(profile: gr.OAuthProfile | None):
202
  return gr.update(elem_classes=["main_ui_logged_in"])
203
  else:
204
  return gr.update(elem_classes=["main_ui_logged_in"])
205
-
 
206
  with gr.Blocks(theme=theme, css=css) as demo:
207
- gr.Markdown('''# LoRA Ease for FLUX 🧞‍♂️
208
- ### Train a high quality FLUX LoRA in a breeze ༄ using [Ostris' AI Toolkit](https://github.com/ostris/ai-toolkit) and [AutoTrain Advanced](https://github.com/huggingface/autotrain-advanced)''')
209
- if(is_spaces):
 
 
210
  gr.LoginButton("Sign in with Hugging Face to train your LoRA on Spaces", visible=is_spaces)
211
  with gr.Tab("Train on Spaces" if is_spaces else "Train locally"):
212
  with gr.Column() as main_ui:
213
  with gr.Row():
214
- lora_name = gr.Textbox(label="The name of your LoRA", info="This has to be a unique name", placeholder="e.g.: Persian Miniature Painting style, Cat Toy")
215
- #training_option = gr.Radio(
 
 
 
 
216
  # label="What are you training?", choices=["object", "style", "character", "face", "custom"]
217
- #)
218
  concept_sentence = gr.Textbox(
219
  label="Trigger word/sentence",
220
  info="Trigger word or sentence to be used",
@@ -233,9 +268,11 @@ with gr.Blocks(theme=theme, css=css) as demo:
233
  )
234
  with gr.Column(scale=3, visible=False) as captioning_area:
235
  with gr.Column():
236
- gr.Markdown("""# Custom captioning
 
237
  You can optionally add a custom caption for each image (or use an AI model for this). [trigger] will represent your concept sentence/trigger word.
238
- """)
 
239
  do_captioning = gr.Button("Add AI captions with Florence-2")
240
  output_components = [captioning_area]
241
  caption_list = []
@@ -251,28 +288,30 @@ with gr.Blocks(theme=theme, css=css) as demo:
251
  scale=2,
252
  show_label=False,
253
  show_share_button=False,
254
- show_download_button=False
255
  )
256
  locals()[f"caption_{i}"] = gr.Textbox(
257
  label=f"Caption {i}", scale=15, interactive=True
258
  )
259
-
260
  output_components.append(locals()[f"captioning_row_{i}"])
261
  output_components.append(locals()[f"image_{i}"])
262
  output_components.append(locals()[f"caption_{i}"])
263
  caption_list.append(locals()[f"caption_{i}"])
264
-
265
  with gr.Accordion("Advanced options", open=False):
266
  steps = gr.Number(label="Steps", value=1000, minimum=1, maximum=10000, step=1)
267
  lr = gr.Number(label="Learning Rate", value=4e-4, minimum=1e-6, maximum=1e-3, step=1e-6)
268
  rank = gr.Number(label="LoRA Rank", value=16, minimum=4, maximum=128, step=4)
269
-
270
  with gr.Accordion("Sample prompts", visible=False) as sample:
271
- gr.Markdown("Include sample prompts to test out your trained model. Don't forget to include your trigger word/sentence (optional)")
 
 
272
  sample_1 = gr.Textbox(label="Test prompt 1")
273
  sample_2 = gr.Textbox(label="Test prompt 2")
274
  sample_3 = gr.Textbox(label="Test prompt 3")
275
-
276
  output_components.append(sample)
277
  output_components.append(sample_1)
278
  output_components.append(sample_2)
@@ -281,7 +320,8 @@ with gr.Blocks(theme=theme, css=css) as demo:
281
  progress_area = gr.Markdown("")
282
 
283
  with gr.Tab("Train locally" if is_spaces else "Instructions"):
284
- gr.Markdown(f'''To use FLUX LoRA Ease locally with this UI, you can clone this repository (yes, HF Spaces are git repos!)
 
285
  ```bash
286
  git clone https://huggingface.co/spaces/flux-train/flux-lora-trainer
287
  cd flux-lora-trainer
@@ -312,23 +352,14 @@ with gr.Blocks(theme=theme, css=css) as demo:
312
  python app.py
313
  ```
314
  If you prefer command line, you can run Ostris' [AI Toolkit](https://github.com/ostris/ai-toolkit) yourself directly.
315
- ''')
316
-
 
317
  dataset_folder = gr.State()
318
 
319
- images.upload(
320
- load_captioning,
321
- inputs=[images, concept_sentence],
322
- outputs=output_components,
323
- queue=False
324
- )
325
 
326
- start.click(
327
- fn=create_dataset,
328
- inputs=[images] + caption_list,
329
- outputs=dataset_folder,
330
- queue=False
331
- ).then(
332
  fn=start_training,
333
  inputs=[
334
  lora_name,
@@ -342,14 +373,12 @@ with gr.Blocks(theme=theme, css=css) as demo:
342
  sample_3,
343
  ],
344
  outputs=progress_area,
345
- queue=False
346
  )
347
 
348
- do_captioning.click(
349
- fn=run_captioning, inputs=[images, concept_sentence] + caption_list, outputs=caption_list
350
- )
351
  demo.load(fn=swap_visibilty, outputs=main_ui, queue=False)
352
 
353
  if __name__ == "__main__":
354
  demo.queue()
355
- demo.launch(share=True)
 
1
  import os
 
2
 
3
+ is_spaces = True if os.environ.get("SPACE_ID") else False
4
+
5
+ if is_spaces:
6
  import spaces
7
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
8
  import sys
9
 
10
  from dotenv import load_dotenv
11
+
12
  load_dotenv()
13
 
14
  # Add the current working directory to the Python path
 
24
  import yaml
25
  from slugify import slugify
26
  from transformers import AutoProcessor, AutoModelForCausalLM
27
+ import subprocess
28
+
29
+ if not is_spaces:
30
  from toolkit.job import get_job
31
 
32
  MAX_IMAGES = 150
33
 
34
+
35
  def load_captioning(uploaded_images, concept_sentence):
36
  updates = []
37
  if len(uploaded_images) <= 1:
 
39
  "Please upload at least 2 images to train your model (the ideal number with default settings is between 4-30)"
40
  )
41
  elif len(uploaded_images) > MAX_IMAGES:
42
+ raise gr.Error(f"For now, only {MAX_IMAGES} or less images are allowed for training")
 
 
43
  # Update for the captioning_area
44
+ # for _ in range(3):
45
  updates.append(gr.update(visible=True))
46
  # Update visibility and image for each captioning row and image
47
  for i in range(1, MAX_IMAGES + 1):
 
53
 
54
  # Update for image component - display image if available, otherwise hide
55
  image_value = uploaded_images[i - 1] if visible else None
56
+
57
  updates.append(gr.update(value=image_value, visible=visible))
58
 
59
+ # Update value of captioning area
60
  text_value = "[trigger]" if visible and concept_sentence else None
61
  updates.append(gr.update(value=text_value, visible=visible))
62
 
63
+ # Update for the sample caption area
64
  updates.append(gr.update(visible=True))
65
  updates.append(gr.update(placeholder=f'A photo of {concept_sentence} holding a sign that reads "Hello friend"'))
66
+ updates.append(gr.update(placeholder=f"A mountainous landscape in the style of {concept_sentence}"))
67
+ updates.append(gr.update(placeholder=f"A {concept_sentence} in a mall"))
68
  return updates
69
 
70
+
71
+ if is_spaces:
72
  load_captioning = spaces.GPU()(load_captioning)
73
 
74
+
75
  def create_dataset(*inputs):
76
  print("Creating dataset")
77
  images = inputs[0]
 
79
  if not os.path.exists(destination_folder):
80
  os.makedirs(destination_folder)
81
 
82
+ jsonl_file_path = os.path.join(destination_folder, "metadata.jsonl")
83
+ with open(jsonl_file_path, "a") as jsonl_file:
84
  for index, image in enumerate(images):
85
  new_image_path = shutil.copy(image, destination_folder)
86
+
87
  original_caption = inputs[index + 1]
88
  file_name = os.path.basename(new_image_path)
89
 
90
  data = {"file_name": file_name, "prompt": original_caption}
91
 
92
  jsonl_file.write(json.dumps(data) + "\n")
93
+
94
  return destination_folder
95
 
96
+
97
  def run_captioning(images, concept_sentence, *captions):
98
  device = "cuda" if torch.cuda.is_available() else "cpu"
99
  torch_dtype = torch.float16
100
+ model = AutoModelForCausalLM.from_pretrained(
101
+ "microsoft/Florence-2-large", torch_dtype=torch_dtype, trust_remote_code=True
102
+ ).to(device)
103
  processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)
104
 
105
  captions = list(captions)
106
  for i, image_path in enumerate(images):
107
  print(captions[i])
108
  if isinstance(image_path, str): # If image is a file path
109
+ image = Image.open(image_path).convert("RGB")
110
+
111
  prompt = "<DETAILED_CAPTION>"
112
  inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)
113
+
114
  generated_ids = model.generate(
115
+ input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, num_beams=3
 
 
 
116
  )
117
+
118
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
119
+ parsed_answer = processor.post_process_generation(
120
+ generated_text, task=prompt, image_size=(image.width, image.height)
121
+ )
122
+ caption_text = parsed_answer["<DETAILED_CAPTION>"].replace("The image shows ", "")
123
+ if concept_sentence:
124
  caption_text = f"{caption_text} [trigger]"
125
  captions[i] = caption_text
126
+
 
127
  yield captions
128
  model.to("cpu")
129
  del model
130
  del processor
131
+
132
+
133
  def start_training(
134
+ profile: gr.OAuthProfile | None,
135
+ oauth_token: gr.OAuthToken | None,
136
  lora_name,
137
  concept_sentence,
138
  steps,
 
153
  config = yaml.safe_load(f)
154
 
155
  # Update the config with user inputs
156
+ config["config"]["name"] = slugged_lora_name
157
+ config["config"]["process"][0]["model"]["low_vram"] = True
158
+ config["config"]["process"][0]["train"]["skip_first_sample"] = True
159
+ config["config"]["process"][0]["train"]["steps"] = int(steps)
160
+ config["config"]["process"][0]["train"]["lr"] = float(lr)
161
+ config["config"]["process"][0]["network"]["linear"] = int(rank)
162
+ config["config"]["process"][0]["network"]["linear_alpha"] = int(rank)
163
+ config["config"]["process"][0]["datasets"][0]["folder_path"] = dataset_folder
164
+ if concept_sentence:
165
+ config["config"]["process"][0]["trigger_word"] = concept_sentence
166
+ if sample_1 or sample_2 or sample_2:
167
+ config["config"]["process"][0]["train"]["disable_sampling"] = False
168
+ config["config"]["process"][0]["sample"]["sample_every"] = steps
169
+ config["config"]["process"][0]["sample"]["prompts"] = []
170
+ if sample_1:
171
+ config["config"]["process"][0]["sample"]["prompts"].append(sample_1)
172
+ if sample_2:
173
+ config["config"]["process"][0]["sample"]["prompts"].append(sample_2)
174
+ if sample_3:
175
+ config["config"]["process"][0]["sample"]["prompts"].append(sample_3)
176
  else:
177
+ config["config"]["process"][0]["train"]["disable_sampling"] = True
178
  # Save the updated config
179
  config_path = f"config/{slugged_lora_name}.yaml"
180
  with open(config_path, "w") as f:
181
  yaml.dump(config, f)
182
+ if is_spaces:
183
  print("Started training with spacerunner...")
184
+ # copy config to dataset_folder
185
+ shutil.copy(config_path, dataset_folder)
186
+ # get location of this script
187
+ script_location = os.path.dirname(os.path.abspath(__file__))
188
+ # copy script.py from current directory to dataset_folder
189
+ shutil.copy(script_location + "/script.py", dataset_folder)
190
+ # copy requirements.autotrain to dataset_folder as requirements.txt
191
+ shutil.copy(script_location + "/requirements.autotrain", dataset_folder + "/requirements.txt")
192
+ # command to run autotrain spacerunner
193
+ cmd = f"autotrain spacerunner --project-name {slugged_lora_name} --script-path {dataset_folder}"
194
+ cmd += f" --username {profile.name} --token {oauth_token} --backend spaces-l4x1"
195
+ outcome = subprocess.run(cmd)
196
+ if outcome.returncode == 0:
197
+ return f"""# Your training has started.
198
+ ## - Training Status: <a href='https://huggingface.co/spaces/{profile.name}/autotrain-{slugged_lora_name}?logs=container'>{profile.name}/autotrain-{slugged_lora_name}</a> <small>(in the logs tab)</small>
199
+ ## - Model page: <a href='https://huggingface.co/{profile.name}/{slugged_lora_name}'>{profile.name}/{slugged_lora_name}</a> <small>(will be available when training finishes)</small>"""
200
+ else:
201
+ print("Error: ", outcome.stderr)
202
+ raise gr.Error("Something went wrong. Make sure the name of your LoRA is unique and try again")
203
  else:
204
+ # run the job locally
205
  job = get_job(config_path)
206
  job.run()
207
  job.cleanup()
208
 
209
  return f"Training completed successfully. Model saved as {slugged_lora_name}"
210
 
211
+
212
  theme = gr.themes.Monochrome(
213
  text_size=gr.themes.Size(lg="18px", md="15px", sm="13px", xl="22px", xs="12px", xxl="24px", xxs="9px"),
214
+ font=[gr.themes.GoogleFont("Source Sans Pro"), "ui-sans-serif", "system-ui", "sans-serif"],
215
  )
216
+ css = """
217
  #component-1{text-align:center}
218
  .main_ui_logged_out{opacity: 0.3; pointer-events: none}
219
  .tabitem{border: 0px}
220
+ """
221
+
222
 
223
  def swap_visibilty(profile: gr.OAuthProfile | None):
224
  print(profile)
225
+ if is_spaces:
226
  if profile is None:
227
  return gr.update(elem_classes=["main_ui_logged_out"])
228
  else:
 
230
  return gr.update(elem_classes=["main_ui_logged_in"])
231
  else:
232
  return gr.update(elem_classes=["main_ui_logged_in"])
233
+
234
+
235
  with gr.Blocks(theme=theme, css=css) as demo:
236
+ gr.Markdown(
237
+ """# LoRA Ease for FLUX 🧞‍♂️
238
+ ### Train a high quality FLUX LoRA in a breeze ༄ using [Ostris' AI Toolkit](https://github.com/ostris/ai-toolkit) and [AutoTrain Advanced](https://github.com/huggingface/autotrain-advanced)"""
239
+ )
240
+ if is_spaces:
241
  gr.LoginButton("Sign in with Hugging Face to train your LoRA on Spaces", visible=is_spaces)
242
  with gr.Tab("Train on Spaces" if is_spaces else "Train locally"):
243
  with gr.Column() as main_ui:
244
  with gr.Row():
245
+ lora_name = gr.Textbox(
246
+ label="The name of your LoRA",
247
+ info="This has to be a unique name",
248
+ placeholder="e.g.: Persian Miniature Painting style, Cat Toy",
249
+ )
250
+ # training_option = gr.Radio(
251
  # label="What are you training?", choices=["object", "style", "character", "face", "custom"]
252
+ # )
253
  concept_sentence = gr.Textbox(
254
  label="Trigger word/sentence",
255
  info="Trigger word or sentence to be used",
 
268
  )
269
  with gr.Column(scale=3, visible=False) as captioning_area:
270
  with gr.Column():
271
+ gr.Markdown(
272
+ """# Custom captioning
273
  You can optionally add a custom caption for each image (or use an AI model for this). [trigger] will represent your concept sentence/trigger word.
274
+ """
275
+ )
276
  do_captioning = gr.Button("Add AI captions with Florence-2")
277
  output_components = [captioning_area]
278
  caption_list = []
 
288
  scale=2,
289
  show_label=False,
290
  show_share_button=False,
291
+ show_download_button=False,
292
  )
293
  locals()[f"caption_{i}"] = gr.Textbox(
294
  label=f"Caption {i}", scale=15, interactive=True
295
  )
296
+
297
  output_components.append(locals()[f"captioning_row_{i}"])
298
  output_components.append(locals()[f"image_{i}"])
299
  output_components.append(locals()[f"caption_{i}"])
300
  caption_list.append(locals()[f"caption_{i}"])
301
+
302
  with gr.Accordion("Advanced options", open=False):
303
  steps = gr.Number(label="Steps", value=1000, minimum=1, maximum=10000, step=1)
304
  lr = gr.Number(label="Learning Rate", value=4e-4, minimum=1e-6, maximum=1e-3, step=1e-6)
305
  rank = gr.Number(label="LoRA Rank", value=16, minimum=4, maximum=128, step=4)
306
+
307
  with gr.Accordion("Sample prompts", visible=False) as sample:
308
+ gr.Markdown(
309
+ "Include sample prompts to test out your trained model. Don't forget to include your trigger word/sentence (optional)"
310
+ )
311
  sample_1 = gr.Textbox(label="Test prompt 1")
312
  sample_2 = gr.Textbox(label="Test prompt 2")
313
  sample_3 = gr.Textbox(label="Test prompt 3")
314
+
315
  output_components.append(sample)
316
  output_components.append(sample_1)
317
  output_components.append(sample_2)
 
320
  progress_area = gr.Markdown("")
321
 
322
  with gr.Tab("Train locally" if is_spaces else "Instructions"):
323
+ gr.Markdown(
324
+ f"""To use FLUX LoRA Ease locally with this UI, you can clone this repository (yes, HF Spaces are git repos!)
325
  ```bash
326
  git clone https://huggingface.co/spaces/flux-train/flux-lora-trainer
327
  cd flux-lora-trainer
 
352
  python app.py
353
  ```
354
  If you prefer command line, you can run Ostris' [AI Toolkit](https://github.com/ostris/ai-toolkit) yourself directly.
355
+ """
356
+ )
357
+
358
  dataset_folder = gr.State()
359
 
360
+ images.upload(load_captioning, inputs=[images, concept_sentence], outputs=output_components, queue=False)
 
 
 
 
 
361
 
362
+ start.click(fn=create_dataset, inputs=[images] + caption_list, outputs=dataset_folder, queue=False).then(
 
 
 
 
 
363
  fn=start_training,
364
  inputs=[
365
  lora_name,
 
373
  sample_3,
374
  ],
375
  outputs=progress_area,
376
+ queue=False,
377
  )
378
 
379
+ do_captioning.click(fn=run_captioning, inputs=[images, concept_sentence] + caption_list, outputs=caption_list)
 
 
380
  demo.load(fn=swap_visibilty, outputs=main_ui, queue=False)
381
 
382
  if __name__ == "__main__":
383
  demo.queue()
384
+ demo.launch(share=True)
requirements.autotrain ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ git+https://github.com/huggingface/diffusers.git
2
+ lycoris-lora==1.8.3
3
+ flatten_json
4
+ pyyaml
5
+ oyaml
6
+ tensorboard
7
+ kornia
8
+ invisible-watermark
9
+ einops
10
+ toml
11
+ albumentations
12
+ pydantic
13
+ omegaconf
14
+ k-diffusion
15
+ open_clip_torch
16
+ prodigyopt
17
+ controlnet_aux==0.0.7
18
+ python-dotenv
19
+ lpips
20
+ pytorch_fid
21
+ optimum-quanto
script.py ADDED
File without changes