TheNetherWatcher's picture
Upload folder using huggingface_hub
d0ffe9c verified
import json
import logging
from os import PathLike
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Union
from pydantic.v1 import BaseConfig, BaseSettings, Field
from pydantic.env_settings import (EnvSettingsSource, InitSettingsSource,
SecretsSettingsSource,
SettingsSourceCallable)
from animatediff import get_dir
from animatediff.schedulers import DiffusionScheduler
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
CKPT_EXTENSIONS = [".pt", ".ckpt", ".pth", ".safetensors"]
class JsonSettingsSource:
__slots__ = ["json_config_path"]
def __init__(
self,
json_config_path: Optional[Union[PathLike, list[PathLike]]] = list(),
) -> None:
if isinstance(json_config_path, list):
self.json_config_path = [Path(path) for path in json_config_path]
else:
self.json_config_path = [Path(json_config_path)] if json_config_path is not None else []
def __call__(self, settings: BaseSettings) -> Dict[str, Any]: # noqa C901
classname = settings.__class__.__name__
encoding = settings.__config__.env_file_encoding
if len(self.json_config_path) == 0:
pass # no json config provided
merged_config = dict() # create an empty dict to merge configs into
for idx, path in enumerate(self.json_config_path):
if path.exists() and path.is_file(): # check if the path exists and is a file
logger.debug(f"{classname}: loading config #{idx+1} from {path}")
merged_config.update(json.loads(path.read_text(encoding=encoding)))
logger.debug(f"{classname}: config state #{idx+1}: {merged_config}")
else:
raise FileNotFoundError(f"{classname}: config #{idx+1} at {path} not found or not a file")
logger.debug(f"{classname}: loaded config: {merged_config}")
return merged_config # return the merged config
def __repr__(self) -> str:
return f"JsonSettingsSource(json_config_path={repr(self.json_config_path)})"
class JsonConfig(BaseConfig):
json_config_path: Optional[Union[Path, list[Path]]] = None
env_file_encoding: str = "utf-8"
@classmethod
def customise_sources(
cls,
init_settings: InitSettingsSource,
env_settings: EnvSettingsSource,
file_secret_settings: SecretsSettingsSource,
) -> Tuple[SettingsSourceCallable, ...]:
# pull json_config_path from init_settings if passed, otherwise use the class var
json_config_path = init_settings.init_kwargs.pop("json_config_path", cls.json_config_path)
logger.debug(f"Using JsonSettingsSource for {cls.__name__}")
json_settings = JsonSettingsSource(json_config_path=json_config_path)
# return the new settings sources
return (
init_settings,
json_settings,
)
class InferenceConfig(BaseSettings):
unet_additional_kwargs: dict[str, Any]
noise_scheduler_kwargs: dict[str, Any]
class Config(JsonConfig):
json_config_path: Path
def get_infer_config(
is_v2:bool,
is_sdxl:bool,
) -> InferenceConfig:
config_path: Path = get_dir("config").joinpath("inference/default.json" if not is_v2 else "inference/motion_v2.json")
if is_sdxl:
config_path = get_dir("config").joinpath("inference/motion_sdxl.json")
settings = InferenceConfig(json_config_path=config_path)
return settings
class ModelConfig(BaseSettings):
name: str = Field(...) # Config name, not actually used for much of anything
path: Path = Field(...) # Path to the model
vae_path: str = "" # Path to the model
motion_module: Path = Field(...) # Path to the motion module
context_schedule: str = "uniform"
lcm_map: Dict[str,Any]= Field({})
gradual_latent_hires_fix_map: Dict[str,Any]= Field({})
compile: bool = Field(False) # whether to compile the model with TorchDynamo
tensor_interpolation_slerp: bool = Field(True)
seed: list[int] = Field([]) # Seed(s) for the random number generators
scheduler: DiffusionScheduler = Field(DiffusionScheduler.k_dpmpp_2m) # Scheduler to use
steps: int = 25 # Number of inference steps to run
guidance_scale: float = 7.5 # CFG scale to use
unet_batch_size: int = 1
clip_skip: int = 1 # skip the last N-1 layers of the CLIP text encoder
prompt_fixed_ratio: float = 0.5
head_prompt: str = ""
prompt_map: Dict[str,str]= Field({})
tail_prompt: str = ""
n_prompt: list[str] = Field([]) # Anti-prompt(s) to use
is_single_prompt_mode : bool = Field(False)
lora_map: Dict[str,Any]= Field({})
motion_lora_map: Dict[str,float]= Field({})
ip_adapter_map: Dict[str,Any]= Field({})
img2img_map: Dict[str,Any]= Field({})
region_map: Dict[str,Any]= Field({})
controlnet_map: Dict[str,Any]= Field({})
upscale_config: Dict[str,Any]= Field({})
stylize_config: Dict[str,Any]= Field({})
output: Dict[str,Any]= Field({})
result: Dict[str,Any]= Field({})
class Config(JsonConfig):
json_config_path: Path
@property
def save_name(self):
return f"{self.name.lower()}-{self.path.stem.lower()}"
def get_model_config(config_path: Path) -> ModelConfig:
settings = ModelConfig(json_config_path=config_path)
return settings