File size: 5,059 Bytes
cb31cb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# generate_yaml.py
# Author: Julie Kallini

# For importing utils
import sys
sys.path.append("..")

from jinja2 import Template
from utils import PERTURBATIONS, CHECKPOINT_WRITE_PATH, \
    PAREN_MODELS, PAREN_MODEL_PATH
import argparse
import os


if __name__ == "__main__":

    parser = argparse.ArgumentParser(
        prog='Generate yaml for training',
        description='Generate train and dataset yaml configs for mistral training')
    parser.add_argument('perturbation_type',
                        default='all',
                        const='all',
                        nargs='?',
                        choices=PERTURBATIONS.keys(),
                        help='Perturbation function used to transform BabyLM dataset')
    parser.add_argument('train_set',
                        default='all',
                        const='all',
                        nargs='?',
                        choices=["100M", "10M"],
                        help='BabyLM train set')
    parser.add_argument('random_seed', type=int, help="Random seed")
    parser.add_argument('paren_model',
                        default='all',
                        const='all',
                        nargs='?',
                        choices=list(PAREN_MODELS.keys()) + ["randinit"],
                        help='Parenthesis model')
    parser.add_argument('-np', '--no_pos_encodings', action='store_true',
                        help="Train GPT-2 with no positional encodings")

    # Get args
    args = parser.parse_args()
    if args.paren_model != "randinit":
        paren_model_path = PAREN_MODEL_PATH + PAREN_MODELS[args.paren_model] + "/checkpoint-5000"
    else:
        paren_model_path = "null"
    paren_model_name = args.paren_model
    no_pos_encodings_str = "-no-positional-encodings" if args.no_pos_encodings else ""
    no_pos_encodings_underscore = "_no_positional_encodings" if args.no_pos_encodings else ""

    # Create directory for yaml
    yaml_directory = f"conf/babylm_{args.perturbation_type}_{args.train_set}_{paren_model_name}{no_pos_encodings_underscore}/seed{args.random_seed}"
    if not os.path.exists(yaml_directory):
        os.makedirs(yaml_directory)

    print("Generating GPT-2 model yaml file...")

    # Get model template, which varies due to changes in vocab size
    model_temp_file = open("conf/template/gpt2-small-template.yaml")
    lines = model_temp_file.readlines()
    model_temp_file.close()

    # Fill model template
    tokenizer = PERTURBATIONS[args.perturbation_type]["gpt2_tokenizer"]
    vocab_size = len(tokenizer)
    model_template = Template("".join(lines))
    model_conf = model_template.render(
        perturbation=args.perturbation_type,
        vocab_size=vocab_size,
        paren_model=paren_model_name,
        paren_model_path=paren_model_path,
        no_pos_encodings=no_pos_encodings_str,
    )

    # Write model yaml to file
    model_file = open(
        f"conf/babylm_{args.perturbation_type}_{args.train_set}_{paren_model_name}{no_pos_encodings_underscore}/gpt2{no_pos_encodings_str}-small-{args.perturbation_type}-{paren_model_name}.yaml", "w")
    model_file.write(model_conf)
    model_file.close()

    print("Generating train yaml file...")

    # Get train template file
    train_temp_file = open("conf/template/babylm_train_template.yaml")
    lines = train_temp_file.readlines()
    train_temp_file.close()

    # Fill train template file
    train_template = Template("".join(lines))
    train_conf = train_template.render(
        perturbation=args.perturbation_type,
        seed=args.random_seed,
        ckpt_path=CHECKPOINT_WRITE_PATH,
        train_set=args.train_set,
        paren_model=paren_model_name,
        no_pos_encodings=no_pos_encodings_str,
        no_pos_encodings_underscore=no_pos_encodings_underscore,
    )

    # Write train yaml to file
    train_file = open(yaml_directory + \
        f"/train_{args.perturbation_type}_{args.train_set}_{paren_model_name}{no_pos_encodings_underscore}_seed{args.random_seed}.yaml", "w")
    train_file.write(train_conf)
    train_file.close()

    print("Generating dataset yaml file...")

    # Get dataset temp file
    dataset_temp_file = open("conf/template/babylm_dataset_template.yaml")
    lines = dataset_temp_file.readlines()
    dataset_temp_file.close()

    # Fill dataset template file
    dataset_template = Template("".join(lines))
    dataset_conf = dataset_template.render(
        perturbation=args.perturbation_type,
        train_set=args.train_set,
        seed=args.random_seed,
    )

    # Write dataset yaml to file
    dataset_file = open(yaml_directory + \
        f"/dataset_{args.perturbation_type}_{args.train_set}_seed{args.random_seed}.yaml", "w")
    dataset_file.write(dataset_conf)
    dataset_file.close()

    # Create directory for model checkpoints
    ckpt_directory = CHECKPOINT_WRITE_PATH + f"/babylm_{args.perturbation_type}_{args.train_set}_{paren_model_name}{no_pos_encodings_underscore}"
    if not os.path.exists(ckpt_directory):
        os.makedirs(ckpt_directory)