gpt-neo-1.3B-persian / src /create_config.py
m3hrdadfi's picture
Initialize
69dd1b0
import ast
import logging
import os
import sys
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
from transformers import (
HfArgumentParser,
AutoConfig
)
logger = logging.getLogger(__name__)
@dataclass
class ConfigArguments:
"""
Arguments to which config we are going to set up.
"""
output_dir: str = field(
default=".",
metadata={"help": "The output directory where the config will be written."},
)
name_or_path: Optional[str] = field(
default=None,
metadata={
"help": "The model checkpoint for weights initialization."
"Don't set if you want to train a model from scratch."
},
)
params: Optional[str] = field(
default=None,
metadata={"help": "Custom configuration for the specific `name_or_path`"}
)
def __post_init__(self):
if self.params:
try:
self.params = ast.literal_eval(self.params)
except Exception as e:
print(f"Your custom parameters do not acceptable due to {e}")
def main():
parser = HfArgumentParser([ConfigArguments])
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
config_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))[0]
else:
config_args = parser.parse_args_into_dataclasses()[0]
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
logger.setLevel(logging.INFO)
logger.info(f"Setting up configuration {config_args.name_or_path} with extra params {config_args.params}")
if config_args.params and isinstance(config_args.params, dict):
config = AutoConfig.from_pretrained(config_args.name_or_path, **config_args.params)
else:
config = AutoConfig.from_pretrained(config_args.name_or_path)
logger.info(f"Your configuration saved here {config_args.output_dir}/config.json")
config.save_pretrained(config_args.output_dir)
if __name__ == '__main__':
main()