John6666 commited on
Commit
e81646e
·
verified ·
1 Parent(s): d518d04

Upload utils.py

Browse files
Files changed (1) hide show
  1. utils.py +181 -104
utils.py CHANGED
@@ -9,6 +9,7 @@ from constants import (
9
  DIRECTORY_LORAS,
10
  DIRECTORY_MODELS,
11
  DIFFUSECRAFT_CHECKPOINT_NAME,
 
12
  CACHE_HF,
13
  STORAGE_ROOT,
14
  )
@@ -28,6 +29,7 @@ from urllib3.util import Retry
28
  import shutil
29
  import subprocess
30
 
 
31
  USER_AGENT = 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:127.0) Gecko/20100101 Firefox/127.0'
32
 
33
 
@@ -66,7 +68,8 @@ class ModelInformation:
66
  )
67
  self.filename_url = self.filename_url if self.filename_url else ""
68
  self.description = json_data.get("description", "")
69
- if self.description is None: self.description = ""
 
70
  self.model_name = json_data.get("model", {}).get("name", "")
71
  self.model_type = json_data.get("model", {}).get("type", "")
72
  self.nsfw = json_data.get("model", {}).get("nsfw", False)
@@ -76,118 +79,175 @@ class ModelInformation:
76
  self.original_json = copy.deepcopy(json_data)
77
 
78
 
79
- def retrieve_model_info(url):
80
- json_data = request_json_data(url)
81
- if not json_data:
82
- return None
83
- model_descriptor = ModelInformation(json_data)
84
- return model_descriptor
 
 
 
 
 
85
 
86
 
87
- def download_things(directory, url, hf_token="", civitai_api_key="", romanize=False):
88
- url = url.strip()
89
- downloaded_file_path = None
90
 
91
- if "drive.google.com" in url:
92
- original_dir = os.getcwd()
93
- os.chdir(directory)
94
- os.system(f"gdown --fuzzy {url}")
95
- os.chdir(original_dir)
96
- elif "huggingface.co" in url:
97
- url = url.replace("?download=true", "")
98
- # url = urllib.parse.quote(url, safe=':/') # fix encoding
99
- if "/blob/" in url:
100
- url = url.replace("/blob/", "/resolve/")
101
- user_header = f'"Authorization: Bearer {hf_token}"'
102
 
103
- filename = unidecode(url.split('/')[-1]) if romanize else url.split('/')[-1]
 
 
 
 
 
104
 
105
- if hf_token:
106
- os.system(f"aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 {url} -d {directory} -o {filename}")
107
- else:
108
- os.system(f"aria2c --optimize-concurrent-downloads --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 {url} -d {directory} -o {filename}")
109
 
110
- downloaded_file_path = os.path.join(directory, filename)
 
111
 
112
- elif "civitai.com" in url:
 
113
 
114
- if not civitai_api_key:
115
- print("\033[91mYou need an API key to download Civitai models.\033[0m")
116
-
117
- model_profile = retrieve_model_info(url)
118
- if (
119
- model_profile is not None
120
- and model_profile.download_url
121
- and model_profile.filename_url
122
- ):
123
- url = model_profile.download_url
124
- filename = unidecode(model_profile.filename_url) if romanize else model_profile.filename_url
125
- else:
126
- if "?" in url:
127
- url = url.split("?")[0]
128
- filename = ""
129
 
130
- url_dl = url + f"?token={civitai_api_key}"
131
- print(f"Filename: {filename}")
132
 
133
- param_filename = ""
134
- if filename:
135
- param_filename = f"-o '{filename}'"
 
 
136
 
137
- aria2_command = (
138
- f'aria2c --console-log-level=error --summary-interval=10 -c -x 16 '
139
- f'-k 1M -s 16 -d "{directory}" {param_filename} "{url_dl}"'
140
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  os.system(aria2_command)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
- if param_filename and os.path.exists(os.path.join(directory, filename)):
144
- downloaded_file_path = os.path.join(directory, filename)
145
-
146
- # # PLAN B
147
- # # Follow the redirect to get the actual download URL
148
- # curl_command = (
149
- # f'curl -L -sI --connect-timeout 5 --max-time 5 '
150
- # f'-H "Content-Type: application/json" '
151
- # f'-H "Authorization: Bearer {civitai_api_key}" "{url}"'
152
- # )
153
-
154
- # headers = os.popen(curl_command).read()
155
-
156
- # # Look for the redirected "Location" URL
157
- # location_match = re.search(r'location: (.+)', headers, re.IGNORECASE)
158
-
159
- # if location_match:
160
- # redirect_url = location_match.group(1).strip()
161
-
162
- # # Extract the filename from the redirect URL's "Content-Disposition"
163
- # filename_match = re.search(r'filename%3D%22(.+?)%22', redirect_url)
164
- # if filename_match:
165
- # encoded_filename = filename_match.group(1)
166
- # # Decode the URL-encoded filename
167
- # decoded_filename = urllib.parse.unquote(encoded_filename)
168
-
169
- # filename = unidecode(decoded_filename) if romanize else decoded_filename
170
- # print(f"Filename: {filename}")
171
-
172
- # aria2_command = (
173
- # f'aria2c --console-log-level=error --summary-interval=10 -c -x 16 '
174
- # f'-k 1M -s 16 -d "{directory}" -o "{filename}" "{redirect_url}"'
175
- # )
176
- # return_code = os.system(aria2_command)
177
-
178
- # # if return_code != 0:
179
- # # raise RuntimeError(f"Failed to download file: {filename}. Error code: {return_code}")
180
- # downloaded_file_path = os.path.join(directory, filename)
181
- # if not os.path.exists(downloaded_file_path):
182
- # downloaded_file_path = None
183
-
184
- # if not downloaded_file_path:
185
- # # Old method
186
- # if "?" in url:
187
- # url = url.split("?")[0]
188
- # url = url + f"?token={civitai_api_key}"
189
- # os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
190
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  else:
192
  os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
193
 
@@ -216,14 +276,15 @@ def extract_parameters(input_string):
216
  if "Steps:" in input_string:
217
  input_string = input_string.replace("Steps:", "Negative prompt: Steps:")
218
  else:
219
- print("Invalid metadata")
 
 
220
  parameters["prompt"] = input_string
221
  return parameters
222
 
223
  parm = input_string.split("Negative prompt:")
224
  parameters["prompt"] = parm[0].strip()
225
  if "Steps:" not in parm[1]:
226
- print("Steps not detected")
227
  parameters["neg_prompt"] = parm[1].strip()
228
  return parameters
229
  parm = parm[1].split("Steps:")
@@ -306,7 +367,8 @@ def get_model_type(repo_id: str):
306
  model = api.model_info(repo_id=repo_id, timeout=5.0)
307
  tags = model.tags
308
  for tag in tags:
309
- if tag in MODEL_TYPE_CLASS.keys(): return MODEL_TYPE_CLASS.get(tag, default)
 
310
 
311
  except Exception:
312
  return default
@@ -433,9 +495,9 @@ def get_folder_size_gb(folder_path):
433
  return total_size_gb
434
 
435
 
436
- def get_used_storage_gb():
437
  try:
438
- used_gb = get_folder_size_gb(STORAGE_ROOT)
439
  print(f"Used Storage: {used_gb:.2f} GB")
440
  except Exception as e:
441
  used_gb = 999
@@ -455,6 +517,21 @@ def delete_model(removal_candidate):
455
  shutil.rmtree(diffusers_model)
456
 
457
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
458
  def progress_step_bar(step, total):
459
  # Calculate the percentage for the progress bar width
460
  percentage = min(100, ((step / total) * 100))
 
9
  DIRECTORY_LORAS,
10
  DIRECTORY_MODELS,
11
  DIFFUSECRAFT_CHECKPOINT_NAME,
12
+ CACHE_HF_ROOT,
13
  CACHE_HF,
14
  STORAGE_ROOT,
15
  )
 
29
  import shutil
30
  import subprocess
31
 
32
+ IS_ZERO_GPU = bool(os.getenv("SPACES_ZERO_GPU"))
33
  USER_AGENT = 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:127.0) Gecko/20100101 Firefox/127.0'
34
 
35
 
 
68
  )
69
  self.filename_url = self.filename_url if self.filename_url else ""
70
  self.description = json_data.get("description", "")
71
+ if self.description is None:
72
+ self.description = ""
73
  self.model_name = json_data.get("model", {}).get("name", "")
74
  self.model_type = json_data.get("model", {}).get("type", "")
75
  self.nsfw = json_data.get("model", {}).get("nsfw", False)
 
79
  self.original_json = copy.deepcopy(json_data)
80
 
81
 
82
+ def get_civit_params(url):
83
+ try:
84
+ json_data = request_json_data(url)
85
+ mdc = ModelInformation(json_data)
86
+ if mdc.download_url and mdc.filename_url:
87
+ return mdc.download_url, mdc.filename_url, mdc.model_url
88
+ else:
89
+ ValueError("Invalid Civitai model URL")
90
+ except Exception as e:
91
+ print(f"Error retrieving Civitai metadata: {e} — fallback to direct download")
92
+ return url, None, None
93
 
94
 
95
+ def civ_redirect_down(url, dir_, civitai_api_key, romanize, alternative_name):
96
+ filename_base = filename = None
 
97
 
98
+ if alternative_name:
99
+ output_path = os.path.join(dir_, alternative_name)
100
+ if os.path.exists(output_path):
101
+ return output_path, alternative_name
 
 
 
 
 
 
 
102
 
103
+ # Follow the redirect to get the actual download URL
104
+ curl_command = (
105
+ f'curl -L -sI --connect-timeout 5 --max-time 5 '
106
+ f'-H "Content-Type: application/json" '
107
+ f'-H "Authorization: Bearer {civitai_api_key}" "{url}"'
108
+ )
109
 
110
+ headers = os.popen(curl_command).read()
 
 
 
111
 
112
+ # Look for the redirected "Location" URL
113
+ location_match = re.search(r'location: (.+)', headers, re.IGNORECASE)
114
 
115
+ if location_match:
116
+ redirect_url = location_match.group(1).strip()
117
 
118
+ # Extract the filename from the redirect URL's "Content-Disposition"
119
+ filename_match = re.search(r'filename%3D%22(.+?)%22', redirect_url)
120
+ if filename_match:
121
+ encoded_filename = filename_match.group(1)
122
+ # Decode the URL-encoded filename
123
+ decoded_filename = urllib.parse.unquote(encoded_filename)
 
 
 
 
 
 
 
 
 
124
 
125
+ filename = unidecode(decoded_filename) if romanize else decoded_filename
126
+ # print(f"Filename redirect: {filename}")
127
 
128
+ filename_base = alternative_name if alternative_name else filename
129
+ if not filename_base:
130
+ return None, None
131
+ elif os.path.exists(os.path.join(dir_, filename_base)):
132
+ return os.path.join(dir_, filename_base), filename_base
133
 
134
+ aria2_command = (
135
+ f'aria2c --console-log-level=error --summary-interval=10 -c -x 16 '
136
+ f'-k 1M -s 16 -d "{dir_}" -o "{filename_base}" "{redirect_url}"'
137
+ )
138
+ r_code = os.system(aria2_command) # noqa
139
+
140
+ # if r_code != 0:
141
+ # raise RuntimeError(f"Failed to download file: {filename_base}. Error code: {r_code}")
142
+
143
+ output_path = os.path.join(dir_, filename_base)
144
+ if not os.path.exists(output_path):
145
+ return None, filename_base
146
+
147
+ return output_path, filename_base
148
+
149
+
150
+ def civ_api_down(url, dir_, civitai_api_key, civ_filename):
151
+ """
152
+ This method is susceptible to being blocked because it generates a lot of temp redirect links with aria2c.
153
+ If an API key limit is reached, generating a new API key and using it can fix the issue.
154
+ """
155
+ output_path = None
156
+
157
+ url_dl = url + f"?token={civitai_api_key}"
158
+ if not civ_filename:
159
+ aria2_command = f'aria2c -c -x 1 -s 1 -d "{dir_}" "{url_dl}"'
160
  os.system(aria2_command)
161
+ else:
162
+ output_path = os.path.join(dir_, civ_filename)
163
+ if not os.path.exists(output_path):
164
+ aria2_command = (
165
+ f'aria2c --console-log-level=error --summary-interval=10 -c -x 16 '
166
+ f'-k 1M -s 16 -d "{dir_}" -o "{civ_filename}" "{url_dl}"'
167
+ )
168
+ os.system(aria2_command)
169
+
170
+ return output_path
171
+
172
+
173
+ def drive_down(url, dir_):
174
+ import gdown
175
+
176
+ output_path = None
177
 
178
+ drive_id, _ = gdown.parse_url.parse_url(url, warning=False)
179
+ dir_files = os.listdir(dir_)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
+ for dfile in dir_files:
182
+ if drive_id in dfile:
183
+ output_path = os.path.join(dir_, dfile)
184
+ break
185
+
186
+ if not output_path:
187
+ original_path = gdown.download(url, f"{dir_}/", fuzzy=True)
188
+
189
+ dir_name, base_name = os.path.split(original_path)
190
+ name, ext = base_name.rsplit(".", 1)
191
+ new_name = f"{name}_{drive_id}.{ext}"
192
+ output_path = os.path.join(dir_name, new_name)
193
+
194
+ os.rename(original_path, output_path)
195
+
196
+ return output_path
197
+
198
+
199
+ def hf_down(url, dir_, hf_token, romanize):
200
+ url = url.replace("?download=true", "")
201
+ # url = urllib.parse.quote(url, safe=':/') # fix encoding
202
+
203
+ filename = unidecode(url.split('/')[-1]) if romanize else url.split('/')[-1]
204
+ output_path = os.path.join(dir_, filename)
205
+
206
+ if os.path.exists(output_path):
207
+ return output_path
208
+
209
+ if "/blob/" in url:
210
+ url = url.replace("/blob/", "/resolve/")
211
+
212
+ if hf_token:
213
+ user_header = f'"Authorization: Bearer {hf_token}"'
214
+ os.system(f"aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 {url} -d {dir_} -o {filename}")
215
+ else:
216
+ os.system(f"aria2c --optimize-concurrent-downloads --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 {url} -d {dir_} -o {filename}")
217
+
218
+ return output_path
219
+
220
+
221
+ def download_things(directory, url, hf_token="", civitai_api_key="", romanize=False):
222
+ url = url.strip()
223
+ downloaded_file_path = None
224
+
225
+ if "drive.google.com" in url:
226
+ downloaded_file_path = drive_down(url, directory)
227
+ elif "huggingface.co" in url:
228
+ downloaded_file_path = hf_down(url, directory, hf_token, romanize)
229
+ elif "civitai.com" in url:
230
+ if not civitai_api_key:
231
+ msg = "You need an API key to download Civitai models."
232
+ print(f"\033[91m{msg}\033[0m")
233
+ gr.Warning(msg)
234
+ return None
235
+
236
+ url, civ_filename, civ_page = get_civit_params(url)
237
+ if civ_page and not IS_ZERO_GPU:
238
+ print(f"\033[92mCivitai model: {civ_filename} [page: {civ_page}]\033[0m")
239
+
240
+ downloaded_file_path, civ_filename = civ_redirect_down(url, directory, civitai_api_key, romanize, civ_filename)
241
+
242
+ if not downloaded_file_path:
243
+ msg = (
244
+ "Download failed.\n"
245
+ "If this is due to an API limit, generating a new API key may resolve the issue.\n"
246
+ "Attempting to download using the old method..."
247
+ )
248
+ print(msg)
249
+ gr.Warning(msg)
250
+ downloaded_file_path = civ_api_down(url, directory, civitai_api_key, civ_filename)
251
  else:
252
  os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
253
 
 
276
  if "Steps:" in input_string:
277
  input_string = input_string.replace("Steps:", "Negative prompt: Steps:")
278
  else:
279
+ msg = "Generation data is invalid."
280
+ gr.Warning(msg)
281
+ print(msg)
282
  parameters["prompt"] = input_string
283
  return parameters
284
 
285
  parm = input_string.split("Negative prompt:")
286
  parameters["prompt"] = parm[0].strip()
287
  if "Steps:" not in parm[1]:
 
288
  parameters["neg_prompt"] = parm[1].strip()
289
  return parameters
290
  parm = parm[1].split("Steps:")
 
367
  model = api.model_info(repo_id=repo_id, timeout=5.0)
368
  tags = model.tags
369
  for tag in tags:
370
+ if tag in MODEL_TYPE_CLASS.keys():
371
+ return MODEL_TYPE_CLASS.get(tag, default)
372
 
373
  except Exception:
374
  return default
 
495
  return total_size_gb
496
 
497
 
498
+ def get_used_storage_gb(path_storage=STORAGE_ROOT):
499
  try:
500
+ used_gb = get_folder_size_gb(path_storage)
501
  print(f"Used Storage: {used_gb:.2f} GB")
502
  except Exception as e:
503
  used_gb = 999
 
517
  shutil.rmtree(diffusers_model)
518
 
519
 
520
+ def clear_hf_cache():
521
+ """
522
+ Clears the entire Hugging Face cache at ~/.cache/huggingface.
523
+ Hugging Face will re-download models as needed later.
524
+ """
525
+ try:
526
+ if os.path.exists(CACHE_HF_ROOT):
527
+ shutil.rmtree(CACHE_HF_ROOT, ignore_errors=True)
528
+ print(f"Hugging Face cache cleared: {CACHE_HF_ROOT}")
529
+ else:
530
+ print(f"No Hugging Face cache found at: {CACHE_HF_ROOT}")
531
+ except Exception as e:
532
+ print(f"Error clearing Hugging Face cache: {e}")
533
+
534
+
535
  def progress_step_bar(step, total):
536
  # Calculate the percentage for the progress bar width
537
  percentage = min(100, ((step / total) * 100))