rafaaa2105 commited on
Commit
09073c1
1 Parent(s): e47b9ec

Enhanced the download and load functions

Browse files
Files changed (2) hide show
  1. README.md +0 -1
  2. app.py +14 -13
README.md CHANGED
@@ -4,7 +4,6 @@ emoji: 🎴
4
  colorFrom: red
5
  colorTo: pink
6
  sdk: gradio
7
- sdk_version: 4.24.0
8
  app_file: app.py
9
  pinned: true
10
  license: apache-2.0
 
4
  colorFrom: red
5
  colorTo: pink
6
  sdk: gradio
 
7
  app_file: app.py
8
  pinned: true
9
  license: apache-2.0
app.py CHANGED
@@ -24,7 +24,7 @@ def download_file(url, filename, progress=gr.Progress(track_tqdm=True)):
24
  if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
25
  print("ERROR, something went wrong")
26
 
27
- def get_civitai_model_info(model_id, progress=gr.Progress(track_tqdm=True)):
28
  url = f"https://civitai.com/api/v1/models/{model_id}"
29
  response = requests.get(url)
30
  if response.status_code != 200:
@@ -37,7 +37,7 @@ def find_download_url(data, file_extension):
37
  return file['downloadUrl']
38
  return None
39
 
40
- def download_civitai_model(model_id, lora_id="", progress=gr.Progress(track_tqdm=True)):
41
  try:
42
  model_data = get_civitai_model_info(model_id)
43
  if model_data is None:
@@ -52,7 +52,8 @@ def download_civitai_model(model_id, lora_id="", progress=gr.Progress(track_tqdm
52
  return f"Error: No suitable file found for model {model_name}."
53
 
54
  file_extension = '.ckpt' if model_ckpt_url else '.safetensors'
55
- download_file(model_url, f"{model_name}{file_extension}")
 
56
 
57
  if lora_id:
58
  lora_data = get_civitai_model_info(lora_id)
@@ -65,12 +66,17 @@ def download_civitai_model(model_id, lora_id="", progress=gr.Progress(track_tqdm
65
  return f"Error: No suitable file found for LoRA {lora_name}."
66
 
67
  download_file(lora_safetensors_url, f"{lora_name}.safetensors")
68
- loras_list.append(lora_name)
 
69
  else:
70
  lora_name = "None"
71
 
72
- models_list.append(model_name)
73
- return "Model/LoRA Downloaded!"
 
 
 
 
74
  except Exception as e:
75
  return f"Error downloading model or LoRA: {e}"
76
 
@@ -114,11 +120,6 @@ def generate_images(
114
  progress=gr.Progress(track_tqdm=True)
115
  ):
116
  if prompt is not None and prompt.strip() != "":
117
- if lora_name == "None":
118
- load_model(model_name, "", False)
119
- elif lora_name in loras_list and lora_name != "None":
120
- load_model(model_name, lora_name, True)
121
-
122
  pipe = models.get(model_name)
123
  if pipe is None:
124
  return []
@@ -172,6 +173,6 @@ with gr.Blocks(theme='ParityError/Interstellar') as demo:
172
 
173
  download_output = gr.Textbox(label="Download Output")
174
 
175
- download_button.click(download_civitai_model, inputs=[model_id, lora_id], outputs=download_output)
176
 
177
- demo.launch()
 
24
  if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
25
  print("ERROR, something went wrong")
26
 
27
+ def get_civitai_model_info(model_id):
28
  url = f"https://civitai.com/api/v1/models/{model_id}"
29
  response = requests.get(url)
30
  if response.status_code != 200:
 
37
  return file['downloadUrl']
38
  return None
39
 
40
+ def download_and_load_civitai_model(model_id, lora_id="", progress=gr.Progress(track_tqdm=True)):
41
  try:
42
  model_data = get_civitai_model_info(model_id)
43
  if model_data is None:
 
52
  return f"Error: No suitable file found for model {model_name}."
53
 
54
  file_extension = '.ckpt' if model_ckpt_url else '.safetensors'
55
+ model_filename = f"{model_name}{file_extension}"
56
+ download_file(model_url, model_filename)
57
 
58
  if lora_id:
59
  lora_data = get_civitai_model_info(lora_id)
 
66
  return f"Error: No suitable file found for LoRA {lora_name}."
67
 
68
  download_file(lora_safetensors_url, f"{lora_name}.safetensors")
69
+ if lora_name not in loras_list:
70
+ loras_list.append(lora_name)
71
  else:
72
  lora_name = "None"
73
 
74
+ if model_name not in models_list:
75
+ models_list.append(model_name)
76
+
77
+ # Load model after downloading
78
+ load_result = load_model(model_filename, lora_name, use_lora=(lora_name != "None"))
79
+ return f"Model/LoRA Downloaded and Loaded! {load_result}"
80
  except Exception as e:
81
  return f"Error downloading model or LoRA: {e}"
82
 
 
120
  progress=gr.Progress(track_tqdm=True)
121
  ):
122
  if prompt is not None and prompt.strip() != "":
 
 
 
 
 
123
  pipe = models.get(model_name)
124
  if pipe is None:
125
  return []
 
173
 
174
  download_output = gr.Textbox(label="Download Output")
175
 
176
+ download_button.click(download_and_load_civitai_model, inputs=[model_id, lora_id], outputs=download_output)
177
 
178
+ demo.launch()