Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from huggingface_hub import HfApi, HfFolder, hf_hub_download, snapshot_download | |
| import os | |
| from pathlib import Path | |
| import shutil | |
| import gc | |
| import re | |
| import urllib.parse | |
| import subprocess | |
| import time | |
| from typing import Any | |
| def get_token(): | |
| try: | |
| token = HfFolder.get_token() | |
| except Exception: | |
| token = False | |
| return token | |
| def set_token(token): | |
| try: | |
| HfFolder.save_token(token) | |
| except Exception: | |
| print(f"Error: Failed to save token.") | |
| def get_state(state: dict, key: str): | |
| if key in state.keys(): return state[key] | |
| else: | |
| print(f"State '{key}' not found.") | |
| return None | |
| def set_state(state: dict, key: str, value: Any): | |
| state[key] = value | |
| def get_user_agent(): | |
| return 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:127.0) Gecko/20100101 Firefox/127.0' | |
| def is_repo_exists(repo_id: str, repo_type: str="model"): | |
| hf_token = get_token() | |
| api = HfApi(token=hf_token) | |
| try: | |
| if api.repo_exists(repo_id=repo_id, repo_type=repo_type, token=hf_token): return True | |
| else: return False | |
| except Exception as e: | |
| print(f"Error: Failed to connect {repo_id} ({repo_type}). {e}") | |
| return True # for safe | |
| MODEL_TYPE_CLASS = { | |
| "diffusers:StableDiffusionPipeline": "SD 1.5", | |
| "diffusers:StableDiffusionXLPipeline": "SDXL", | |
| "diffusers:FluxPipeline": "FLUX", | |
| } | |
| def get_model_type(repo_id: str): | |
| hf_token = get_token() | |
| api = HfApi(token=hf_token) | |
| lora_filename = "pytorch_lora_weights.safetensors" | |
| diffusers_filename = "model_index.json" | |
| default = "SDXL" | |
| try: | |
| if api.file_exists(repo_id=repo_id, filename=lora_filename, token=hf_token): return "LoRA" | |
| if not api.file_exists(repo_id=repo_id, filename=diffusers_filename, token=hf_token): return "None" | |
| model = api.model_info(repo_id=repo_id, token=hf_token) | |
| tags = model.tags | |
| for tag in tags: | |
| if tag in MODEL_TYPE_CLASS.keys(): return MODEL_TYPE_CLASS.get(tag, default) | |
| except Exception: | |
| return default | |
| return default | |
| def list_uniq(l): | |
| return sorted(set(l), key=l.index) | |
| def list_sub(a, b): | |
| return [e for e in a if e not in b] | |
| def is_repo_name(s): | |
| return re.fullmatch(r'^[\w_\-\.]+/[\w_\-\.]+$', s) | |
| def get_hf_url(repo_id: str, repo_type: str="model"): | |
| if repo_type == "dataset": url = f"https://huggingface.co/datasets/{repo_id}" | |
| elif repo_type == "space": url = f"https://huggingface.co/spaces/{repo_id}" | |
| else: url = f"https://huggingface.co/{repo_id}" | |
| return url | |
| def split_hf_url(url: str): | |
| try: | |
| s = list(re.findall(r'^(?:https?://huggingface.co/)(?:(datasets|spaces)/)?(.+?/.+?)/\w+?/.+?/(?:(.+)/)?(.+?.\w+)(?:\?download=true)?$', url)[0]) | |
| if len(s) < 4: return "", "", "", "" | |
| repo_id = s[1] | |
| if s[0] == "datasets": repo_type = "dataset" | |
| elif s[0] == "spaces": repo_type = "space" | |
| else: repo_type = "model" | |
| subfolder = urllib.parse.unquote(s[2]) if s[2] else None | |
| filename = urllib.parse.unquote(s[3]) | |
| return repo_id, filename, subfolder, repo_type | |
| except Exception as e: | |
| print(e) | |
| def download_hf_file(directory, url, progress=gr.Progress(track_tqdm=True)): | |
| hf_token = get_token() | |
| repo_id, filename, subfolder, repo_type = split_hf_url(url) | |
| try: | |
| print(f"Downloading {url} to {directory}") | |
| if subfolder is not None: path = hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder, repo_type=repo_type, local_dir=directory, token=hf_token) | |
| else: path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type=repo_type, local_dir=directory, token=hf_token) | |
| return path | |
| except Exception as e: | |
| print(f"Failed to download: {e}") | |
| return None | |
| def download_thing(directory, url, civitai_api_key="", progress=gr.Progress(track_tqdm=True)): # requires aria2, gdown | |
| try: | |
| url = url.strip() | |
| if "drive.google.com" in url: | |
| original_dir = os.getcwd() | |
| os.chdir(directory) | |
| subprocess.run(f"gdown --fuzzy {url}", shell=True) | |
| os.chdir(original_dir) | |
| elif "huggingface.co" in url: | |
| url = url.replace("?download=true", "") | |
| if "/blob/" in url: url = url.replace("/blob/", "/resolve/") | |
| download_hf_file(directory, url) | |
| elif "civitai.com" in url: | |
| if civitai_api_key: | |
| url = f"'{url}&token={civitai_api_key}'" if "?" in url else f"{url}?token={civitai_api_key}" | |
| print(f"Downloading {url}") | |
| subprocess.run(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}", shell=True) | |
| else: | |
| print("You need an API key to download Civitai models.") | |
| else: | |
| os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}") | |
| except Exception as e: | |
| print(f"Failed to download: {e}") | |
| def get_local_file_list(dir_path, recursive=False): | |
| file_list = [] | |
| pattern = "**/*.*" if recursive else "*/*.*" | |
| for file in Path(dir_path).glob(pattern): | |
| if file.is_file(): | |
| file_path = str(file) | |
| file_list.append(file_path) | |
| return file_list | |
| def get_download_file(temp_dir, url, civitai_key, progress=gr.Progress(track_tqdm=True)): | |
| try: | |
| if not "http" in url and is_repo_name(url) and not Path(url).exists(): | |
| print(f"Use HF Repo: {url}") | |
| new_file = url | |
| elif not "http" in url and Path(url).exists(): | |
| print(f"Use local file: {url}") | |
| new_file = url | |
| elif Path(f"{temp_dir}/{url.split('/')[-1]}").exists(): | |
| print(f"File to download alreday exists: {url}") | |
| new_file = f"{temp_dir}/{url.split('/')[-1]}" | |
| elif "huggingface.co" in url: | |
| url = url.replace("?download=true", "") | |
| if "/blob/" in url: url = url.replace("/blob/", "/resolve/") | |
| new_file = download_hf_file(temp_dir, url) | |
| else: | |
| print(f"Start downloading: {url}") | |
| recursive = False if "huggingface.co" in url else True | |
| before = get_local_file_list(temp_dir, recursive) | |
| download_thing(temp_dir, url.strip(), civitai_key) | |
| after = get_local_file_list(temp_dir, recursive) | |
| new_file = list_sub(after, before)[0] if list_sub(after, before) else "" | |
| if not new_file: | |
| print(f"Download failed: {url}") | |
| return "" | |
| print(f"Download completed: {url}") | |
| return new_file | |
| except Exception as e: | |
| print(f"Download failed: {url} {e}") | |
| return "" | |
| def download_repo(repo_id: str, dir_path: str, progress=gr.Progress(track_tqdm=True)): # for diffusers repo | |
| hf_token = get_token() | |
| try: | |
| snapshot_download(repo_id=repo_id, local_dir=dir_path, token=hf_token, allow_patterns=["*.safetensors", "*.bin"], | |
| ignore_patterns=["*.fp16.*", "/*.safetensors", "/*.bin"], force_download=True) | |
| return True | |
| except Exception as e: | |
| print(f"Error: Failed to download {repo_id}. {e}") | |
| gr.Warning(f"Error: Failed to download {repo_id}. {e}") | |
| return False | |
| def upload_repo(repo_id: str, dir_path: str, is_private: bool, is_pr: bool=False, progress=gr.Progress(track_tqdm=True)): # for diffusers repo | |
| hf_token = get_token() | |
| api = HfApi(token=hf_token) | |
| try: | |
| progress(0, desc="Start uploading...") | |
| api.create_repo(repo_id=repo_id, token=hf_token, private=is_private, exist_ok=True) | |
| api.upload_folder(repo_id=repo_id, folder_path=dir_path, path_in_repo="", create_pr=is_pr, token=hf_token) | |
| progress(1, desc="Uploaded.") | |
| return get_hf_url(repo_id, "model") | |
| except Exception as e: | |
| print(f"Error: Failed to upload to {repo_id}. {e}") | |
| return "" | |
| def gate_repo(repo_id: str, gated_str: str, repo_type: str="model"): | |
| hf_token = get_token() | |
| api = HfApi(token=hf_token) | |
| try: | |
| if gated_str == "auto": gated = "auto" | |
| elif gated_str == "manual": gated = "manual" | |
| else: gated = False | |
| api.update_repo_settings(repo_id=repo_id, gated=gated, repo_type=repo_type, token=hf_token) | |
| except Exception as e: | |
| print(f"Error: Failed to update settings {repo_id}. {e}") | |
| HF_SUBFOLDER_NAME = ["None", "user_repo"] | |
| def duplicate_hf_repo(src_repo: str, dst_repo: str, src_repo_type: str, dst_repo_type: str, | |
| is_private: bool, subfolder_type: str=HF_SUBFOLDER_NAME[1], progress=gr.Progress(track_tqdm=True)): | |
| hf_token = get_token() | |
| api = HfApi(token=hf_token) | |
| try: | |
| if subfolder_type == "user_repo": subfolder = src_repo.replace("/", "_") | |
| else: subfolder = "" | |
| progress(0, desc="Start duplicating...") | |
| api.create_repo(repo_id=dst_repo, repo_type=dst_repo_type, private=is_private, exist_ok=True, token=hf_token) | |
| for path in api.list_repo_files(repo_id=src_repo, repo_type=src_repo_type, token=hf_token): | |
| file = hf_hub_download(repo_id=src_repo, filename=path, repo_type=src_repo_type, token=hf_token) | |
| if not Path(file).exists(): continue | |
| if Path(file).is_dir(): # unused for now | |
| api.upload_folder(repo_id=dst_repo, folder_path=file, path_in_repo=f"{subfolder}/{path}" if subfolder else path, | |
| repo_type=dst_repo_type, token=hf_token) | |
| elif Path(file).is_file(): | |
| api.upload_file(repo_id=dst_repo, path_or_fileobj=file, path_in_repo=f"{subfolder}/{path}" if subfolder else path, | |
| repo_type=dst_repo_type, token=hf_token) | |
| if Path(file).exists(): Path(file).unlink() | |
| progress(1, desc="Duplicated.") | |
| return f"{get_hf_url(dst_repo, dst_repo_type)}/tree/main/{subfolder}" if subfolder else get_hf_url(dst_repo, dst_repo_type) | |
| except Exception as e: | |
| print(f"Error: Failed to duplicate repo {src_repo} to {dst_repo}. {e}") | |
| return "" | |
| BASE_DIR = str(Path(__file__).resolve().parent.resolve()) | |
| CIVITAI_API_KEY = os.environ.get("CIVITAI_API_KEY") | |
| def get_file(url: str, path: str): # requires aria2, gdown | |
| print(f"Downloading {url} to {path}...") | |
| get_download_file(path, url, CIVITAI_API_KEY) | |
| def git_clone(url: str, path: str, pip: bool=False, addcmd: str=""): # requires git | |
| os.makedirs(str(Path(BASE_DIR, path)), exist_ok=True) | |
| os.chdir(Path(BASE_DIR, path)) | |
| print(f"Cloning {url} to {path}...") | |
| cmd = f'git clone {url}' | |
| print(f'Running {cmd} at {Path.cwd()}') | |
| i = subprocess.run(cmd, shell=True).returncode | |
| if i != 0: print(f'Error occured at running {cmd}') | |
| p = url.split("/")[-1] | |
| if not Path(p).exists: return | |
| if pip: | |
| os.chdir(Path(BASE_DIR, path, p)) | |
| cmd = f'pip install -r requirements.txt' | |
| print(f'Running {cmd} at {Path.cwd()}') | |
| i = subprocess.run(cmd, shell=True).returncode | |
| if i != 0: print(f'Error occured at running {cmd}') | |
| if addcmd: | |
| os.chdir(Path(BASE_DIR, path, p)) | |
| cmd = addcmd | |
| print(f'Running {cmd} at {Path.cwd()}') | |
| i = subprocess.run(cmd, shell=True).returncode | |
| if i != 0: print(f'Error occured at running {cmd}') | |
| def run(cmd: str, timeout: float=0): | |
| print(f'Running {cmd} at {Path.cwd()}') | |
| if timeout == 0: | |
| i = subprocess.run(cmd, shell=True).returncode | |
| if i != 0: print(f'Error occured at running {cmd}') | |
| else: | |
| p = subprocess.Popen(cmd, shell=True) | |
| time.sleep(timeout) | |
| p.terminate() | |
| print(f'Terminated in {timeout} seconds') | |