File size: 2,169 Bytes
b80dc1e
642e32f
b80dc1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
642e32f
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
import sys
import yaml
from loguru import logger
from typing import Dict, Any


def complete_path(file_name: str = "v7-base.yaml") -> str:
    """
    Ensures the path to a model configuration is a existing file

    Parameters:
        file_name (str): The filename or path, with default 'v7-base.yaml'.

    Returns:
        str: A complete path with necessary prefix and extension.
    """
    # Ensure the file has the '.yaml' extension if missing
    if not file_name.endswith(".yaml"):
        file_name += ".yaml"

    # Add folder prefix if only the filename is provided
    if os.path.dirname(file_name) == "":
        file_name = os.path.join("./config/model", file_name)

    return file_name


def load_model_cfg(file_path: str) -> Dict[str, Any]:
    """
    Read a YAML configuration file, ensure necessary keys are present, and return its content as a dictionary.

    Args:
        file_path (str): The path to the YAML configuration file.

    Returns:
        Dict[str, Any]: The contents of the YAML file as a dictionary.

    Raises:
        FileNotFoundError: If the YAML file cannot be found.
        yaml.YAMLError: If there is an error parsing the YAML file.
    """
    file_path = complete_path(file_path)
    try:
        with open(file_path, "r") as file:
            model_cfg = yaml.safe_load(file) or {}

        # Check for required keys and set defaults if not present
        if "nc" not in model_cfg:
            model_cfg["nc"] = 80
            logger.warning("'nc' not found in the YAML file. Setting default 'nc' to 80.")

        if "model" not in model_cfg:
            logger.error("'model' is missing in the configuration file.")
            raise ValueError("Missing required key: 'model'")

        return model_cfg

    except FileNotFoundError:
        logger.error(f"YAML file not found: {file_path}")
        raise
    except yaml.YAMLError as e:
        logger.error(f"Error parsing YAML file: {e}")
        raise


def custom_logger():
    logger.remove()
    logger.add(
        sys.stderr,
        format="<green>{time:MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <level>{message}</level>",
    )