|
import json |
|
import os |
|
from functools import lru_cache |
|
from typing import List |
|
|
|
from iopaint.schema import ModelType, ModelInfo |
|
from loguru import logger |
|
from pathlib import Path |
|
|
|
from iopaint.const import ( |
|
DEFAULT_MODEL_DIR, |
|
DIFFUSERS_SD_CLASS_NAME, |
|
DIFFUSERS_SD_INPAINT_CLASS_NAME, |
|
DIFFUSERS_SDXL_CLASS_NAME, |
|
DIFFUSERS_SDXL_INPAINT_CLASS_NAME, |
|
ANYTEXT_NAME, |
|
) |
|
from iopaint.model.original_sd_configs import get_config_files |
|
|
|
|
|
def cli_download_model(model: str): |
|
from iopaint.model import models |
|
from iopaint.model.utils import handle_from_pretrained_exceptions |
|
|
|
if model in models and models[model].is_erase_model: |
|
logger.info(f"Downloading {model}...") |
|
models[model].download() |
|
logger.info(f"Done.") |
|
elif model == ANYTEXT_NAME: |
|
logger.info(f"Downloading {model}...") |
|
models[model].download() |
|
logger.info(f"Done.") |
|
else: |
|
logger.info(f"Downloading model from Huggingface: {model}") |
|
from diffusers import DiffusionPipeline |
|
|
|
downloaded_path = handle_from_pretrained_exceptions( |
|
DiffusionPipeline.download, |
|
pretrained_model_name=model, |
|
variant="fp16", |
|
resume_download=True, |
|
) |
|
logger.info(f"Done. Downloaded to {downloaded_path}") |
|
|
|
|
|
def folder_name_to_show_name(name: str) -> str: |
|
return name.replace("models--", "").replace("--", "/") |
|
|
|
|
|
@lru_cache(maxsize=512) |
|
def get_sd_model_type(model_abs_path: str) -> ModelType: |
|
if "inpaint" in Path(model_abs_path).name.lower(): |
|
model_type = ModelType.DIFFUSERS_SD_INPAINT |
|
else: |
|
|
|
from diffusers import StableDiffusionInpaintPipeline |
|
|
|
try: |
|
StableDiffusionInpaintPipeline.from_single_file( |
|
model_abs_path, |
|
load_safety_checker=False, |
|
num_in_channels=9, |
|
config_files=get_config_files(), |
|
) |
|
model_type = ModelType.DIFFUSERS_SD_INPAINT |
|
except ValueError as e: |
|
if "Trying to set a tensor of shape torch.Size([320, 4, 3, 3])" in str(e): |
|
model_type = ModelType.DIFFUSERS_SD |
|
else: |
|
raise e |
|
return model_type |
|
|
|
|
|
@lru_cache() |
|
def get_sdxl_model_type(model_abs_path: str) -> ModelType: |
|
if "inpaint" in model_abs_path: |
|
model_type = ModelType.DIFFUSERS_SDXL_INPAINT |
|
else: |
|
|
|
from diffusers import StableDiffusionXLInpaintPipeline |
|
|
|
try: |
|
model = StableDiffusionXLInpaintPipeline.from_single_file( |
|
model_abs_path, |
|
load_safety_checker=False, |
|
num_in_channels=9, |
|
config_files=get_config_files(), |
|
) |
|
if model.unet.config.in_channels == 9: |
|
|
|
model_type = ModelType.DIFFUSERS_SDXL_INPAINT |
|
else: |
|
model_type = ModelType.DIFFUSERS_SDXL |
|
except ValueError as e: |
|
if "Trying to set a tensor of shape torch.Size([320, 4, 3, 3])" in str(e): |
|
model_type = ModelType.DIFFUSERS_SDXL |
|
else: |
|
raise e |
|
return model_type |
|
|
|
|
|
def scan_single_file_diffusion_models(cache_dir) -> List[ModelInfo]: |
|
cache_dir = Path(cache_dir) |
|
stable_diffusion_dir = cache_dir / "stable_diffusion" |
|
cache_file = stable_diffusion_dir / "iopaint_cache.json" |
|
model_type_cache = {} |
|
if cache_file.exists(): |
|
try: |
|
with open(cache_file, "r", encoding="utf-8") as f: |
|
model_type_cache = json.load(f) |
|
assert isinstance(model_type_cache, dict) |
|
except: |
|
pass |
|
|
|
res = [] |
|
for it in stable_diffusion_dir.glob(f"*.*"): |
|
if it.suffix not in [".safetensors", ".ckpt"]: |
|
continue |
|
model_abs_path = str(it.absolute()) |
|
model_type = model_type_cache.get(it.name) |
|
if model_type is None: |
|
model_type = get_sd_model_type(model_abs_path) |
|
model_type_cache[it.name] = model_type |
|
res.append( |
|
ModelInfo( |
|
name=it.name, |
|
path=model_abs_path, |
|
model_type=model_type, |
|
is_single_file_diffusers=True, |
|
) |
|
) |
|
if stable_diffusion_dir.exists(): |
|
with open(cache_file, "w", encoding="utf-8") as fw: |
|
json.dump(model_type_cache, fw, indent=2, ensure_ascii=False) |
|
|
|
stable_diffusion_xl_dir = cache_dir / "stable_diffusion_xl" |
|
sdxl_cache_file = stable_diffusion_xl_dir / "iopaint_cache.json" |
|
sdxl_model_type_cache = {} |
|
if sdxl_cache_file.exists(): |
|
try: |
|
with open(sdxl_cache_file, "r", encoding="utf-8") as f: |
|
sdxl_model_type_cache = json.load(f) |
|
assert isinstance(sdxl_model_type_cache, dict) |
|
except: |
|
pass |
|
|
|
for it in stable_diffusion_xl_dir.glob(f"*.*"): |
|
if it.suffix not in [".safetensors", ".ckpt"]: |
|
continue |
|
model_abs_path = str(it.absolute()) |
|
model_type = sdxl_model_type_cache.get(it.name) |
|
if model_type is None: |
|
model_type = get_sdxl_model_type(model_abs_path) |
|
sdxl_model_type_cache[it.name] = model_type |
|
if stable_diffusion_xl_dir.exists(): |
|
with open(sdxl_cache_file, "w", encoding="utf-8") as fw: |
|
json.dump(sdxl_model_type_cache, fw, indent=2, ensure_ascii=False) |
|
|
|
res.append( |
|
ModelInfo( |
|
name=it.name, |
|
path=model_abs_path, |
|
model_type=model_type, |
|
is_single_file_diffusers=True, |
|
) |
|
) |
|
return res |
|
|
|
|
|
def scan_inpaint_models(model_dir: Path) -> List[ModelInfo]: |
|
res = [] |
|
from iopaint.model import models |
|
|
|
|
|
|
|
for name, m in models.items(): |
|
if m.is_erase_model and m.is_downloaded(): |
|
res.append( |
|
ModelInfo( |
|
name=name, |
|
path=name, |
|
model_type=ModelType.INPAINT, |
|
) |
|
) |
|
return res |
|
|
|
|
|
def scan_diffusers_models() -> List[ModelInfo]: |
|
from huggingface_hub.constants import HF_HUB_CACHE |
|
|
|
available_models = [] |
|
cache_dir = Path(HF_HUB_CACHE) |
|
|
|
diffusers_model_names = [] |
|
for it in cache_dir.glob("**/*/model_index.json"): |
|
with open(it, "r", encoding="utf-8") as f: |
|
try: |
|
data = json.load(f) |
|
except: |
|
continue |
|
|
|
_class_name = data["_class_name"] |
|
name = folder_name_to_show_name(it.parent.parent.parent.name) |
|
if name in diffusers_model_names: |
|
continue |
|
if "PowerPaint" in name: |
|
model_type = ModelType.DIFFUSERS_OTHER |
|
elif _class_name == DIFFUSERS_SD_CLASS_NAME: |
|
model_type = ModelType.DIFFUSERS_SD |
|
elif _class_name == DIFFUSERS_SD_INPAINT_CLASS_NAME: |
|
model_type = ModelType.DIFFUSERS_SD_INPAINT |
|
elif _class_name == DIFFUSERS_SDXL_CLASS_NAME: |
|
model_type = ModelType.DIFFUSERS_SDXL |
|
elif _class_name == DIFFUSERS_SDXL_INPAINT_CLASS_NAME: |
|
model_type = ModelType.DIFFUSERS_SDXL_INPAINT |
|
elif _class_name in [ |
|
"StableDiffusionInstructPix2PixPipeline", |
|
"PaintByExamplePipeline", |
|
"KandinskyV22InpaintPipeline", |
|
"AnyText", |
|
]: |
|
model_type = ModelType.DIFFUSERS_OTHER |
|
else: |
|
continue |
|
|
|
diffusers_model_names.append(name) |
|
available_models.append( |
|
ModelInfo( |
|
name=name, |
|
path=name, |
|
model_type=model_type, |
|
) |
|
) |
|
return available_models |
|
|
|
|
|
def _scan_converted_diffusers_models(cache_dir) -> List[ModelInfo]: |
|
cache_dir = Path(cache_dir) |
|
available_models = [] |
|
diffusers_model_names = [] |
|
for it in cache_dir.glob("**/*/model_index.json"): |
|
with open(it, "r", encoding="utf-8") as f: |
|
try: |
|
data = json.load(f) |
|
except: |
|
logger.error( |
|
f"Failed to load {it}, please try revert from original model or fix model_index.json by hand." |
|
) |
|
continue |
|
|
|
_class_name = data["_class_name"] |
|
name = folder_name_to_show_name(it.parent.name) |
|
if name in diffusers_model_names: |
|
continue |
|
elif _class_name == DIFFUSERS_SD_CLASS_NAME: |
|
model_type = ModelType.DIFFUSERS_SD |
|
elif _class_name == DIFFUSERS_SD_INPAINT_CLASS_NAME: |
|
model_type = ModelType.DIFFUSERS_SD_INPAINT |
|
elif _class_name == DIFFUSERS_SDXL_CLASS_NAME: |
|
model_type = ModelType.DIFFUSERS_SDXL |
|
elif _class_name == DIFFUSERS_SDXL_INPAINT_CLASS_NAME: |
|
model_type = ModelType.DIFFUSERS_SDXL_INPAINT |
|
else: |
|
continue |
|
|
|
diffusers_model_names.append(name) |
|
available_models.append( |
|
ModelInfo( |
|
name=name, |
|
path=str(it.parent.absolute()), |
|
model_type=model_type, |
|
) |
|
) |
|
return available_models |
|
|
|
|
|
def scan_converted_diffusers_models(cache_dir) -> List[ModelInfo]: |
|
cache_dir = Path(cache_dir) |
|
available_models = [] |
|
stable_diffusion_dir = cache_dir / "stable_diffusion" |
|
stable_diffusion_xl_dir = cache_dir / "stable_diffusion_xl" |
|
available_models.extend(_scan_converted_diffusers_models(stable_diffusion_dir)) |
|
available_models.extend(_scan_converted_diffusers_models(stable_diffusion_xl_dir)) |
|
return available_models |
|
|
|
|
|
def scan_models() -> List[ModelInfo]: |
|
model_dir = os.getenv("XDG_CACHE_HOME", DEFAULT_MODEL_DIR) |
|
available_models = [] |
|
available_models.extend(scan_inpaint_models(model_dir)) |
|
available_models.extend(scan_single_file_diffusion_models(model_dir)) |
|
available_models.extend(scan_diffusers_models()) |
|
available_models.extend(scan_converted_diffusers_models(model_dir)) |
|
return available_models |
|
|