fffiloni commited on
Commit
db2a4a7
1 Parent(s): d666556

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -7
app.py CHANGED
@@ -4,13 +4,14 @@ import os
4
  import requests
5
  import subprocess
6
  from subprocess import getoutput
7
- from huggingface_hub import snapshot_download
 
 
8
 
9
  hf_token = os.environ.get("HF_TOKEN_WITH_WRITE_PERMISSION")
10
 
11
  is_shared_ui = True if "fffiloni/train-dreambooth-lora-sdxl" in os.environ['SPACE_ID'] else False
12
 
13
-
14
  is_gpu_associated = torch.cuda.is_available()
15
  if is_gpu_associated:
16
  gpu_info = getoutput('nvidia-smi')
@@ -44,8 +45,7 @@ def get_sleep_time(hf_token):
44
  return gcTimeout
45
 
46
  def write_to_community(title, description,hf_token):
47
- from huggingface_hub import HfApi
48
- api = HfApi()
49
  api.create_discussion(repo_id=os.environ['SPACE_ID'], title=title, description=description,repo_type="space", token=hf_token)
50
 
51
 
@@ -161,8 +161,9 @@ def main(dataset_id,
161
 
162
  instance_data_dir = repo_parts[-1]
163
  train_dreambooth_lora_sdxl(instance_data_dir, lora_trained_xl_folder, instance_prompt, max_train_steps, checkpoint_steps, remove_gpu)
164
-
165
- return f"Done, your trained model has been stored in your models library: your_user_name/{lora_trained_xl_folder}"
 
166
 
167
  css="""
168
  #col-container {max-width: 780px; margin-left: auto; margin-right: auto;}
@@ -219,4 +220,4 @@ with gr.Blocks(css=css) as demo:
219
  outputs = [status]
220
  )
221
 
222
- demo.queue().launch()
 
4
  import requests
5
  import subprocess
6
  from subprocess import getoutput
7
+ from huggingface_hub import snapshot_download, HfApi
8
+
9
+ api = HfApi()
10
 
11
  hf_token = os.environ.get("HF_TOKEN_WITH_WRITE_PERMISSION")
12
 
13
  is_shared_ui = True if "fffiloni/train-dreambooth-lora-sdxl" in os.environ['SPACE_ID'] else False
14
 
 
15
  is_gpu_associated = torch.cuda.is_available()
16
  if is_gpu_associated:
17
  gpu_info = getoutput('nvidia-smi')
 
45
  return gcTimeout
46
 
47
  def write_to_community(title, description,hf_token):
48
+
 
49
  api.create_discussion(repo_id=os.environ['SPACE_ID'], title=title, description=description,repo_type="space", token=hf_token)
50
 
51
 
 
161
 
162
  instance_data_dir = repo_parts[-1]
163
  train_dreambooth_lora_sdxl(instance_data_dir, lora_trained_xl_folder, instance_prompt, max_train_steps, checkpoint_steps, remove_gpu)
164
+
165
+ your_username = api.whoami(token=hf_token)["name"]
166
+ return f"Done, your trained model has been stored in your models library: {your_username}/{lora_trained_xl_folder}"
167
 
168
  css="""
169
  #col-container {max-width: 780px; margin-left: auto; margin-right: auto;}
 
220
  outputs = [status]
221
  )
222
 
223
+ demo.queue(default_enabled=False).launch(debug=True)