Spaces:
Runtime error
Runtime error
import logging | |
import os | |
import sys | |
from os.path import dirname, join | |
from utils.config import Config | |
from utils.distributed import init_distributed_mode, is_main_process | |
from utils.logger import setup_logger | |
logger = logging.getLogger(__name__) | |
def setup_config(): | |
"""Conbine yaml config and command line config with OmegaConf. | |
Also converts types, e.g., `'None'` (str) --> `None` (None) | |
""" | |
config = Config.get_config() | |
if config.debug: | |
config.wandb.enable = False | |
return config | |
def setup_evaluate_config(config): | |
"""setup evaluation default settings, e.g., disable wandb""" | |
assert config.evaluate | |
config.wandb.enable = False | |
if config.output_dir is None: | |
config.output_dir = join(dirname(config.pretrained_path), "eval") | |
return config | |
def setup_output_dir(output_dir, excludes=["code"]): | |
"""ensure not overwritting an exisiting/non-empty output dir""" | |
if not os.path.exists(output_dir): | |
os.makedirs(output_dir, exist_ok=False) | |
else: | |
existing_dirs_files = os.listdir(output_dir) # list | |
remaining = set(existing_dirs_files) - set(excludes) | |
remaining = [e for e in remaining if "slurm" not in e] | |
remaining = [e for e in remaining if ".out" not in e] | |
# assert len(remaining) == 0, f"remaining dirs or files: {remaining}" | |
logger.warn(f"remaining dirs or files: {remaining}") | |
def setup_main(): | |
""" | |
Setup config, logger, output_dir, etc. | |
Shared for pretrain and all downstream tasks. | |
""" | |
config = setup_config() | |
if hasattr(config, "evaluate") and config.evaluate: | |
config = setup_evaluate_config(config) | |
init_distributed_mode(config) | |
if is_main_process(): | |
setup_output_dir(config.output_dir, excludes=["code"]) | |
setup_logger(output=config.output_dir, color=True, name="vindlu") | |
logger.info(f"config: {Config.pretty_text(config)}") | |
Config.dump(config, os.path.join(config.output_dir, "config.json")) | |
return config | |