Spaces:
Sleeping
Sleeping
| import os | |
| from pathlib import Path | |
| from typing import Optional, List, Tuple, Dict | |
| from dataclasses import dataclass, field | |
| from omegaconf import OmegaConf, MISSING | |
| from utils.class_registry import ClassRegistry | |
| from models.methods import methods_registry | |
| from metrics.metrics import metrics_registry | |
| args = ClassRegistry() | |
| class ExperimentArgs: | |
| config_dir: str = str(Path(__file__).resolve().parent / "configs") | |
| config: str = MISSING | |
| output_dir: str = "results_dir" | |
| seed: int = 1 | |
| root: str = os.getenv("EXP_ROOT", ".") | |
| domain: str = "human_faces" | |
| wandb: bool = False | |
| class DataArgs: | |
| inference_dir: str = "" | |
| transform: str = "face_1024" | |
| class InferenceArgs: | |
| inference_runner: str = "base_inference_runner" | |
| editings_data: Dict = field(default_factory=lambda: {}) | |
| class ModelArgs: | |
| method: str = "fse_full" | |
| device: str = "0" | |
| batch_size: int = 4 | |
| workers: int = 4 | |
| checkpoint_path: str = "" | |
| MethodsArgs = methods_registry.make_dataclass_from_args("MethodsArgs") | |
| args.add_to_registry("methods_args")(MethodsArgs) | |
| MetricsArgs = metrics_registry.make_dataclass_from_args("MetricsArgs") | |
| args.add_to_registry("metrics")(MetricsArgs) | |
| Args = args.make_dataclass_from_classes("Args") | |
| def load_config(): | |
| config = OmegaConf.structured(Args) | |
| conf_cli = OmegaConf.from_cli() | |
| config.exp.config = conf_cli.exp.config | |
| config.exp.config_dir = conf_cli.exp.config_dir | |
| config_path = os.path.join(config.exp.config_dir, config.exp.config) | |
| conf_file = OmegaConf.load(config_path) | |
| config = OmegaConf.merge(config, conf_file) | |
| for method in list(config.methods_args.keys()): | |
| if method != config.model.method: | |
| config.methods_args.__delattr__(method) | |
| config = OmegaConf.merge(config, conf_cli) | |
| return config | |