aliabd commited on
Commit
c6e7238
1 Parent(s): c528e7b

full working demo

Browse files
.idea/gpt-neo.iml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="inheritedJdk" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ </module>
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/gpt-neo.iml" filepath="$PROJECT_DIR$/.idea/gpt-neo.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
CODEOWNERS ADDED
@@ -0,0 +1 @@
 
1
+ * EleutherAI/pm-gptneo
Dockerfile ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM gcr.io/deeplearning-platform-release/tf-cpu.1-15
2
+
3
+ WORKDIR /neogpt
4
+
5
+ # Make RUN commands use `bash --login`:
6
+ SHELL ["/bin/bash", "--login", "-c"]
7
+ ENV DEBIAN_FRONTEND=noninteractive
8
+ RUN apt-get update -y && apt-get install tmux -y
9
+ RUN conda install gcc_linux-64 gxx_linux-64 -y
10
+ ADD requirements.txt .
11
+ RUN pip install -r requirements.txt
12
+ RUN apt-get install screen htop -y
13
+ RUN python -m pip install tensorboard==1.15 cloud_tpu_profiler==1.15
14
+
15
+ CMD tmux
GPTNeo_example_notebook.ipynb ADDED
The diff for this file is too large to render. See raw diff
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2020 EleutherAI
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
app.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ title = "GPT-Neo Demo"
4
+ description = "demo for GPT-Neo by EleutherAI for text generation. To use it, simply add your text, or click one of the examples to load them. Read more at the links below."
5
+ article = "<p style='text-align: center'><a href='http://github.com/eleutherai/gpt-neo'>GPT-Neo: Large Scale Autoregressive Language Modeling with Mesh-Tensorflow</a></p>"
6
+ examples = [
7
+ ['The tower is 324 metres (1,063 ft) tall,'],
8
+ ["The Moon's orbit around Earth has"],
9
+ ["The smooth Borealis basin in the Northern Hemisphere covers 40%"]
10
+ ]
11
+
12
+ gr.Interface.load("huggingface/EleutherAI/gpt-neo-2.7B", inputs=gr.inputs.Textbox(lines=5, label="Input Text"),title=title,description=description,article=article, examples=examples).launch()
configs.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ from collections import defaultdict
4
+
5
+ DATASETS = {}
6
+
7
+ for path in Path("configs/dataset_configs").glob("*.json"):
8
+ dataset_id = path.stem
9
+ DATASETS[dataset_id] = json.loads(path.read_text())
10
+
11
+
12
+ def fetch_model_params(model):
13
+ model_path = model if model.endswith(".json") else f"configs/{model}.json"
14
+ with open(model_path) as f:
15
+ params = json.load(f)
16
+
17
+ dataset_ids = []
18
+ for d in params.get("datasets"):
19
+ if isinstance(d, list):
20
+ dataset_ids.append(d[0])
21
+ else:
22
+ dataset_ids.append(d)
23
+ no_datasets = params.get("no_dataset", False)
24
+ assert no_datasets or len(dataset_ids) > 0, "You must specify at least one dataset id in the model config"
25
+
26
+ datasets = {}
27
+ last_dataset = None
28
+ for dataset_id in dataset_ids:
29
+ assert dataset_id in DATASETS, f"Dataset '{dataset_id}' was not found under dataset_configs/ folder. Please follow the example.json in that folder."
30
+ dataset = DATASETS[dataset_id]
31
+ assert params["n_vocab"] >= dataset["n_vocab"], f"The embedding table size '{params['n_vocab']}' must be greater or equal to the vocab size used to encode the dataset '{dataset_id}' ({dataset['n_vocab']})"
32
+ datasets[dataset_id] = dataset
33
+ last_dataset = dataset
34
+
35
+ if last_dataset is not None:
36
+ params["padding_id"] = last_dataset.get("padding_id", 0)
37
+ params["eos_id"] = last_dataset.get("eos_id", 1)
38
+
39
+ params["dataset_configs"] = datasets
40
+
41
+ # Set some other parameter defaults
42
+ params["mlm_training"] = params.get("mlm_training") == True
43
+ params["causal"] = not params["mlm_training"]
44
+
45
+ # Set all other parameter values to default to None
46
+ params = defaultdict(lambda: None, params)
47
+ return params
configs/dataset_configs/example.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "n_vocab": 32768,
3
+ "path": "./tfrecords/openwebtext_*.tfrecords",
4
+ "eval_path": "",
5
+ "tokenizer_path": "./datasets/openwebtext/byte-level-bpe.tokenizer.json",
6
+ "eos_id": 1,
7
+ "padding_id": 0
8
+ }
configs/dataset_configs/openwebtext2_new_inputs.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "n_vocab": 50257,
3
+ "path": "gs://neo-datasets/openwebtext2_new_inputs/train/*.tfrecords",
4
+ "eval_path": "gs://neo-datasets/openwebtext2_new_inputs/eval/*.tfrecords",
5
+ "tokenizer_is_pretrained": true,
6
+ "tokenizer_path": "gpt2",
7
+ "eos_id": 50256,
8
+ "padding_id": 50257
9
+ }
configs/dataset_configs/pile.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "n_vocab": 50257,
3
+ "path": "gs://neo-datasets/pile/pile_*.tfrecords",
4
+ "eval_path": "gs://neo-datasets/pile_val.tfrecords",
5
+ "tokenizer_is_pretrained": true,
6
+ "tokenizer_path": "gpt2",
7
+ "eos_id": 50256,
8
+ "padding_id": 50257
9
+ }
configs/gpt2_small.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "n_head": 6,
3
+ "n_vocab": 50257,
4
+ "embed_dropout": 0.1,
5
+ "lr": 0.0006,
6
+ "lr_decay": "cosine",
7
+ "warmup_steps": 3000,
8
+ "beta1": 0.9,
9
+ "beta2": 0.95,
10
+ "epsilon": 1e-8,
11
+ "opt_name": "adam",
12
+ "weight_decay": 0,
13
+ "train_batch_size": 512,
14
+ "attn_dropout": 0.1,
15
+ "train_steps": 1000000,
16
+ "lr_decay_end": 300000,
17
+ "eval_steps": 30,
18
+ "predict_steps": 0,
19
+ "res_dropout": 0.1,
20
+ "eval_batch_size": 128,
21
+ "predict_batch_size": 8,
22
+ "iterations": 2500,
23
+ "n_embd": 768,
24
+ "datasets": ["openwebtext2_new_inputs"],
25
+ "model_path": "gs://neo-models/GPT2_SMALL",
26
+ "n_ctx": 1024,
27
+ "n_layer": 12,
28
+ "scale_by_depth": true,
29
+ "scale_by_in": false,
30
+ "attention_types" : [[["global"],12]],
31
+ "activation_function": "gelu",
32
+ "mesh_shape": "all:64",
33
+ "layout": "batch:all",
34
+ "recompute_grad": false,
35
+ "gradient_clipping": 1.0
36
+ }
configs/gpt3_13B_256.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "n_head": 40,
3
+ "n_vocab": 50257,
4
+ "embed_dropout": 0,
5
+ "lr": 0.0001,
6
+ "lr_decay": "cosine",
7
+ "warmup_steps": 3000,
8
+ "beta1": 0.9,
9
+ "beta2": 0.95,
10
+ "epsilon": 1e-8,
11
+ "ada_epsilon1": 1e-30,
12
+ "ada_epsilon2": 1e-3,
13
+ "opt_name": "adam",
14
+ "weight_decay": 0.10,
15
+ "train_batch_size": 1024,
16
+ "attn_dropout": 0,
17
+ "train_steps": 143075,
18
+ "eval_steps": 0,
19
+ "predict_steps": 1,
20
+ "res_dropout": 0,
21
+ "eval_batch_size": 128,
22
+ "predict_batch_size": 1,
23
+ "iterations": 500,
24
+ "n_embd": 5120,
25
+ "datasets": [["openwebtext-documents", 25, "documents_random", 1.0]],
26
+ "model_path": "gs://neo-models/GPT3_13B",
27
+ "n_ctx": 2048,
28
+ "n_layer": 40,
29
+ "scale_by_depth": true,
30
+ "scale_by_in": false,
31
+ "attention_types" : [[["global", "local"],20]],
32
+ "mesh_shape": "x:16,y:16",
33
+ "layout": "batch:x,embd:y,memory_length:y",
34
+ "activation_function": "gelu",
35
+ "recompute_grad": true,
36
+ "gradient_clipping": 1.0,
37
+ "tokens_per_mb_per_replica": 2048,
38
+ "precision": "bfloat16"
39
+ }
40
+
configs/gpt3_13B_256_Pile.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ {
3
+ "n_head": 40,
4
+ "n_vocab": 50257,
5
+ "embed_dropout": 0,
6
+ "lr": 0.0001,
7
+ "lr_decay": "cosine",
8
+ "warmup_steps": 3000,
9
+ "beta1": 0.9,
10
+ "beta2": 0.95,
11
+ "epsilon": 1e-8,
12
+ "opt_name": "adam",
13
+ "weight_decay": 0.1,
14
+ "train_batch_size": 1024,
15
+ "attn_dropout": 0,
16
+ "train_steps": 286150,
17
+ "eval_steps": 10,
18
+ "predict_steps": 1,
19
+ "res_dropout": 0,
20
+ "eval_batch_size": 512,
21
+ "predict_batch_size": 1,
22
+ "iterations": 500,
23
+ "n_embd": 5120,
24
+ "datasets": [["pile", 25, "documents_random", 1.0]],
25
+ "model_path": "gs://neo-models/GPT3_13B_Pile",
26
+ "n_ctx": 2048,
27
+ "n_layer": 40,
28
+ "scale_by_depth": true,
29
+ "scale_by_in": false,
30
+ "attention_types" : [[["global"],40]],
31
+ "mesh_shape": "x:16,y:16",
32
+ "layout": "batch:x,memory_length:y,embd:y",
33
+ "activation_function": "gelu",
34
+ "recompute_grad": true,
35
+ "gradient_clipping": 1.0,
36
+ "tokens_per_mb_per_replica": 2048,
37
+ "precision": "bfloat16"
38
+ }
configs/gpt3_2-7B_256.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "n_head": 32,
3
+ "n_vocab": 50257,
4
+ "embed_dropout": 0,
5
+ "lr": 0.00016,
6
+ "lr_decay": "cosine",
7
+ "warmup_steps": 3000,
8
+ "beta1": 0.9,
9
+ "beta2": 0.95,
10
+ "epsilon": 1e-8,
11
+ "ada_epsilon1": 1e-30,
12
+ "ada_epsilon2": 1e-3,
13
+ "opt_name": "adam",
14
+ "weight_decay": 0.10,
15
+ "train_batch_size": 512,
16
+ "attn_dropout": 0,
17
+ "train_steps": 286150,
18
+ "eval_steps": 0,
19
+ "predict_steps": 1,
20
+ "res_dropout": 0,
21
+ "eval_batch_size": 128,
22
+ "predict_batch_size": 1,
23
+ "iterations": 500,
24
+ "n_embd": 2560,
25
+ "datasets": [["openwebtext-documents", 25, "documents_random", 1.0]],
26
+ "model_path": "gs://neo-models/GPT3_2-7B",
27
+ "n_ctx": 2048,
28
+ "n_layer": 32,
29
+ "scale_by_depth": true,
30
+ "scale_by_in": false,
31
+ "attention_types" : [[["global"],32]],
32
+ "mesh_shape": "x:128,y:2",
33
+ "layout": "embd:y,batch:x",
34
+ "activation_function": "gelu",
35
+ "recompute_grad": true,
36
+ "gradient_clipping": 1.0
37
+ }
38
+
configs/gpt3_6-7B_256.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "n_head": 32,
3
+ "n_vocab": 50257,
4
+ "embed_dropout": 0,
5
+ "lr": 0.00012,
6
+ "lr_decay": "cosine",
7
+ "warmup_steps": 3000,
8
+ "beta1": 0.9,
9
+ "beta2": 0.95,
10
+ "epsilon": 1e-8,
11
+ "opt_name": "adam",
12
+ "weight_decay": 0.10,
13
+ "train_batch_size": 1024,
14
+ "attn_dropout": 0,
15
+ "train_steps": 143075,
16
+ "eval_steps": 0,
17
+ "predict_steps": 1,
18
+ "res_dropout": 0,
19
+ "eval_batch_size": 128,
20
+ "predict_batch_size": 1,
21
+ "iterations": 500,
22
+ "n_embd": 4096,
23
+ "datasets": [["openwebtext-documents", 25, "documents_random", 1.0]],
24
+ "model_path": "gs://neo-models/GPT3_6-7B",
25
+ "n_ctx": 2048,
26
+ "n_layer": 32,
27
+ "scale_by_depth": true,
28
+ "scale_by_in": false,
29
+ "attention_types" : [[["global"],32]],
30
+ "mesh_shape": "x:128,y:2",
31
+ "layout": "embd:y,batch:x",
32
+ "activation_function": "gelu",
33
+ "recompute_grad": true,
34
+ "gradient_clipping": 1.0
35
+ }
36
+
configs/gpt3_PAR_small_256.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "n_head": 12,
3
+ "n_vocab": 50304,
4
+ "embed_dropout": 0,
5
+ "lr": 0.0006,
6
+ "lr_decay": "cosine",
7
+ "warmup_steps": 3000,
8
+ "beta1": 0.9,
9
+ "beta2": 0.95,
10
+ "epsilon": 1e-8,
11
+ "opt_name": "adam",
12
+ "weight_decay": 0.10,
13
+ "train_batch_size": 256,
14
+ "attn_dropout": 0,
15
+ "train_steps": 572300,
16
+ "eval_steps": 0,
17
+ "predict_steps": 1,
18
+ "res_dropout": 0,
19
+ "eval_batch_size": 64,
20
+ "predict_batch_size": 1,
21
+ "iterations": 1000,
22
+ "n_embd": 768,
23
+ "datasets": [["openwebtext-documents", 25, "documents_random", 1.0]],
24
+ "model_path": "gs://neo-models/GPT3_PAR_SMALL",
25
+ "n_ctx": 2048,
26
+ "n_layer": 19,
27
+ "scale_by_depth": true,
28
+ "scale_by_in": false,
29
+ "attention_types": [[["global", "none", "none"],5], [["none"], 4]],
30
+ "mesh_shape": "x:64,y:4",
31
+ "layout": "batch:x,heads:y,vocab:y,intermediate_expanded:y",
32
+ "activation_function": "gelu",
33
+ "recompute_grad": false,
34
+ "gradient_clipping": 1.0
35
+ }
36
+
configs/gpt3_XL_256_Pile.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "n_head": 32,
3
+ "n_vocab": 50257,
4
+ "embed_dropout": 0,
5
+ "lr": 0.0002,
6
+ "lr_decay": "cosine",
7
+ "warmup_steps": 3000,
8
+ "beta1": 0.9,
9
+ "beta2": 0.95,
10
+ "epsilon": 1e-8,
11
+ "opt_name": "adam",
12
+ "weight_decay": 0.1,
13
+ "train_batch_size": 512,
14
+ "attn_dropout": 0,
15
+ "train_steps": 286150,
16
+ "eval_steps": 10,
17
+ "predict_steps": 1,
18
+ "res_dropout": 0,
19
+ "eval_batch_size": 512,
20
+ "predict_batch_size": 1,
21
+ "iterations": 500,
22
+ "n_embd": 2048,
23
+ "datasets": [["pile", 25, "documents_random", 1.0]],
24
+ "model_path": "gs://neo-models/GPT3_XL_Pile",
25
+ "n_ctx": 2048,
26
+ "n_layer": 24,
27
+ "scale_by_depth": true,
28
+ "scale_by_in": false,
29
+ "attention_types" : [[["global"],24]],
30
+ "mesh_shape": "x:128,y:2",
31
+ "layout": "batch:x,memory_length:y,embd:y",
32
+ "activation_function": "gelu",
33
+ "recompute_grad": true,
34
+ "gradient_clipping": 1.0,
35
+ "tokens_per_mb_per_replica": 2048,
36
+ "precision": "bfloat16"
37
+ }
configs/gpt3_large_256.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "n_head": 16,
3
+ "n_vocab": 50304,
4
+ "embed_dropout": 0,
5
+ "lr": 0.00025,
6
+ "lr_decay": "cosine",
7
+ "warmup_steps": 3000,
8
+ "beta1": 0.9,
9
+ "beta2": 0.95,
10
+ "epsilon": 1e-8,
11
+ "ada_epsilon1": 1e-30,
12
+ "ada_epsilon2": 1e-3,
13
+ "opt_name": "adam",
14
+ "weight_decay": 0.10,
15
+ "train_batch_size": 256,
16
+ "attn_dropout": 0,
17
+ "train_steps": 572300,
18
+ "eval_steps": 0,
19
+ "predict_steps": 1,
20
+ "res_dropout": 0,
21
+ "eval_batch_size": 64,
22
+ "predict_batch_size": 1,
23
+ "iterations": 2500,
24
+ "n_embd": 1536,
25
+ "datasets": [["openwebtext-documents", 25, "documents_random", 1.0]],
26
+ "model_path": "gs://neo-models/GPT3_LARGE",
27
+ "n_ctx": 2048,
28
+ "n_layer": 24,
29
+ "scale_by_depth": true,
30
+ "scale_by_in": false,
31
+ "attention_types" : [[["global"],24]],
32
+ "mesh_shape": "x:64,y:4",
33
+ "layout": "batch:x,vocab:y,heads:y",
34
+ "activation_function": "gelu",
35
+ "recompute_grad": true,
36
+ "gradient_clipping": 1.0,
37
+ "tokens_per_mb_per_replica": 2048
38
+ }
39
+
configs/gpt3_medium_256.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "n_head": 16,
3
+ "n_vocab": 50304,
4
+ "embed_dropout": 0,
5
+ "lr": 0.0003,
6
+ "lr_decay": "cosine",
7
+ "warmup_steps": 3000,
8
+ "beta1": 0.9,
9
+ "beta2": 0.95,
10
+ "epsilon": 1e-8,
11
+ "opt_name": "adam",
12
+ "weight_decay": 0.10,
13
+ "train_batch_size": 256,
14
+ "attn_dropout": 0,
15
+ "train_steps": 572300,
16
+ "eval_steps": 0,
17
+ "predict_steps": 1,
18
+ "res_dropout": 0,
19
+ "eval_batch_size": 64,
20
+ "predict_batch_size": 1,
21
+ "iterations": 2500,
22
+ "n_embd": 1024,
23
+ "datasets": [["openwebtext-documents", 25, "documents_random", 1.0]],
24
+ "model_path": "gs://neo-models/GPT3_MEDIUM",
25
+ "n_ctx": 2048,
26
+ "n_layer": 24,
27
+ "scale_by_depth": true,
28
+ "scale_by_in": false,
29
+ "attention_types" : [[["global"],24]],
30
+ "mesh_shape": "x:64,y:4",
31
+ "layout": "batch:x,heads:y,vocab:y",
32
+ "activation_function": "gelu",
33
+ "recompute_grad": false,
34
+ "gradient_clipping": 1.0
35
+ }
36
+
configs/gpt3_small_256.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "n_head": 12,
3
+ "n_vocab": 50304,
4
+ "embed_dropout": 0,
5
+ "lr": 0.0006,
6
+ "lr_decay": "cosine",
7
+ "warmup_steps": 3000,
8
+ "beta1": 0.9,
9
+ "beta2": 0.95,
10
+ "epsilon": 1e-8,
11
+ "opt_name": "adam",
12
+ "weight_decay": 0.10,
13
+ "train_batch_size": 256,
14
+ "attn_dropout": 0,
15
+ "train_steps": 572300,
16
+ "eval_steps": 0,
17
+ "predict_steps": 1,
18
+ "res_dropout": 0,
19
+ "eval_batch_size": 64,
20
+ "predict_batch_size": 1,
21
+ "iterations": 2500,
22
+ "n_embd": 768,
23
+ "datasets": [["openwebtext-documents", 25, "documents_random", 1.0]],
24
+ "model_path": "gs://neo-models/GPT3_SMALL",
25
+ "n_ctx": 2048,
26
+ "n_layer": 12,
27
+ "scale_by_depth": true,
28
+ "scale_by_in": false,
29
+ "attention_types": [[["global"],12]],
30
+ "mesh_shape": "x:64,y:4",
31
+ "layout": "batch:x,heads:y,vocab:y,intermediate_expanded:y",
32
+ "activation_function": "gelu",
33
+ "recompute_grad": false,
34
+ "gradient_clipping": 1.0
35
+ }
36
+
data/create_tfrecords.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from pathlib import Path
4
+
5
+ import ftfy
6
+ import tensorflow as tf
7
+ from lm_dataformat import Reader
8
+ from tokenizers import Tokenizer
9
+ from transformers import GPT2TokenizerFast
10
+ from tqdm import tqdm
11
+ import logging
12
+ from multiprocessing import Pool, cpu_count
13
+ from itertools import repeat
14
+ import re
15
+
16
+ logging.getLogger("transformers").setLevel(logging.ERROR)
17
+
18
+ parser = argparse.ArgumentParser()
19
+ parser.add_argument("--input_dir", type=str, help="Path to where your files are located. Files ending in .zst are "
20
+ "treated as archives, all others as raw text.")
21
+ parser.add_argument("--files_per", type=int, default=100000, help="Text files per tfrecord")
22
+ parser.add_argument("--name", type=str, default="openwebtext",
23
+ help="Name of output files will be name_i.tfrecords where i is the number of the file")
24
+ parser.add_argument("--output_dir", type=str, default="./tfrecords", help="Where to put tfrecords")
25
+ parser.add_argument("--encoder_path", type=str,
26
+ help="Path to encoder files, or leave unspecified to use GPT2 tokenizer")
27
+ parser.add_argument("--minimum_size", type=int, default=100, help="Minimum size a document has to be to be included")
28
+ parser.add_argument("--ftfy", action="store_false", help="normalize with ftfy")
29
+ parser.add_argument("--wikitext-detokenize", action="store_false", help="use wikitext detokenizer")
30
+ parser.add_argument("--separator", nargs="+", type=int, default=[50256],
31
+ help="separator to place between files in chunk mode")
32
+ parser.add_argument("--chunk_size", type=int, default=2048, help="How big a chunk should be in chunk mode. "
33
+ "Should equal your model's context size")
34
+ parser.add_argument("--write_dataset_config", action="store_true", help="Write the dataset config file on completion")
35
+ parser.add_argument("--processes", type=int, default=0, help="Number of processes to use. Defaults to cpu count.")
36
+
37
+ args = parser.parse_args()
38
+ if not args.output_dir.endswith("/"):
39
+ args.output_dir = args.output_dir + "/"
40
+ if not args.input_dir.endswith("/"):
41
+ args.input_dir = args.input_dir + "/"
42
+ assert len(args.separator) == 1
43
+
44
+
45
+ def wikitext_detokenizer(string):
46
+ # contractions
47
+ string = string.replace("s '", "s'")
48
+ string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
49
+ # number separators
50
+ string = string.replace(" @-@ ", "-")
51
+ string = string.replace(" @,@ ", ",")
52
+ string = string.replace(" @.@ ", ".")
53
+ # punctuation
54
+ string = string.replace(" : ", ": ")
55
+ string = string.replace(" ; ", "; ")
56
+ string = string.replace(" . ", ". ")
57
+ string = string.replace(" ! ", "! ")
58
+ string = string.replace(" ? ", "? ")
59
+ string = string.replace(" , ", ", ")
60
+ # double brackets
61
+ string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string)
62
+ string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string)
63
+ string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string)
64
+ string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string)
65
+ string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string)
66
+ # miscellaneous
67
+ string = string.replace("= = = =", "====")
68
+ string = string.replace("= = =", "===")
69
+ string = string.replace("= =", "==")
70
+ string = string.replace(" " + chr(176) + " ", chr(176))
71
+ string = string.replace(" \n", "\n")
72
+ string = string.replace("\n ", "\n")
73
+ string = string.replace(" N ", " 1 ")
74
+ string = string.replace(" 's", "'s")
75
+
76
+ return string
77
+
78
+
79
+ def _int64_feature(value):
80
+ """
81
+ Returns an int64_list from a bool / enum / int / uint.
82
+ """
83
+ return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
84
+
85
+
86
+ def write_to_file(writer, data):
87
+ """
88
+ writes data to tfrecord file
89
+ """
90
+ feature = {
91
+ "text": _int64_feature(data)
92
+ }
93
+ tf_example = tf.train.Example(features=tf.train.Features(feature=feature))
94
+ writer.write(tf_example.SerializeToString())
95
+
96
+
97
+ def get_tokenizer(args):
98
+ if args.encoder_path is None:
99
+ return GPT2TokenizerFast.from_pretrained('gpt2')
100
+ else:
101
+ return Tokenizer.from_file(args.encoder_path)
102
+
103
+
104
+ def split_list(l, n):
105
+ # splits list/string into n size chunks
106
+ return [l[i:i + n] for i in range(0, len(l), n)]
107
+
108
+
109
+ def archive_to_tokens(f, encoder, args):
110
+ # Generator that yields the contents of the files in an archive
111
+ # if data_to_prepend is not None, prepend data_to_prepend + a EOS separator to the encoded data
112
+ reader = Reader(f)
113
+ for doc in reader.stream_data(threaded=False):
114
+ if args.ftfy: # fix text with ftfy if specified
115
+ doc = ftfy.fix_text(doc, normalization='NFKC')
116
+ if args.wikitext_detokenize:
117
+ doc = wikitext_detokenizer(doc)
118
+ doc = encoder.encode(doc) + args.separator # read document from lmd and append separator token
119
+ yield split_list(doc, args.chunk_size) # split into n_ctx + 1 size chunks
120
+
121
+
122
+ def write_files(files, files_per, output_dir, out_name, start_no, write_remainder=False, process_no=None):
123
+ # writes a list of files to .tfrecords
124
+ if files == None:
125
+ return
126
+ chunks = split_list(files, files_per)
127
+
128
+ if len(chunks[-1]) != files_per and not write_remainder: # pop the last file if it's length != files per
129
+ remainder = chunks.pop(-1)
130
+ else:
131
+ remainder = None # assuming files = remainder from an old chunk here
132
+ files_per = len(chunks[-1])
133
+
134
+ for files in chunks:
135
+ fp = f"{output_dir}/{out_name}_{start_no}"
136
+ if process_no is not None:
137
+ fp += f"_{process_no}"
138
+ fp += f"_{files_per}" # add number of files in tfrecord to end of fp
139
+ fp += ".tfrecords"
140
+ with tf.io.TFRecordWriter(fp) as writer:
141
+ for f in files:
142
+ write_to_file(writer, f)
143
+ start_no += 1
144
+ return start_no, remainder
145
+
146
+
147
+ def get_files(input_dir, filetypes=None):
148
+ # gets all files of <filetypes> in input_dir
149
+ if filetypes == None:
150
+ filetypes = ["jsonl.zst", ".txt", ".xz", ".tar.gz"]
151
+ files = [list(Path(input_dir).glob(f"*{ft}")) for ft in filetypes]
152
+ return [str(item) for sublist in files for item in sublist] # flatten list of list -> list and stringify Paths
153
+
154
+
155
+ def read_checkpoint(checkpoint_path, resume_from_checkpoint=True):
156
+ # init checkpointing
157
+ if resume_from_checkpoint and os.path.isfile(checkpoint_path):
158
+ try:
159
+ resume_files_processed, tfrecord_count = [int(i) for i in open(checkpoint_path, "r").read().split(", ")]
160
+ print(f"\nResuming from tfrecord no. {tfrecord_count} / file no. {resume_files_processed}")
161
+ return resume_files_processed, tfrecord_count
162
+ except:
163
+ pass
164
+ return 0, 0
165
+
166
+
167
+ def create_tfrecords(params, write_remainder=True, write_every_n_files=1, save_checkpoints=False,
168
+ resume_from_checkpoint=False, display_pbar=False):
169
+ # iterates through files in input_dir, splitting into <args.chunk_size> chunks and saving a tfrecords file every <args.files_per> chunks.
170
+ files, args, process_no = params
171
+ enc = get_tokenizer(args) # get tokenizer
172
+
173
+ # init metadata
174
+ discarded_files = 0
175
+ files_processed = 0
176
+ pbar = tqdm(desc=f"Writing TFRecord Files to {args.output_dir}. Parsed 0 input files. files_written ",
177
+ disable=not display_pbar)
178
+ checkpoint_path = f"{args.output_dir}/checkpoint.txt"
179
+ resume_files_processed, tfrecord_count = read_checkpoint(checkpoint_path, resume_from_checkpoint)
180
+
181
+ data_to_prepend = []
182
+ tokenized_files_array = []
183
+
184
+ for f in files:
185
+ for tokenized_files in archive_to_tokens(f, enc, args):
186
+ files_processed += 1
187
+ if files_processed < resume_files_processed:
188
+ continue # resume from checkpoint
189
+
190
+ # if the last chunk < chunk size, but > minimum_size, take it and append it to the beginning of the next file
191
+ n_tokens = len(tokenized_files[-1])
192
+ if n_tokens < args.chunk_size:
193
+ data = tokenized_files.pop(-1)
194
+ if n_tokens >= args.minimum_size:
195
+ data_to_prepend.extend(data)
196
+ else:
197
+ discarded_files += 1
198
+
199
+ if len(data_to_prepend) >= args.chunk_size:
200
+ # if length of data_to_prepend becomes greater than chunk size, add concatted files to tokenized files
201
+ tokenized_files_array.append(data_to_prepend[:args.chunk_size])
202
+ data_to_prepend = data_to_prepend[args.chunk_size:]
203
+ # add tokenized files > chunk size to main array
204
+ tokenized_files_array.extend(tokenized_files)
205
+
206
+ if len(tokenized_files_array) >= args.files_per * write_every_n_files: # write every n files
207
+ _tfrecord_count, remainder = write_files(tokenized_files_array, files_per=args.files_per,
208
+ output_dir=args.output_dir, out_name=args.name,
209
+ start_no=tfrecord_count, process_no=process_no)
210
+ pbar.update(_tfrecord_count - tfrecord_count) # update progress bar
211
+ pbar.set_description(
212
+ f"Writing TFRecord Files to {args.output_dir}. Parsed {files_processed} input files. files_written ")
213
+ tfrecord_count = _tfrecord_count
214
+ tokenized_files_array = remainder if remainder is not None else [] # add remaining files to next chunk
215
+ with open(checkpoint_path, "w") as checkpoint_file:
216
+ checkpoint_file.write(f"{files_processed}, {tfrecord_count}")
217
+
218
+ if len(tokenized_files_array) >= args.files_per: # also write at end
219
+ _tfrecord_count, remainder = write_files(tokenized_files_array, files_per=args.files_per,
220
+ output_dir=args.output_dir, out_name=args.name,
221
+ start_no=tfrecord_count, process_no=process_no)
222
+ pbar.update(_tfrecord_count - tfrecord_count)
223
+ pbar.set_description(
224
+ f"Writing TFRecord Files to {args.output_dir}. Parsed {files_processed} input files. files_written ")
225
+ tfrecord_count = _tfrecord_count
226
+ with open(checkpoint_path, "w") as checkpoint_file:
227
+ checkpoint_file.write(f"{files_processed}, {tfrecord_count}")
228
+ else:
229
+ remainder = tokenized_files_array # add remaining to remainder
230
+
231
+ if write_remainder:
232
+ # write out the remaining files even if there's less than files_per
233
+ write_files(remainder, files_per=args.files_per, output_dir=args.output_dir, out_name=args.name,
234
+ start_no=tfrecord_count, write_remainder=True)
235
+
236
+ successful_files = files_processed - discarded_files
237
+ return {"discarded": discarded_files, "processed": files_processed, "successful": successful_files}
238
+
239
+
240
+ def create_tfrecords_mp(files, args):
241
+ files = split_list(files, len(files) // args.processes)
242
+ with Pool(processes=args.processes) as pool:
243
+ pbar = tqdm(pool.imap(create_tfrecords, zip(files, repeat(args), range(len(files)))))
244
+ meta = {"discarded": 0, "processed": 0, "successful": 0}
245
+ for results in pbar:
246
+ pbar.update()
247
+ for k, v in results.items():
248
+ meta[k] += v # update metadata
249
+ return meta
250
+
251
+
252
+ if __name__ == "__main__":
253
+ os.makedirs(args.output_dir, exist_ok=True) # make output dir if it doesn't exist
254
+ files = get_files(args.input_dir)
255
+ args.chunk_size += 1 # we shift the data by 1 to the right for targets, so increment the chunk size here
256
+
257
+ if args.processes == 0:
258
+ args.processes = cpu_count()
259
+ if args.processes > 1:
260
+ results = create_tfrecords_mp(files, args)
261
+ else:
262
+ results = create_tfrecords((files, args, 0), display_pbar=True)
263
+ print(results)
data/encoders.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tokenizers import Tokenizer
2
+ from transformers import GPT2Tokenizer, GPT2TokenizerFast
3
+
4
+ def fetch_encoder(params):
5
+ no_dataset = params.get('no_dataset', False)
6
+ if no_dataset:
7
+ return None
8
+
9
+ dataset = next(iter(params['dataset_configs'].values())) # Get the first value from the dict
10
+ path = dataset["tokenizer_path"]
11
+ is_pretrained = dataset.get("tokenizer_is_pretrained", False)
12
+
13
+ if is_pretrained:
14
+ tok = GPT2TokenizerFast.from_pretrained(path)
15
+
16
+ # Will add a padding token id of 50257 at run-time
17
+ tok.add_special_tokens({'pad_token': '<|padding|>'})
18
+ return tok
19
+
20
+ return Tokenizer.from_file(path)
21
+
22
+
23
+ # GPT2Tokenizer and Tokenizer have different ways of fetching token ids
24
+ def encode(encoder, text):
25
+ result = encoder.encode(text)
26
+ if isinstance(result, list):
27
+ return result
28
+ return result.ids
data/train_tokenizer.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import argparse
4
+ import shutil
5
+ from glob import glob
6
+ from pathlib import Path
7
+
8
+ from lm_dataformat import Reader
9
+ from tokenizers import (Tokenizer, decoders, models, pre_tokenizers,
10
+ processors, trainers)
11
+ from tokenizers.normalizers import NFKC
12
+ from tqdm import tqdm
13
+
14
+ # parser
15
+
16
+ parser = argparse.ArgumentParser()
17
+ parser.add_argument("--base_dir", type=str, help="Path to where your files are located. Files ending in .zst are treated as \
18
+ archives, all others as raw text.")
19
+ parser.add_argument("--output_dir", type=str, default="tokenizers", help="Where to put the tokenizer")
20
+ parser.add_argument("--file_type", type=str, choices=["xz", "txt"], default="xz", help="Extension of file to parse")
21
+ parser.add_argument("--vocab_size", type=int, help="Size of vocabulary", required = True)
22
+ args = parser.parse_args()
23
+
24
+ # main script
25
+
26
+ data_path = Path(args.base_dir)
27
+ archives = glob(str(data_path / f"*.{args.file_type}"))
28
+
29
+ out_path = Path(args.output_dir)
30
+
31
+ if os.path.exists(out_path):
32
+ shutil.rmtree(out_path)
33
+
34
+ if not out_path.is_dir():
35
+ out_path.mkdir()
36
+
37
+ for arch in tqdm(archives):
38
+ name = os.path.basename(arch).split(".")[0] + ".txt"
39
+ fp = out_path / name
40
+
41
+ if args.file_type == 'xz':
42
+ g = Reader(arch).stream_data()
43
+
44
+ with open(fp, "w") as f:
45
+ for s in g:
46
+ f.write(s)
47
+ f.write("\n\n")
48
+ elif args.file_type == 'txt':
49
+ shutil.copyfile(str(arch), str(fp))
50
+
51
+ data_files = glob(str(out_path / "*.txt"))
52
+ data_files = random.sample(data_files, int(0.2 * len(data_files)))
53
+
54
+ assert len(data_files) > 0, 'No data files found'
55
+
56
+ # Initialize a tokenizer
57
+ tokenizer = Tokenizer(models.BPE())
58
+
59
+ # Customize pre-tokenization and decoding
60
+ tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=True)
61
+ tokenizer.decoder = decoders.ByteLevel()
62
+ tokenizer.post_processor = processors.ByteLevel(trim_offsets=True)
63
+ tokenizer.normalizer = NFKC()
64
+
65
+ # And then train
66
+ trainer = trainers.BpeTrainer(vocab_size=args.vocab_size, min_frequency=2, special_tokens=["<|endoftext|>", "<|padding|>"])
67
+ tokenizer.train(trainer, data_files)
68
+
69
+ # And Save it
70
+ tokenizer_path = out_path / "byte-level-bpe.tokenizer.json"
71
+ tokenizer.save(str(tokenizer_path), pretty=True)
72
+
73
+ print(f'tokenizer saved at {str(tokenizer_path)}')
docker-compose.yml ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: '3'
2
+ services:
3
+
4
+ mongo:
5
+ image: mongo
6
+ ports:
7
+ - 127.0.0.1:27017:27017
8
+ environment:
9
+ MONGO_INITDB_ROOT_USERNAME: user
10
+ MONGO_INITDB_ROOT_PASSWORD: password
11
+ MONGO_INITDB_DATABASE: db
12
+ expose:
13
+ - 27017
14
+ networks:
15
+ - omniboard
16
+ volumes:
17
+ - ./data:/data/db
18
+
19
+ mongoClientTemp:
20
+ image: mongo:latest
21
+ container_name: mongoClientTemp
22
+ links:
23
+ - mongo:mongo
24
+ command: mongo --host mongo -u user -p password --eval "db.getSiblingDB('db').createUser({user:'readonly', pwd:'password', roles:[{role:'read',db:'db'}]});"
25
+ depends_on:
26
+ - mongo
27
+ networks:
28
+ - omniboard
29
+
30
+ omniboard_readonly:
31
+ #image: vivekratnavel/omniboard:latest
32
+ build: https://github.com/lucidrains/omniboard.git
33
+ command: ["--mu", "mongodb://readonly:password@mongo:27017/db"]
34
+ ports:
35
+ - 0.0.0.0:8081:9000
36
+ networks:
37
+ - omniboard
38
+ depends_on:
39
+ - mongo
40
+
41
+ omniboard:
42
+ #image: vivekratnavel/omniboard:latest
43
+ build: https://github.com/lucidrains/omniboard.git
44
+ command: ["--mu", "mongodb://user:password@mongo:27017/db?authSource=admin"]
45
+ expose:
46
+ - 9000
47
+ networks:
48
+ - omniboard
49
+ depends_on:
50
+ - mongo
51
+
52
+ nginx:
53
+ image: dhswt/nginx-basic-auth:1.3
54
+ environment:
55
+ - HTPASSWD=isaac: #put passwd here
56
+ - FORWARD_HOST=omniboard
57
+ - FORWARD_PORT=9000
58
+ networks:
59
+ - omniboard
60
+ depends_on:
61
+ - omniboard
62
+ ports:
63
+ - 0.0.0.0:8080:80
64
+ expose:
65
+ - 8080
66
+ networks:
67
+ omniboard:
encoders.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tokenizers import Tokenizer
2
+ from transformers import GPT2Tokenizer, GPT2TokenizerFast
3
+
4
+ def fetch_encoder(params):
5
+ no_dataset = params.get('no_dataset', False)
6
+ if no_dataset:
7
+ return None
8
+
9
+ dataset = next(iter(params['dataset_configs'].values())) # Get the first value from the dict
10
+ path = dataset["tokenizer_path"]
11
+ is_pretrained = dataset.get("tokenizer_is_pretrained", False)
12
+
13
+ if is_pretrained:
14
+ tok = GPT2TokenizerFast.from_pretrained(path)
15
+
16
+ # Will add a padding token id of 50257 at run-time
17
+ tok.add_special_tokens({'pad_token': '<|padding|>'})
18
+ return tok
19
+
20
+ return Tokenizer.from_file(path)
21
+
22
+
23
+ # GPT2Tokenizer and Tokenizer have different ways of fetching token ids
24
+ def encode(encoder, text, gpt=True):
25
+ result = encoder.encode(text)
26
+ if isinstance(result, list):
27
+ return result
28
+ return result.ids
export.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow.compat.v1 as tf
2
+
3
+ def export_model(estimator, export_dir, params,
4
+ checkpoint_path=None):
5
+
6
+
7
+ def serving_input_receiver_fn():
8
+ t = tf.placeholder(dtype=tf.int64,
9
+ shape=[1, params["n_ctx"]],
10
+ name='input_example_tensor')
11
+ return tf.estimator.export.ServingInputReceiver(t, t)
12
+
13
+ return estimator.export_saved_model(
14
+ export_dir, serving_input_receiver_fn, checkpoint_path=checkpoint_path)
gradio/demo.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ title = "GPT-Neo Demo"
4
+ description = "demo for GPT-Neo by EleutherAI for text generation. To use it, simply add your text, or click one of the examples to load them. Read more at the links below."
5
+ article = "<p style='text-align: center'><a href='http://github.com/eleutherai/gpt-neo'>GPT-Neo: Large Scale Autoregressive Language Modeling with Mesh-Tensorflow</a></p>"
6
+ examples = [
7
+ ['The tower is 324 metres (1,063 ft) tall,'],
8
+ ["The Moon's orbit around Earth has"],
9
+ ["The smooth Borealis basin in the Northern Hemisphere covers 40%"]
10
+ ]
11
+
12
+ gr.Interface.load("huggingface/EleutherAI/gpt-neo-2.7B", inputs=gr.inputs.Textbox(lines=5, label="Input Text"),title=title,description=description,article=article, examples=examples).launch()
inputs.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tensorflow.compat.v1 as tf
3
+ from functools import partial
4
+ from data.encoders import encode
5
+ import random
6
+ import re
7
+ import logging
8
+ from itertools import cycle
9
+ from utils import natural_sort
10
+
11
+
12
+ ### IN USE ###
13
+
14
+ def _get_number_of_documents(filename):
15
+ # extracts number of files from a filename formatted "<name>_<num_documents>.tfrecords."
16
+ # if no pattern is matched, returns None
17
+ match = re.search("_(\d{1,}).tfrecords$", filename)
18
+ return int(match.group(1)) if match is not None else match
19
+
20
+
21
+ def _get_number_of_documents_by_iteration(filename):
22
+ # extracts number of files from a tfrecord document in the event it doesn't have metadata in the filename
23
+ # this could be very slow.
24
+ logging.warning(
25
+ "inputs/sequential_input() found no metadata found in filename - iterating through first tfrecord to find global length")
26
+ count = 0
27
+ for item in tf.io.tf_record_iterator(filename):
28
+ count += 1
29
+ return count
30
+
31
+
32
+ def _get_skip_index(all_files, n_batches):
33
+ prev_cumsum = 0
34
+ cumsum = 0
35
+ global_n_documents = None
36
+ for count, f in cycle(enumerate(all_files)):
37
+ prev_cumsum = cumsum
38
+ if _get_number_of_documents(f) is not None:
39
+ cumsum += _get_number_of_documents(f)
40
+ elif global_n_documents is None:
41
+ global_n_documents = _get_number_of_documents_by_iteration(f)
42
+ cumsum += global_n_documents
43
+ else:
44
+ cumsum += global_n_documents
45
+ if cumsum == n_batches:
46
+ remainder = 0
47
+ skip_idx = count + 1
48
+ elif cumsum > n_batches:
49
+ remainder = n_batches - prev_cumsum
50
+ skip_idx = count
51
+ break
52
+ return skip_idx, remainder
53
+
54
+
55
+ def _parse_function(example_proto):
56
+ features = {
57
+ "text": tf.VarLenFeature(tf.int64)
58
+ }
59
+ parsed_features = tf.parse_single_example(example_proto, features)
60
+ return tf.sparse.to_dense(parsed_features["text"], parsed_features["text"].dense_shape[0])
61
+
62
+
63
+ def autoregressive_sample_text(params, x):
64
+ vals1 = x[:params["n_ctx"]]
65
+ vals2 = x[1:params["n_ctx"] + 1]
66
+
67
+ vals1 = tf.reshape(vals1, [params["n_ctx"]])
68
+ vals2 = tf.reshape(vals2, [params["n_ctx"]])
69
+ vals1 = tf.cast(vals1, dtype=tf.int32)
70
+ vals2 = tf.cast(vals2, dtype=tf.int32)
71
+ return vals1, vals2
72
+
73
+
74
+ def sequential_input(params, global_step=None, eval=False):
75
+ """
76
+ Input fn that reads tfrecords encoded with a fixed chunk size (== n_ctx + 1), and that either:
77
+
78
+ - has the number of documents for each tfrecord file encoded in the title in the format
79
+ <name>_<n_documents>.tfrecords.
80
+
81
+ OR
82
+
83
+ - has a fixed number of documents per tfrecord file.
84
+
85
+ If the glob pattern above isn't matched, we assume that each document has the same number of samples as the first tfrecord read.
86
+ If this isn't the case, it may result in errors, or some samples being missed.
87
+
88
+ This means we can calculate the number of samples we've seen so far using the global step,
89
+ and can use dataset.skip() to iterate through the list of filenames, as opposed to the whole dataset, which is incredibly inefficient.
90
+
91
+ If training is starting and stopping often, as with TPU pre-emption, reading the whole dataset sequentially appears to improve model
92
+ performance, as it results in less repeated data.
93
+ """
94
+ if not eval:
95
+ assert global_step is not None
96
+ logging.warning(
97
+ "Changing batch size with sequential_input() will result in some data being skipped or repeated. Please ensure your batch size stays constant throughout training.")
98
+ batch_size = params['eval_batch_size' if eval else 'train_batch_size']
99
+
100
+ filenames = []
101
+ for dataset_config in params['dataset_configs'].values(): # iterate through each dataset and read params
102
+ path_key = 'path' if not eval else 'eval_path'
103
+ path = dataset_config[path_key]
104
+ filenames.extend(
105
+ tf.io.gfile.glob(path)) # then glob all files that fit the pattern specified in dataset_configs
106
+
107
+ filenames = natural_sort(filenames)
108
+ shuffle_filenames = params.get("shuffle_input_filenames", True)
109
+ if shuffle_filenames:
110
+ seed = params.get('seed', 1) # shuffle deterministically
111
+ random.seed(seed)
112
+ random.shuffle(filenames)
113
+
114
+ dataset = tf.data.Dataset.from_tensor_slices(filenames).repeat() # repeat filenames to infinity
115
+
116
+ if not eval:
117
+ # skip forward first in the filenames list, then skip the remaining amount in the parsed tfrecords files
118
+ skip_idx, remainder = _get_skip_index(filenames, n_batches=global_step * params[
119
+ "train_batch_size"]) # TODO: fix for > 1 epoch
120
+ dataset = dataset.skip(skip_idx) # skip to skip idx
121
+
122
+ # read tfrecord examples and skip remainder
123
+ dataset = dataset.apply(tf.data.TFRecordDataset)
124
+ dataset = dataset.skip(remainder)
125
+ else:
126
+ # shuffle filenames if in eval mode
127
+ dataset = dataset.shuffle(len(filenames))
128
+ dataset = dataset.apply(tf.data.TFRecordDataset)
129
+
130
+ # parse the tokenized data from the tfrecord files and shuffle
131
+ dataset = dataset.map(_parse_function, num_parallel_calls=1)
132
+ dataset = dataset.map(partial(autoregressive_sample_text, params), num_parallel_calls=1)
133
+
134
+ # batch data and repeat to infinity
135
+ dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(params["iterations"] * 2)
136
+ return dataset.repeat()
137
+
138
+
139
+ def pred_input(params, logger, enc=None,
140
+ path_to_prompt=""):
141
+ unicorns = "In a shocking finding, scientists discovered a herd of unicorns living in a remote, " \
142
+ "previously unexplored valley, in the Andes Mountains. Even more surprising to the " \
143
+ "researchers was the fact that the unicorns spoke perfect English."
144
+
145
+ text = unicorns if path_to_prompt == "" else open(path_to_prompt, "r").read()
146
+ tokens = encode(enc, text)
147
+
148
+ if len(tokens) > params["n_ctx"]:
149
+ logger.info("The length of your input prompt is longer than the model's context length - truncating input.")
150
+ tokens = tokens[len(tokens) - params["n_ctx"]:]
151
+ if len(tokens) < params["n_ctx"]:
152
+ tokens = tf.pad(tokens, [[0, params["n_ctx"] - len(tokens)]], constant_values=params["padding_id"])
153
+ t = tf.broadcast_to(tokens, [params["batch_size"], params["n_ctx"]])
154
+ dataset = tf.data.Dataset.from_tensors(t)
155
+
156
+ def _dummy_labels(x):
157
+ return x, x
158
+
159
+ dataset = dataset.map(_dummy_labels)
160
+ return dataset
161
+
162
+
163
+ def handle_pred_output(predictions, logger, enc, params, out_name="test"):
164
+ with tf.gfile.Open(f"{out_name}.txt", "w") as f:
165
+ for i, p in enumerate(predictions):
166
+ p = p["outputs"]
167
+
168
+ # remove eos + padding ids from output
169
+ idx = np.argmax(p == params['eos_id'])
170
+ if idx > 0:
171
+ p = p[:idx]
172
+ idx = np.argmax(p == params['padding_id'])
173
+ if idx > 0:
174
+ p = p[:idx]
175
+
176
+ text = enc.decode(p)
177
+ f.write("=" * 40 + " SAMPLE " + str(i) + " " + "=" * 40 + "\n")
178
+ f.write(text)
179
+ f.write("\n" + "=" * 80 + "\n")
180
+
181
+ logger.info("=" * 40 + " SAMPLE " + str(i) + " " + "=" * 40 + "\n")
182
+ logger.info(text)
183
+ logger.info("\n" + "=" * 80 + "\n")
184
+
185
+
186
+ ### DEPRECATED ###
187
+
188
+ def generic_text(params, eval=False, sample_text_fn=None, **kwargs):
189
+ logging.warning("DEPRECATION WARNING: generic_text will be phased out in future versions.")
190
+ i = 0 if not eval else 1
191
+
192
+ weights = []
193
+ datasets = []
194
+
195
+ for dataset in params["datasets"]:
196
+ dataset_id, stitch, datatype, weight = dataset
197
+
198
+ assert dataset_id in params[
199
+ 'dataset_configs'], f'Unknown dataset id {dataset_id} given. Please make sure your dataset ids contain that configuration'
200
+ dataset_config = params['dataset_configs'][dataset_id]
201
+
202
+ path_key = 'path' if not eval else 'eval_path'
203
+ path = dataset_config[path_key]
204
+
205
+ datasets.append(text_dataset(
206
+ tf.io.gfile.glob(path),
207
+ params,
208
+ stitch=stitch,
209
+ datatype=datatype,
210
+ batch=False,
211
+ sample_text_fn=sample_text_fn
212
+ ))
213
+
214
+ weights.append(weight)
215
+
216
+ batch_size = params['eval_batch_size' if eval else 'train_batch_size']
217
+
218
+ seed = params.get('seed', None)
219
+ dataset = tf.data.experimental.sample_from_datasets(datasets, weights=weights, seed=seed)
220
+ dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(params["iterations"] * 2)
221
+ return dataset
222
+
223
+
224
+ def text_dataset(files, params, stitch, datatype, batch=True, sample_text_fn=None):
225
+ seed = params.get('seed', None)
226
+ deterministic = seed is not None
227
+ num_parallel_calls = 1 if deterministic else tf.data.experimental.AUTOTUNE
228
+
229
+ dataset = tf.data.Dataset.from_tensor_slices(files)
230
+
231
+ if deterministic:
232
+ dataset = dataset.interleave(tf.data.TFRecordDataset, cycle_length=4)
233
+ else:
234
+ dataset = dataset.apply(
235
+ tf.data.experimental.parallel_interleave(tf.data.TFRecordDataset, cycle_length=4, sloppy=False))
236
+
237
+ if "documents" in datatype:
238
+ def _parse_function(example_proto):
239
+ features = {
240
+ # "hash": tf.VarLenFeature(tf.string),
241
+ "text": tf.VarLenFeature(tf.int64)
242
+ }
243
+ parsed_features = tf.parse_single_example(example_proto, features)
244
+ return parsed_features["text"], parsed_features["text"].dense_shape[0]
245
+ else:
246
+ def _parse_function(example_proto):
247
+ features = {
248
+ "text": tf.VarLenFeature(tf.int64)
249
+ }
250
+ parsed_features = tf.parse_single_example(example_proto, features)
251
+ return parsed_features["text"] # Assuming the text is not sparse
252
+
253
+ dataset = dataset.map(_parse_function, num_parallel_calls=1)
254
+
255
+ # Subsample method
256
+ if "documents" in datatype:
257
+ # Since samples can be less than the correct length, and TPUs don't like variable lengths, this function stitches together enough samples
258
+ # to have a text at least 1024 tokens long. For this to work the stitch parameter must be correctly tuned so that
259
+ # stitch * min(characters_in_text) >= amount
260
+ def _stitch_text(x, y):
261
+ x = tf.sparse.to_dense(x)
262
+
263
+ def _get_x(i):
264
+ return tf.gather(x[i], tf.range(y[i]))
265
+
266
+ out = _get_x(0)
267
+ eos_id = params['eos_id']
268
+
269
+ for i in range(1, stitch):
270
+ out = tf.concat([out, [eos_id], _get_x(i)], axis=0) # text1<|endoftext|>text2
271
+
272
+ return out
273
+
274
+ # Hack-y way to stitch together multiple texts
275
+
276
+ dataset = dataset.shuffle(1000 * stitch, seed=seed).batch(stitch, drop_remainder=True).map(_stitch_text,
277
+ num_parallel_calls=num_parallel_calls)
278
+
279
+ # Sample 1024(+1) tokens from the stitched together text
280
+ is_random_documents = datatype == "documents_random"
281
+ if sample_text_fn is not None:
282
+ _sample_text = partial(sample_text_fn, random_documents=is_random_documents)
283
+ else:
284
+ _sample_text = autoregressive_sample_text_random_documents if is_random_documents else autoregressive_sample_text
285
+ _sample_text = partial(_sample_text, params)
286
+
287
+ dataset = dataset.map(_sample_text, num_parallel_calls=num_parallel_calls)
288
+
289
+ if batch:
290
+ dataset = dataset.batch(params["train_batch_size"], drop_remainder=True).prefetch(params["iterations"] * 2)
291
+
292
+ dataset = dataset.repeat()
293
+
294
+ return dataset
295
+
296
+
297
+ def autoregressive_sample_text_random_documents(params, x):
298
+ seed = params.get('seed', None)
299
+ s = tf.size(x)
300
+ r = tf.random.uniform([], maxval=s - (params["n_ctx"] + 1), dtype=tf.dtypes.int32, seed=seed)
301
+ r1 = tf.range(r, r + params["n_ctx"])
302
+ r2 = tf.range(r + 1, (r + 1) + params["n_ctx"])
303
+ r1 = tf.reshape(r1, [params["n_ctx"]]) # Somehow, this makes the compiler happy
304
+ r2 = tf.reshape(r2, [params[
305
+ "n_ctx"]]) # TPUs want constant sized input, and these reshapes makes it recognize the shape of the input
306
+ vals1 = tf.gather(x, r1)
307
+ vals2 = tf.gather(x, r2)
308
+
309
+ vals1 = tf.reshape(vals1, [params["n_ctx"]])
310
+ vals2 = tf.reshape(vals2, [params["n_ctx"]])
311
+ vals1 = tf.cast(vals1, dtype=tf.int32)
312
+ vals2 = tf.cast(vals2, dtype=tf.int32)
313
+ return vals1, vals2
314
+
315
+
316
+ def mlm_sample_text(params, x, random_documents=False):
317
+ seed = params.get('seed', None)
318
+ ctx_len = params["n_ctx"]
319
+ assert 'mlm_mask_id' in params, 'the key `mlm_mask_id` must be set on your config to do masked language model training, specifying the id of the reserved mask token'
320
+
321
+ mask_id = params['mlm_mask_id']
322
+ cls_token_id = params.get('mlm_cls_token_id', None)
323
+ num_tokens = params.get('n_vocab', None)
324
+
325
+ mask_ignore_ids = set(params.get('mlm_mask_ignore_ids', []))
326
+ mask_ignore_ids.add(cls_token_id)
327
+
328
+ mask_prob = params.get('mlm_mask_prob', 0.15)
329
+ same_token_prob = params.get('mlm_same_token_prob', 0.10)
330
+ random_token_prob = params.get('mlm_random_token_prob', 0.)
331
+
332
+ seq_len = ctx_len if cls_token_id is None else (ctx_len - 1)
333
+
334
+ if random_documents:
335
+ s = tf.size(x)
336
+ r = tf.random.uniform([], maxval=(s - seq_len), dtype=tf.dtypes.int32, seed=seed)
337
+ r1 = tf.range(r, r + seq_len)
338
+ r1 = tf.reshape(r1, [seq_len])
339
+ features = tf.gather(x, r1)
340
+ else:
341
+ features = x[:seq_len]
342
+
343
+ # add cls token id if specified by `mlm_cls_token_id`
344
+ if cls_token_id is not None:
345
+ features = tf.pad(features, [[1, 0]], constant_values=cls_token_id)
346
+
347
+ features = tf.cast(features, dtype=tf.int32)
348
+ shape = features.shape
349
+
350
+ # determine which tokens are mask-able
351
+ can_mask = tf.not_equal(features, 0)
352
+ for ignore_id in mask_ignore_ids:
353
+ can_mask &= tf.not_equal(features, ignore_id)
354
+
355
+ # generate boolean mask for masking ids
356
+ mask_mask = tf.less(tf.random.uniform(shape, minval=0., maxval=1., dtype=tf.float32, seed=seed), mask_prob)
357
+ mask_mask &= can_mask
358
+
359
+ # generate mask for actually replacing the tokens, for allowing a small number of tokens to stay the same
360
+ replace_mask = tf.less(tf.random.uniform(shape, minval=0., maxval=1., dtype=tf.float32, seed=seed),
361
+ 1 - same_token_prob)
362
+
363
+ # randomly replace some tokens with random tokens before masking
364
+ if random_token_prob > 0:
365
+ random_token_mask = tf.less(tf.random.uniform(shape, minval=0., maxval=1., dtype=tf.float32, seed=seed),
366
+ random_token_prob)
367
+ random_tokens = tf.random.uniform(shape, minval=1, maxval=num_tokens, dtype=tf.dtypes.int32, seed=seed)
368
+
369
+ # make sure random tokens do not include illegal token ids specified by `mlm_mask_ignore_ids`
370
+ random_can_mask = tf.not_equal(random_tokens, 0)
371
+ for ignore_id in mask_ignore_ids:
372
+ random_can_mask &= tf.not_equal(random_tokens, ignore_id)
373
+
374
+ features = tf.where(random_token_mask & random_can_mask, random_tokens, features)
375
+
376
+ # mask the tokens
377
+ mask_tokens = tf.ones(shape, dtype=tf.int32) * mask_id
378
+ masked_features = tf.where(mask_mask & replace_mask, mask_tokens, features)
379
+
380
+ # labels will be set to 0 for all non-masked tokens
381
+ labels = tf.where(mask_mask, tf.zeros(shape, dtype=tf.int32), features)
382
+
383
+ masked_features, labels = map(lambda t: tf.reshape(t, [ctx_len]), (masked_features, labels))
384
+ return masked_features, labels
main.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """GPT-like model in Mesh-Tensorflow"""
2
+
3
+ from functools import partial
4
+ import mesh_tensorflow as mtf
5
+ import tensorflow.compat.v1 as tf
6
+ from tensorflow.python.tpu import tpu_config, tpu_estimator
7
+ from tensorflow_estimator.python.estimator import estimator as estimator_lib
8
+ from utils import save_config, expand_attention_types_params, yes_or_no, remove_gs_or_filepath, setup_logging, \
9
+ check_dataset
10
+ from inputs import sequential_input, pred_input, handle_pred_output, mlm_sample_text, generic_text
11
+ from export import export_model
12
+ from model_fns import model_fn
13
+ from data.encoders import fetch_encoder
14
+ from configs import fetch_model_params
15
+ from tasks import task_descriptors
16
+ import argparse
17
+ import json
18
+ import numpy
19
+
20
+
21
+ def parse_args():
22
+ # Parse command line arguments
23
+ parser = argparse.ArgumentParser()
24
+ parser.add_argument("--tpu", type=str, help="Name of TPU to train on, if any.")
25
+ parser.add_argument("--gpu_ids", nargs="+", type=str, default=["device:GPU:0"],
26
+ help="If training on GPU, can specify your GPU names in a list - i.e 'device:GPU:0 device:GPU:1'")
27
+ parser.add_argument("--model", type=str, default=None, help="JSON file that contains model parameters.")
28
+ parser.add_argument("--steps_per_checkpoint", type=int, default=5000, help="Save a model checkpoint every X steps.")
29
+ parser.add_argument("--auto_layout", action="store_true", help="If set, generates and prints the most memory "
30
+ "efficient layout according to MTF auto layout.")
31
+ parser.add_argument("--auto_layout_and_mesh_shape", action="store_true",
32
+ help="If set, generates and prints the most memory efficient layout and mesh shape according to"
33
+ " MTF auto layout.")
34
+ parser.add_argument("--new", action="store_true", help="If set, deletes previous checkpoint, if it exists, and "
35
+ "starts a new training run")
36
+ parser.add_argument("--predict", action="store_true", help="If set, uses the model to predict rather than train.")
37
+ parser.add_argument("--eval", action="store_true", help="If set, run model in evaluation mode.")
38
+ parser.add_argument("--prompt", type=str, help="path to .txt file containing a prompt for prediction. If empty, "
39
+ "defaults to unicorns.",
40
+ default="")
41
+ parser.add_argument("--check_dataset", action="store_true",
42
+ help="If set, outputs sample from the dataset and quits.")
43
+ parser.add_argument("--sacred_id", type=str, default="nosacred", help="Sacred run id.")
44
+ parser.add_argument("--entmax_sampling", action="store_true", help="(experimental) use entmax sampling")
45
+ parser.add_argument("--export", action="store_true", help="If set, will export the model.")
46
+ args = parser.parse_args()
47
+ assert args.model is not None, "Model must be set"
48
+ return args
49
+
50
+
51
+ def main(args):
52
+ # Setup logging
53
+ logger = setup_logging(args)
54
+
55
+ # Read params of model
56
+ params = fetch_model_params(args.model)
57
+
58
+ # Fetch appropriate input functions
59
+ input_fn = params.get("input_fn", "sequential_input")
60
+ if input_fn == "sequential_input":
61
+ input_fn = sequential_input
62
+ elif input_fn == "generic_text":
63
+ input_fn = generic_text
64
+ pred_input_fn = pred_input
65
+ handle_pred_output_fn = handle_pred_output
66
+
67
+ # get current step
68
+ current_step = int(estimator_lib._load_global_step_from_checkpoint_dir(params["model_path"]))
69
+ logger.info(f"Current step {current_step}")
70
+
71
+ if params["mlm_training"]:
72
+ mlm_sample_text_fn = partial(mlm_sample_text, params)
73
+ input_fn = partial(generic_text, sample_text_fn=mlm_sample_text_fn)
74
+ if args.check_dataset:
75
+ check_dataset(input_fn, params)
76
+
77
+
78
+ # Fetch encoder per params
79
+ encoder = fetch_encoder(params)
80
+
81
+ pred_input_fn = partial(pred_input_fn, path_to_prompt=args.prompt, logger=logger, enc=encoder)
82
+
83
+ # Sample from Dataset if check dataset flag is on
84
+ if args.check_dataset:
85
+ check_dataset(input_fn, params, global_step=current_step)
86
+
87
+ # Confirm deletion of checkpoint files if --new flag is set
88
+ if args.new:
89
+ if yes_or_no(f"Are you sure you want to remove '{params['model_path']}' to start afresh?"):
90
+ remove_gs_or_filepath(params["model_path"])
91
+ else:
92
+ exit()
93
+
94
+ # Save config to logdir for experiment management
95
+ save_config(params, params["model_path"])
96
+
97
+ # Add to params: auto_layout, auto_layout_and_mesh_shape, use_tpu, num_cores
98
+ mesh_shape = mtf.convert_to_shape(params["mesh_shape"])
99
+ params["num_cores"] = mesh_shape.size
100
+ params["auto_layout"] = args.auto_layout
101
+ params["auto_layout_and_mesh_shape"] = args.auto_layout_and_mesh_shape
102
+ params["use_tpu"] = True if not args.tpu is None else False
103
+ params["gpu_ids"] = args.gpu_ids
104
+ params["steps_per_checkpoint"] = args.steps_per_checkpoint
105
+ # Expand attention types param
106
+ params["attention_types"] = expand_attention_types_params(params["attention_types"])
107
+ assert len(params["attention_types"]) == params["n_layer"] # Assert that the length of expanded list = num layers
108
+ params["predict_batch_size"] = params.get("predict_batch_size", 1) # Default to 1
109
+ params["predict"] = args.predict
110
+ params['model'] = params.get("model", "GPT") # Default model selection to GPT since it's the only option for now
111
+ params["export"] = args.export
112
+ # Set sampling parameters
113
+ params["sampling_use_entmax"] = args.entmax_sampling
114
+
115
+ # Sample quality of MoE models suffers when using the faster sampling method, so default to slow_sampling if
116
+ # moe layers are present
117
+ params["slow_sampling"] = True if params["moe_layers"] is not None else False
118
+
119
+ logger.info(f"params = {params}")
120
+
121
+ # Get eval tasks from params
122
+ eval_tasks = params.get("eval_tasks", [])
123
+ has_predict_or_eval_steps_or_eval_tasks = params["predict_steps"] > 0 or params["eval_steps"] > 0 or len(
124
+ eval_tasks) > 0
125
+
126
+ for t in eval_tasks:
127
+ assert t in task_descriptors, f"Eval task '{t}' is not known"
128
+ task_descriptors[t]["init_fn"](params)
129
+
130
+ # Set up TPUs and Estimator
131
+ if args.tpu == "colab":
132
+ tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver() if params["use_tpu"] else None
133
+ else:
134
+ tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(args.tpu) if params["use_tpu"] else None
135
+
136
+ config = tpu_config.RunConfig(
137
+ cluster=tpu_cluster_resolver,
138
+ model_dir=params["model_path"],
139
+ save_checkpoints_steps=None, # Disable the default saver
140
+ save_checkpoints_secs=None, # Disable the default saver
141
+ log_step_count_steps=params["iterations"],
142
+ save_summary_steps=params["iterations"],
143
+ tpu_config=tpu_config.TPUConfig(
144
+ num_shards=mesh_shape.size,
145
+ iterations_per_loop=params["iterations"],
146
+ num_cores_per_replica=1,
147
+ per_host_input_for_training=tpu_config.InputPipelineConfig.BROADCAST))
148
+
149
+ estimator = tpu_estimator.TPUEstimator(
150
+ use_tpu=params["use_tpu"],
151
+ model_fn=model_fn,
152
+ config=config,
153
+ train_batch_size=params["train_batch_size"],
154
+ eval_batch_size=params["train_batch_size"],
155
+ predict_batch_size=params["predict_batch_size"],
156
+ params=params)
157
+
158
+ def _make_task_estimator(task):
159
+ task_params = params.copy()
160
+ task_params["eval_task"] = task
161
+ return tpu_estimator.TPUEstimator(
162
+ use_tpu=params["use_tpu"],
163
+ model_fn=model_fn,
164
+ config=config,
165
+ train_batch_size=params["train_batch_size"],
166
+ eval_batch_size=params["eval_batch_size"],
167
+ predict_batch_size=params["predict_batch_size"],
168
+ params=task_params)
169
+
170
+ eval_task_estimators = {
171
+ task: _make_task_estimator(task)
172
+ for task in eval_tasks
173
+ }
174
+
175
+ if args.export:
176
+ export_model(estimator, "export", params)
177
+ return
178
+
179
+ if args.predict:
180
+ # Predict
181
+ predictions = estimator.predict(input_fn=pred_input_fn)
182
+ logger.info("Predictions generated")
183
+ enc = fetch_encoder(params)
184
+ handle_pred_output_fn(predictions, logger, enc, params, out_name=f"predictions_{args.sacred_id}_{current_step}")
185
+ return
186
+
187
+ def save_eval_results(task, eval_results):
188
+ def as_python(x):
189
+ if isinstance(x, numpy.generic):
190
+ return x.item()
191
+ return x
192
+ eval_results = {k: as_python(v) for k, v in eval_results.items()}
193
+ with open(f'eval_{args.sacred_id}.jsonl', 'a') as fh:
194
+ json.dump({'task': task, 'current_step': current_step, **eval_results}, fh)
195
+ fh.write('\n')
196
+
197
+ def run_eval():
198
+ logger.info("Running evaluation...")
199
+ eval_results = estimator.evaluate(
200
+ input_fn=partial(input_fn, eval=True),
201
+ steps=params["eval_steps"])
202
+ logger.info(f"Eval results: {eval_results}")
203
+ save_eval_results('validation', eval_results)
204
+
205
+ def run_eval_tasks():
206
+ for task in eval_tasks:
207
+ logger.info(f"Starting evaluation task '{task}'")
208
+ task_info = task_descriptors[task]["get_task_info_fn"](params)
209
+ task_estimator = eval_task_estimators[task]
210
+ task_input_fn = task_descriptors[task]["input_fn"]
211
+ eval_results = task_estimator.evaluate(
212
+ input_fn=task_input_fn,
213
+ steps=task_info["n_steps"],
214
+ name=task)
215
+ logger.info(f"Eval task '{task}' results: {eval_results}")
216
+ save_eval_results(task, eval_results)
217
+
218
+ if args.eval:
219
+ run_eval_tasks()
220
+ if params["eval_steps"] > 0:
221
+ run_eval()
222
+ return
223
+
224
+
225
+ elif has_predict_or_eval_steps_or_eval_tasks:
226
+ # Eval and train - stop and predict and/or eval every checkpoint
227
+ while current_step < params["train_steps"]:
228
+ next_checkpoint = min(current_step + args.steps_per_checkpoint,
229
+ params["train_steps"])
230
+
231
+ estimator.train(input_fn=partial(input_fn, global_step=current_step, eval=False), max_steps=next_checkpoint)
232
+ current_step = next_checkpoint
233
+
234
+ if params["predict_steps"] > 0:
235
+ logger.info("Running prediction...")
236
+ predictions = estimator.predict(input_fn=pred_input_fn)
237
+ enc = fetch_encoder(params)
238
+ handle_pred_output_fn(predictions, logger, enc, params, out_name=f"predictions_{args.sacred_id}_{current_step}")
239
+
240
+ if params["eval_steps"] > 0:
241
+ run_eval()
242
+
243
+ if eval_tasks:
244
+ run_eval_tasks()
245
+
246
+ return
247
+ else:
248
+ # Else, just train
249
+ while current_step < params["train_steps"]:
250
+ # Else, don't stop and restart
251
+ estimator.train(input_fn=partial(input_fn, global_step=current_step, eval=False), max_steps=params["train_steps"])
252
+
253
+
254
+ if __name__ == "__main__":
255
+ tf.disable_v2_behavior()
256
+ args = parse_args()
257
+ main(args)
model_fns.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import mesh_tensorflow as mtf
2
+ import tensorflow.compat.v1 as tf
3
+ from tensorflow.python.tpu import tpu_estimator
4
+ import mesh_tensorflow.transformer as mtf_transformer
5
+ from optimizers import get_optimizer
6
+ from utils import (create_host_call, get_graph_info, remove_batch_from_layout, simd_mesh_setup, add_mode_to_params,
7
+ get_batch_size, auto_layout, auto_layout_and_mesh_shape)
8
+ from models.utils import biasmask_attn_weights
9
+ from tensorflow.python.ops import resources
10
+ from sample import sample_autoregressive
11
+ from models.gpt2 import gpt2
12
+ import math
13
+
14
+
15
+ def model_fn(features, labels, mode, params):
16
+ # Get global step
17
+ global_step = tf.train.get_global_step()
18
+
19
+ # Construct mtf graph + mesh from params
20
+ graph = mtf.Graph()
21
+ mesh_shape = mtf.convert_to_shape(params["mesh_shape"])
22
+ layout_rules = mtf.convert_to_layout_rules(params["layout"])
23
+
24
+ # Mesh setup
25
+ if params["use_tpu"]:
26
+ var_placer, mesh_impl = simd_mesh_setup(params, mesh_shape, layout_rules)
27
+ else:
28
+ var_placer = None
29
+ gpu_ids = params["gpu_ids"]
30
+ mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
31
+ mesh_shape, layout_rules, gpu_ids)
32
+
33
+ # Trainable variable precision
34
+ # Store to checkpoints in master type, train in slice type, compute in activation type
35
+ if params["precision"] == "bfloat16":
36
+ variable_dtype = mtf.VariableDType(master_dtype=tf.bfloat16, slice_dtype=tf.float32,
37
+ activation_dtype=tf.bfloat16)
38
+ else:
39
+ variable_dtype = mtf.VariableDType(master_dtype=tf.float32, slice_dtype=tf.float32, activation_dtype=tf.float32)
40
+
41
+ # Build mtf mesh object
42
+ mesh = mtf.Mesh(graph, "my_mesh", var_placer)
43
+
44
+ # Build mtf_features & seq length dict for getting number of microbatches
45
+ # We need to pack inputs into a dict to pass into serialize_training_step
46
+ features_dict = {"inputs": features, "labels": labels}
47
+ sequence_length_dict = {"inputs": params["n_ctx"], "labels": params["n_ctx"]}
48
+
49
+ params = add_mode_to_params(params, mode)
50
+ batch_size = get_batch_size(params)
51
+
52
+ batch_dim = mtf.Dimension("batch", batch_size)
53
+ batch_dims = [batch_dim]
54
+ feature_length = sequence_length_dict["inputs"]
55
+ length_dim = mtf.Dimension("sequence", feature_length)
56
+
57
+ mtf_features = {}
58
+ for key, x in features_dict.items():
59
+ if x is not None:
60
+ feature_shape = mtf.Shape(batch_dims + [length_dim])
61
+ if type(features_dict[key]) == dict:
62
+ features_dict[key] = features_dict[key]["feature"]
63
+ x = tf.cast(features_dict[key], tf.int32)
64
+ x = tf.reshape(x, feature_shape.to_integer_list)
65
+ mtf_features[key] = mtf.import_fully_replicated(
66
+ mesh, x, feature_shape, name=key)
67
+
68
+ # Instantiate dict for dimensions, bias, etc that can be calculated here once then passed into model
69
+ other_features = {}
70
+ memory_length_dim = mtf.Dimension("memory_length", length_dim.size)
71
+
72
+ attn_bias = biasmask_attn_weights(mesh, length_dim, memory_length_dim, variable_dtype) if params["causal"] else None
73
+
74
+ # Add attn_bias into mtf_features
75
+ other_features["attn_bias"] = attn_bias
76
+
77
+ # Define other Dimensions that we'll need inside the model
78
+ embd_dim = mtf.Dimension("embd", params["n_embd"])
79
+ vocab_dim = mtf.Dimension("vocab", params["n_vocab"])
80
+ # We need this because gathering when both the args have the same dimension in them breaks things
81
+ # This dim is specifically for the weights
82
+ # This prevents the "Einsum has lhs dimension without corresponding rhs or output dimension." error
83
+ embed_sequence_dim = mtf.Dimension("embed_sequence", params["n_ctx"])
84
+
85
+ other_features["embd_dim"] = embd_dim
86
+ other_features["vocab_dim"] = vocab_dim
87
+ other_features["embed_sequence_dim"] = embed_sequence_dim
88
+ other_features["memory_length_dim"] = memory_length_dim
89
+
90
+ if mode == tf.estimator.ModeKeys.PREDICT:
91
+ # Set up the model for prediction
92
+ inputs = mtf_features["inputs"]
93
+ if params["remove_partial_sequences"] is None:
94
+ params["remove_partial_sequences"] = False
95
+
96
+ export = params.get("export", False)
97
+
98
+ if not export:
99
+ mtf_samples = sample_autoregressive(
100
+ inputs, other_features=other_features, params=params, variable_dtype=variable_dtype,
101
+ remove_partial_sequences=params["remove_partial_sequences"], stop_at_token=params["eos_id"],
102
+ sampling_use_entmax=params['sampling_use_entmax'], max_steps=params["predict_max_steps"])
103
+
104
+ else:
105
+ with mtf.utils.outside_all_rewrites():
106
+ with tf.variable_scope('gpt2'):
107
+ mtf_samples, loss, loss_batch = gpt2.model(mtf_features, other_features, params, mesh,
108
+ variable_dtype=variable_dtype, context=None)
109
+
110
+ mtf_samples = mtf.anonymize(mtf_samples)
111
+ inputs = mtf.anonymize(inputs)
112
+ lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=True)
113
+ inputs = lowering.export_to_tf_tensor(inputs)
114
+ outputs = lowering.export_to_tf_tensor(mtf_samples)
115
+ predictions = {
116
+ "inputs": inputs,
117
+ "outputs": outputs}
118
+
119
+ def scaffold_fn():
120
+ return tf.train.Scaffold(
121
+ local_init_op=tf.group(
122
+ tf.train.Scaffold.default_local_init_op(),
123
+ lowering.copy_masters_to_slices(),
124
+ name="mtf_local_init_op"),
125
+ ready_op=tf.concat(
126
+ [tf.report_uninitialized_variables(),
127
+ resources.report_uninitialized_resources()],
128
+ axis=0,
129
+ name="mtf_ready_op"))
130
+
131
+ return tpu_estimator.TPUEstimatorSpec(
132
+ mode=tf.estimator.ModeKeys.PREDICT,
133
+ predictions=predictions,
134
+ scaffold_fn=scaffold_fn,
135
+ prediction_hooks=[mtf.MtfRestoreHook(lowering)])
136
+
137
+ # We're not predicting, so we better be training or evaluating
138
+ assert (mode == tf.estimator.ModeKeys.TRAIN or mode == tf.estimator.ModeKeys.EVAL)
139
+
140
+ if mode == tf.estimator.ModeKeys.TRAIN:
141
+ # Gets number of microbatches per batch for serialized training
142
+ # if param tokens_per_mb_per_replica = None, this defaults to 1 and no microbatching is performed
143
+ num_microbatches = int(mtf_transformer.utils.serialize_num_microbatches(batch_dim=batch_dim,
144
+ sequence_length=sequence_length_dict,
145
+ mesh_shape=mesh_shape,
146
+ layout_rules=layout_rules,
147
+ tokens_per_microbatch_per_replica=
148
+ params["tokens_per_mb_per_replica"]))
149
+ else:
150
+ num_microbatches = 1
151
+
152
+ params["num_microbatches"] = num_microbatches # Add num microbatches to params
153
+
154
+ if num_microbatches > 1:
155
+
156
+ # For serialize_training_step we need to modify the model to output results in a dict
157
+ def serialized_fn(mtf_features):
158
+ if params["model"] == "GPT":
159
+ with tf.variable_scope('gpt2'):
160
+ logits, loss, loss_batch = gpt2.model(mtf_features, other_features, params, mesh,
161
+ variable_dtype=variable_dtype)
162
+ return {"logits": logits, "loss": loss, "loss_batch": loss_batch}
163
+ else:
164
+ raise Exception(f"'{params['model']}' is not a valid model - please select from [GPT]")
165
+
166
+ # Serialize the training step - Gradients are accumulated locally and reduced once.
167
+ var_grads, output_dict = mtf.serialize_training_step(mtf_features, serialized_fn, batch_dim, num_microbatches)
168
+ loss = output_dict["loss"]
169
+ loss_batch = output_dict["loss_batch"]
170
+ logits = output_dict["logits"]
171
+ else:
172
+ # If we're not splitting into microbatches, return logits & loss as is
173
+ if params["model"] == "GPT":
174
+ with mtf.utils.outside_all_rewrites():
175
+ with tf.variable_scope('gpt2'):
176
+ logits, loss, loss_batch = gpt2.model(mtf_features, other_features, params, mesh,
177
+ variable_dtype=variable_dtype, context=None)
178
+ else:
179
+ raise Exception(f"'{params['model']}' is not a valid model - please select from [GPT]")
180
+
181
+ # Auto layout generation
182
+ if params["auto_layout"]:
183
+ auto_layout(graph, mesh_shape, logits, loss)
184
+ if params["auto_layout_and_mesh_shape"]:
185
+ auto_layout_and_mesh_shape(graph, params["num_cores"], logits, loss)
186
+
187
+ if mode == tf.estimator.ModeKeys.TRAIN:
188
+ # In TRAIN mode, get optimizer
189
+ if params["num_microbatches"] > 1:
190
+ # If we are splitting the batch into microbatches, var grads are created in the serialize_training_step fn
191
+ # So we pass them in here
192
+ _, update_ops, var_grads = get_optimizer(mesh, loss, params, variable_dtype=variable_dtype,
193
+ inp_var_grads=var_grads)
194
+ else:
195
+ # Otherwise, they are created in the get_optimizer fn, so we leave inp_var_grads blank
196
+ _, update_ops, var_grads = get_optimizer(mesh, loss, params, variable_dtype=variable_dtype)
197
+ # Log summaries to tensorboard
198
+ mtf.scalar_summary("loss", loss)
199
+ # Log gradients if in params
200
+ if params["log_grads"] not in [None, False]:
201
+ for g in var_grads:
202
+ grad_norm = mtf.sqrt(mtf.reduce_sum(mtf.square(g)))
203
+ mtf.scalar_summary("grads/norm" + g.name[:-2], grad_norm)
204
+ else:
205
+ # For now, we can only export fully-replicated tensors.
206
+ # This has to be done before lowering or they will not be included in the graph
207
+ mean_logits = mtf.reduce_mean(logits, reduced_dim=vocab_dim)
208
+ max_logits = mtf.argmax(logits, vocab_dim)
209
+ del logits
210
+ fully_replicated_mean_logits = mtf.anonymize(mean_logits)
211
+ fully_replicated_max_logits = mtf.anonymize(max_logits)
212
+ fully_replicated_loss_batch = mtf.anonymize(loss_batch)
213
+
214
+ # Gets & prints info about no. trainable vars in the model & dimension names
215
+ get_graph_info(graph)
216
+
217
+ # 'lowers' mtf tensors into a tf graph - this enables us to export results as tf tensors
218
+ lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=True)
219
+ tf_loss = lowering.export_to_tf_tensor(loss)
220
+ tf_loss = tf.cast(tf_loss, tf.float32)
221
+
222
+ if mode == tf.estimator.ModeKeys.TRAIN:
223
+ # Use our patched version until mtf updates theirs
224
+ host_call = create_host_call(params['model_path'])
225
+ mtf.utils.remove_summaries()
226
+
227
+ # Creates train_op
228
+ tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
229
+ tf_update_ops.append(tf.assign_add(global_step, 1)) # Need to manually increment global_step
230
+ tf.logging.info(f"tf_update_ops: {tf_update_ops}")
231
+ train_op = tf.group(tf_update_ops)
232
+ else:
233
+ tf_mean_logits = lowering.export_to_tf_tensor(fully_replicated_mean_logits)
234
+ tf_max_logits = lowering.export_to_tf_tensor(fully_replicated_max_logits)
235
+ tf_loss_batch = tf.to_float(lowering.export_to_tf_tensor(fully_replicated_loss_batch))
236
+
237
+ with mtf.utils.outside_all_rewrites():
238
+ # Copy master variables to slices. Must be called first.
239
+ restore_hook = mtf.MtfRestoreHook(lowering)
240
+ if mode == tf.estimator.ModeKeys.TRAIN:
241
+ # Set up the checkpoint server and return the TPUEstimatorSpec
242
+ saver = tf.train.Saver(
243
+ tf.global_variables(),
244
+ sharded=True,
245
+ max_to_keep=10,
246
+ keep_checkpoint_every_n_hours=2,
247
+ defer_build=False,
248
+ save_relative_paths=True)
249
+ tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
250
+ saver_listener = mtf.MtfCheckpointSaverListener(lowering)
251
+ saver_hook = tf.train.CheckpointSaverHook(
252
+ params["model_path"],
253
+ save_steps=params["steps_per_checkpoint"],
254
+ saver=saver,
255
+ listeners=[saver_listener])
256
+
257
+ return tpu_estimator.TPUEstimatorSpec(
258
+ tf.estimator.ModeKeys.TRAIN,
259
+ loss=tf_loss,
260
+ host_call=host_call,
261
+ train_op=train_op,
262
+ training_hooks=[restore_hook, saver_hook])
263
+
264
+ elif mode == tf.estimator.ModeKeys.EVAL:
265
+ # Evaluation metrics
266
+ def _perplexity(loss):
267
+ perplexity = tf.exp(loss)
268
+ return tf.metrics.mean(perplexity)
269
+
270
+ def _bits_per_byte(loss):
271
+ bpb = loss * (0.29335 / math.log(2))
272
+ return tf.metrics.mean(bpb)
273
+
274
+ def _metric_fn(tf_mean_logits, tf_loss_batch):
275
+ mean_logits = tf.metrics.mean(tf_mean_logits)
276
+ loss = tf.reduce_mean(tf_loss_batch)
277
+ perp = _perplexity(loss)
278
+ bpb = _bits_per_byte(loss)
279
+ return {"mean_logits": mean_logits, "perplexity": perp, "bits per byte": bpb}
280
+
281
+ def _lambada_metric_fn(labels, tf_max_logits, tf_loss_batch):
282
+ eos_token = params["eos_id"]
283
+ answer_positions = tf.where(tf.math.not_equal(labels, eos_token))
284
+
285
+ correct_answers = tf.gather_nd(tf.math.equal(tf_max_logits, labels), answer_positions)
286
+ accuracy = tf.metrics.mean(tf.cast(correct_answers, tf.float32))
287
+
288
+ # I guess tf_loss_batch has z_loss and maybe other stuff added to it
289
+ # so maybe this should be calculated separately in the future
290
+ answer_loss = tf.gather_nd(tf_loss_batch, answer_positions)
291
+ log_perplexity = tf.metrics.mean(answer_loss)
292
+
293
+ return {"lambada_acc": accuracy, "lambada_log_ppl": log_perplexity}
294
+
295
+ eval_task = params["eval_task"]
296
+ if eval_task == "lambada":
297
+ eval_metrics = (_lambada_metric_fn, [labels, tf_max_logits, tf_loss_batch])
298
+ else:
299
+ eval_metrics = (_metric_fn, [tf_mean_logits, tf_loss_batch])
300
+
301
+ return tpu_estimator.TPUEstimatorSpec(
302
+ tf.estimator.ModeKeys.EVAL,
303
+ evaluation_hooks=[restore_hook],
304
+ loss=tf_loss,
305
+ eval_metrics=eval_metrics)
models/activations.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import mesh_tensorflow as mtf
2
+ import tensorflow.compat.v1 as tf
3
+ import random
4
+
5
+ BASE_FNS = {'gelu': mtf.gelu,
6
+ 'relu': mtf.relu,
7
+ 'sigmoid': mtf.sigmoid,
8
+ 'tanh': mtf.tanh,
9
+ 'selu': mtf.selu,
10
+ 'elu': mtf.elu,
11
+ 'abs': mtf.abs,
12
+ 'sin': mtf.sin,
13
+ 'cos': mtf.cos,
14
+ 'sign': mtf.sign,
15
+ 'silu': mtf.swish,
16
+ 'softplus': mtf.softplus
17
+ }
18
+
19
+
20
+ def _arcsinh(x):
21
+ return mtf.log(x + mtf.sqrt(1 + x ** 2))
22
+
23
+
24
+ def _var(x, init):
25
+ return mtf.get_variable(x.mesh, f"activation-{random.randint(0, 2 ** 32):x}", [],
26
+ initializer=tf.constant_initializer(init), dtype=x.dtype)
27
+
28
+
29
+ def _pos_var(x, val):
30
+ return mtf.softplus(_var(x, 0)) + val
31
+
32
+
33
+ def _rrelu(x):
34
+ negative_scale = random.random()
35
+ return (negative_scale * mtf.abs(x) + x) / (1 + negative_scale)
36
+
37
+
38
+ def _elish(x):
39
+ cond = mtf.cast(mtf.greater(x, 0), x.dtype)
40
+ exp = mtf.exp(x)
41
+ return cond * x / (1 + exp) + (1 - cond) * (exp - 1) / (1 / exp + 1)
42
+
43
+
44
+ CUSTOM_FNS = {'lrelu001': lambda x: mtf.leaky_relu(x, alpha=0.01),
45
+ 'lrelu020': lambda x: mtf.leaky_relu(x, alpha=0.20),
46
+ 'id': lambda x: x,
47
+ 'triangle_relax': lambda x: mtf.sin(x) - mtf.sin(3 * x) / 9 + mtf.sin(5 * x) / 25 - mtf.sin(7 * x) / 49,
48
+ 'square_relax': lambda x: mtf.cos(x) - mtf.cos(3 * x) / 3 + mtf.cos(5 * x) / 5 - mtf.cos(7 * x) / 7,
49
+ 'spike': lambda x: 1 / (1 + x ** 2),
50
+ 'spike2': lambda x: mtf.exp(-x ** 2),
51
+ 'tanhshrink': lambda x: x - tanh(x),
52
+ 'softsign': lambda x: x / (mtf.abs(x) + 1),
53
+ 'softmax': lambda x: mtf.softmax(x, x.shape[-1]),
54
+ 'logsoftmax': lambda x: mtf.log_softmax(x, x.shape[-1]),
55
+ 'bipolarsigmoid': lambda x: mtf.sigmoid(x) * 2 - 1,
56
+ 'rrelu': _rrelu,
57
+ 'elish': _elish,
58
+ 'arcsinh': _arcsinh,
59
+ 'aria': lambda x: x * (_var(x, 0) + _var(x, 1) / (
60
+ _pos_var(x, 0) + _var(x, 1) * mtf.exp(_var(x, -1) * x) ** (1 / _pos_var(x, 1)))),
61
+ 'prelu': lambda x: mtf.leaky_relu(x, alpha=_var(x, 0.2)),
62
+ 'parcsinh': lambda x: _var(x, 1) * _arcsinh(x * _pos_var(x, 1)),
63
+ 'psoftplus': lambda x: _var(x, 1) * mtf.softplus(x * _var(x, 1)) + _var(x, 0),
64
+ 'proottanh': lambda x: (x ** _pos_var(x, 2) + _pos_var(x, 1)) ** (1 / _pos_var(x, 3)) * mtf.tanh(x),
65
+ 'maxsig': lambda x: mtf.maximum(x, mtf.sigmoid(x)),
66
+ 'cosid': lambda x: mtf.cos(x) - x,
67
+ 'minsin': lambda x: mtf.minimum(x, mtf.sin(x)),
68
+ 'maxtanh': lambda x: mtf.maximum(x, mtf.tanh(x)),
69
+ 'mish': lambda x: x * mtf.tanh(mtf.softplus(x)),
70
+ 'tanhexp': lambda x: x * mtf.tanh(mtf.exp(x)),
71
+ 'lisht': lambda x: x * mtf.tanh(x),
72
+ 'seagull': lambda x: mtf.log(1 + x ** 2),
73
+ 'snake': lambda x: x + mtf.sin(x) ** 2,
74
+ 'roottanh': lambda x: (x ** 2 + 1) ** (1 / 3) * mtf.tanh(x),
75
+ 'softplusmone': lambda x: mtf.softplus(x) - 1
76
+ }
77
+
78
+
79
+ def get_activation_fn(params):
80
+ if "activation_fn" in params:
81
+ activation_fn = params["activation_fn"]
82
+ else:
83
+ print("Defaulting to GELU activation (see here: https://arxiv.org/abs/1606.08415)")
84
+ activation_fn = "gelu"
85
+
86
+ if activation_fn in BASE_FNS:
87
+ return BASE_FNS[activation_fn]
88
+
89
+ if activation_fn in CUSTOM_FNS:
90
+ return CUSTOM_FNS[activation_fn]
91
+
92
+ raise ValueError('unknown activation function "activation_fn" in config')
93
+
94
+
95
+
models/gpt2/gpt2.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """GPT-like model in Mesh-Tensorflow"""
2
+ import tensorflow.compat.v1 as tf
3
+ import mesh_tensorflow.transformer as mtf_transformer
4
+
5
+ from models.utils import parse_inputs, entmax_cross_entropy_with_logits
6
+ from models.layers import *
7
+
8
+
9
+ # --------------------------------------------------------------------------------
10
+ # TRANSFORMER BLOCK:
11
+
12
+ def block(params, scope, layer_num, bias, sequence_dim, memory_length_dim, pos_emb, variable_dtype, context=None):
13
+ use_mlp_glu = params["mlp_glu"] == True
14
+ use_scale_norm = params["scalenorm"] == True
15
+ use_moe = exists(params["moe_layers"]) and (layer_num in params["moe_layers"])
16
+ use_rezero = params["rezero"] == True
17
+ macaron_attention = params["macaron"] == True
18
+
19
+ def fn(x):
20
+ with tf.variable_scope(scope):
21
+ nx = x.shape[-1] # Grab last dimension from input
22
+
23
+ if use_rezero:
24
+ prenorm = identity
25
+ elif use_scale_norm:
26
+ prenorm = scale_norm
27
+ else:
28
+ prenorm = layer_norm
29
+
30
+ pre_residual_fn = rezero if use_rezero else identity
31
+
32
+ attention_type = params["attention_types"][layer_num]
33
+
34
+ if macaron_attention:
35
+ mult = 0.5
36
+ mlp_fn = mlp_glu if use_mlp_glu else mlp
37
+ intermediate_size = nx.size * 4 * (1 if not use_mlp_glu else 2)
38
+ # Define intermediate layer of mlp - to split
39
+ dim_intermediate_expanded = mtf.Dimension("intermediate_expanded", intermediate_size)
40
+ m = mlp_fn(x, "mlp_macaron", dim_intermediate_expanded, variable_dtype=variable_dtype, params=params)
41
+
42
+ x = x + (m * mult)
43
+ else:
44
+ mult = 1
45
+
46
+ if attention_type != "none":
47
+ res_x = prenorm(x, "norm_1", variable_dtype=variable_dtype, params=params)
48
+ a = attn(res_x, "attn", nx, attention_type=attention_type,
49
+ params=params, bias=bias, dim_seq=sequence_dim, memory_length_dim=memory_length_dim,
50
+ variable_dtype=variable_dtype, context=context, pos_emb=pos_emb)
51
+ else:
52
+ a = x
53
+
54
+ x = x + pre_residual_fn(a, "norm_rezero_1", dtype=variable_dtype)
55
+
56
+ res_x = prenorm(x, "norm_2", variable_dtype=variable_dtype, params=params)
57
+
58
+ if use_moe:
59
+ moe_params = mtf.transformer.moe.HParams()
60
+ mtf.transformer.moe.set_default_moe_hparams(moe_params)
61
+ moe_params.add_hparam("moe_min_expert_capacity", 1)
62
+ moe_params.add_hparam("moe_use_experts_attention", False)
63
+
64
+ # Override defaults
65
+ for k, v in params["moe_params"].items():
66
+ moe_params.add_hparam(k, v)
67
+
68
+ moe_train = params["mode"] == "train"
69
+
70
+ m, aux_loss = mtf.transformer.moe.transformer_moe_layer_v1(res_x, x.shape[-1], moe_params,
71
+ train=moe_train,
72
+ mesh_shape=params["mesh_shape"],
73
+ layout=params["layout"],
74
+ activation=params.get("moe_activation",
75
+ "relu"),
76
+ variable_dtype=variable_dtype,
77
+ num_microbatches=params["num_microbatches"])
78
+ m = mtf.dropout(m, rate=params["res_dropout"], name="moe_dropout")
79
+ else:
80
+
81
+ mlp_fn = mlp_glu if use_mlp_glu else mlp
82
+ intermediate_size = nx.size * 4 * (1 if not use_mlp_glu else 2)
83
+
84
+ # Define intermediate layer of mlp - to split
85
+ dim_intermediate_expanded = mtf.Dimension("intermediate_expanded", intermediate_size)
86
+
87
+ m = mlp_fn(res_x, "mlp", dim_intermediate_expanded, variable_dtype=variable_dtype, params=params)
88
+ aux_loss = mtf.zeros(x.mesh, mtf.Shape([]), dtype=variable_dtype.slice_dtype)
89
+
90
+ x = x + pre_residual_fn((m * mult), "norm_rezero_2", variable_dtype)
91
+ return x, aux_loss
92
+
93
+ return fn
94
+
95
+
96
+ # --------------------------------------------------------------------------------
97
+ # GPT2 MODEL:
98
+
99
+ def model(mtf_features, other_features, params, mesh, variable_dtype, context=None):
100
+ """A GPT style model implemented in mesh tensorflow."""
101
+
102
+ x, batch_dim, sequence_dim, embd_dim, vocab_dim, embed_sequence_dim = parse_inputs(mtf_features, other_features)
103
+
104
+ if is_incremental_inference(context):
105
+ # reshape inputs if in inference mode
106
+ x = mtf.gather(x, context.position - 1, sequence_dim)
107
+ x = mtf.reshape(x, [batch_dim])
108
+
109
+ use_axial_pos_emb = exists(params["axial_pos_emb"])
110
+ use_rotary_emb = exists(params["rotary_emb"])
111
+
112
+ # Text encoding
113
+ wte = mtf.get_variable(mesh, "wte", mtf.Shape([vocab_dim, embd_dim]),
114
+ initializer=tf.random_normal_initializer(stddev=0.02),
115
+ master_dtype=variable_dtype.master_dtype,
116
+ slice_dtype=variable_dtype.slice_dtype,
117
+ activation_dtype=variable_dtype.activation_dtype)
118
+
119
+ with tf.variable_scope("token_embd"):
120
+ # Text embedding
121
+ h = mtf.gather(wte, x, vocab_dim)
122
+ if params["embed_dropout"] > 0 and params["mode"] == "train":
123
+ h = mtf.dropout(h, rate=params["embed_dropout"], name="wte_dropout")
124
+
125
+ # Position encoding
126
+
127
+ if use_rotary_emb:
128
+ wpe = None
129
+ layer_pos_emb = rotary_positional_emb(mesh, sequence_dim, params, variable_dtype)
130
+ elif use_axial_pos_emb:
131
+ wpe = axial_positional_emb(embd_dim, mesh, params, variable_dtype)
132
+ layer_pos_emb = None
133
+ else:
134
+ # Use standard position encoding
135
+ wpe = mtf.get_variable(mesh, "wpe", mtf.Shape([embed_sequence_dim, embd_dim]),
136
+ initializer=tf.random_normal_initializer(stddev=0.01),
137
+ master_dtype=variable_dtype.master_dtype,
138
+ slice_dtype=variable_dtype.slice_dtype,
139
+ activation_dtype=variable_dtype.activation_dtype)
140
+ layer_pos_emb = None
141
+
142
+ if exists(wpe):
143
+ with tf.variable_scope("pos_embd"):
144
+ # Positional embedding
145
+ position_indices = mtf.range(mesh, sequence_dim, tf.int64) if not is_incremental_inference(context) else (
146
+ context.position - 1)
147
+ pos_emb = mtf.gather(wpe, position_indices, wpe.shape[0])
148
+ if params["embed_dropout"] > 0 and params["mode"] == "train":
149
+ pos_emb = mtf.dropout(pos_emb, rate=params["embed_dropout"], name="wte_dropout")
150
+ h += pos_emb
151
+
152
+ aux_losses = 0 # instantiate auxiliary losses (for MOE models)
153
+
154
+ for layer in range(params["n_layer"]):
155
+ # attn blocks
156
+ share_parameters = exists(params["share_parameters"]) and params["share_parameters"] == True
157
+ block_scope = f"h{layer}" if not share_parameters else ""
158
+
159
+ block_fn = block(params=params, scope=block_scope, layer_num=layer,
160
+ bias=other_features["attn_bias"],
161
+ sequence_dim=sequence_dim,
162
+ memory_length_dim=other_features["memory_length_dim"],
163
+ pos_emb = layer_pos_emb,
164
+ variable_dtype=variable_dtype,
165
+ context=context)
166
+
167
+ # If true and in train mode, enable gradient checkpointing
168
+ recompute_grad = params["recompute_grad"] and (params["mode"] == "train") == True
169
+ h, loss = block_fn(h) if not recompute_grad else mtf.recompute_grad(block_fn, [h])
170
+ aux_losses += loss
171
+
172
+ no_weight_tie_emb = params["no_weight_tie"] == True
173
+ if no_weight_tie_emb:
174
+ with tf.variable_scope("wte_final_linear"):
175
+ logits = linear(h, "linear_out", vocab_dim, variable_dtype=variable_dtype, params=params)
176
+ else:
177
+ # Layer normalize & affine transform
178
+ h = layer_norm(h, "ln_f", variable_dtype=variable_dtype)
179
+ seq_dim = sequence_dim if not is_incremental_inference(context) else mtf.Dimension("sequence", 1)
180
+ with tf.variable_scope("wte_final_einsum"):
181
+ # Equivalent to tf.matmul
182
+ logits = mtf.einsum([h, wte], output_shape=[batch_dim, seq_dim, vocab_dim])
183
+
184
+ if params["mode"] in ["train", "eval"]:
185
+ labels = mtf_features["labels"]
186
+ z_loss = params.get("z_loss", 1e-4) # an auxiliary loss used to stabilize mtf xentropy
187
+
188
+ # Go to full precision for the logits
189
+ logits = mtf.cast(logits, tf.float32)
190
+
191
+ use_entmax_loss = params.get("entmax_loss", False)
192
+ loss_fn = mtf.layers.softmax_cross_entropy_with_logits if not use_entmax_loss else entmax_cross_entropy_with_logits
193
+
194
+ with tf.variable_scope("xentropy_final"):
195
+ loss_batch = loss_fn(logits=logits, targets=labels,
196
+ vocab_dim=logits.shape[-1], z_loss=z_loss)
197
+
198
+ # For non-autoregressive models (masked language modeling training)
199
+ # Make sure labels with padding tokens are not counted in the loss
200
+ if not params["causal"]:
201
+ padding_id = params.get("padding_id", 0)
202
+ loss_batch = mtf.where(mtf.not_equal(labels, padding_id), loss_batch, mtf.zeros_like(loss_batch))
203
+
204
+ with tf.variable_scope("reduce_mean_final"):
205
+ loss = mtf.reduce_mean(loss_batch)
206
+
207
+ loss += aux_losses # Add on auxiliary losses (currently only used for MoE)
208
+ loss /= params["num_microbatches"]
209
+ # Convert to train dtype
210
+ loss = mtf.cast(loss, variable_dtype.slice_dtype)
211
+ else:
212
+ loss = None
213
+ loss_batch = None
214
+
215
+ # Cast back to checkpoint dtype
216
+ logits = mtf.cast(logits, variable_dtype.master_dtype)
217
+ return logits, loss, loss_batch
models/layers.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import mesh_tensorflow as mtf
2
+ import tensorflow.compat.v1 as tf
3
+ import math
4
+ import mesh_tensorflow.transformer as mtf_transformer
5
+
6
+ from models.activations import get_activation_fn
7
+
8
+
9
+ # --------------------------------------------------------------------------------
10
+ # LAYERS:
11
+
12
+ sentinel = object()
13
+
14
+
15
+ def exists(x):
16
+ return x is not None
17
+
18
+
19
+ def identity(x, *args, **kwargs):
20
+ return x
21
+
22
+
23
+ def is_incremental_inference(context):
24
+ return exists(context) and context.mode == "incremental"
25
+
26
+
27
+ def norm(x, axis, epsilon=1e-8):
28
+ x -= mtf.reduce_mean(x, reduced_dim=axis, name="norm_reduce_mean_u")
29
+ s = mtf.reduce_mean(mtf.square(x), reduced_dim=axis, name="norm_reduce_mean_s")
30
+ return x * mtf.rsqrt(s + epsilon)
31
+
32
+
33
+ def rezero(x, scope, dtype):
34
+ with tf.variable_scope(scope):
35
+ g = mtf.get_variable(x.mesh, "g", [], initializer=tf.constant_initializer(0), dtype=dtype)
36
+ return x * g
37
+
38
+
39
+ def scale_norm(x, scope, *, variable_dtype, axis=sentinel, epsilon=1e-5, params=None):
40
+ if axis is sentinel:
41
+ axis = x.shape[-1]
42
+
43
+ with tf.variable_scope(scope):
44
+ g = mtf.get_variable(x.mesh, "g", [], initializer=tf.constant_initializer(1),
45
+ master_dtype=variable_dtype.master_dtype,
46
+ slice_dtype=variable_dtype.slice_dtype,
47
+ activation_dtype=variable_dtype.activation_dtype)
48
+
49
+ x = norm(x, axis, epsilon)
50
+ x = x * g
51
+ return x
52
+
53
+
54
+ def layer_norm(x, scope, *, variable_dtype, axis=sentinel, epsilon=1e-5, params=None):
55
+ """Normalize to mean = 0, std = 1, then do a diagonal affine transform."""
56
+ if axis is sentinel:
57
+ axis = x.shape[-1]
58
+
59
+ with tf.variable_scope(scope):
60
+ n_state = x.shape[-1]
61
+
62
+ g = mtf.get_variable(x.mesh, "g", [n_state], initializer=tf.constant_initializer(1),
63
+ master_dtype=variable_dtype.master_dtype,
64
+ slice_dtype=variable_dtype.slice_dtype,
65
+ activation_dtype=variable_dtype.activation_dtype)
66
+ b = mtf.get_variable(x.mesh, "b", [n_state], initializer=tf.constant_initializer(0),
67
+ master_dtype=variable_dtype.master_dtype,
68
+ slice_dtype=variable_dtype.slice_dtype,
69
+ activation_dtype=variable_dtype.activation_dtype)
70
+
71
+ x = norm(x, axis, epsilon)
72
+ x = x * g + b
73
+ return x
74
+
75
+
76
+ def linear_attention(q, k, v):
77
+ batch_dim, seq_dim, head_dim, dim_out = (v.shape[0], v.shape[1], v.shape[2], v.shape[3])
78
+ q = mtf.rename_dimension(q, "features_per_head", "features_per_head_in")
79
+ k = mtf.rename_dimension(k, "features_per_head", "features_per_head_in")
80
+
81
+ dim_in = k.shape[-1]
82
+
83
+ q = mtf.softmax(q, dim_in)
84
+ k = mtf.softmax(k, seq_dim)
85
+
86
+ context = mtf.einsum([k, v], output_shape=[batch_dim, head_dim, dim_in, dim_out])
87
+ attn = mtf.einsum([q, context], output_shape=[batch_dim, seq_dim, head_dim, dim_out])
88
+ return attn
89
+
90
+
91
+ def causal_linear_attention(q, k, v, eps = 1e-6):
92
+ batch_dim, seq_dim, head_dim, dim_out = (v.shape[0], v.shape[1], v.shape[2], v.shape[3])
93
+ q = mtf.rename_dimension(q, "features_per_head", "features_per_head_in")
94
+ k = mtf.rename_dimension(k, "features_per_head", "features_per_head_in")
95
+
96
+ dim_in = k.shape[-1]
97
+
98
+ q = mtf.softmax(q, dim_in)
99
+ k = mtf.exp(k)
100
+
101
+ cumulative_k = mtf.cumsum(k, seq_dim) + eps
102
+ D_inv = 1. / mtf.einsum([q, cumulative_k], output_shape=[batch_dim, seq_dim, head_dim])
103
+
104
+ context = mtf.einsum([k, v], output_shape=[batch_dim, seq_dim, head_dim, dim_in, dim_out])
105
+ cumulative_context = mtf.cumsum(context, seq_dim)
106
+
107
+ attn = mtf.einsum([q, cumulative_context, D_inv], output_shape=[batch_dim, seq_dim, head_dim, dim_out])
108
+ return attn
109
+
110
+
111
+ def linear(x, scope, nf, *, w_init_stdev=0.02, variable_dtype, params=None, scale=False):
112
+ # nf = number of features
113
+ if params["scale_by_depth"] and scale:
114
+ # Scale by sqrt(num_layers), only happens at the final projection before a res block output
115
+ w_init_stdev = w_init_stdev * (1. / math.sqrt(params["n_layer"]))
116
+ if params["scale_by_in"]: # Scale by sqrt(num_input_features)
117
+ w_init_stdev = w_init_stdev * (1. / math.sqrt(x.shape[-1].size)) # Dimension is a namedtuple of (name, size)
118
+ # Not in the variable_scope because mtf already has a variable_scope in it
119
+ with tf.variable_scope("conv1d_main"):
120
+ c = mtf.layers.dense(x, new_dims=[nf], reduced_dims=[x.shape[-1]], name=scope, use_bias=True,
121
+ kernel_initializer=tf.random_normal_initializer(stddev=w_init_stdev),
122
+ variable_dtype=variable_dtype,
123
+ )
124
+ return c
125
+
126
+
127
+ def memory_key_values(k, v, num_mem_kv, dim_batch, dim_heads, variable_dtype, mesh):
128
+ """memory / key values from all attention paper"""
129
+
130
+ dim_mem_kv = mtf.Dimension("mem_kv_sequence", num_mem_kv)
131
+ emb_dim = k.shape[-1]
132
+ mem_std = 1 / math.sqrt(emb_dim.size)
133
+
134
+ mem_k = mtf.get_variable(mesh, "mem_k", mtf.Shape([dim_mem_kv, dim_heads, emb_dim]),
135
+ initializer=tf.random_normal_initializer(stddev=mem_std),
136
+ master_dtype=variable_dtype.master_dtype,
137
+ slice_dtype=variable_dtype.slice_dtype,
138
+ activation_dtype=variable_dtype.activation_dtype,
139
+ )
140
+ mem_v = mtf.get_variable(mesh, "mem_v", mtf.Shape([dim_mem_kv, dim_heads, emb_dim]),
141
+ initializer=tf.random_normal_initializer(stddev=mem_std),
142
+ master_dtype=variable_dtype.master_dtype,
143
+ slice_dtype=variable_dtype.slice_dtype,
144
+ activation_dtype=variable_dtype.activation_dtype)
145
+
146
+ mem_k, mem_v = map(lambda t: mtf.broadcast(t, [dim_batch, dim_mem_kv, dim_heads, emb_dim]),
147
+ (mem_k, mem_v))
148
+ mem_k, mem_v = map(lambda t: mtf.rename_dimension(t, "mem_kv_sequence", "sequence"),
149
+ (mem_k, mem_v))
150
+
151
+ k = mtf.concat([mem_k, k], "sequence")
152
+ v = mtf.concat([mem_v, v], "sequence")
153
+ return k, v
154
+
155
+
156
+ def attn(x, scope, n_state, *, attention_type, params, bias, dim_seq, memory_length_dim, variable_dtype, context=None, pos_emb=None):
157
+ # x :: [batch, seq, n_embd]
158
+ x_shape, dim_batch, *_, dim_embd, mesh = x.shape, *x.shape, x.mesh
159
+
160
+ # n_state is the same as config["n_embd"], which is also the same as dim_embd.
161
+ assert n_state.size % params["n_head"] == 0
162
+
163
+ dim_heads = mtf.Dimension("heads", params["n_head"])
164
+
165
+ num_mem_kv = params.get("num_mem_kv", 0)
166
+ use_num_mem_kv = num_mem_kv > 0
167
+
168
+ with tf.variable_scope(scope):
169
+ # Compute attention inputs
170
+ dim_kv = mtf.Dimension("features_per_head", params["n_embd"] // params["n_head"])
171
+ mtfparams = mtf.transformer.attention.attention_params_simple(
172
+ x.mesh,
173
+ io_dim=dim_embd,
174
+ kv_dim=dim_kv,
175
+ heads_dim=dim_heads,
176
+ variable_dtype=variable_dtype
177
+ )
178
+ q = mtfparams.compute_q(x)
179
+ k = mtfparams.compute_k(x)
180
+ v = mtfparams.compute_v(x)
181
+
182
+ if is_incremental_inference(context):
183
+ one_hot = mtf.one_hot(context.position - 1, dim_seq, dtype=variable_dtype.master_dtype)
184
+ inv_one_hot = 1.0 - one_hot
185
+ old_k, old_v = context.get_states(2)
186
+ k = old_k * inv_one_hot + k * one_hot
187
+ v = old_v * inv_one_hot + v * one_hot
188
+
189
+ if exists(context):
190
+ context.record_new_states([k, v])
191
+
192
+ if exists(pos_emb):
193
+ cos, sin = pos_emb
194
+ k = apply_rotary_emb(k, cos, sin)
195
+
196
+ if is_incremental_inference(context):
197
+ seq_dim = cos.shape.get_dim_by_name('sequence')
198
+ cos = mtf.gather(cos, context.position - 1, seq_dim)
199
+ sin = mtf.gather(sin, context.position - 1, seq_dim)
200
+
201
+ q = apply_rotary_emb(q, cos, sin)
202
+
203
+ with tf.variable_scope("attention"):
204
+ if attention_type == "local":
205
+ # `local_attention_1d` has built in autoregressive masking, so we don't need mask_attn_weights.
206
+ radius = params.get("local_attention_radius", 256)
207
+
208
+ if is_incremental_inference(context):
209
+ q *= one_hot
210
+
211
+ a = mtf_transformer.attention.local_attention_1d(
212
+ q, k, v,
213
+ length_dim=k.shape[1],
214
+ key_dim=dim_kv,
215
+ value_dim=dim_kv,
216
+ radius=radius,
217
+ length_dim_num_splits=1,
218
+ fully_autoregressive=params["causal"],
219
+ attention_kwargs={},
220
+ )
221
+
222
+ if is_incremental_inference(context):
223
+ a = mtf.gather(a, context.position - 1, dim_seq)
224
+
225
+ elif attention_type == "global":
226
+
227
+ # TODO: pass in fake context
228
+ # Broadcast mask bias across batch and heads
229
+ if exists(bias):
230
+ if not is_incremental_inference(context):
231
+ broadcasted_bias = mtf.broadcast(bias, [dim_batch, dim_heads, bias.shape[-2], bias.shape[-1]])
232
+ else:
233
+ # In the incremental case, a custom mask needs to be built that masks out all key/values that are greater than the current position
234
+ bias = mtf.gather(bias, context.position - 1, dim_seq)
235
+ broadcasted_bias = mtf.broadcast(bias, [dim_batch, dim_heads, bias.shape[-1]])
236
+
237
+ # memory key / values, from all-attention paper
238
+ if use_num_mem_kv:
239
+ k, v = memory_key_values(k, v, num_mem_kv, dim_batch, dim_heads, variable_dtype, mesh)
240
+
241
+ k = mtf.replace_dimensions(k, k.shape[1], memory_length_dim)
242
+ v = mtf.replace_dimensions(v, v.shape[1], memory_length_dim)
243
+
244
+ attn_dropout_rate = params["attn_dropout"] if params["mode"] == "train" else 0
245
+
246
+ a = mtf_transformer.attention.attention(
247
+ q, k, v,
248
+ memory_length_dim=memory_length_dim,
249
+ key_dim=dim_kv,
250
+ value_dim=dim_kv,
251
+ bias=broadcasted_bias,
252
+ dropout_rate=attn_dropout_rate
253
+ )
254
+
255
+ elif attention_type == "linear":
256
+ linear_attn_fn = causal_linear_attention if params["causal"] else linear_attention
257
+ a = linear_attn_fn(q, k, v)
258
+
259
+ else:
260
+ raise NotImplementedError("Unknown attention type {}!".format(attention_type))
261
+
262
+ with tf.variable_scope("compute_output"):
263
+ a = mtfparams.compute_output(a, x_shape)
264
+
265
+ with tf.variable_scope("compute_output_bias"):
266
+ b = mtf.get_variable(x.mesh, "o_b", [dim_embd], initializer=tf.constant_initializer(0),
267
+ master_dtype=variable_dtype.master_dtype,
268
+ slice_dtype=variable_dtype.slice_dtype,
269
+ activation_dtype=variable_dtype.activation_dtype)
270
+ a += b
271
+
272
+ if params["mode"] == "train" and params["res_dropout"] > 0:
273
+ a = mtf.dropout(a, rate=params["res_dropout"], name="res_dropout")
274
+ return a
275
+
276
+
277
+ def mlp(x, scope, n_state, *, variable_dtype, params):
278
+ activation_fn = get_activation_fn(params)
279
+ with tf.variable_scope(scope):
280
+ nx = x.shape[-1]
281
+ h = activation_fn(linear(x, "c_fc", n_state, variable_dtype=variable_dtype, params=params))
282
+ h2 = linear(h, "c_proj", nx, variable_dtype=variable_dtype, params=params, scale=True)
283
+ if params["mode"] == "train" and params["res_dropout"] > 0:
284
+ h2 = mtf.dropout(h2, rate=params["res_dropout"], name="mlp_dropout")
285
+ return h2
286
+
287
+
288
+ def mlp_glu(x, scope, n_state, *, variable_dtype, params):
289
+ activation_fn = get_activation_fn(params)
290
+ with tf.variable_scope(scope):
291
+ nx = x.shape[-1]
292
+ h = linear(x, "c_fc", n_state, params=params)
293
+
294
+ h, gate = mtf.split(h, h.shape[-1], 2)
295
+ h *= activation_fn(gate)
296
+
297
+ h2 = linear(h, "c_proj", nx, variable_dtype=variable_dtype, params=params, scale=True)
298
+ if params["mode"] == "train" and params["res_dropout"] > 0:
299
+ h2 = mtf.dropout(h2, rate=params["res_dropout"], name="mlp_dropout")
300
+ return h2
301
+
302
+
303
+ def axial_positional_emb(embd_dim, mesh, params, variable_dtype):
304
+ # Use axial position encoding
305
+ axial_dim_1, axial_dim_2 = params["axial_pos_emb"]
306
+
307
+ axial_dim = mtf.Dimension("axial_dim", axial_dim_1 * axial_dim_2)
308
+ dim_axials = [mtf.Dimension(f"axial_dim_{i}", t) for i, t in enumerate((axial_dim_1, axial_dim_2))]
309
+
310
+ axial_wpe_1 = mtf.get_variable(mesh, "axial_wpe_1", mtf.Shape([dim_axials[0], embd_dim]),
311
+ initializer=tf.random_normal_initializer(stddev=0.01),
312
+ master_dtype=variable_dtype.master_dtype,
313
+ slice_dtype=variable_dtype.slice_dtype,
314
+ activation_dtype=variable_dtype.activation_dtype)
315
+
316
+ axial_wpe_2 = mtf.get_variable(mesh, "axial_wpe_2", mtf.Shape([dim_axials[1], embd_dim]),
317
+ initializer=tf.random_normal_initializer(stddev=0.01),
318
+ master_dtype=variable_dtype.master_dtype,
319
+ slice_dtype=variable_dtype.slice_dtype,
320
+ activation_dtype=variable_dtype.activation_dtype)
321
+
322
+ axial_wpe_1, axial_wpe_2 = map(lambda t: mtf.broadcast(t, [dim_axials[0], dim_axials[1], embd_dim]),
323
+ (axial_wpe_1, axial_wpe_2))
324
+ wpe = (axial_wpe_1 + axial_wpe_2) / 2
325
+
326
+ wpe = mtf.reshape(wpe, [axial_dim, embd_dim])
327
+
328
+ return wpe
329
+
330
+ def rotary_positional_emb(mesh, sequence_dim, params, variable_dtype):
331
+ dtype = variable_dtype.master_dtype
332
+ dim_head = params["n_embd"] // params["n_head"]
333
+
334
+ dim_head = mtf.Dimension("features_per_head", dim_head)
335
+ half_dim_head = mtf.Dimension("half_features_per_head", dim_head.size // 2)
336
+
337
+ dim_range = mtf.range(mesh, half_dim_head, dtype) * 2 / dim_head.size
338
+ half_freqs = 1. / mtf.pow(mtf.constant(mesh, 10000, dtype = dtype), dim_range)
339
+
340
+ seq = mtf.range(mesh, sequence_dim, dtype)
341
+ half_freqs = mtf.einsum([half_freqs, seq], [sequence_dim, half_dim_head])
342
+
343
+ freqs = mtf.concat((half_freqs, half_freqs), half_dim_head.name)
344
+ freqs = mtf.rename_dimension(freqs, half_dim_head.name, dim_head.name)
345
+ return mtf.cos(freqs), mtf.sin(freqs)
346
+
347
+ def rotate_half(x):
348
+ dim_head_name = "features_per_head"
349
+ dim_head = x.shape.get_dim_by_name(dim_head_name)
350
+ half_dim_head_size = dim_head.size // 2
351
+ x1 = mtf.slice(x, 0, half_dim_head_size, dim_head_name)
352
+ x2 = mtf.slice(x, half_dim_head_size, half_dim_head_size, dim_head_name)
353
+ return mtf.concat((-x2, x1), dim_head.name)
354
+
355
+ def apply_rotary_emb(x, cos, sin):
356
+ rotated_x = rotate_half(x)
357
+ return x * cos + rotated_x * sin
models/utils.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import mesh_tensorflow as mtf
3
+ from functools import partial
4
+
5
+
6
+ def entmax_backward(explicit_inputs, all_inputs, forward_operations, outputs, output_grads, alpha=1.3, dim=None,
7
+ n_iter=50):
8
+ x, = explicit_inputs
9
+ y, = outputs
10
+ dY, = output_grads
11
+
12
+ gppr = mtf.where(mtf.greater(y, 0), mtf.pow(y, (2 - alpha)), mtf.zeros_like(y))
13
+ dX = dY * gppr
14
+
15
+ q = mtf.reduce_sum(dX, reduced_dim=dim) / mtf.reduce_sum(gppr, reduced_dim=dim)
16
+ dX = dX - q * gppr
17
+
18
+ return dX,
19
+
20
+
21
+ def entmax_forward(x, alpha=1.3, dim=None, n_iter=50):
22
+ assert alpha > 1 and alpha < 2, 'alpha must be between 1 and 2'
23
+
24
+ _gp = lambda x, alpha: x ** (alpha - 1)
25
+ _gp_inv = lambda x, alpha: mtf.pow(x, (1 / (alpha - 1)))
26
+ _p = lambda x, alpha: _gp_inv(mtf.relu(x), alpha)
27
+
28
+ dim = x.shape[-1] if dim is None else dim
29
+ d = dim.size
30
+
31
+ x = x * (alpha - 1)
32
+
33
+ max_val = mtf.reduce_max(x, reduced_dim=dim)
34
+
35
+ tau_lo = max_val - _gp(1, alpha)
36
+ tau_hi = max_val - _gp(1 / d, alpha)
37
+
38
+ f_lo = mtf.reduce_sum(_p(x - tau_lo, alpha), reduced_dim=dim) - 1
39
+
40
+ dm = tau_hi - tau_lo
41
+
42
+ for _ in range(n_iter):
43
+ dm = dm / 2
44
+ tau_m = tau_lo + dm
45
+ p_m = _p(x - tau_m, alpha)
46
+ f_m = mtf.reduce_sum(p_m, reduced_dim=dim) - 1
47
+
48
+ mask = mtf.greater_equal((f_m * f_lo), 0)
49
+ tau_lo = mtf.where(mask, tau_m, tau_lo)
50
+
51
+ p_m = p_m / mtf.reduce_sum(p_m, reduced_dim=dim)
52
+ return p_m
53
+
54
+
55
+ def entmax(x, alpha=1.3, dim=None, n_iter=50):
56
+ kwargs = dict(alpha=alpha, dim=dim, n_iter=n_iter)
57
+
58
+ return mtf.custom_gradient(
59
+ partial(entmax_forward, **kwargs),
60
+ partial(entmax_backward, **kwargs),
61
+ [x]
62
+ )
63
+
64
+
65
+ def entmax_cross_entropy_with_logits(logits, targets, vocab_dim, z_loss=0.0):
66
+ if targets.dtype.is_integer:
67
+ # hard targets
68
+ if (set(targets.shape.dims) != set(logits.shape.dims).difference([vocab_dim])):
69
+ raise ValueError(
70
+ "softmax_cross_entropy_with_logits with hard targets "
71
+ "dims in targets=%s should be dims in logits=%s other than "
72
+ "vocab_dim=%s" % (targets, logits, vocab_dim))
73
+ targets = mtf.one_hot(targets, vocab_dim, dtype=logits.dtype)
74
+ elif set(targets.shape.dims) != set(logits.shape.dims):
75
+ raise ValueError(
76
+ "softmax_cross_entropy_with_logits with soft targets "
77
+ "dims in targets=%s should be dims in logits=%s" % (targets, logits))
78
+
79
+ if vocab_dim not in logits.shape.dims:
80
+ raise ValueError("vocab_dim must be in logits.shape.dims")
81
+
82
+ log_entmax = mtf.log(entmax(logits, dim=vocab_dim))
83
+
84
+ loss = mtf.negative(
85
+ mtf.reduce_sum(log_entmax * targets, reduced_dim=vocab_dim))
86
+
87
+ return loss
88
+
89
+
90
+ def sample_categorical(x, dim=None):
91
+ dim = x.shape[-1] if dim is None else dim
92
+
93
+ cdf = mtf.cumsum(x, dim)
94
+ rand_uniform = mtf.random_uniform(x.mesh, x.shape - dim, minval=0, maxval=1)
95
+ mask = mtf.cast(mtf.greater(cdf, rand_uniform), tf.int32)
96
+ return mtf.argmax(mask, dim)
97
+
98
+
99
+ def biasmask_attn_weights(mesh, nd, ns, variable_dtype):
100
+ # The old mask_attn_weights applied directly to the QK;
101
+ # this returns a bias that the attention code from mtf adds to the attention matrix.
102
+ # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.
103
+ # n_src and n_dest are both the same, i.e equal to sequence length
104
+ # We rename ns because we want bias to have shape [batch, heads, memory_length, sequence] to match up with QK^T
105
+ # Information flows from k and v (memory_length) to q (sequence)
106
+ i = mtf.range(mesh, nd, tf.int32) + ns.size - nd.size
107
+ j = mtf.range(mesh, ns, tf.int32)
108
+ i, j = map(lambda t: mtf.broadcast(t, [nd, ns]), (i, j))
109
+ dtype = variable_dtype.activation_dtype
110
+ return mtf.cast(mtf.less(i, j), dtype) * -1e10
111
+
112
+
113
+ def parse_inputs(mtf_features, other_features):
114
+ # Parse inputs and labels from the mtf_features / other_features input dicts
115
+ # All dimensions are defined inside model_fn for efficiency
116
+ x = mtf_features["inputs"]
117
+
118
+ batch_dim = x.shape[0]
119
+ sequence_dim = x.shape[1]
120
+ embd_dim = other_features["embd_dim"]
121
+ vocab_dim = other_features["vocab_dim"]
122
+ embed_sequence_dim = other_features["embed_sequence_dim"]
123
+
124
+ return x, batch_dim, sequence_dim, embd_dim, vocab_dim, embed_sequence_dim
optimizers.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ import re
6
+ import mesh_tensorflow as mtf
7
+ import tensorflow.compat.v1 as tf
8
+
9
+ def clip_by_global_norm(grads, clip_norm):
10
+ """Clip the grads by global norm."""
11
+ global_norm = mtf.sqrt(mtf.add_n([mtf.reduce_sum(mtf.square(t)) for t in grads if t is not None]))
12
+ multiplier = clip_norm / mtf.maximum(global_norm, clip_norm)
13
+ clipped_grads = [None if t is None else t * multiplier for t in grads]
14
+ return clipped_grads, global_norm
15
+
16
+ def get_optimizer(mesh, loss, params, variable_dtype, inp_var_grads=None):
17
+ """Creates and returns an optimizer training op."""
18
+ global_step = tf.train.get_or_create_global_step()
19
+
20
+ learning_rate = tf.constant(value=params["lr"], shape=[], dtype=variable_dtype.slice_dtype)
21
+ clip_value = mtf.constant(mesh, params["gradient_clipping"], dtype=variable_dtype.slice_dtype)
22
+
23
+ if inp_var_grads is None:
24
+ var_grads = mtf.gradients([loss], [v.outputs[0] for v in mesh.graph.trainable_variables])
25
+ else:
26
+ var_grads = inp_var_grads
27
+
28
+ # Cast to full precision
29
+ var_grads_fp = [mtf.cast(v, variable_dtype.slice_dtype) for v in var_grads]
30
+
31
+ # decrease LR to final lr (lr*0.1) by this step - defaults to train_steps
32
+ end_step = params.get("lr_decay_end", params["train_steps"])
33
+
34
+ if params["lr_decay"] == "linear":
35
+ learning_rate = tf.train.polynomial_decay(
36
+ learning_rate,
37
+ global_step,
38
+ end_step,
39
+ end_learning_rate=params["lr"]*0.1, # Decrease to 10% of initial LR according to GPT-3 paper
40
+ power=1.0,
41
+ cycle=False)
42
+ elif params["lr_decay"] == "cosine":
43
+ learning_rate = tf.train.cosine_decay(
44
+ learning_rate,
45
+ global_step,
46
+ end_step,
47
+ alpha=0.1 # Alpha is min lr value as a fraction of init lr.
48
+ )
49
+
50
+ if params["warmup_steps"] > 0:
51
+ global_steps_int = tf.cast(global_step, tf.int32)
52
+ warmup_steps_int = tf.constant(params["warmup_steps"], dtype=tf.int32)
53
+
54
+ dtype = variable_dtype.slice_dtype
55
+
56
+ global_steps_float = tf.cast(global_steps_int, dtype)
57
+ warmup_steps_float = tf.cast(warmup_steps_int, dtype)
58
+
59
+ warmup_percent_done = global_steps_float / warmup_steps_float
60
+ warmup_learning_rate = learning_rate * warmup_percent_done
61
+
62
+ is_warmup = tf.cast(global_steps_int < warmup_steps_int, dtype)
63
+ learning_rate = ((1.0 - is_warmup) * learning_rate +
64
+ is_warmup * warmup_learning_rate)
65
+
66
+ learning_rate = mtf.import_fully_replicated(mesh, learning_rate, mtf.Shape([]), name="learning_rate")
67
+ mtf.scalar_summary("lr", learning_rate)
68
+
69
+ if params["opt_name"].lower() == "adam":
70
+ optimizer = AdamWeightDecayOptimizer(
71
+ learning_rate=learning_rate,
72
+ weight_decay_rate=params["weight_decay"],
73
+ beta_1=params["beta1"],
74
+ beta_2=params["beta2"],
75
+ epsilon=params["epsilon"],
76
+ exclude_from_weight_decay=["norm", "bias"],
77
+ variable_dtype=variable_dtype
78
+ )
79
+ else:
80
+ optimizer = mtf.optimize.AdafactorOptimizer(
81
+ learning_rate=params["lr"],
82
+ decay_rate=params["weight_decay"],
83
+ beta1=params["beta1"],
84
+ epsilon1=params["ada_epsilon1"],
85
+ epsilon2=params["ada_epsilon2"]
86
+ )
87
+
88
+ if params["gradient_clipping"] is not None:
89
+ (var_grads_fp, _) = clip_by_global_norm(var_grads_fp, clip_norm=clip_value)
90
+
91
+ update_ops = optimizer.apply_grads(var_grads_fp, mesh.graph.trainable_variables)
92
+ return learning_rate, update_ops, var_grads_fp
93
+
94
+
95
+ class AdamWeightDecayOptimizer(mtf.optimize.Optimizer):
96
+ """A basic Adam optimizer that includes "correct" L2 weight decay."""
97
+
98
+ def __init__(self,
99
+ learning_rate,
100
+ weight_decay_rate=0.0,
101
+ beta_1=0.9,
102
+ beta_2=0.999,
103
+ epsilon=1e-6,
104
+ exclude_from_weight_decay=None,
105
+ variable_dtype=None):
106
+ """Constructs a AdamWeightDecayOptimizer."""
107
+
108
+ self.learning_rate = learning_rate
109
+ self.weight_decay_rate = weight_decay_rate
110
+ self.beta_1 = beta_1
111
+ self.beta_2 = beta_2
112
+ self.epsilon = epsilon
113
+ self.exclude_from_weight_decay = exclude_from_weight_decay
114
+ self.variable_dtype = variable_dtype
115
+
116
+ def apply_grad(self, grad, var):
117
+ """See base class."""
118
+ if grad is None:
119
+ tf.logging.warning("Gradient is None for variable %s" % var.name)
120
+ return []
121
+
122
+ grad = mtf.to_float(grad)
123
+
124
+ assignments = []
125
+
126
+ m = mtf.get_variable(
127
+ var.mesh, var.name + "/adam_m", var.shape,
128
+ initializer=tf.zeros_initializer(),
129
+ # master_dtype=self.variable_dtype.master_dtype,
130
+ # slice_dtype=self.variable_dtype.slice_dtype,
131
+ # activation_dtype=self.variable_dtype.activation_dtype,
132
+ trainable=False)
133
+
134
+ v = mtf.get_variable(
135
+ var.mesh, var.name + "/adam_v", var.shape,
136
+ initializer=tf.zeros_initializer(),
137
+ # master_dtype=self.variable_dtype.master_dtype,
138
+ # slice_dtype=self.variable_dtype.slice_dtype,
139
+ # activation_dtype=self.variable_dtype.activation_dtype,
140
+ trainable=False)
141
+
142
+ # Standard Adam update.
143
+ next_m = self.beta_1 * m + (1.0 - self.beta_1) * grad
144
+ next_v = self.beta_2 * v + (1.0 - self.beta_2) * mtf.square(grad)
145
+
146
+ update = next_m / (mtf.sqrt(next_v) + self.epsilon)
147
+
148
+ # Just adding the square of the weights to the loss function is *not*
149
+ # the correct way of using L2 regularization/weight decay with Adam,
150
+ # since that will interact with the m and v parameters in strange ways.
151
+ #
152
+ # Instead we want to decay the weights in a manner that doesn't interact
153
+ # with the m/v parameters. This is equivalent to adding the square
154
+ # of the weights to the loss with plain (non-momentum) SGD.
155
+ if self._do_use_weight_decay(var.name):
156
+ update += mtf.to_float(var.value) * self.weight_decay_rate
157
+
158
+ update_with_lr = self.learning_rate * update
159
+
160
+ var_update = mtf.assign_sub(var, update_with_lr)
161
+
162
+ assignments.extend(
163
+ [var_update,
164
+ mtf.assign(m, next_m),
165
+ mtf.assign(v, next_v)])
166
+ return assignments
167
+
168
+ def _do_use_weight_decay(self, param_name):
169
+ """Whether to use L2 weight decay for `param_name`."""
170
+ if not self.weight_decay_rate:
171
+ return False
172
+ if self.exclude_from_weight_decay:
173
+ for r in self.exclude_from_weight_decay:
174
+ if re.search(r, param_name) is not None:
175
+ return False
176
+ return True
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ google-api-python-client
2
+ jsonlines
3
+ lm_dataformat
4
+ mesh-tensorflow==0.1.18
5
+ numpy
6
+ oauth2client
7
+ ortools
8
+ pytest
9
+ sacred
10
+ tensorflow==2.5.0
11
+ tensorflow-datasets==3.2.1
12
+ tokenizers==0.9.4
13
+ transformers==4.1.1
14
+ tpunicorn
15
+ absl-py
16
+ ftfy
17
+ sacred
18
+ pymongo
run_experiment.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import atexit
2
+ import sacred
3
+ import argparse
4
+ import time
5
+ import math
6
+ import subprocess
7
+ import shutil
8
+ import os
9
+ import json
10
+ import threading
11
+ import requests
12
+ import glob
13
+ from configs import fetch_model_params
14
+ import socket
15
+ import subprocess
16
+ import queue
17
+ import sys
18
+ import signal
19
+
20
+
21
+ parser = argparse.ArgumentParser()
22
+ parser.add_argument('--tpu', type=str, required=True) # Name of TPU to train on, if any
23
+ parser.add_argument('--model', type=str, required=True) # JSON file that contains model parameters
24
+ parser.add_argument('--experiment_name', type=str, required=True) # name of experiment (will show up in omniboard)
25
+ parser.add_argument('--steps_per_checkpoint', type=int, default=5000)
26
+ parser.add_argument('--autostack', action="store_false")
27
+ parser.add_argument('--auto_layout', action="store_true")
28
+ parser.add_argument('--auto_layout_and_mesh_shape', action="store_true")
29
+ parser.add_argument('--new', action='store_true')
30
+ parser.add_argument('--test', action='store_true')
31
+ parser.add_argument('--eval', action='store_true')
32
+ parser.add_argument('--predict', action='store_true')
33
+ parser.add_argument('--no_delete_tpu', action='store_true')
34
+ parser.add_argument('--initial_heartbeat_timeout', type=int, default=7200)
35
+ parser.add_argument('--heartbeat_timeout', type=int, default=1800) # kill and restart if nothing logged to tensorboard in this many seconds
36
+ args = parser.parse_args()
37
+
38
+ params = fetch_model_params(args.model)
39
+
40
+ ex = sacred.Experiment(args.experiment_name)
41
+ ex.observers.append(sacred.observers.QueuedMongoObserver(url='127.0.0.1:27017', db_name='db', username='user', password='password'))
42
+
43
+
44
+ def get_open_port(lo=8000, hi=8100):
45
+ for i in range(lo, hi):
46
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
47
+ if s.connect_ex(('localhost', i)) != 0:
48
+ return i
49
+
50
+
51
+ def train_thread(args, tpu, id, q):
52
+ print('starting training on', tpu)
53
+
54
+ # pass binary flags through
55
+ opts = ''
56
+ for flag in ['auto_layout', 'auto_layout_and_mesh_shape', 'new', 'test', 'predict', 'eval', ]:
57
+ if args.__getattribute__(flag):
58
+ opts += ' --' + flag
59
+
60
+ for flag in ['autostack', ]:
61
+ if not args.__getattribute__(flag):
62
+ opts += ' --' + flag
63
+
64
+ cmd = "python3 main.py --tpu {tpu} --model run_configs/config_{id}.json --steps_per_checkpoint {steps_per_checkpoint} {opts} --sacred_id {run_id}".format(tpu=tpu, id=id, steps_per_checkpoint=args.steps_per_checkpoint, opts=opts, run_id=id)
65
+ print('Running:', cmd)
66
+ proc = subprocess.Popen(cmd, shell=True)
67
+
68
+ # poll until it's exited
69
+ while proc.poll() is None:
70
+ time.sleep(60)
71
+ try:
72
+ nq, *nargs = q.get_nowait()
73
+ if nq == 'kill':
74
+ print('train thread recieved kill signal from logging thread')
75
+ # first send SIGTERM
76
+ proc.terminate()
77
+
78
+ time.sleep(60)
79
+
80
+ # if it still hasn't exited, we send SIGKILL
81
+ if proc.poll() is None:
82
+ print('SIGTERM not successful, sending SIGKILL')
83
+ proc.kill()
84
+
85
+ except queue.Empty:
86
+ pass
87
+
88
+ print('exited training!')
89
+ if proc.returncode == 0:
90
+ print('exited gracefully')
91
+ os.kill(os.getpid(), signal.SIGINT)
92
+ return
93
+
94
+ if args.no_delete_tpu:
95
+ print('recreate done, exiting train_thread - not killing tpu!')
96
+ return
97
+ print("Recreating {} in 60sec...".format(tpu))
98
+ time.sleep(60)
99
+ os.system("pu recreate {} --yes --retry 3600 --retry-randomness 1.5".format(tpu))
100
+ print('recreate done, exiting train_thread')
101
+
102
+ # clear out queue
103
+ while True:
104
+ try:
105
+ q.get_nowait()
106
+ print('dropped request in queue after pu recreate')
107
+ except queue.Empty:
108
+ break
109
+
110
+
111
+ def get_json(uri, params=None, timeout=15):
112
+ resp = requests.get(uri, params=params, timeout=timeout)
113
+ resp.raise_for_status()
114
+ return resp.json()
115
+
116
+
117
+ def get_tag_sets(base_uri):
118
+ j = get_json(f'{base_uri}/data/plugin/scalars/tags', {'experiment': ''})
119
+ assert isinstance(j, dict)
120
+ return {
121
+ run: j[run].keys()
122
+ for run in j.keys()
123
+ }
124
+
125
+
126
+ def get_scalar_data(base_uri, run, tag):
127
+ j = get_json(f'{base_uri}/data/plugin/scalars/scalars', {'experiment': '', 'run': run, 'tag': tag})
128
+ assert isinstance(j, list)
129
+ return j
130
+
131
+
132
+ def get_run_data(port):
133
+ base_uri = f'http://localhost:{port}/'
134
+ r = {}
135
+ try:
136
+ tag_sets = get_tag_sets(base_uri)
137
+ runs = tag_sets.keys()
138
+ if '.' in runs:
139
+ if 'loss' in tag_sets['.']:
140
+ r['loss'] = get_scalar_data(base_uri, '.', 'loss')
141
+ if 'eval' in runs:
142
+ if 'loss' in tag_sets['eval']:
143
+ r['val_loss'] = get_scalar_data(base_uri, 'eval', 'loss')
144
+ if 'eval_lambada' in runs:
145
+ if 'lambada_acc' in tag_sets['eval_lambada']:
146
+ r['lambada_acc'] = get_scalar_data(base_uri, 'eval_lambada', 'lambada_acc')
147
+ if 'lambada_log_ppl' in tag_sets['eval_lambada']:
148
+ r['lambada_ppl'] = [
149
+ [t, s, math.exp(lp)]
150
+ for [t, s, lp] in get_scalar_data(base_uri, 'eval_lambada', 'lambada_log_ppl')
151
+ ]
152
+ except:
153
+ import traceback
154
+ traceback.print_exc()
155
+ return r
156
+
157
+
158
+ @ex.main
159
+ def main(_run):
160
+ print('Starting run', _run._id)
161
+ print('experiment main invoked with argv:', " ".join(sys.argv))
162
+ print('WARNING: please remember to remove old metric log files from the model directory.')
163
+
164
+ os.makedirs('run_configs', exist_ok=True)
165
+ shutil.copy(args.model if args.model.endswith('.json') else 'configs/{}.json'.format(args.model), 'run_configs/config_{}.json'.format(_run._id))
166
+
167
+ tensorboard_port = get_open_port()
168
+ print('Tensorboard at port:', tensorboard_port)
169
+ print('Tensorboard url: ', 'http://eleutherai.bmk.sh:'+ str(tensorboard_port))
170
+ os.system("screen -S tensorboard_{} -d -m bash -c 'tensorboard --logdir {} --port {} --bind_all --reload_multifile=true || tensorboard --logdir {} --port {} --reload_multifile=true'".format(_run._id, params["model_path"], tensorboard_port,params["model_path"], tensorboard_port,))
171
+ atexit.register(goodbye, _run._id)
172
+
173
+ curr_step = {}
174
+ seen_predictions = set()
175
+
176
+ heartbeat_timeout = args.initial_heartbeat_timeout * 2
177
+ while True:
178
+ last_tb_log_time = time.time()
179
+ start_time = time.time()
180
+ q = queue.Queue()
181
+ trainthd = threading.Thread(target=train_thread, args=(args, args.tpu, _run._id, q))
182
+ trainthd.start()
183
+
184
+ while trainthd.is_alive():
185
+ time.sleep(60)
186
+
187
+ if start_time + args.initial_heartbeat_timeout < time.time():
188
+ # after initial args.initial_heartbeat_timeout grace period, now we want to set the timeout threshold much lower
189
+ heartbeat_timeout = args.heartbeat_timeout
190
+
191
+ print('Polling tensorboard for metrics...')
192
+ data = get_run_data(tensorboard_port)
193
+ for k in data.keys():
194
+ for ts, step, val in data[k]:
195
+ if step <= curr_step.get(k, -1):
196
+ continue
197
+ _run.log_scalar(k, val, step)
198
+ if k == 'loss':
199
+ _run.log_scalar('tb_ts', ts, step)
200
+ print('Logged to sacred: step={},loss={},tb_ts={}'.format(step, val, ts))
201
+
202
+ # found something new, so logging!
203
+ last_tb_log_time = time.time()
204
+
205
+ curr_step[k] = step
206
+
207
+ for f in glob.glob('predictions_{}_*'.format(_run._id)):
208
+ if f in seen_predictions:
209
+ continue
210
+ print('collecting prediction file', f)
211
+ ex.add_artifact(f)
212
+
213
+ seen_predictions.add(f)
214
+
215
+ # collect eval metrics from jsonl
216
+ if os.path.exists(f'eval_{_run._id}.jsonl'):
217
+ with open(f'eval_{_run._id}.jsonl') as fh:
218
+ for line in fh:
219
+ ob = json.loads(line)
220
+ val_step = ob['global_step']
221
+ val_task = ob['task']
222
+ for metr in ob.keys():
223
+ k = 'fs.' + val_task + '.' + metr
224
+ if metr in ['task', 'global_step']: continue
225
+ if val_step <= curr_step.get(k, -1): continue
226
+ _run.log_scalar(k, ob[metr], val_step)
227
+ curr_step[k] = val_step
228
+
229
+ if time.time() - last_tb_log_time > heartbeat_timeout:
230
+ # the run hasn't logged in a while, so we restart it
231
+ q.put(('kill',))
232
+
233
+ # give training thread some time to do its thing and recreate tpu
234
+ while trainthd.is_alive():
235
+ print('logging thread waiting for killing stalled run and for tpu recreate to finish')
236
+ time.sleep(60)
237
+
238
+ # reset heartbeat timeout to initial
239
+ heartbeat_timeout = args.initial_heartbeat_timeout
240
+ last_tb_log_time = time.time()
241
+
242
+
243
+ if args.no_delete_tpu:
244
+ break
245
+
246
+
247
+ def goodbye(id):
248
+ print("You are now leaving the Python sector.")
249
+ print("Sie verlassen den pythonischen Sektor.")
250
+
251
+ os.system("screen -S tensorboard_{} -X quit".format(id))
252
+
253
+
254
+ if __name__ == '__main__':
255
+ for file in glob.glob("**/*", recursive=True):
256
+ if file.split('.')[-1] in ['py']:
257
+ print('Adding', file, 'to sacred')
258
+ ex.add_source_file(file)
259
+
260
+ ex.add_config({
261
+ 'tpu_name': args.tpu,
262
+ **params
263
+ })
264
+
265
+ ex.run()
sample.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import mesh_tensorflow as mtf
2
+ import tensorflow.compat.v1 as tf
3
+ import mesh_tensorflow.transformer as mtf_transformer
4
+
5
+ from models.utils import entmax, sample_categorical
6
+ from models.gpt2 import gpt2
7
+
8
+ def sample_autoregressive(partial_sequences,
9
+ other_features,
10
+ params,
11
+ stop_at_token=50256,
12
+ max_steps=None,
13
+ temperature=0.9,
14
+ variable_dtype=mtf.VariableDType(tf.float32),
15
+ encoder_output=None,
16
+ encoder_sequence_id=None,
17
+ encoder_inputs=None,
18
+ shared_params=None,
19
+ has_partial_sequences=True,
20
+ encoder_layer_outputs=None,
21
+ never_end=False,
22
+ remove_partial_sequences=False,
23
+ sampling_keep_top_k=-1,
24
+ sampling_use_entmax = False,
25
+ bos_id=50256,
26
+ ):
27
+ """Sample randomly one token at a time.
28
+
29
+ The partial_sequences represent partial sequences to be continued. The
30
+ first tokens of each sequence are nonzero representing the given partial
31
+ sequences and the last tokens of each sequence are zeros, representing what
32
+ needs to be filled in.
33
+
34
+ If there are no partial sequences (you want to sample from the beginning),
35
+ then pass partial_sequences=mtf.zeros(mesh, shape, dtype=tf.int32) and
36
+ has_partial_sequences=False (so we can skip computation).
37
+
38
+ Args:
39
+ partial_sequences: an int32 Tensor with shape [<batch_dims>, length_dim]
40
+ stop_at_token: an optional integer eos id. Stop when we produce it.
41
+ max_steps: an optional integer, the max number of steps to decode.
42
+ temperature: an optional floating point value between 0.0 and 1.0 0.0
43
+ means argmax, 1.0 means sample according to predicted distribution.
44
+ variable_dtype: a mtf.VariableDType
45
+ encoder_output: an optional Tensor
46
+ encoder_sequence_id: an optional Tensor
47
+ encoder_inputs: an optional Tensor
48
+ shared_params: an optional dictionary
49
+ has_partial_sequences: a boolean
50
+ encoder_layer_outputs: optional - readonly list of tensor activations when
51
+ decoding, one per each input layer + the embedding layer
52
+ never_end: a boolean - if set, then avoid generating stop_at_token
53
+ remove_partial_sequences: a boolean - whether to remove the partial
54
+ sequences from the output
55
+ sampling_keep_top_k: an integer - if not -1, only sample from the top k
56
+ logits.
57
+ bos_id: beginning of sequence id
58
+
59
+ Returns:
60
+ a Tensor with shape [<batch_dims>, length_dim]
61
+ """
62
+
63
+ inputs = partial_sequences # Partial sequences to fill in
64
+ batch_dims = inputs.shape.dims[:-1]
65
+ length_dim = inputs.shape.dims[-1]
66
+ padding_id = params.get("padding_id", 0)
67
+ slow_sampling = params.get("slow_sampling", False)
68
+
69
+
70
+ initial_position = mtf.reduce_sum(
71
+ mtf.to_int32(mtf.not_equal(inputs, padding_id)), reduced_dim=length_dim) # Gets position where zero padding starts
72
+
73
+ length_range = mtf.range(inputs.mesh, length_dim, tf.int32)
74
+ input_full_attention = True # for now hardcode this to true bc lazy
75
+ if input_full_attention:
76
+ # Vanilla autoregressive model - each position can see previous positions.
77
+ # Think this feeds in to the loop fn and tells each position where it can attend to?
78
+ read_priority = write_priority = length_range * mtf.to_int32(
79
+ mtf.greater(length_range, initial_position))
80
+ else:
81
+ read_priority = write_priority = length_range
82
+
83
+ # Builds context to pass around internally
84
+ # The 'first part' context records initial states of k / v / x
85
+
86
+ if not slow_sampling:
87
+ context_first_part = mtf_transformer.transformer.Context(
88
+ model=None,
89
+ mesh=inputs.mesh,
90
+ batch_dims=batch_dims,
91
+ length_dim=length_dim,
92
+ variable_dtype=variable_dtype,
93
+ mode="first_part",
94
+ position=length_range,
95
+ position_is_default=True,
96
+ new_states=[],
97
+ initial_position=initial_position,
98
+ sequence_id=None,
99
+ encoder_output=encoder_output,
100
+ encoder_sequence_id=encoder_sequence_id,
101
+ constant_states=[],
102
+ shared_params=shared_params,
103
+ encoder_layer_outputs=encoder_layer_outputs,
104
+ write_priority=write_priority,
105
+ read_priority=read_priority,
106
+ inputs=inputs,
107
+ encoder_inputs=encoder_inputs)
108
+
109
+ with tf.variable_scope("gpt2"):
110
+ logits, _, _ = gpt2.model({"inputs": inputs}, other_features, params, inputs.mesh, variable_dtype=variable_dtype, context=context_first_part)
111
+
112
+ if not has_partial_sequences:
113
+ initial_states = [mtf.zeros_like(t) for t in context_first_part.new_states]
114
+ else:
115
+ initial_states = context_first_part.new_states
116
+ else:
117
+ initial_states = []
118
+
119
+ if not has_partial_sequences:
120
+ partial_sequences_eos_count = 0
121
+
122
+ if stop_at_token is not None:
123
+ partial_sequences_eos_count = mtf.reduce_sum(
124
+ mtf.to_int32(mtf.equal(partial_sequences, stop_at_token)),
125
+ reduced_dim=length_dim)
126
+
127
+ def cond_fn(position, ids, *unused_states):
128
+ """Should we run another loop iteration?"""
129
+ past_end = mtf.greater_equal(position, length_dim.size)
130
+ if max_steps:
131
+ past_end = mtf.logical_or(
132
+ past_end, mtf.greater_equal(position - initial_position, max_steps))
133
+
134
+ is_done = past_end
135
+ if stop_at_token is not None:
136
+ eos_count = mtf.reduce_sum(
137
+ mtf.to_int32(mtf.equal(ids, stop_at_token)),
138
+ reduced_dim=length_dim)
139
+ has_additional_eos = mtf.greater(eos_count, partial_sequences_eos_count)
140
+ is_done = mtf.logical_or(is_done, has_additional_eos)
141
+ all_done = mtf.reduce_all(is_done)
142
+ return mtf.logical_not(all_done)
143
+
144
+ def body_fn(position, ids, *states):
145
+ """One step in the decode loop."""
146
+ nonlocal sampling_keep_top_k
147
+
148
+ context = mtf_transformer.transformer.Context(
149
+ model=None,
150
+ mesh=inputs.mesh,
151
+ batch_dims=batch_dims,
152
+ length_dim=length_dim,
153
+ variable_dtype=variable_dtype,
154
+ mode="incremental",
155
+ position=position,
156
+ position_is_default=True,
157
+ states=states,
158
+ new_states=[],
159
+ initial_position=position,
160
+ sequence_id=None,
161
+ encoder_output=encoder_output,
162
+ encoder_sequence_id=encoder_sequence_id,
163
+ shared_params=shared_params,
164
+ encoder_layer_outputs=encoder_layer_outputs,
165
+ write_priority=write_priority,
166
+ read_priority=read_priority,
167
+ inputs=ids,
168
+ encoder_inputs=encoder_inputs) if not slow_sampling else None
169
+
170
+ with tf.variable_scope("gpt2", reuse=tf.AUTO_REUSE):
171
+ logits, _, _ = gpt2.model({"inputs": ids}, other_features, params, inputs.mesh, variable_dtype=variable_dtype, context = context)
172
+
173
+ if not sampling_use_entmax:
174
+ # By default, do top_k sampling of 0.9
175
+ if sampling_keep_top_k == -2:
176
+ sampling_keep_top_k = int(logits.shape[-1].size * 0.1)
177
+
178
+ if sampling_keep_top_k != -1:
179
+ if sampling_keep_top_k <= 0:
180
+ raise ValueError("sampling_keep_top_k must either be -1 or positive.")
181
+ k_largest = mtf.nth_largest_element(
182
+ logits, n=sampling_keep_top_k,
183
+ reduced_dim=other_features["vocab_dim"])
184
+ logits = mtf.where(mtf.less_equal(logits, k_largest),
185
+ mtf.ones_like(logits) * -1e6, logits)
186
+
187
+ ids_this_step = mtf.sample_with_temperature(
188
+ logits, other_features["vocab_dim"], temperature)
189
+ else:
190
+ ids_this_step = sample_categorical(entmax(logits))
191
+
192
+ if slow_sampling:
193
+ ids_this_step = mtf.shift(ids_this_step, offset=1, dim=length_dim, wrap=False)
194
+ else:
195
+ ids_this_step = mtf.reshape(ids_this_step, (batch_dims))
196
+
197
+ one_hot = mtf.one_hot(position, length_dim, dtype=tf.int32)
198
+ one_new_id = ids_this_step * one_hot
199
+ new_ids = (1 - one_hot) * ids + one_new_id
200
+ new_position = position + 1
201
+
202
+ ret = [new_position, new_ids]
203
+ if context is not None:
204
+ ret += context.new_states
205
+ return ret
206
+
207
+ while_loop_inputs = [initial_position, inputs] + initial_states
208
+ final_position, outputs = mtf.while_loop(
209
+ cond_fn, body_fn, while_loop_inputs)[:2]
210
+ del final_position
211
+ if has_partial_sequences and remove_partial_sequences:
212
+ # Remove partial sequences from outputs
213
+ partial_length = mtf.reduce_sum(
214
+ mtf.to_int32(mtf.not_equal(partial_sequences, padding_id)),
215
+ reduced_dim=length_dim)
216
+ outputs = mtf.dynamic_shift(
217
+ outputs, -partial_length, length_dim, wrap=False)
218
+ return outputs
tasks.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ import json
3
+ import requests
4
+ import numpy as np
5
+ import ftfy
6
+ from data.encoders import fetch_encoder, encode
7
+ import tensorflow as tf
8
+ import re
9
+ from functools import partial
10
+
11
+ lambada_src_uri = 'http://eaidata.bmk.sh/data/lambada_test.jsonl'
12
+ normalization = 'NFKC'
13
+
14
+
15
+ # Note: this task is called "lambada" but it really refers to OpenAI's version
16
+ # of the task, which actually differs in some ways from the task described in
17
+ # the original paper. So, strictly speaking, accuracy values from this task
18
+ # should not be compared to accuracy values from the original lambada task.
19
+ # For more information, see
20
+ # https://github.com/openai/gpt-2/issues/131
21
+
22
+ def lambada_create_tokens_data(params, path):
23
+ with open(path, 'w') as f:
24
+ req = requests.get(lambada_src_uri)
25
+ req.raise_for_status()
26
+ jsons = [json.loads(l) for l in req.iter_lines()]
27
+ texts = [ftfy.fix_text(j['text'], normalization=normalization) for j in jsons]
28
+ enc = fetch_encoder(params)
29
+ arrays = [encode(enc, t) for t in texts]
30
+ json.dump(arrays, f)
31
+ return arrays
32
+
33
+
34
+ def lambada_read_or_create_tokens_data(params, path):
35
+ # if you tell me where the file should go, i will helpfully create it for you
36
+ if not os.path.exists(path):
37
+ return lambada_create_tokens_data(params, path)
38
+ with open(path) as f:
39
+ return json.load(f)
40
+
41
+
42
+ def bin_pack(params, tokens_data):
43
+ eos_token = params['eos_id']
44
+ n_ctx = params['n_ctx']
45
+ dummy_token = 1
46
+ pad_batch_size = params['eval_batch_size']
47
+ bins = []
48
+ for a in tokens_data:
49
+ if len(bins) == 0 or len(bins[-1]) + len(a) + 1 > n_ctx:
50
+ bins.append([])
51
+ bins[-1] += a
52
+ bins[-1].append(eos_token)
53
+ while len(bins) % pad_batch_size != 0:
54
+ bins.append([])
55
+ bins_array = np.full((len(bins), n_ctx), dummy_token, dtype=np.uint16)
56
+ for i, b in enumerate(bins):
57
+ bins_array[i, 0:len(b)] = b
58
+ return bins_array
59
+
60
+
61
+ def lambada_init(params):
62
+ ds_configs = params['dataset_configs']
63
+ l = [
64
+ ds_configs[ds_id].get('lambada_tokens_path', "./lambada.json")
65
+ for ds_id, _, _, _ in params['datasets']
66
+ ]
67
+ assert len(l) > 0, 'lambada_tokens_path not found in the dataset config'
68
+ lt_path = l[0]
69
+ assert lt_path.endswith('.json'), 'lambada_tokens_path must have extension json'
70
+
71
+ tokens_data = lambada_read_or_create_tokens_data(params, lt_path)
72
+ bins_array = bin_pack(params, tokens_data)
73
+ params['lambada_tokens_path'] = lt_path
74
+ params['lambada_n_steps'] = len(bins_array) // params['eval_batch_size']
75
+
76
+
77
+ def lambada_get_task_info(params):
78
+ return {
79
+ 'n_steps': params['lambada_n_steps'],
80
+ }
81
+
82
+
83
+ # The LAMBADA evaluation code looks at the logits of each position just before an eos_token
84
+ def lambada_input(params):
85
+ eos_token = 50256 if params['n_vocab'] >= 50257 else 0
86
+ n_ctx = params['n_ctx']
87
+ lt_path = params['lambada_tokens_path']
88
+ tokens_data = lambada_read_or_create_tokens_data(params, lt_path)
89
+ bins_array = bin_pack(params, tokens_data)
90
+ dataset = tf.data.Dataset.from_tensor_slices(bins_array)
91
+
92
+ def _get_output(bin):
93
+ bin = tf.cast(bin, dtype=tf.int32)
94
+ indexes = tf.range(n_ctx)
95
+ results = tf.gather(bin, (indexes + 1) % n_ctx)
96
+ eos_next_positions = tf.math.equal(tf.gather(bin, (indexes + 2) % n_ctx), eos_token)
97
+ output = tf.where(eos_next_positions, results, tf.constant(eos_token, shape=[n_ctx]))
98
+ bin = tf.reshape(bin, [n_ctx])
99
+ bin = tf.cast(bin, dtype=tf.int32)
100
+ output = tf.reshape(output, [n_ctx])
101
+ output = tf.cast(output, dtype=tf.int32)
102
+ return bin, output
103
+
104
+ dataset = dataset.map(_get_output)
105
+ dataset = dataset.batch(params['eval_batch_size'], drop_remainder=True)
106
+ dataset = dataset.repeat()
107
+ return dataset
108
+
109
+
110
+ task_descriptors = {
111
+ 'lambada': {
112
+ 'init_fn': lambada_init,
113
+ 'get_task_info_fn': lambada_get_task_info,
114
+ 'input_fn': lambada_input,
115
+ }
116
+ }
test_models.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ import traceback
3
+ import logging
4
+ from collections import defaultdict
5
+ from contextlib import contextmanager
6
+
7
+ import tensorflow as tf
8
+ tf.compat.v1.enable_eager_execution()
9
+ import mesh_tensorflow as mtf
10
+ from mesh_tensorflow import placement_mesh_impl
11
+
12
+ from inputs import mlm_sample_text
13
+ from models.gpt2 import gpt2
14
+ from models.utils import biasmask_attn_weights, entmax, sample_categorical
15
+
16
+ from sample import sample_autoregressive
17
+
18
+ # helper functions
19
+
20
+ @contextmanager
21
+ def not_raises(exception):
22
+ try:
23
+ yield
24
+ except exception:
25
+ logging.error(traceback.format_exc())
26
+ raise pytest.fail("DID RAISE {0}".format(exception))
27
+
28
+ # fixtures
29
+
30
+ params = defaultdict(lambda: None, {
31
+ "n_head": 1,
32
+ "n_ctx": 4,
33
+ "n_embd": 2,
34
+ "n_vocab": 256,
35
+ "embed_dropout": 0.,
36
+ "n_layer": 2,
37
+ "num_microbatches": 1,
38
+ "train_batch_size": 1,
39
+ "causal": True,
40
+ "attention_types": ['global', 'local'],
41
+ "res_dropout": 0.1,
42
+ "rotary_emb": True,
43
+ "activation_function": "gelu",
44
+ "moe_layers": (1,),
45
+ "num_mem_kv": 16,
46
+ "no_weight_tie": True,
47
+ "moe_params": {
48
+ 'moe_dropout_rate': 0.0
49
+ },
50
+ "mesh_shape": [],
51
+ "layout": {},
52
+ "local_attention_radius": 128,
53
+ "share_parameters": True,
54
+ "rezero": True
55
+ })
56
+
57
+ # tests
58
+
59
+ def test_model():
60
+ graph = mtf.Graph()
61
+ mesh = mtf.Mesh(graph, "my_mesh")
62
+
63
+ seq_len = params["n_ctx"]
64
+
65
+ batch_dim = mtf.Dimension("batch", 1)
66
+ sequence_dim = mtf.Dimension("sequence", seq_len)
67
+
68
+ features = {
69
+ 'inputs': mtf.ones(mesh, mtf.Shape((batch_dim, sequence_dim)), tf.int32),
70
+ 'labels': mtf.ones(mesh, mtf.Shape((batch_dim, sequence_dim)), tf.int32)
71
+ }
72
+
73
+ # create mask
74
+
75
+ num_mem_kv = params.get('num_mem_kv', 0)
76
+ length_dim = mtf.Dimension('sequence', seq_len)
77
+ memory_length_dim = mtf.Dimension('memory_length', seq_len + num_mem_kv)
78
+ embed_sequence_dim = mtf.Dimension('embed_sequence', seq_len)
79
+ embd_dim = mtf.Dimension("embd", params["n_embd"])
80
+ vocab_dim = mtf.Dimension("vocab", params["n_vocab"])
81
+
82
+ other_features = {}
83
+ variable_dtype = mtf.VariableDType(tf.float32, tf.float32, tf.float32)
84
+
85
+ other_features["attn_bias"] = biasmask_attn_weights(mesh, length_dim, memory_length_dim, variable_dtype)
86
+ other_features["embd_dim"] = embd_dim
87
+ other_features["vocab_dim"] = vocab_dim
88
+ other_features["embed_sequence_dim"] = embed_sequence_dim
89
+ other_features["memory_length_dim"] = memory_length_dim
90
+
91
+ with not_raises(Exception):
92
+ logits, _, _ = gpt2.model(features, other_features, params, mesh, variable_dtype=variable_dtype)
93
+
94
+ mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""])
95
+ lowering = mtf.Lowering(graph, {mesh: mesh_impl})
96
+ logits = lowering.export_to_tf_tensor(logits)
97
+
98
+
99
+ def test_sampling():
100
+ graph = mtf.Graph()
101
+ mesh = mtf.Mesh(graph, "my_mesh")
102
+
103
+ batch_dim = mtf.Dimension("batch", 1)
104
+ sequence_dim = mtf.Dimension("sequence", 1)
105
+
106
+ inputs = mtf.ones(mesh, mtf.Shape((batch_dim, sequence_dim)), tf.int32)
107
+ inputs = mtf.pad(inputs, [0, 3], sequence_dim.name)
108
+
109
+ # create mask
110
+
111
+ seq_len = params["n_ctx"]
112
+ num_mem_kv = params.get('num_mem_kv', 0)
113
+ length_dim = mtf.Dimension('sequence', seq_len)
114
+ memory_length_dim = mtf.Dimension('memory_length', seq_len + num_mem_kv)
115
+ embed_sequence_dim = mtf.Dimension('embed_sequence', seq_len)
116
+ embd_dim = mtf.Dimension("embd", params["n_embd"])
117
+ vocab_dim = mtf.Dimension("vocab", params["n_vocab"])
118
+
119
+ other_features = {}
120
+
121
+ other_features["attn_bias"] = biasmask_attn_weights(mesh, length_dim, memory_length_dim, mtf.VariableDType(tf.float32))
122
+ other_features["embd_dim"] = embd_dim
123
+ other_features["vocab_dim"] = vocab_dim
124
+ other_features["embed_sequence_dim"] = embed_sequence_dim
125
+ other_features["memory_length_dim"] = memory_length_dim
126
+
127
+ params["mode"] = "predict"
128
+
129
+ with not_raises(Exception):
130
+ samples = sample_autoregressive(
131
+ inputs, other_features=other_features, params=params, variable_dtype=mtf.VariableDType(),
132
+ remove_partial_sequences=params["remove_partial_sequences"], stop_at_token=params["eos_id"], sampling_use_entmax=True)
133
+
134
+ mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""])
135
+ lowering = mtf.Lowering(graph, {mesh: mesh_impl})
136
+ samples = lowering.export_to_tf_tensor(samples)
137
+
138
+ # mlm
139
+
140
+ mlm_params = defaultdict(lambda: None, {
141
+ "n_head": 1,
142
+ "n_ctx": 4,
143
+ "n_embd": 1,
144
+ "n_vocab": 256,
145
+ "embed_dropout": 0.,
146
+ "n_layer": 2,
147
+ "num_microbatches": 1,
148
+ "train_batch_size": 1,
149
+ "attention_types": ['global', 'local'],
150
+ "res_dropout": 0.1,
151
+ "mesh_shape": [],
152
+ "layout": {},
153
+ "share_parameters": True,
154
+ "mlm_training": True,
155
+ "mlm_mask_id": 3,
156
+ "mlm_cls_token_id": 4,
157
+ "mlm_random_token_prob": 0.1
158
+ })
159
+
160
+ def test_mlm_sample_text():
161
+ document = tf.random.normal((16,))
162
+ with not_raises(Exception):
163
+ features, labels = mlm_sample_text(mlm_params, document, random_documents = True)
164
+ assert features.shape == (mlm_params['n_ctx'],)
165
+
166
+ # entmax
167
+
168
+ def test_entmax():
169
+ graph = mtf.Graph()
170
+ mesh = mtf.Mesh(graph, "my_mesh")
171
+ length = mtf.Dimension("tensor_length", 8)
172
+ tensor = mtf.range(mesh, length, tf.float32)
173
+ output = entmax(tensor)
174
+ grad = mtf.gradients([output], [tensor])[0]
175
+ sample = sample_categorical(output, length)
176
+
177
+ mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""])
178
+ lowering = mtf.Lowering(graph, {mesh: mesh_impl})
179
+ sample = lowering.export_to_tf_tensor(sample)
180
+ grad = lowering.export_to_tf_tensor(grad)
utils.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from urllib.parse import urlparse
3
+ from shutil import rmtree
4
+ import logging
5
+ import os
6
+ from pathlib import Path
7
+ import sys
8
+ import tensorflow.compat.v1 as tf
9
+ import tensorflow.compat.v2 as tf2
10
+ import mesh_tensorflow as mtf
11
+ from data.encoders import fetch_encoder
12
+ import re
13
+
14
+ def setup_logging(args):
15
+ Path("logs").mkdir(exist_ok=True)
16
+ tf.logging.set_verbosity(logging.INFO)
17
+ tf.get_logger().propagate = False # Remove double log on console
18
+ name = os.path.splitext(os.path.basename(args.model))[0]
19
+ handlers = [
20
+ logging.FileHandler(f"logs/{name}.log"),
21
+ logging.StreamHandler(sys.stdout)
22
+ ]
23
+ logger = logging.getLogger("tensorflow")
24
+ logger.handlers = handlers
25
+ return logger
26
+
27
+
28
+ def get_batch_size(params):
29
+ return params[f"{params['mode']}_batch_size"]
30
+
31
+
32
+ def add_mode_to_params(params, mode):
33
+ if mode == tf.estimator.ModeKeys.PREDICT:
34
+ params["mode"] = "predict"
35
+ elif mode == tf.estimator.ModeKeys.EVAL:
36
+ params["mode"] = "eval"
37
+ elif mode == tf.estimator.ModeKeys.TRAIN:
38
+ params["mode"] = "train"
39
+ else:
40
+ raise ValueError(f"Invalid mode {mode}")
41
+ return params
42
+
43
+
44
+ def simd_mesh_setup(params, mesh_shape, layout_rules):
45
+ """Constructs SimdMesh function - instructions on how to evenly split tensors across all TPU cores"""
46
+
47
+ num_hosts = params["context"].num_hosts
48
+ host_placement_fn = params["context"].tpu_host_placement_function
49
+ device_list = [host_placement_fn(host_id=i) for i in range(num_hosts)]
50
+ tf.logging.info(f"device_list = {device_list}")
51
+
52
+ # TODO: Better estimation of replica cache size?
53
+ replica_cache_size = 300 * 1000000 # 300M per replica
54
+
55
+ # Worker 0 caches all the TPU binaries
56
+ worker0_mem = replica_cache_size * params["context"].num_replicas
57
+ devices_memory_usage = [worker0_mem] + [0] * (num_hosts - 1)
58
+ var_placer = mtf.utils.BalancedVariablePlacer(device_list, devices_memory_usage)
59
+ mesh_devices = [""] * mesh_shape.size
60
+ mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
61
+ mesh_shape, layout_rules, mesh_devices, params["context"].device_assignment)
62
+
63
+ return var_placer, mesh_impl
64
+
65
+
66
+ def remove_batch_from_layout(layout):
67
+ """
68
+ The tf-mesh layout splits across batch size, remove it.
69
+ Useful for prediction steps, when you no longer want large batches.
70
+
71
+ :param layout: string describing tf-mesh layout
72
+ :return: layout minus batch dimension
73
+ """
74
+ layout = layout.split(',')
75
+ ret_layout = ""
76
+ for i in layout:
77
+ if "batch" in i:
78
+ pass
79
+ else:
80
+ ret_layout += f"{i},"
81
+ return ret_layout[:-1]
82
+
83
+
84
+ def yes_or_no(question):
85
+ while True:
86
+ reply = str(input(question+' (y/n): ')).lower().strip()
87
+ if reply[:1] == 'y':
88
+ return True
89
+ if reply[:1] == 'n':
90
+ return False
91
+
92
+
93
+ def remove_gs_or_filepath(path):
94
+ parsed_url = urlparse(path)
95
+ if parsed_url.scheme == "gs":
96
+ os.system(f"gsutil rm -rf {path}")
97
+ return
98
+ rmtree(path)
99
+
100
+
101
+ def save_config(params_dict, logdir):
102
+ print(f"Saving config to {logdir}")
103
+ text = "{\n\n"
104
+ total_params = len(params_dict)
105
+ for count, key in enumerate(params_dict):
106
+ config_value = str(params_dict[key])
107
+ if re.search('[a-zA-Z]', config_value):
108
+ if config_value.lower() != 'true':
109
+ if config_value.lower() != 'false':
110
+ if config_value[0] != '[':
111
+ # TODO: Making a manual exception for parsing epsilon right now since it's the only number in
112
+ # scientific notation. Should fix this.
113
+ if key != "epsilon":
114
+ config_value = f'"{config_value}"'
115
+ if count == total_params - 1:
116
+ text += f'"{str(key)}"' + ' : ' + config_value + '\n\n'
117
+ else:
118
+ text += f'"{str(key)}"' + ' : ' + config_value + ',\n\n'
119
+ text += '\n\n}'
120
+ sess = tf.InteractiveSession()
121
+ summary_op = tf.summary.text("run_config", tf.convert_to_tensor(text))
122
+ summary_writer = tf.summary.FileWriter(f"{logdir}/config", sess.graph)
123
+ text = sess.run(summary_op)
124
+ summary_writer.add_summary(text, 0)
125
+ summary_writer.flush()
126
+ summary_writer.close()
127
+ tf.reset_default_graph()
128
+ print('Done!')
129
+
130
+
131
+ def expand_attention_types_params(params_list):
132
+ newlist = []
133
+ for item in params_list:
134
+ for _ in range(item[1]):
135
+ newlist.extend(item[0])
136
+ return newlist
137
+
138
+
139
+ def get_n_trainable_vars(graph):
140
+ """
141
+ Gets number of trainable vars in a MTF model.
142
+
143
+ :param graph: Mesh-Tensorflow graph
144
+ :return: None
145
+ """
146
+ total_parameters = 0
147
+ for variable in graph.trainable_variables:
148
+ shape = variable.shape.dims
149
+ variable_parameters = 1
150
+ for dim in shape:
151
+ variable_parameters *= dim.size
152
+ total_parameters += variable_parameters
153
+ print(f"\n\nN TRAINABLE VARS:\n{total_parameters:,}\n\n")
154
+
155
+
156
+ def print_dim_names(graph):
157
+ """
158
+ Print names of all Dimensions
159
+ :param graph: Mesh-Tensorflow graph
160
+ :return: None
161
+ """
162
+ all_dim_names = []
163
+ for variable in graph.all_variables:
164
+ names = variable.shape.dimension_names
165
+ all_dim_names.append(names)
166
+
167
+ # Print all dim names in graph & write to file
168
+ all_dim_names = [item for sublist in all_dim_names for item in sublist] # Flatten all dims
169
+ unique_dims = list(set(all_dim_names))
170
+ print("ALL DIM NAMES:")
171
+ for dim_name in unique_dims:
172
+ print(dim_name)
173
+ print('\n')
174
+
175
+
176
+ def get_graph_info(graph):
177
+ """
178
+ Wrapper fn that calculates number of trainable vars in an MTF graph & prints all dim_names to file
179
+ TODO: how to get un-trainable dim-names too, batch etc.
180
+
181
+ :param graph: Mesh-Tensorflow graph
182
+ :return: None
183
+ """
184
+ get_n_trainable_vars(graph)
185
+ print_dim_names(graph)
186
+
187
+
188
+ def loss_denominator(targets, num_microbatches):
189
+ """Denominator applied to losses.
190
+
191
+ This is usually the size of the targets tensor (omitting ensemble
192
+ dimensions). Alternatively, it is an override value passed to the
193
+ class constructor.
194
+
195
+ Args:
196
+ targets: a mtf.Tensor
197
+ num_microbatches: an integer - greater than one if the step has been
198
+ serialized into multiple microbatches to save memory.
199
+ Returns:
200
+ a float
201
+ """
202
+ ret = float(targets.shape.size) * num_microbatches
203
+ return float(ret)
204
+
205
+ def check_dataset(input_fn, params, global_step=None):
206
+ tf.enable_eager_execution()
207
+ if global_step is not None:
208
+ dataset = input_fn(params, global_step=global_step)
209
+ else:
210
+ dataset = input_fn(params)
211
+ dataset_iter = dataset.make_one_shot_iterator()
212
+ tensor, _ = next(dataset_iter)
213
+ enc = fetch_encoder(params)
214
+
215
+ for p in tensor[:1]:
216
+ txt = enc.decode(p)
217
+
218
+ print('-' * 50)
219
+ print(txt[:500], '\n\n...\n\n', txt[-500:])
220
+ print('-' * 50)
221
+ exit()
222
+
223
+ def auto_layout(graph, mesh_shape, logits, loss):
224
+ layout_rules = mtf.auto_mtf.layout(graph, mesh_shape, [logits, loss])
225
+ print(f"Auto-selected layout:\n{layout_rules}\nRe-initialize graph with selected layout")
226
+ quit()
227
+
228
+ def auto_layout_and_mesh_shape(graph, num_cores, logits, loss):
229
+ layout_rules, mesh_shape = mtf.auto_mtf.layout_and_mesh_shape(graph, num_cores,
230
+ [logits, loss], max_mesh_shape_dimensions=4)
231
+ print(f"Num cores:\n{num_cores}\nAuto-selected layout:\n{layout_rules}\nAuto-selected mesh shape:\n{mesh_shape}" \
232
+ f"\nRe-initialize graph with selected layout & mesh shape")
233
+ quit()
234
+
235
+ def create_host_call(model_dir):
236
+ """Construct a host_call writing scalar summaries.
237
+
238
+ Borrowed from t2t.
239
+
240
+ Args:
241
+ model_dir: String containing path to train
242
+ Returns:
243
+ (fn, args) Pair to be called by TPUEstimator as the host_call.
244
+ """
245
+
246
+ graph = tf.get_default_graph()
247
+ # A list of (name, lowered tensor) tuples
248
+ summaries = graph.get_collection(mtf.utils.SCALAR_SUMMARIES_COLLECTION_KEY)
249
+
250
+ def maybe_cast(tensor):
251
+ assert tensor.shape.is_compatible_with([]), tensor.name
252
+ if tensor.dtype == tf.int64:
253
+ return tf.to_int32(tensor)
254
+ if tensor.dtype == tf.bfloat16:
255
+ return tf.cast(tensor, tf.float32)
256
+ return tensor
257
+
258
+ reshaped_tensors = [tf.reshape(maybe_cast(t), [1]) for _, t in summaries]
259
+
260
+ # When no supported summaries are found, don't create host_call. Otherwise,
261
+ # TPU outfeed queue would enqueue global_step while host_call doesn't dequeue
262
+ # it, eventually causing hang.
263
+ if not reshaped_tensors:
264
+ return None
265
+
266
+ def host_call_fn(global_step, *args):
267
+ """Training host call. Creates scalar summaries for training metrics."""
268
+ # This function is executed on the CPU and should not directly reference
269
+ # any Tensors in the rest of the `model_fn`. To pass Tensors from the
270
+ # model to the `model_fn`, provide as part of the `host_call`.
271
+ global_step = tf.cast(global_step[0], tf.int64)
272
+ with tf2.summary.create_file_writer(model_dir).as_default():
273
+ # We cannot directly use any tensor from summaries, because each
274
+ # tensor here must be a concat of multiple tensors from all shards.
275
+ # Therefore, we rely on the assumption that args wil have the same
276
+ # length as summaries, and all tensors in args will have the same
277
+ # order of self._tup_summaries.
278
+ assert len(args) == len(summaries)
279
+ for i, tensor in enumerate(args):
280
+ name = summaries[i][0]
281
+ tf2.summary.scalar(name, tf.reduce_mean(tensor), step=global_step)
282
+ return tf.summary.all_v2_summary_ops()
283
+
284
+ global_step_t = tf.reshape(tf.to_int32(tf.train.get_global_step()), [1])
285
+ return host_call_fn, [global_step_t] + reshaped_tensors
286
+
287
+
288
+ def natural_sort(l):
289
+ convert = lambda text: int(text) if text.isdigit() else text.lower()
290
+ alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ]
291
+ return sorted(l, key = alphanum_key)