Spaces:
Runtime error
Runtime error
add dataset id to model card
Browse files
app.py
CHANGED
@@ -56,7 +56,7 @@ def set_accelerate_default_config():
|
|
56 |
except subprocess.CalledProcessError as e:
|
57 |
print(f"An error occurred: {e}")
|
58 |
|
59 |
-
def train_dreambooth_lora_sdxl(instance_data_dir, lora_trained_xl_folder, instance_prompt, max_train_steps, checkpoint_steps, remove_gpu):
|
60 |
|
61 |
script_filename = "train_dreambooth_lora_sdxl.py" # Assuming it's in the same folder
|
62 |
|
@@ -66,6 +66,7 @@ def train_dreambooth_lora_sdxl(instance_data_dir, lora_trained_xl_folder, instan
|
|
66 |
script_filename, # Use the local script
|
67 |
"--pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0",
|
68 |
"--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix",
|
|
|
69 |
f"--instance_data_dir={instance_data_dir}",
|
70 |
f"--output_dir={lora_trained_xl_folder}",
|
71 |
"--mixed_precision=fp16",
|
@@ -160,7 +161,7 @@ def main(dataset_id,
|
|
160 |
gr.Info("Training begins ...")
|
161 |
|
162 |
instance_data_dir = repo_parts[-1]
|
163 |
-
train_dreambooth_lora_sdxl(instance_data_dir, lora_trained_xl_folder, instance_prompt, max_train_steps, checkpoint_steps, remove_gpu)
|
164 |
|
165 |
your_username = api.whoami(token=hf_token)["name"]
|
166 |
return f"Done, your trained model has been stored in your models library: {your_username}/{lora_trained_xl_folder}"
|
|
|
56 |
except subprocess.CalledProcessError as e:
|
57 |
print(f"An error occurred: {e}")
|
58 |
|
59 |
+
def train_dreambooth_lora_sdxl(dataset_id, instance_data_dir, lora_trained_xl_folder, instance_prompt, max_train_steps, checkpoint_steps, remove_gpu):
|
60 |
|
61 |
script_filename = "train_dreambooth_lora_sdxl.py" # Assuming it's in the same folder
|
62 |
|
|
|
66 |
script_filename, # Use the local script
|
67 |
"--pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0",
|
68 |
"--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix",
|
69 |
+
f"--dataset_id={dataset_id}",
|
70 |
f"--instance_data_dir={instance_data_dir}",
|
71 |
f"--output_dir={lora_trained_xl_folder}",
|
72 |
"--mixed_precision=fp16",
|
|
|
161 |
gr.Info("Training begins ...")
|
162 |
|
163 |
instance_data_dir = repo_parts[-1]
|
164 |
+
train_dreambooth_lora_sdxl(dataset_id, instance_data_dir, lora_trained_xl_folder, instance_prompt, max_train_steps, checkpoint_steps, remove_gpu)
|
165 |
|
166 |
your_username = api.whoami(token=hf_token)["name"]
|
167 |
return f"Done, your trained model has been stored in your models library: {your_username}/{lora_trained_xl_folder}"
|