| import os | |
| import requests | |
| import hashlib | |
| import re | |
| from typing import Sequence, Mapping, Any, Union, Set | |
| from pathlib import Path | |
| import shutil | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download, constants as hf_constants | |
| import torch | |
| import numpy as np | |
| from PIL import Image, ImageChops | |
| import yaml | |
| from core.settings import * | |
| DISK_LIMIT_GB = 120 | |
| MODELS_ROOT_DIR = "ComfyUI/models" | |
| IPADAPTER_PRESETS = None | |
| class UniqueKeyLoader(yaml.SafeLoader): | |
| """ | |
| A custom YAML loader that handles duplicate keys by grouping their values into a list. | |
| """ | |
| def construct_mapping(self, node, deep=False): | |
| mapping = [] | |
| for key_node, value_node in node.value: | |
| key = self.construct_object(key_node, deep=deep) | |
| value = self.construct_object(value_node, deep=deep) | |
| mapping.append((key, value)) | |
| result = {} | |
| for k, v in mapping: | |
| if k in result: | |
| if isinstance(result[k], list): | |
| result[k].append(v) | |
| else: | |
| result[k] = [result[k], v] | |
| else: | |
| result[k] = v | |
| return result | |
| UniqueKeyLoader.add_constructor(yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, UniqueKeyLoader.construct_mapping) | |
| def save_uploaded_file_with_hash(file_obj: gr.File, target_dir: str) -> str: | |
| if not file_obj: | |
| return "" | |
| temp_path = file_obj.name | |
| sha256 = hashlib.sha256() | |
| with open(temp_path, 'rb') as f: | |
| for block in iter(lambda: f.read(65536), b''): | |
| sha256.update(block) | |
| file_hash = sha256.hexdigest() | |
| _, extension = os.path.splitext(temp_path) | |
| hashed_filename = f"{file_hash}{extension.lower()}" | |
| dest_path = os.path.join(target_dir, hashed_filename) | |
| os.makedirs(target_dir, exist_ok=True) | |
| if not os.path.exists(dest_path): | |
| shutil.copy(temp_path, dest_path) | |
| print(f"✅ Saved uploaded file as: {dest_path}") | |
| else: | |
| print(f"ℹ️ File already exists (deduplicated): {dest_path}") | |
| return hashed_filename | |
| def bytes_to_gb(byte_size: int) -> float: | |
| if byte_size is None or byte_size == 0: | |
| return 0.0 | |
| return round(byte_size / (1024 ** 3), 2) | |
| def get_directory_size(path: str) -> int: | |
| total_size = 0 | |
| if not os.path.exists(path): | |
| return 0 | |
| try: | |
| for dirpath, _, filenames in os.walk(path): | |
| for f in filenames: | |
| fp = os.path.join(dirpath, f) | |
| if os.path.isfile(fp) and not os.path.islink(fp): | |
| total_size += os.path.getsize(fp) | |
| except OSError as e: | |
| print(f"Warning: Could not access {path} to calculate size: {e}") | |
| return total_size | |
| def enforce_disk_limit(): | |
| disk_limit_bytes = DISK_LIMIT_GB * (1024 ** 3) | |
| cache_dir = hf_constants.HF_HUB_CACHE | |
| if not os.path.exists(cache_dir): | |
| return | |
| print(f"--- [Storage Manager] Checking disk usage in '{cache_dir}' (Limit: {DISK_LIMIT_GB} GB) ---") | |
| try: | |
| all_files = [] | |
| current_size_bytes = 0 | |
| for dirpath, _, filenames in os.walk(cache_dir): | |
| for f in filenames: | |
| if f.endswith(".incomplete") or f.endswith(".lock"): | |
| continue | |
| file_path = os.path.join(dirpath, f) | |
| if os.path.isfile(file_path) and not os.path.islink(file_path): | |
| try: | |
| file_size = os.path.getsize(file_path) | |
| creation_time = os.path.getctime(file_path) | |
| all_files.append((creation_time, file_path, file_size)) | |
| current_size_bytes += file_size | |
| except OSError: | |
| continue | |
| print(f"--- [Storage Manager] Current usage: {bytes_to_gb(current_size_bytes)} GB ---") | |
| if current_size_bytes > disk_limit_bytes: | |
| print(f"--- [Storage Manager] Usage exceeds limit. Starting cleanup... ---") | |
| all_files.sort(key=lambda x: x[0]) | |
| while current_size_bytes > disk_limit_bytes and all_files: | |
| oldest_file_time, oldest_file_path, oldest_file_size = all_files.pop(0) | |
| try: | |
| os.remove(oldest_file_path) | |
| current_size_bytes -= oldest_file_size | |
| print(f"--- [Storage Manager] Deleted oldest file: {os.path.basename(oldest_file_path)} ({bytes_to_gb(oldest_file_size)} GB freed) ---") | |
| except OSError as e: | |
| print(f"--- [Storage Manager] Error deleting file {oldest_file_path}: {e} ---") | |
| print(f"--- [Storage Manager] Cleanup finished. New usage: {bytes_to_gb(current_size_bytes)} GB ---") | |
| else: | |
| print("--- [Storage Manager] Disk usage is within the limit. No action needed. ---") | |
| except Exception as e: | |
| print(f"--- [Storage Manager] An unexpected error occurred: {e} ---") | |
| def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any: | |
| try: | |
| return obj[index] | |
| except (KeyError, IndexError): | |
| try: | |
| return obj["result"][index] | |
| except (KeyError, IndexError): | |
| return None | |
| def sanitize_prompt(prompt: str) -> str: | |
| if not isinstance(prompt, str): | |
| return "" | |
| return "".join(char for char in prompt if char.isprintable() or char in ('\n', '\t')) | |
| def sanitize_id(input_id: str) -> str: | |
| if not isinstance(input_id, str): | |
| return "" | |
| return re.sub(r'[^0-9]', '', input_id) | |
| def sanitize_url(url: str) -> str: | |
| if not isinstance(url, str): | |
| raise ValueError("URL must be a string.") | |
| url = url.strip() | |
| if not re.match(r'^https?://[^\s/$.?#].[^\s]*$', url): | |
| raise ValueError("Invalid URL format or scheme. Only HTTP and HTTPS are allowed.") | |
| return url | |
| def sanitize_filename(filename: str) -> str: | |
| if not isinstance(filename, str): | |
| return "" | |
| sanitized = filename.replace('..', '') | |
| sanitized = re.sub(r'[^\w\.\-]', '_', sanitized) | |
| return sanitized.lstrip('/\\') | |
| def get_civitai_file_info(version_id: str) -> dict | None: | |
| api_url = f"https://civitai.com/api/v1/model-versions/{version_id}" | |
| try: | |
| response = requests.get(api_url, timeout=10) | |
| response.raise_for_status() | |
| data = response.json() | |
| for file_data in data.get('files', []): | |
| if file_data.get('type') == 'Model' and file_data['name'].endswith(('.safetensors', '.pt', '.bin')): | |
| return file_data | |
| if data.get('files'): | |
| return data['files'][0] | |
| except Exception: | |
| return None | |
| def download_file(url: str, save_path: str, api_key: str = None, progress=None, desc: str = "") -> str: | |
| enforce_disk_limit() | |
| if os.path.exists(save_path): | |
| return f"File already exists: {os.path.basename(save_path)}" | |
| headers = {'Authorization': f'Bearer {api_key}'} if api_key and api_key.strip() else {} | |
| try: | |
| if progress: | |
| progress(0, desc=desc) | |
| response = requests.get(url, stream=True, headers=headers, timeout=15) | |
| response.raise_for_status() | |
| total_size = int(response.headers.get('content-length', 0)) | |
| with open(save_path, "wb") as f: | |
| downloaded = 0 | |
| for chunk in response.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| if progress and total_size > 0: | |
| downloaded += len(chunk) | |
| progress(downloaded / total_size, desc=desc) | |
| return f"Successfully downloaded: {os.path.basename(save_path)}" | |
| except Exception as e: | |
| if os.path.exists(save_path): | |
| os.remove(save_path) | |
| return f"Download failed for {os.path.basename(save_path)}: {e}" | |
| def get_lora_path(source: str, id_or_url: str, civitai_key: str, progress) -> tuple[str | None, str]: | |
| if not id_or_url or not id_or_url.strip(): | |
| return None, "No ID/URL provided." | |
| try: | |
| if source == "Civitai": | |
| version_id = sanitize_id(id_or_url) | |
| if not version_id: | |
| return None, "Invalid Civitai ID provided. Must be numeric." | |
| filename = sanitize_filename(f"civitai_{version_id}.safetensors") | |
| local_path = os.path.join(LORA_DIR, filename) | |
| file_info = get_civitai_file_info(version_id) | |
| api_key_to_use = civitai_key | |
| source_name = f"Civitai ID {version_id}" | |
| else: | |
| return None, "Invalid source." | |
| except ValueError as e: | |
| return None, f"Input validation failed: {e}" | |
| if os.path.exists(local_path): | |
| return local_path, "File already exists." | |
| if not file_info or not file_info.get('downloadUrl'): | |
| return None, f"Could not get download link for {source_name}." | |
| status = download_file(file_info['downloadUrl'], local_path, api_key_to_use, progress=progress, desc=f"Downloading {source_name}") | |
| return (local_path, status) if "Successfully" in status else (None, status) | |
| def get_embedding_path(source: str, id_or_url: str, civitai_key: str, progress) -> tuple[str | None, str]: | |
| if not id_or_url or not id_or_url.strip(): | |
| return None, "No ID/URL provided." | |
| try: | |
| file_ext = ".safetensors" | |
| if source == "Civitai": | |
| version_id = sanitize_id(id_or_url) | |
| if not version_id: | |
| return None, "Invalid Civitai ID. Must be numeric." | |
| file_info = get_civitai_file_info(version_id) | |
| if file_info and file_info['name'].lower().endswith(('.pt', '.bin')): | |
| file_ext = os.path.splitext(file_info['name'])[1] | |
| filename = sanitize_filename(f"civitai_{version_id}{file_ext}") | |
| local_path = os.path.join(EMBEDDING_DIR, filename) | |
| api_key_to_use = civitai_key | |
| source_name = f"Embedding Civitai ID {version_id}" | |
| else: | |
| return None, "Invalid source." | |
| except ValueError as e: | |
| return None, f"Input validation failed: {e}" | |
| if os.path.exists(local_path): | |
| return local_path, "File already exists." | |
| if not file_info or not file_info.get('downloadUrl'): | |
| return None, f"Could not get download link for {source_name}." | |
| status = download_file(file_info['downloadUrl'], local_path, api_key_to_use, progress=progress, desc=f"Downloading {source_name}") | |
| return (local_path, status) if "Successfully" in status else (None, status) | |
| def get_vae_path(source: str, id_or_url: str, civitai_key: str, progress) -> tuple[str | None, str]: | |
| if not id_or_url or not id_or_url.strip(): | |
| return None, "No ID/URL provided." | |
| try: | |
| file_ext = ".safetensors" | |
| if source == "Civitai": | |
| version_id = sanitize_id(id_or_url) | |
| if not version_id: | |
| return None, "Invalid Civitai ID. Must be numeric." | |
| file_info = get_civitai_file_info(version_id) | |
| if file_info and file_info['name'].lower().endswith(('.pt', '.bin')): | |
| file_ext = os.path.splitext(file_info['name'])[1] | |
| filename = sanitize_filename(f"civitai_{version_id}{file_ext}") | |
| local_path = os.path.join(VAE_DIR, filename) | |
| api_key_to_use = civitai_key | |
| source_name = f"VAE Civitai ID {version_id}" | |
| else: | |
| return None, "Invalid source." | |
| except ValueError as e: | |
| return None, f"Input validation failed: {e}" | |
| if os.path.exists(local_path): | |
| return local_path, "File already exists." | |
| if not file_info or not file_info.get('downloadUrl'): | |
| return None, f"Could not get download link for {source_name}." | |
| status = download_file(file_info['downloadUrl'], local_path, api_key_to_use, progress=progress, desc=f"Downloading {source_name}") | |
| return (local_path, status) if "Successfully" in status else (None, status) | |
| def _ensure_model_downloaded(display_name: str, progress=gr.Progress()): | |
| if display_name not in ALL_MODEL_MAP: | |
| raise ValueError(f"Model '{display_name}' not found in configuration.") | |
| model_info = ALL_MODEL_MAP[display_name] | |
| repo_filename = model_info[1] | |
| base_filename = os.path.basename(repo_filename) | |
| download_info = ALL_FILE_DOWNLOAD_MAP.get(base_filename) | |
| if not download_info: | |
| raise gr.Error(f"Model '{base_filename}' not found in file_list.yaml. Cannot download.") | |
| category = download_info.get("category") | |
| dest_dir = CATEGORY_TO_DIR_MAP.get(category) | |
| if not dest_dir: | |
| raise ValueError(f"Unknown YAML category '{category}' for '{base_filename}'.") | |
| dest_path = os.path.join(dest_dir, base_filename) | |
| if os.path.lexists(dest_path): | |
| if not os.path.exists(dest_path): | |
| print(f"⚠️ Found and removed broken symlink: {dest_path}") | |
| os.remove(dest_path) | |
| else: | |
| return base_filename | |
| source = download_info.get("source") | |
| try: | |
| progress(0, desc=f"Downloading: {base_filename}") | |
| if source == "hf": | |
| repo_id = download_info.get("repo_id") | |
| hf_filename = download_info.get("repository_file_path", base_filename) | |
| if not repo_id: | |
| raise ValueError(f"repo_id is missing for HF model '{base_filename}'") | |
| cached_path = hf_hub_download(repo_id=repo_id, filename=hf_filename, token=os.environ.get("HF_TOKEN")) | |
| os.makedirs(dest_dir, exist_ok=True) | |
| os.symlink(cached_path, dest_path) | |
| print(f"✅ Symlinked '{cached_path}' to '{dest_path}'") | |
| elif source == "civitai": | |
| model_version_id = download_info.get("model_version_id") | |
| if not model_version_id: | |
| raise ValueError(f"model_version_id is missing for Civitai model '{base_filename}'") | |
| file_info = get_civitai_file_info(model_version_id) | |
| if not file_info or not file_info.get('downloadUrl'): | |
| raise ConnectionError(f"Could not get download URL for Civitai model version ID {model_version_id}") | |
| status = download_file( | |
| file_info['downloadUrl'], dest_path, api_key=os.environ.get("CIVITAI_API_KEY", ""), progress=progress, desc=f"Downloading: {base_filename}" | |
| ) | |
| if "Failed" in status: | |
| raise ConnectionError(status) | |
| else: | |
| raise NotImplementedError(f"Download source '{source}' is not implemented for '{base_filename}'") | |
| progress(1.0, desc=f"Downloaded: {base_filename}") | |
| except Exception as e: | |
| if os.path.lexists(dest_path): | |
| try: | |
| os.remove(dest_path) | |
| except OSError: pass | |
| raise gr.Error(f"Failed to download and link '{display_name}': {e}") | |
| return base_filename | |
| def ensure_controlnet_model_downloaded(filename: str, progress): | |
| if not filename or filename == "None": | |
| return | |
| download_info = ALL_FILE_DOWNLOAD_MAP.get(filename) | |
| if not download_info: | |
| raise gr.Error(f"ControlNet model '{filename}' not found in configuration (file_list.yaml). Cannot download.") | |
| category = download_info.get("category", "controlnet") | |
| dest_dir = CATEGORY_TO_DIR_MAP.get(category, CONTROLNET_DIR) | |
| dest_path = os.path.join(dest_dir, filename) | |
| if os.path.lexists(dest_path): | |
| if not os.path.exists(dest_path): | |
| print(f"⚠️ Found and removed broken symlink: {dest_path}") | |
| os.remove(dest_path) | |
| else: | |
| return | |
| source = download_info.get("source") | |
| try: | |
| if source == "hf": | |
| repo_id = download_info.get("repo_id") | |
| repo_filename = download_info.get("repository_file_path", filename) | |
| if not repo_id: | |
| raise ValueError("repo_id is missing for Hugging Face download.") | |
| progress(0, desc=f"Downloading CN: {filename}") | |
| cached_path = hf_hub_download(repo_id=repo_id, filename=repo_filename, token=os.environ.get("HF_TOKEN")) | |
| os.makedirs(dest_dir, exist_ok=True) | |
| os.symlink(cached_path, dest_path) | |
| print(f"✅ Symlinked ControlNet '{cached_path}' to '{dest_path}'") | |
| progress(1.0, desc=f"Downloaded CN: {filename}") | |
| elif source == "civitai": | |
| model_version_id = download_info.get("model_version_id") | |
| if not model_version_id: | |
| raise ValueError("model_version_id is missing for Civitai download.") | |
| file_info = get_civitai_file_info(model_version_id) | |
| if not file_info or not file_info.get('downloadUrl'): | |
| raise ConnectionError(f"Could not get download URL for Civitai model version ID {model_version_id}") | |
| status = download_file( | |
| file_info['downloadUrl'], | |
| dest_path, | |
| api_key=os.environ.get("CIVITAI_API_KEY", ""), | |
| progress=progress, | |
| desc=f"Downloading CN: {filename}" | |
| ) | |
| if "Failed" in status: | |
| raise ConnectionError(status) | |
| else: | |
| raise NotImplementedError(f"Download source '{source}' is not implemented for ControlNets.") | |
| except Exception as e: | |
| if os.path.lexists(dest_path): | |
| try: | |
| os.remove(dest_path) | |
| except OSError: | |
| pass | |
| raise gr.Error(f"Failed to download ControlNet model '{filename}': {e}") | |
| def load_ipadapter_presets(): | |
| global IPADAPTER_PRESETS | |
| if IPADAPTER_PRESETS is not None: | |
| return | |
| _PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| _IPADAPTER_MODELS_PATH = os.path.join(_PROJECT_ROOT, 'yaml', 'ipadapter_models.yaml') | |
| try: | |
| with open(_IPADAPTER_MODELS_PATH, 'r', encoding='utf-8') as f: | |
| presets_list = yaml.load(f, Loader=UniqueKeyLoader) | |
| IPADAPTER_PRESETS = {item['preset_name']: item for item in presets_list} | |
| print("✅ IPAdapter presets loaded successfully.") | |
| except Exception as e: | |
| print(f"❌ FATAL: Could not load or parse ipadapter_models.yaml. IPAdapter will not work. Error: {e}") | |
| IPADAPTER_PRESETS = {} | |
| def ensure_ipadapter_models_downloaded(preset_name: str, progress): | |
| if not preset_name: | |
| return | |
| if IPADAPTER_PRESETS is None: | |
| raise RuntimeError("IPAdapter presets have not been loaded. `load_ipadapter_presets` must be called on startup.") | |
| preset_info = IPADAPTER_PRESETS.get(preset_name) | |
| if not preset_info: | |
| print(f"⚠️ Warning: IPAdapter preset '{preset_name}' not found in configuration. Skipping download.") | |
| return | |
| model_files_to_check = [] | |
| def add_files(value, type_name): | |
| if not value: return | |
| if isinstance(value, list): | |
| for v in value: | |
| model_files_to_check.append((v, type_name)) | |
| else: | |
| model_files_to_check.append((value, type_name)) | |
| add_files(preset_info.get('clip_vision'), 'CLIP_VISION') | |
| add_files(preset_info.get('ipadapter'), 'IPADAPTER') | |
| add_files(preset_info.get('loras'), 'LORA') | |
| for filename, model_type in model_files_to_check: | |
| if not filename: | |
| continue | |
| temp_display_name = f"ipadapter_asset_{filename}" | |
| if temp_display_name not in ALL_MODEL_MAP: | |
| ALL_MODEL_MAP[temp_display_name] = (None, filename, model_type, None, None) | |
| try: | |
| _ensure_model_downloaded(temp_display_name, progress) | |
| except Exception as e: | |
| print(f"❌ Error ensuring download for IPAdapter asset '{filename}': {e}") | |
| def ensure_sd3_ipadapter_models_downloaded(progress): | |
| _PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| yaml_path = os.path.join(_PROJECT_ROOT, 'yaml', 'ipadapter_sd3_models.yaml') | |
| try: | |
| with open(yaml_path, 'r', encoding='utf-8') as f: | |
| sd3_models = yaml.safe_load(f) | |
| if sd3_models: | |
| if 'ipadapter' in sd3_models: | |
| _ensure_model_downloaded(sd3_models['ipadapter'], progress) | |
| if 'clip_vision' in sd3_models: | |
| _ensure_model_downloaded(sd3_models['clip_vision'], progress) | |
| except Exception as e: | |
| print(f"Warning: Failed to load or download sd3 ipadapter models: {e}") | |
| def print_welcome_message(): | |
| author_name = "RioShiina" | |
| project_url = "https://huggingface.co/RioShiina" | |
| border = "=" * 72 | |
| message = ( | |
| f"\n{border}\n\n" | |
| f" Thank you for using this project!\n\n" | |
| f" **Author:** {author_name}\n" | |
| f" **Find more from the author:** {project_url}\n\n" | |
| f" This project is open-source under the GNU General Public License v3.0 (GPL-3.0).\n" | |
| f" As it's built upon GPL-3.0 components (like ComfyUI), any modifications you\n" | |
| f" distribute must also be open-sourced under the same license.\n\n" | |
| f" Your respect for the principles of free software is greatly appreciated!\n\n" | |
| f"{border}\n" | |
| ) | |
| print(message) | |
| def get_model_generation_defaults(model_display_name: str, model_type: str, defaults_config: dict): | |
| final_defaults = { | |
| 'steps': 25, 'cfg': 7.0, 'sampler_name': 'euler', 'scheduler': 'simple', | |
| 'positive_prompt': '', 'negative_prompt': '' | |
| } | |
| if 'Default' in defaults_config: | |
| final_defaults.update(defaults_config['Default']) | |
| model_type_key = next((key for key in defaults_config if key.lower().replace(" ", "-").replace(".", "") == model_type.lower()), None) | |
| if model_type_key: | |
| model_type_config = defaults_config[model_type_key] | |
| if '_defaults' in model_type_config: | |
| final_defaults.update(model_type_config['_defaults']) | |
| if model_display_name in model_type_config: | |
| final_defaults.update(model_type_config[model_display_name]) | |
| return final_defaults |