File size: 5,059 Bytes
cb31cb8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
# generate_yaml.py
# Author: Julie Kallini
# For importing utils
import sys
sys.path.append("..")
from jinja2 import Template
from utils import PERTURBATIONS, CHECKPOINT_WRITE_PATH, \
PAREN_MODELS, PAREN_MODEL_PATH
import argparse
import os
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog='Generate yaml for training',
description='Generate train and dataset yaml configs for mistral training')
parser.add_argument('perturbation_type',
default='all',
const='all',
nargs='?',
choices=PERTURBATIONS.keys(),
help='Perturbation function used to transform BabyLM dataset')
parser.add_argument('train_set',
default='all',
const='all',
nargs='?',
choices=["100M", "10M"],
help='BabyLM train set')
parser.add_argument('random_seed', type=int, help="Random seed")
parser.add_argument('paren_model',
default='all',
const='all',
nargs='?',
choices=list(PAREN_MODELS.keys()) + ["randinit"],
help='Parenthesis model')
parser.add_argument('-np', '--no_pos_encodings', action='store_true',
help="Train GPT-2 with no positional encodings")
# Get args
args = parser.parse_args()
if args.paren_model != "randinit":
paren_model_path = PAREN_MODEL_PATH + PAREN_MODELS[args.paren_model] + "/checkpoint-5000"
else:
paren_model_path = "null"
paren_model_name = args.paren_model
no_pos_encodings_str = "-no-positional-encodings" if args.no_pos_encodings else ""
no_pos_encodings_underscore = "_no_positional_encodings" if args.no_pos_encodings else ""
# Create directory for yaml
yaml_directory = f"conf/babylm_{args.perturbation_type}_{args.train_set}_{paren_model_name}{no_pos_encodings_underscore}/seed{args.random_seed}"
if not os.path.exists(yaml_directory):
os.makedirs(yaml_directory)
print("Generating GPT-2 model yaml file...")
# Get model template, which varies due to changes in vocab size
model_temp_file = open("conf/template/gpt2-small-template.yaml")
lines = model_temp_file.readlines()
model_temp_file.close()
# Fill model template
tokenizer = PERTURBATIONS[args.perturbation_type]["gpt2_tokenizer"]
vocab_size = len(tokenizer)
model_template = Template("".join(lines))
model_conf = model_template.render(
perturbation=args.perturbation_type,
vocab_size=vocab_size,
paren_model=paren_model_name,
paren_model_path=paren_model_path,
no_pos_encodings=no_pos_encodings_str,
)
# Write model yaml to file
model_file = open(
f"conf/babylm_{args.perturbation_type}_{args.train_set}_{paren_model_name}{no_pos_encodings_underscore}/gpt2{no_pos_encodings_str}-small-{args.perturbation_type}-{paren_model_name}.yaml", "w")
model_file.write(model_conf)
model_file.close()
print("Generating train yaml file...")
# Get train template file
train_temp_file = open("conf/template/babylm_train_template.yaml")
lines = train_temp_file.readlines()
train_temp_file.close()
# Fill train template file
train_template = Template("".join(lines))
train_conf = train_template.render(
perturbation=args.perturbation_type,
seed=args.random_seed,
ckpt_path=CHECKPOINT_WRITE_PATH,
train_set=args.train_set,
paren_model=paren_model_name,
no_pos_encodings=no_pos_encodings_str,
no_pos_encodings_underscore=no_pos_encodings_underscore,
)
# Write train yaml to file
train_file = open(yaml_directory + \
f"/train_{args.perturbation_type}_{args.train_set}_{paren_model_name}{no_pos_encodings_underscore}_seed{args.random_seed}.yaml", "w")
train_file.write(train_conf)
train_file.close()
print("Generating dataset yaml file...")
# Get dataset temp file
dataset_temp_file = open("conf/template/babylm_dataset_template.yaml")
lines = dataset_temp_file.readlines()
dataset_temp_file.close()
# Fill dataset template file
dataset_template = Template("".join(lines))
dataset_conf = dataset_template.render(
perturbation=args.perturbation_type,
train_set=args.train_set,
seed=args.random_seed,
)
# Write dataset yaml to file
dataset_file = open(yaml_directory + \
f"/dataset_{args.perturbation_type}_{args.train_set}_seed{args.random_seed}.yaml", "w")
dataset_file.write(dataset_conf)
dataset_file.close()
# Create directory for model checkpoints
ckpt_directory = CHECKPOINT_WRITE_PATH + f"/babylm_{args.perturbation_type}_{args.train_set}_{paren_model_name}{no_pos_encodings_underscore}"
if not os.path.exists(ckpt_directory):
os.makedirs(ckpt_directory) |