gpt2-medium-persian / src /create_config.py
m3hrdadfi's picture
Hello gpt2-persian
21d29cb
raw
history blame
No virus
2.32 kB
import ast
import logging
import os
import sys
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
from transformers import (
HfArgumentParser,
AutoConfig
)
logger = logging.getLogger(__name__)
@dataclass
class ConfigArguments:
"""
Arguments to which config we are going to set up.
"""
output_dir: str = field(
default=".",
metadata={"help": "The output directory where the config will be written."},
)
name_or_path: Optional[str] = field(
default=None,
metadata={
"help": "The model checkpoint for weights initialization."
"Don't set if you want to train a model from scratch."
},
)
params: Optional[str] = field(
default=None,
metadata={"help": "Custom configuration for the specific `name_or_path`"}
)
def __post_init__(self):
if self.params:
try:
self.params = ast.literal_eval(self.params)
except Exception as e:
print(f"Your custom parameters do not acceptable due to {e}")
def main():
parser = HfArgumentParser([ConfigArguments])
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
config_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))[0]
else:
config_args = parser.parse_args_into_dataclasses()[0]
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
logger.setLevel(logging.INFO)
logger.info(f"Setting up configuration {config_args.name_or_path} with extra params {config_args.params}")
if config_args.params and isinstance(config_args.params, dict):
config = AutoConfig.from_pretrained(config_args.name_or_path, **config_args.params)
else:
config = AutoConfig.from_pretrained(config_args.name_or_path)
logger.info(f"Your configuration saved here {config_args.output_dir}/config.json")
config.save_pretrained(config_args.output_dir)
if __name__ == '__main__':
main()