File size: 2,319 Bytes
21d29cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import ast
import logging
import os
import sys
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple

from transformers import (
    HfArgumentParser,
    AutoConfig
)

logger = logging.getLogger(__name__)


@dataclass
class ConfigArguments:
    """
    Arguments to which config we are going to set up.
    """
    output_dir: str = field(
        default=".",
        metadata={"help": "The output directory where the config will be written."},
    )
    name_or_path: Optional[str] = field(
        default=None,
        metadata={
            "help": "The model checkpoint for weights initialization."
                    "Don't set if you want to train a model from scratch."
        },
    )
    params: Optional[str] = field(
        default=None,
        metadata={"help": "Custom configuration for the specific `name_or_path`"}
    )

    def __post_init__(self):
        if self.params:
            try:
                self.params = ast.literal_eval(self.params)
            except Exception as e:
                print(f"Your custom parameters do not acceptable due to {e}")


def main():
    parser = HfArgumentParser([ConfigArguments])
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        config_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))[0]
    else:
        config_args = parser.parse_args_into_dataclasses()[0]

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )
    logger.setLevel(logging.INFO)

    logger.info(f"Setting up configuration {config_args.name_or_path} with extra params {config_args.params}")

    if config_args.params and isinstance(config_args.params, dict):
        config = AutoConfig.from_pretrained(config_args.name_or_path, **config_args.params)
    else:
        config = AutoConfig.from_pretrained(config_args.name_or_path)

    logger.info(f"Your configuration saved here {config_args.output_dir}/config.json")
    config.save_pretrained(config_args.output_dir)


if __name__ == '__main__':
    main()