fffiloni commited on
Commit
ff441d1
1 Parent(s): 0fcf13e

add dataset id to model card

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -56,7 +56,7 @@ def set_accelerate_default_config():
56
  except subprocess.CalledProcessError as e:
57
  print(f"An error occurred: {e}")
58
 
59
- def train_dreambooth_lora_sdxl(instance_data_dir, lora_trained_xl_folder, instance_prompt, max_train_steps, checkpoint_steps, remove_gpu):
60
 
61
  script_filename = "train_dreambooth_lora_sdxl.py" # Assuming it's in the same folder
62
 
@@ -66,6 +66,7 @@ def train_dreambooth_lora_sdxl(instance_data_dir, lora_trained_xl_folder, instan
66
  script_filename, # Use the local script
67
  "--pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0",
68
  "--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix",
 
69
  f"--instance_data_dir={instance_data_dir}",
70
  f"--output_dir={lora_trained_xl_folder}",
71
  "--mixed_precision=fp16",
@@ -160,7 +161,7 @@ def main(dataset_id,
160
  gr.Info("Training begins ...")
161
 
162
  instance_data_dir = repo_parts[-1]
163
- train_dreambooth_lora_sdxl(instance_data_dir, lora_trained_xl_folder, instance_prompt, max_train_steps, checkpoint_steps, remove_gpu)
164
 
165
  your_username = api.whoami(token=hf_token)["name"]
166
  return f"Done, your trained model has been stored in your models library: {your_username}/{lora_trained_xl_folder}"
 
56
  except subprocess.CalledProcessError as e:
57
  print(f"An error occurred: {e}")
58
 
59
+ def train_dreambooth_lora_sdxl(dataset_id, instance_data_dir, lora_trained_xl_folder, instance_prompt, max_train_steps, checkpoint_steps, remove_gpu):
60
 
61
  script_filename = "train_dreambooth_lora_sdxl.py" # Assuming it's in the same folder
62
 
 
66
  script_filename, # Use the local script
67
  "--pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0",
68
  "--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix",
69
+ f"--dataset_id={dataset_id}",
70
  f"--instance_data_dir={instance_data_dir}",
71
  f"--output_dir={lora_trained_xl_folder}",
72
  "--mixed_precision=fp16",
 
161
  gr.Info("Training begins ...")
162
 
163
  instance_data_dir = repo_parts[-1]
164
+ train_dreambooth_lora_sdxl(dataset_id, instance_data_dir, lora_trained_xl_folder, instance_prompt, max_train_steps, checkpoint_steps, remove_gpu)
165
 
166
  your_username = api.whoami(token=hf_token)["name"]
167
  return f"Done, your trained model has been stored in your models library: {your_username}/{lora_trained_xl_folder}"