Upload 131 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +5 -0
- megatron/__init__.py +31 -0
- megatron/arguments.py +1018 -0
- megatron/checkpointing.py +675 -0
- megatron/data/Makefile +9 -0
- megatron/data/__init__.py +1 -0
- megatron/data/autoaugment.py +320 -0
- megatron/data/bert_dataset.py +234 -0
- megatron/data/biencoder_dataset_utils.py +208 -0
- megatron/data/blendable_dataset.py +68 -0
- megatron/data/data_samplers.py +199 -0
- megatron/data/dataset_utils.py +938 -0
- megatron/data/glm_dataset.py +377 -0
- megatron/data/gpt_dataset.py +430 -0
- megatron/data/helpers.cpp +717 -0
- megatron/data/helpers.cpython-38-x86_64-linux-gnu.so +0 -0
- megatron/data/helpers.cpython-39-x86_64-linux-gnu.so +0 -0
- megatron/data/ict_dataset.py +156 -0
- megatron/data/image_folder.py +302 -0
- megatron/data/indexed_dataset.py +576 -0
- megatron/data/orqa_wiki_dataset.py +205 -0
- megatron/data/realm_dataset_utils.py +198 -0
- megatron/data/realm_index.py +224 -0
- megatron/data/t5_dataset.py +270 -0
- megatron/data/test/test_indexed_dataset.py +125 -0
- megatron/data/test/test_preprocess_data.sh +10 -0
- megatron/data/vit_dataset.py +262 -0
- megatron/dist_signal_handler.py +81 -0
- megatron/fp16_deprecated/loss_scaler.py +39 -0
- megatron/fused_kernels/__init__.py +125 -0
- megatron/fused_kernels/build/.ninja_deps +0 -0
- megatron/fused_kernels/build/.ninja_log +99 -0
- megatron/fused_kernels/build/build.ninja +28 -0
- megatron/fused_kernels/build/fused_mix_prec_layer_norm_cuda.so +0 -0
- megatron/fused_kernels/build/layer_norm_cuda.o +0 -0
- megatron/fused_kernels/build/layer_norm_cuda_kernel.cuda.o +0 -0
- megatron/fused_kernels/build/scaled_masked_softmax.o +0 -0
- megatron/fused_kernels/build/scaled_masked_softmax_cuda.cuda.o +3 -0
- megatron/fused_kernels/build/scaled_masked_softmax_cuda.so +3 -0
- megatron/fused_kernels/build/scaled_softmax.o +0 -0
- megatron/fused_kernels/build/scaled_softmax_cuda.cuda.o +3 -0
- megatron/fused_kernels/build/scaled_softmax_cuda.so +3 -0
- megatron/fused_kernels/build/scaled_upper_triang_masked_softmax.o +0 -0
- megatron/fused_kernels/build/scaled_upper_triang_masked_softmax_cuda.cuda.o +0 -0
- megatron/fused_kernels/build/scaled_upper_triang_masked_softmax_cuda.so +3 -0
- megatron/fused_kernels/compat.h +31 -0
- megatron/fused_kernels/fused_weight_gradient_dense.cpp +47 -0
- megatron/fused_kernels/fused_weight_gradient_dense.cu +157 -0
- megatron/fused_kernels/layer_norm_cuda.cpp +201 -0
- megatron/fused_kernels/layer_norm_cuda_kernel.cu +832 -0
.gitattributes
CHANGED
@@ -36,3 +36,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
36 |
clue_data/csl/train.json filter=lfs diff=lfs merge=lfs -text
|
37 |
clue_data/iflytek/train.json filter=lfs diff=lfs merge=lfs -text
|
38 |
clue_data/ocnli/train.json filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
36 |
clue_data/csl/train.json filter=lfs diff=lfs merge=lfs -text
|
37 |
clue_data/iflytek/train.json filter=lfs diff=lfs merge=lfs -text
|
38 |
clue_data/ocnli/train.json filter=lfs diff=lfs merge=lfs -text
|
39 |
+
megatron/fused_kernels/build/scaled_masked_softmax_cuda.cuda.o filter=lfs diff=lfs merge=lfs -text
|
40 |
+
megatron/fused_kernels/build/scaled_masked_softmax_cuda.so filter=lfs diff=lfs merge=lfs -text
|
41 |
+
megatron/fused_kernels/build/scaled_softmax_cuda.cuda.o filter=lfs diff=lfs merge=lfs -text
|
42 |
+
megatron/fused_kernels/build/scaled_softmax_cuda.so filter=lfs diff=lfs merge=lfs -text
|
43 |
+
megatron/fused_kernels/build/scaled_upper_triang_masked_softmax_cuda.so filter=lfs diff=lfs merge=lfs -text
|
megatron/__init__.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
import torch
|
16 |
+
|
17 |
+
from .global_vars import get_args
|
18 |
+
from .global_vars import get_current_global_batch_size
|
19 |
+
from .global_vars import get_num_microbatches
|
20 |
+
from .global_vars import get_signal_handler
|
21 |
+
from .global_vars import update_num_microbatches
|
22 |
+
from .global_vars import get_tokenizer
|
23 |
+
from .global_vars import get_tensorboard_writer
|
24 |
+
from .global_vars import get_adlr_autoresume
|
25 |
+
from .global_vars import get_timers
|
26 |
+
from .global_vars import get_global_memory_buffer
|
27 |
+
from .initialize import initialize_megatron
|
28 |
+
|
29 |
+
from .utils import (print_rank_0,
|
30 |
+
is_last_rank,
|
31 |
+
print_rank_last)
|
megatron/arguments.py
ADDED
@@ -0,0 +1,1018 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Megatron arguments."""
|
17 |
+
|
18 |
+
import argparse
|
19 |
+
import os
|
20 |
+
|
21 |
+
import torch
|
22 |
+
|
23 |
+
def parse_args(extra_args_provider=None, ignore_unknown_args=False):
|
24 |
+
"""Parse all arguments."""
|
25 |
+
parser = argparse.ArgumentParser(description='Megatron-LM Arguments',
|
26 |
+
allow_abbrev=False)
|
27 |
+
|
28 |
+
# Standard arguments.
|
29 |
+
parser = _add_network_size_args(parser)
|
30 |
+
parser = _add_regularization_args(parser)
|
31 |
+
parser = _add_training_args(parser)
|
32 |
+
parser = _add_initialization_args(parser)
|
33 |
+
parser = _add_learning_rate_args(parser)
|
34 |
+
parser = _add_checkpointing_args(parser)
|
35 |
+
parser = _add_mixed_precision_args(parser)
|
36 |
+
parser = _add_distributed_args(parser)
|
37 |
+
parser = _add_validation_args(parser)
|
38 |
+
parser = _add_data_args(parser)
|
39 |
+
parser = _add_autoresume_args(parser)
|
40 |
+
parser = _add_biencoder_args(parser)
|
41 |
+
parser = _add_vision_args(parser)
|
42 |
+
parser = _add_logging_args(parser)
|
43 |
+
parser = _add_inference_args(parser)
|
44 |
+
|
45 |
+
# Custom arguments.
|
46 |
+
if extra_args_provider is not None:
|
47 |
+
parser = extra_args_provider(parser)
|
48 |
+
|
49 |
+
# Parse.
|
50 |
+
if ignore_unknown_args:
|
51 |
+
args, _ = parser.parse_known_args()
|
52 |
+
else:
|
53 |
+
args = parser.parse_args()
|
54 |
+
|
55 |
+
# Args from environment
|
56 |
+
args.rank = int(os.getenv('RANK', '0'))
|
57 |
+
args.world_size = int(os.getenv("WORLD_SIZE", '1'))
|
58 |
+
|
59 |
+
return args
|
60 |
+
|
61 |
+
def validate_args(args, defaults={}):
|
62 |
+
# Tensor model parallel size.
|
63 |
+
args.tensor_model_parallel_size = min(
|
64 |
+
args.tensor_model_parallel_size, args.world_size)
|
65 |
+
assert args.world_size % args.tensor_model_parallel_size == 0, 'world size'\
|
66 |
+
' ({}) is not divisible by tensor model parallel size ({})'.format(
|
67 |
+
args.world_size, args.tensor_model_parallel_size)
|
68 |
+
# Pipeline model parallel size.
|
69 |
+
args.pipeline_model_parallel_size = min(
|
70 |
+
args.pipeline_model_parallel_size,
|
71 |
+
(args.world_size // args.tensor_model_parallel_size))
|
72 |
+
args.transformer_pipeline_model_parallel_size = (
|
73 |
+
args.pipeline_model_parallel_size - 1
|
74 |
+
if args.standalone_embedding_stage else
|
75 |
+
args.pipeline_model_parallel_size
|
76 |
+
)
|
77 |
+
# Checks.
|
78 |
+
model_parallel_size = args.pipeline_model_parallel_size * \
|
79 |
+
args.tensor_model_parallel_size
|
80 |
+
assert args.world_size % model_parallel_size == 0, 'world size is not'\
|
81 |
+
' divisible by tensor parallel size ({}) times pipeline parallel ' \
|
82 |
+
'size ({})'.format(args.world_size, args.tensor_model_parallel_size,
|
83 |
+
args.pipeline_model_parallel_size)
|
84 |
+
args.data_parallel_size = args.world_size // model_parallel_size
|
85 |
+
if args.rank == 0:
|
86 |
+
print('using world size: {}, data-parallel-size: {}, '
|
87 |
+
'tensor-model-parallel size: {}, '
|
88 |
+
'pipeline-model-parallel size: {} '.format(
|
89 |
+
args.world_size, args.data_parallel_size,
|
90 |
+
args.tensor_model_parallel_size,
|
91 |
+
args.pipeline_model_parallel_size), flush=True)
|
92 |
+
if args.pipeline_model_parallel_size > 1:
|
93 |
+
if args.pipeline_model_parallel_split_rank is not None:
|
94 |
+
assert args.pipeline_model_parallel_split_rank < \
|
95 |
+
args.pipeline_model_parallel_size, 'split rank needs'\
|
96 |
+
' to be less than pipeline model parallel size ({})'.format(
|
97 |
+
args.pipeline_model_parallel_size)
|
98 |
+
if args.data_path:
|
99 |
+
# Dataset arguments
|
100 |
+
data_path = args.data_path
|
101 |
+
processed_data_path = []
|
102 |
+
for path in data_path:
|
103 |
+
files = os.listdir(path)
|
104 |
+
idx_files = [fn[:-4] for fn in files if fn.endswith('.idx')]
|
105 |
+
bin_files = [fn[:-4] for fn in files if fn.endswith('.bin')]
|
106 |
+
for idx_fn in idx_files:
|
107 |
+
if idx_fn in bin_files:
|
108 |
+
# add weight and data path
|
109 |
+
processed_data_path.append('1')
|
110 |
+
processed_data_path.append(os.path.join(path, idx_fn))
|
111 |
+
args.raw_data_path = data_path
|
112 |
+
args.data_path = processed_data_path
|
113 |
+
|
114 |
+
|
115 |
+
# Deprecated arguments
|
116 |
+
assert args.batch_size is None, '--batch-size argument is no longer ' \
|
117 |
+
'valid, use --micro-batch-size instead'
|
118 |
+
del args.batch_size
|
119 |
+
assert args.warmup is None, '--warmup argument is no longer valid, use ' \
|
120 |
+
'--lr-warmup-fraction instead'
|
121 |
+
del args.warmup
|
122 |
+
assert args.model_parallel_size is None, '--model-parallel-size is no ' \
|
123 |
+
'longer valid, use --tensor-model-parallel-size instead'
|
124 |
+
del args.model_parallel_size
|
125 |
+
|
126 |
+
if args.checkpoint_activations:
|
127 |
+
args.recompute_granularity = 'full'
|
128 |
+
args.recompute_method = 'uniform'
|
129 |
+
if args.rank == 0:
|
130 |
+
print('--checkpoint-activations is no longer valid, '
|
131 |
+
'use --recompute-granularity and --recompute-method instead. '
|
132 |
+
'Defaulting to recompute-granularity=full and recompute-method=uniform.')
|
133 |
+
del args.checkpoint_activations
|
134 |
+
|
135 |
+
if args.recompute_activations:
|
136 |
+
args.recompute_granularity = 'selective'
|
137 |
+
del args.recompute_activations
|
138 |
+
|
139 |
+
# Set input defaults.
|
140 |
+
for key in defaults:
|
141 |
+
# For default to be valid, it should not be provided in the
|
142 |
+
# arguments that are passed to the program. We check this by
|
143 |
+
# ensuring the arg is set to None.
|
144 |
+
if getattr(args, key) is not None:
|
145 |
+
if args.rank == 0:
|
146 |
+
print('WARNING: overriding default arguments for {key}:{v} \
|
147 |
+
with {key}:{v2}'.format(key=key, v=defaults[key],
|
148 |
+
v2=getattr(args, key)),
|
149 |
+
flush=True)
|
150 |
+
else:
|
151 |
+
setattr(args, key, defaults[key])
|
152 |
+
|
153 |
+
# Batch size.
|
154 |
+
assert args.micro_batch_size is not None
|
155 |
+
assert args.micro_batch_size > 0
|
156 |
+
if args.global_batch_size is None:
|
157 |
+
args.global_batch_size = args.micro_batch_size * args.data_parallel_size
|
158 |
+
if args.rank == 0:
|
159 |
+
print('setting global batch size to {}'.format(
|
160 |
+
args.global_batch_size), flush=True)
|
161 |
+
assert args.global_batch_size > 0
|
162 |
+
if args.num_layers_per_virtual_pipeline_stage is not None:
|
163 |
+
assert args.pipeline_model_parallel_size > 2, \
|
164 |
+
'pipeline-model-parallel size should be greater than 2 with ' \
|
165 |
+
'interleaved schedule'
|
166 |
+
assert args.num_layers % args.num_layers_per_virtual_pipeline_stage == 0, \
|
167 |
+
'number of layers is not divisible by number of layers per virtual ' \
|
168 |
+
'pipeline stage'
|
169 |
+
args.virtual_pipeline_model_parallel_size = \
|
170 |
+
(args.num_layers // args.transformer_pipeline_model_parallel_size) // \
|
171 |
+
args.num_layers_per_virtual_pipeline_stage
|
172 |
+
else:
|
173 |
+
args.virtual_pipeline_model_parallel_size = None
|
174 |
+
|
175 |
+
# Parameters dtype.
|
176 |
+
args.params_dtype = torch.float
|
177 |
+
if args.fp16:
|
178 |
+
assert not args.bf16
|
179 |
+
args.params_dtype = torch.half
|
180 |
+
if args.bf16:
|
181 |
+
assert not args.fp16
|
182 |
+
args.params_dtype = torch.bfloat16
|
183 |
+
# bfloat16 requires gradient accumulation and all-reduce to
|
184 |
+
# be done in fp32.
|
185 |
+
if not args.accumulate_allreduce_grads_in_fp32:
|
186 |
+
args.accumulate_allreduce_grads_in_fp32 = True
|
187 |
+
if args.rank == 0:
|
188 |
+
print('accumulate and all-reduce gradients in fp32 for '
|
189 |
+
'bfloat16 data type.', flush=True)
|
190 |
+
|
191 |
+
if args.rank == 0:
|
192 |
+
print('using {} for parameters ...'.format(args.params_dtype),
|
193 |
+
flush=True)
|
194 |
+
|
195 |
+
# If we do accumulation and all-reduces in fp32, we need to have local DDP
|
196 |
+
# and we should make sure use-contiguous-buffers-in-local-ddp is not off.
|
197 |
+
if args.accumulate_allreduce_grads_in_fp32:
|
198 |
+
assert args.DDP_impl == 'local'
|
199 |
+
assert args.use_contiguous_buffers_in_local_ddp
|
200 |
+
else:
|
201 |
+
if args.gradient_accumulation_fusion:
|
202 |
+
args.gradient_accumulation_fusion = False
|
203 |
+
if args.rank == 0:
|
204 |
+
print('Gradient accumulation fusion to linear layer weight '
|
205 |
+
'gradient computation is supported only with fp32 '
|
206 |
+
'gradient accumulation. Setting gradient_accumulation_fusion '
|
207 |
+
'to False', flush=True)
|
208 |
+
|
209 |
+
# If we use the distributed optimizer, we need to have local DDP
|
210 |
+
# and we should make sure use-contiguous-buffers-in-local-ddp is on.
|
211 |
+
if args.use_distributed_optimizer:
|
212 |
+
assert args.DDP_impl == 'local'
|
213 |
+
assert args.use_contiguous_buffers_in_local_ddp
|
214 |
+
|
215 |
+
# For torch DDP, we do not use contiguous buffer
|
216 |
+
if args.DDP_impl == 'torch':
|
217 |
+
args.use_contiguous_buffers_in_local_ddp = False
|
218 |
+
|
219 |
+
if args.dataloader_type is None:
|
220 |
+
args.dataloader_type = 'single'
|
221 |
+
|
222 |
+
# Consumed tokens.
|
223 |
+
args.consumed_train_samples = 0
|
224 |
+
args.consumed_valid_samples = 0
|
225 |
+
|
226 |
+
# Iteration-based training.
|
227 |
+
if args.train_iters:
|
228 |
+
# If we use iteration-based training, make sure the
|
229 |
+
# sample-based options are off.
|
230 |
+
assert args.train_samples is None, \
|
231 |
+
'expected iteration-based training'
|
232 |
+
assert args.lr_decay_samples is None, \
|
233 |
+
'expected iteration-based learning rate decay'
|
234 |
+
assert args.lr_warmup_samples == 0, \
|
235 |
+
'expected iteration-based learning rate warmup'
|
236 |
+
assert args.rampup_batch_size is None, \
|
237 |
+
'expected no batch-size rampup for iteration-based training'
|
238 |
+
if args.lr_warmup_fraction is not None:
|
239 |
+
assert args.lr_warmup_iters == 0, \
|
240 |
+
'can only specify one of lr-warmup-fraction and lr-warmup-iters'
|
241 |
+
|
242 |
+
# Sample-based training.
|
243 |
+
if args.train_samples:
|
244 |
+
# If we use sample-based training, make sure the
|
245 |
+
# iteration-based options are off.
|
246 |
+
assert args.train_iters is None, \
|
247 |
+
'expected sample-based training'
|
248 |
+
assert args.lr_decay_iters is None, \
|
249 |
+
'expected sample-based learning rate decay'
|
250 |
+
assert args.lr_warmup_iters == 0, \
|
251 |
+
'expected sample-based learnig rate warmup'
|
252 |
+
if args.lr_warmup_fraction is not None:
|
253 |
+
assert args.lr_warmup_samples == 0, \
|
254 |
+
'can only specify one of lr-warmup-fraction ' \
|
255 |
+
'and lr-warmup-samples'
|
256 |
+
|
257 |
+
# Check required arguments.
|
258 |
+
required_args = ['num_layers', 'hidden_size', 'num_attention_heads',
|
259 |
+
'max_position_embeddings']
|
260 |
+
for req_arg in required_args:
|
261 |
+
_check_arg_is_not_none(args, req_arg)
|
262 |
+
|
263 |
+
# Checks.
|
264 |
+
if args.ffn_hidden_size is None:
|
265 |
+
args.ffn_hidden_size = 4 * args.hidden_size
|
266 |
+
|
267 |
+
if args.kv_channels is None:
|
268 |
+
assert args.hidden_size % args.num_attention_heads == 0
|
269 |
+
args.kv_channels = args.hidden_size // args.num_attention_heads
|
270 |
+
|
271 |
+
if args.seq_length is not None:
|
272 |
+
assert args.encoder_seq_length is None
|
273 |
+
args.encoder_seq_length = args.seq_length
|
274 |
+
else:
|
275 |
+
assert args.encoder_seq_length is not None
|
276 |
+
args.seq_length = args.encoder_seq_length
|
277 |
+
|
278 |
+
if args.seq_length is not None:
|
279 |
+
assert args.max_position_embeddings >= args.seq_length
|
280 |
+
if args.decoder_seq_length is not None:
|
281 |
+
assert args.max_position_embeddings >= args.decoder_seq_length
|
282 |
+
if args.lr is not None:
|
283 |
+
assert args.min_lr <= args.lr
|
284 |
+
if args.save is not None:
|
285 |
+
assert args.save_interval is not None
|
286 |
+
# Mixed precision checks.
|
287 |
+
if args.fp16_lm_cross_entropy:
|
288 |
+
assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.'
|
289 |
+
if args.fp32_residual_connection:
|
290 |
+
assert args.fp16 or args.bf16, \
|
291 |
+
'residual connection in fp32 only supported when using fp16 or bf16.'
|
292 |
+
|
293 |
+
if args.weight_decay_incr_style == 'constant':
|
294 |
+
assert args.start_weight_decay is None
|
295 |
+
assert args.end_weight_decay is None
|
296 |
+
args.start_weight_decay = args.weight_decay
|
297 |
+
args.end_weight_decay = args.weight_decay
|
298 |
+
else:
|
299 |
+
assert args.start_weight_decay is not None
|
300 |
+
assert args.end_weight_decay is not None
|
301 |
+
|
302 |
+
TORCH_MAJOR = int(torch.__version__.split('.')[0])
|
303 |
+
TORCH_MINOR = int(torch.__version__.split('.')[1])
|
304 |
+
# Persistent fused layer norm.
|
305 |
+
if TORCH_MAJOR < 1 or (TORCH_MAJOR == 1 and TORCH_MINOR < 11):
|
306 |
+
args.no_persist_layer_norm = True
|
307 |
+
if args.rank == 0:
|
308 |
+
print('Persistent fused layer norm kernel is supported from '
|
309 |
+
'pytorch v1.11 (nvidia pytorch container paired with v1.11). '
|
310 |
+
'Defaulting to no_persist_layer_norm=True')
|
311 |
+
|
312 |
+
# Activation recomputing.
|
313 |
+
if args.distribute_saved_activations:
|
314 |
+
assert args.tensor_model_parallel_size > 1, 'can distribute ' \
|
315 |
+
'recomputed activations only across tensor model ' \
|
316 |
+
'parallel groups'
|
317 |
+
assert args.recompute_granularity == 'full', \
|
318 |
+
'distributed recompute activations is only '\
|
319 |
+
'application to full recompute granularity'
|
320 |
+
assert args.recompute_method is not None, \
|
321 |
+
'for distributed recompute activations to work you '\
|
322 |
+
'need to use a recompute method '
|
323 |
+
assert TORCH_MAJOR >= 1 and TORCH_MINOR >= 10, \
|
324 |
+
'distributed recompute activations are supported for pytorch ' \
|
325 |
+
'v1.10 and above (Nvidia Pytorch container >= 21.07). Current ' \
|
326 |
+
'pytorch version is v%s.%s.' % (TORCH_MAJOR, TORCH_MINOR)
|
327 |
+
|
328 |
+
if args.recompute_granularity == 'selective':
|
329 |
+
assert args.recompute_method is None, \
|
330 |
+
'recompute method is not yet supported for ' \
|
331 |
+
'selective recomputing granularity'
|
332 |
+
|
333 |
+
# disable sequence parallelism when tp=1
|
334 |
+
# to avoid change in numerics when
|
335 |
+
# sequence_parallelism is enabled.
|
336 |
+
if args.tensor_model_parallel_size == 1:
|
337 |
+
args.sequence_parallel = False
|
338 |
+
|
339 |
+
# disable async_tensor_model_parallel_allreduce when
|
340 |
+
# model parallel memory optimization is enabled
|
341 |
+
if args.sequence_parallel:
|
342 |
+
args.async_tensor_model_parallel_allreduce = False
|
343 |
+
|
344 |
+
_print_args(args)
|
345 |
+
return args
|
346 |
+
|
347 |
+
|
348 |
+
def _print_args(args):
|
349 |
+
"""Print arguments."""
|
350 |
+
if args.rank == 0:
|
351 |
+
print('------------------------ arguments ------------------------',
|
352 |
+
flush=True)
|
353 |
+
str_list = []
|
354 |
+
for arg in vars(args):
|
355 |
+
dots = '.' * (48 - len(arg))
|
356 |
+
str_list.append(' {} {} {}'.format(arg, dots, getattr(args, arg)))
|
357 |
+
for arg in sorted(str_list, key=lambda x: x.lower()):
|
358 |
+
print(arg, flush=True)
|
359 |
+
print('-------------------- end of arguments ---------------------',
|
360 |
+
flush=True)
|
361 |
+
|
362 |
+
|
363 |
+
def _check_arg_is_not_none(args, arg):
|
364 |
+
assert getattr(args, arg) is not None, '{} argument is None'.format(arg)
|
365 |
+
|
366 |
+
|
367 |
+
def _add_inference_args(parser):
|
368 |
+
group = parser.add_argument_group(title='inference')
|
369 |
+
|
370 |
+
group.add_argument('--inference-batch-times-seqlen-threshold',
|
371 |
+
type=int, default=512,
|
372 |
+
help='During inference, if batch-size times '
|
373 |
+
'sequence-length is smaller than this threshold '
|
374 |
+
'then we will not use pipelining, otherwise we will.')
|
375 |
+
|
376 |
+
return parser
|
377 |
+
|
378 |
+
|
379 |
+
def _add_network_size_args(parser):
|
380 |
+
group = parser.add_argument_group(title='network size')
|
381 |
+
|
382 |
+
group.add_argument('--num-layers', type=int, default=None,
|
383 |
+
help='Number of transformer layers.')
|
384 |
+
group.add_argument('--num-layers-decoder', type=int, default=None,
|
385 |
+
help='Number of transformer layers decoder.')
|
386 |
+
group.add_argument('--hidden-size', type=int, default=None,
|
387 |
+
help='Tansformer hidden size.')
|
388 |
+
group.add_argument('--ffn-hidden-size', type=int, default=None,
|
389 |
+
help='Transformer Feed-Forward Network hidden size. '
|
390 |
+
'This is set to 4*hidden-size if not provided')
|
391 |
+
group.add_argument('--num-attention-heads', type=int, default=None,
|
392 |
+
help='Number of transformer attention heads.')
|
393 |
+
group.add_argument('--kv-channels', type=int, default=None,
|
394 |
+
help='Projection weights dimension in multi-head '
|
395 |
+
'attention. This is set to '
|
396 |
+
' args.hidden_size // args.num_attention_heads '
|
397 |
+
'if not provided.')
|
398 |
+
group.add_argument('--max-position-embeddings', type=int, default=None,
|
399 |
+
help='Maximum number of position embeddings to use. '
|
400 |
+
'This is the size of position embedding.')
|
401 |
+
group.add_argument('--make-vocab-size-divisible-by', type=int, default=128,
|
402 |
+
help='Pad the vocab size to be divisible by this value.'
|
403 |
+
'This is added for computational efficieny reasons.')
|
404 |
+
group.add_argument('--layernorm-epsilon', type=float, default=1e-5,
|
405 |
+
help='Layer norm epsilon.')
|
406 |
+
group.add_argument('--apply-residual-connection-post-layernorm',
|
407 |
+
action='store_true',
|
408 |
+
help='If set, use original BERT residula connection '
|
409 |
+
'ordering.')
|
410 |
+
group.add_argument('--openai-gelu', action='store_true',
|
411 |
+
help='Use OpenAIs GeLU implementation. This option'
|
412 |
+
'should not be used unless for backward compatibility'
|
413 |
+
'reasons.')
|
414 |
+
group.add_argument('--onnx-safe', type=bool, required=False,
|
415 |
+
help='Use workarounds for known problems with '
|
416 |
+
'Torch ONNX exporter')
|
417 |
+
group.add_argument('--bert-no-binary-head', action='store_false',
|
418 |
+
help='Disable BERT binary head.',
|
419 |
+
dest='bert_binary_head')
|
420 |
+
group.add_argument('--num-experts', type=int, default=None,
|
421 |
+
help='Number of Experts in Switch Transformer (None means no Switch)')
|
422 |
+
return parser
|
423 |
+
|
424 |
+
|
425 |
+
def _add_logging_args(parser):
|
426 |
+
group = parser.add_argument_group(title='logging')
|
427 |
+
|
428 |
+
group.add_argument('--log-params-norm', action='store_true',
|
429 |
+
help='If set, calculate and log parameters norm.')
|
430 |
+
group.add_argument('--log-num-zeros-in-grad', action='store_true',
|
431 |
+
help='If set, calculate and log the number of zeros in gradient.')
|
432 |
+
group.add_argument('--tensorboard-log-interval', type=int, default=1,
|
433 |
+
help='Report to tensorboard interval.')
|
434 |
+
group.add_argument('--tensorboard-queue-size', type=int, default=1000,
|
435 |
+
help='Size of the tensorboard queue for pending events '
|
436 |
+
'and summaries before one of the ‘add’ calls forces a '
|
437 |
+
'flush to disk.')
|
438 |
+
group.add_argument('--log-timers-to-tensorboard', action='store_true',
|
439 |
+
help='If set, write timers to tensorboard.')
|
440 |
+
group.add_argument('--log-batch-size-to-tensorboard', action='store_true',
|
441 |
+
help='If set, write batch-size to tensorboard.')
|
442 |
+
group.add_argument('--no-log-learnig-rate-to-tensorboard',
|
443 |
+
action='store_false',
|
444 |
+
help='Disable learning rate logging to tensorboard.',
|
445 |
+
dest='log_learning_rate_to_tensorboard')
|
446 |
+
group.add_argument('--no-log-loss-scale-to-tensorboard',
|
447 |
+
action='store_false',
|
448 |
+
help='Disable loss-scale logging to tensorboard.',
|
449 |
+
dest='log_loss_scale_to_tensorboard')
|
450 |
+
group.add_argument('--log-validation-ppl-to-tensorboard',
|
451 |
+
action='store_true',
|
452 |
+
help='If set, write validation perplexity to '
|
453 |
+
'tensorboard.')
|
454 |
+
group.add_argument('--log-memory-to-tensorboard',
|
455 |
+
action='store_true',
|
456 |
+
help='Enable memory logging to tensorboard.')
|
457 |
+
group.add_argument('--log-world-size-to-tensorboard',
|
458 |
+
action='store_true',
|
459 |
+
help='Enable world size logging to tensorboard.')
|
460 |
+
|
461 |
+
return parser
|
462 |
+
|
463 |
+
|
464 |
+
def _add_regularization_args(parser):
|
465 |
+
group = parser.add_argument_group(title='regularization')
|
466 |
+
|
467 |
+
group.add_argument('--attention-dropout', type=float, default=0.1,
|
468 |
+
help='Post attention dropout probability.')
|
469 |
+
group.add_argument('--hidden-dropout', type=float, default=0.1,
|
470 |
+
help='Dropout probability for hidden state transformer.')
|
471 |
+
group.add_argument('--weight-decay', type=float, default=0.01,
|
472 |
+
help='Weight decay coefficient for L2 regularization.')
|
473 |
+
group.add_argument('--start-weight-decay', type=float,
|
474 |
+
help='Initial weight decay coefficient for L2 regularization.')
|
475 |
+
group.add_argument('--end-weight-decay', type=float,
|
476 |
+
help='End of run weight decay coefficient for L2 regularization.')
|
477 |
+
group.add_argument('--weight-decay-incr-style', type=str, default='constant',
|
478 |
+
choices=['constant', 'linear', 'cosine'],
|
479 |
+
help='Weight decay increment function.')
|
480 |
+
group.add_argument('--clip-grad', type=float, default=1.0,
|
481 |
+
help='Gradient clipping based on global L2 norm.')
|
482 |
+
group.add_argument('--adam-beta1', type=float, default=0.9,
|
483 |
+
help='First coefficient for computing running averages '
|
484 |
+
'of gradient and its square')
|
485 |
+
group.add_argument('--adam-beta2', type=float, default=0.999,
|
486 |
+
help='Second coefficient for computing running averages '
|
487 |
+
'of gradient and its square')
|
488 |
+
group.add_argument('--adam-eps', type=float, default=1e-08,
|
489 |
+
help='Term added to the denominator to improve'
|
490 |
+
'numerical stability')
|
491 |
+
group.add_argument('--sgd-momentum', type=float, default=0.9,
|
492 |
+
help='Momentum factor for sgd')
|
493 |
+
|
494 |
+
return parser
|
495 |
+
|
496 |
+
|
497 |
+
def _add_training_args(parser):
|
498 |
+
group = parser.add_argument_group(title='training')
|
499 |
+
|
500 |
+
group.add_argument('--micro-batch-size', type=int, default=None,
|
501 |
+
help='Batch size per model instance (local batch size). '
|
502 |
+
'Global batch size is local batch size times data '
|
503 |
+
'parallel size times number of micro batches.')
|
504 |
+
group.add_argument('--batch-size', type=int, default=None,
|
505 |
+
help='Old batch size parameter, do not use. '
|
506 |
+
'Use --micro-batch-size instead')
|
507 |
+
group.add_argument('--global-batch-size', type=int, default=None,
|
508 |
+
help='Training batch size. If set, it should be a '
|
509 |
+
'multiple of micro-batch-size times data-parallel-size. '
|
510 |
+
'If this value is None, then '
|
511 |
+
'use micro-batch-size * data-parallel-size as the '
|
512 |
+
'global batch size. This choice will result in 1 for '
|
513 |
+
'number of micro-batches.')
|
514 |
+
group.add_argument('--rampup-batch-size', nargs='*', default=None,
|
515 |
+
help='Batch size ramp up with the following values:'
|
516 |
+
' --rampup-batch-size <start batch size> '
|
517 |
+
' <batch size incerement> '
|
518 |
+
' <ramp-up samples> '
|
519 |
+
'For example:'
|
520 |
+
' --rampup-batch-size 16 8 300000 \ '
|
521 |
+
' --global-batch-size 1024'
|
522 |
+
'will start with global batch size 16 and over '
|
523 |
+
' (1024 - 16) / 8 = 126 intervals will increase'
|
524 |
+
'the batch size linearly to 1024. In each interval'
|
525 |
+
'we will use approximately 300000 / 126 = 2380 samples.')
|
526 |
+
group.add_argument('--recompute-activations', action='store_true',
|
527 |
+
help='recompute activation to allow for training '
|
528 |
+
'with larger models, sequences, and batch sizes.')
|
529 |
+
group.add_argument('--recompute-granularity', type=str, default=None,
|
530 |
+
choices=['full', 'selective'],
|
531 |
+
help='Checkpoint activations to allow for training '
|
532 |
+
'with larger models, sequences, and batch sizes. '
|
533 |
+
'It is supported at two granularities 1) full: '
|
534 |
+
'whole transformer layer is recomputed, '
|
535 |
+
'2) selective: core attention part of the transformer '
|
536 |
+
'layer is recomputed.')
|
537 |
+
group.add_argument('--distribute-saved-activations',
|
538 |
+
action='store_true',
|
539 |
+
help='If set, distribute recomputed activations '
|
540 |
+
'across model parallel group.')
|
541 |
+
group.add_argument('--recompute-method', type=str, default=None,
|
542 |
+
choices=['uniform', 'block'],
|
543 |
+
help='1) uniform: uniformly divide the total number of '
|
544 |
+
'Transformer layers and recompute the input activation of '
|
545 |
+
'each divided chunk at specified granularity, '
|
546 |
+
'2) recompute the input activations of only a set number of '
|
547 |
+
'individual Transformer layers per pipeline stage and do the '
|
548 |
+
'rest without any recomputing at specified granularity'
|
549 |
+
'default) do not apply activations recompute to any layers')
|
550 |
+
group.add_argument('--recompute-num-layers', type=int, default=1,
|
551 |
+
help='1) uniform: the number of Transformer layers in each '
|
552 |
+
'uniformly divided recompute unit, '
|
553 |
+
'2) block: the number of individual Transformer layers '
|
554 |
+
'to recompute within each pipeline stage.')
|
555 |
+
|
556 |
+
# deprecated
|
557 |
+
group.add_argument('--checkpoint-activations', action='store_true',
|
558 |
+
help='Checkpoint activation to allow for training '
|
559 |
+
'with larger models, sequences, and batch sizes.')
|
560 |
+
group.add_argument('--train-iters', type=int, default=None,
|
561 |
+
help='Total number of iterations to train over all '
|
562 |
+
'training runs. Note that either train-iters or '
|
563 |
+
'train-samples should be provided.')
|
564 |
+
group.add_argument('--train-samples', type=int, default=None,
|
565 |
+
help='Total number of samples to train over all '
|
566 |
+
'training runs. Note that either train-iters or '
|
567 |
+
'train-samples should be provided.')
|
568 |
+
group.add_argument('--log-interval', type=int, default=100,
|
569 |
+
help='Report loss and timing interval.')
|
570 |
+
group.add_argument('--exit-interval', type=int, default=None,
|
571 |
+
help='Exit the program after the iteration is divisible '
|
572 |
+
'by this value.')
|
573 |
+
group.add_argument('--exit-duration-in-mins', type=int, default=None,
|
574 |
+
help='Exit the program after this many minutes.')
|
575 |
+
group.add_argument('--exit-signal-handler', action='store_true',
|
576 |
+
help='Dynamically save the checkpoint and shutdown the '
|
577 |
+
'training if SIGTERM is received')
|
578 |
+
group.add_argument('--tensorboard-dir', type=str, default=None,
|
579 |
+
help='Write TensorBoard logs to this directory.')
|
580 |
+
group.add_argument('--no-masked-softmax-fusion',
|
581 |
+
action='store_false',
|
582 |
+
help='Disable fusion of query_key_value scaling, '
|
583 |
+
'masking, and softmax.',
|
584 |
+
dest='masked_softmax_fusion')
|
585 |
+
group.add_argument('--no-bias-gelu-fusion', action='store_false',
|
586 |
+
help='Disable bias and gelu fusion.',
|
587 |
+
dest='bias_gelu_fusion')
|
588 |
+
group.add_argument('--no-bias-dropout-fusion', action='store_false',
|
589 |
+
help='Disable bias and dropout fusion.',
|
590 |
+
dest='bias_dropout_fusion')
|
591 |
+
group.add_argument('--optimizer', type=str, default='adam',
|
592 |
+
choices=['adam', 'sgd'],
|
593 |
+
help='Optimizer function')
|
594 |
+
group.add_argument('--dataloader-type', type=str, default=None,
|
595 |
+
choices=['single', 'cyclic'],
|
596 |
+
help='Single pass vs multiple pass data loader')
|
597 |
+
group.add_argument('--no-async-tensor-model-parallel-allreduce',
|
598 |
+
action='store_false',
|
599 |
+
help='Disable asynchronous execution of '
|
600 |
+
'tensor-model-parallel all-reduce with weight '
|
601 |
+
'gradient compuation of a column-linear layer.',
|
602 |
+
dest='async_tensor_model_parallel_allreduce')
|
603 |
+
group.add_argument('--no-persist-layer-norm', action='store_true',
|
604 |
+
help='Disable using persistent fused layer norm kernel. '
|
605 |
+
'This kernel supports only a set of hidden sizes. Please '
|
606 |
+
'check persist_ln_hidden_sizes if your hidden '
|
607 |
+
'size is supported.')
|
608 |
+
group.add_argument('--sequence-parallel', action='store_true',
|
609 |
+
help='Enable sequence parallel optimization.')
|
610 |
+
group.add_argument('--no-gradient-accumulation-fusion',
|
611 |
+
action='store_false',
|
612 |
+
help='Disable fusing gradient accumulation to weight '
|
613 |
+
'gradient computation of linear layers',
|
614 |
+
dest='gradient_accumulation_fusion')
|
615 |
+
return parser
|
616 |
+
|
617 |
+
|
618 |
+
def _add_initialization_args(parser):
|
619 |
+
group = parser.add_argument_group(title='initialization')
|
620 |
+
|
621 |
+
group.add_argument('--seed', type=int, default=1234,
|
622 |
+
help='Random seed used for python, numpy, '
|
623 |
+
'pytorch, and cuda.')
|
624 |
+
group.add_argument('--data-parallel-random-init', action='store_true',
|
625 |
+
help='Enable random initialization of params '
|
626 |
+
'across data parallel ranks')
|
627 |
+
group.add_argument('--init-method-std', type=float, default=0.02,
|
628 |
+
help='Standard deviation of the zero mean normal '
|
629 |
+
'distribution used for weight initialization.')
|
630 |
+
group.add_argument('--init-method-xavier-uniform', action='store_true',
|
631 |
+
help='Enable Xavier uniform parameter initialization')
|
632 |
+
|
633 |
+
return parser
|
634 |
+
|
635 |
+
|
636 |
+
def _add_learning_rate_args(parser):
|
637 |
+
group = parser.add_argument_group(title='learning rate')
|
638 |
+
|
639 |
+
group.add_argument('--lr', type=float, default=None,
|
640 |
+
help='Initial learning rate. Depending on decay style '
|
641 |
+
'and initial warmup, the learing rate at each '
|
642 |
+
'iteration would be different.')
|
643 |
+
group.add_argument('--lr-decay-style', type=str, default='linear',
|
644 |
+
choices=['constant', 'linear', 'cosine'],
|
645 |
+
help='Learning rate decay function.')
|
646 |
+
group.add_argument('--lr-decay-iters', type=int, default=None,
|
647 |
+
help='number of iterations to decay learning rate over,'
|
648 |
+
' If None defaults to `--train-iters`')
|
649 |
+
group.add_argument('--lr-decay-samples', type=int, default=None,
|
650 |
+
help='number of samples to decay learning rate over,'
|
651 |
+
' If None defaults to `--train-samples`')
|
652 |
+
group.add_argument('--lr-warmup-fraction', type=float, default=None,
|
653 |
+
help='fraction of lr-warmup-(iters/samples) to use '
|
654 |
+
'for warmup (as a float)')
|
655 |
+
group.add_argument('--lr-warmup-iters', type=int, default=0,
|
656 |
+
help='number of iterations to linearly warmup '
|
657 |
+
'learning rate over.')
|
658 |
+
group.add_argument('--lr-warmup-samples', type=int, default=0,
|
659 |
+
help='number of samples to linearly warmup '
|
660 |
+
'learning rate over.')
|
661 |
+
group.add_argument('--warmup', type=int, default=None,
|
662 |
+
help='Old lr warmup argument, do not use. Use one of the'
|
663 |
+
'--lr-warmup-* arguments above')
|
664 |
+
group.add_argument('--min-lr', type=float, default=0.0,
|
665 |
+
help='Minumum value for learning rate. The scheduler'
|
666 |
+
'clip values below this threshold.')
|
667 |
+
group.add_argument('--override-opt_param-scheduler', action='store_true',
|
668 |
+
help='Reset the values of the scheduler (learning rate,'
|
669 |
+
'warmup iterations, minimum learning rate, maximum '
|
670 |
+
'number of iterations, and decay style from input '
|
671 |
+
'arguments and ignore values from checkpoints. Note'
|
672 |
+
'that all the above values will be reset.')
|
673 |
+
group.add_argument('--use-checkpoint-opt_param-scheduler', action='store_true',
|
674 |
+
help='Use checkpoint to set the values of the scheduler '
|
675 |
+
'(learning rate, warmup iterations, minimum learning '
|
676 |
+
'rate, maximum number of iterations, and decay style '
|
677 |
+
'from checkpoint and ignore input arguments.')
|
678 |
+
|
679 |
+
return parser
|
680 |
+
|
681 |
+
|
682 |
+
def _add_checkpointing_args(parser):
|
683 |
+
group = parser.add_argument_group(title='checkpointing')
|
684 |
+
|
685 |
+
group.add_argument('--save', type=str, default=None,
|
686 |
+
help='Output directory to save checkpoints to.')
|
687 |
+
group.add_argument('--save-interval', type=int, default=None,
|
688 |
+
help='Number of iterations between checkpoint saves.')
|
689 |
+
group.add_argument('--no-save-optim', action='store_true', default=None,
|
690 |
+
help='Do not save current optimizer.')
|
691 |
+
group.add_argument('--no-save-rng', action='store_true', default=None,
|
692 |
+
help='Do not save current rng state.')
|
693 |
+
group.add_argument('--load', type=str, default=None,
|
694 |
+
help='Directory containing a model checkpoint.')
|
695 |
+
group.add_argument('--no-load-optim', action='store_true', default=None,
|
696 |
+
help='Do not load optimizer when loading checkpoint.')
|
697 |
+
group.add_argument('--no-load-rng', action='store_true', default=None,
|
698 |
+
help='Do not load rng state when loading checkpoint.')
|
699 |
+
group.add_argument('--finetune', action='store_true',
|
700 |
+
help='Load model for finetuning. Do not load optimizer '
|
701 |
+
'or rng state from checkpoint and set iteration to 0. '
|
702 |
+
'Assumed when loading a release checkpoint.')
|
703 |
+
group.add_argument('--no-initialization', action='store_false',
|
704 |
+
help='Do not perform initialization when building model, '
|
705 |
+
'can reduce startup time when definitely loading from a '
|
706 |
+
'checkpoint',
|
707 |
+
dest='perform_initialization')
|
708 |
+
group.add_argument('--use-checkpoint-args', action='store_true',
|
709 |
+
help='Override any command line arguments with arguments '
|
710 |
+
'from the checkpoint')
|
711 |
+
|
712 |
+
return parser
|
713 |
+
|
714 |
+
|
715 |
+
def _add_mixed_precision_args(parser):
|
716 |
+
group = parser.add_argument_group(title='mixed precision')
|
717 |
+
|
718 |
+
group.add_argument('--fp16', action='store_true',
|
719 |
+
help='Run model in fp16 mode.')
|
720 |
+
group.add_argument('--bf16', action='store_true',
|
721 |
+
help='Run model in bfloat16 mode.')
|
722 |
+
group.add_argument('--loss-scale', type=float, default=None,
|
723 |
+
help='Static loss scaling, positive power of 2 '
|
724 |
+
'values can improve fp16 convergence. If None, dynamic'
|
725 |
+
'loss scaling is used.')
|
726 |
+
group.add_argument('--initial-loss-scale', type=float, default=2**32,
|
727 |
+
help='Initial loss-scale for dynamic loss scaling.')
|
728 |
+
group.add_argument('--min-loss-scale', type=float, default=1.0,
|
729 |
+
help='Minimum loss scale for dynamic loss scale.')
|
730 |
+
group.add_argument('--loss-scale-window', type=float, default=1000,
|
731 |
+
help='Window over which to raise/lower dynamic scale.')
|
732 |
+
group.add_argument('--hysteresis', type=int, default=2,
|
733 |
+
help='hysteresis for dynamic loss scaling')
|
734 |
+
group.add_argument('--fp32-residual-connection', action='store_true',
|
735 |
+
help='Move residual connections to fp32.')
|
736 |
+
group.add_argument('--no-query-key-layer-scaling', action='store_false',
|
737 |
+
help='Do not scale Q * K^T by 1 / layer-number.',
|
738 |
+
dest='apply_query_key_layer_scaling')
|
739 |
+
group.add_argument('--attention-softmax-in-fp32', action='store_true',
|
740 |
+
help='Run attention masking and softmax in fp32. '
|
741 |
+
'This flag is ignored unless '
|
742 |
+
'--no-query-key-layer-scaling is specified.')
|
743 |
+
group.add_argument('--accumulate-allreduce-grads-in-fp32',
|
744 |
+
action='store_true',
|
745 |
+
help='Gradient accumulation and all-reduce in fp32.')
|
746 |
+
group.add_argument('--fp16-lm-cross-entropy', action='store_true',
|
747 |
+
help='Move the cross entropy unreduced loss calculation'
|
748 |
+
'for lm head to fp16.')
|
749 |
+
|
750 |
+
return parser
|
751 |
+
|
752 |
+
|
753 |
+
def _add_distributed_args(parser):
|
754 |
+
group = parser.add_argument_group(title='distributed')
|
755 |
+
|
756 |
+
group.add_argument('--tensor-model-parallel-size', type=int, default=1,
|
757 |
+
help='Degree of tensor model parallelism.')
|
758 |
+
group.add_argument('--pipeline-model-parallel-size', type=int, default=1,
|
759 |
+
help='Degree of pipeline model parallelism.')
|
760 |
+
group.add_argument('--pipeline-model-parallel-split-rank',
|
761 |
+
type=int, default=None,
|
762 |
+
help='Rank where encoder and decoder should be split.')
|
763 |
+
group.add_argument('--model-parallel-size', type=int, default=None,
|
764 |
+
help='Old model parallel argument, do not use. Use '
|
765 |
+
'--tensor-model-parallel-size instead.')
|
766 |
+
group.add_argument('--num-layers-per-virtual-pipeline-stage', type=int, default=None,
|
767 |
+
help='Number of layers per virtual pipeline stage')
|
768 |
+
group.add_argument('--distributed-backend', default='nccl',
|
769 |
+
choices=['nccl', 'gloo'],
|
770 |
+
help='Which backend to use for distributed training.')
|
771 |
+
group.add_argument('--DDP-impl', default='local',
|
772 |
+
choices=['local', 'torch'],
|
773 |
+
help='which DistributedDataParallel implementation '
|
774 |
+
'to use.')
|
775 |
+
group.add_argument('--no-contiguous-buffers-in-local-ddp',
|
776 |
+
action='store_false', help='If set, dont use '
|
777 |
+
'contiguous buffer in local DDP.',
|
778 |
+
dest='use_contiguous_buffers_in_local_ddp')
|
779 |
+
group.add_argument('--no-scatter-gather-tensors-in-pipeline', action='store_false',
|
780 |
+
help='Use scatter/gather to optimize communication of tensors in pipeline',
|
781 |
+
dest='scatter_gather_tensors_in_pipeline')
|
782 |
+
group.add_argument('--local_rank', type=int, default=None,
|
783 |
+
help='local rank passed from distributed launcher.')
|
784 |
+
group.add_argument('--lazy-mpu-init', type=bool, required=False,
|
785 |
+
help='If set to True, initialize_megatron() '
|
786 |
+
'skips DDP initialization and returns function to '
|
787 |
+
'complete it instead.Also turns on '
|
788 |
+
'--use-cpu-initialization flag. This is for '
|
789 |
+
'external DDP manager.' )
|
790 |
+
group.add_argument('--use-cpu-initialization', action='store_true',
|
791 |
+
default=None, help='If set, affine parallel weights '
|
792 |
+
'initialization uses CPU' )
|
793 |
+
group.add_argument('--empty-unused-memory-level', default=0, type=int,
|
794 |
+
choices=[0, 1, 2],
|
795 |
+
help='Call torch.cuda.empty_cache() each iteration '
|
796 |
+
'(training and eval), to reduce fragmentation.'
|
797 |
+
'0=off, 1=moderate, 2=aggressive.')
|
798 |
+
group.add_argument('--standalone-embedding-stage', action='store_true',
|
799 |
+
default=False, help='If set, *input* embedding layer '
|
800 |
+
'is placed on its own pipeline stage, without any '
|
801 |
+
'transformer layers. (For T5, this flag currently only '
|
802 |
+
'affects the encoder embedding.)')
|
803 |
+
group.add_argument('--use-distributed-optimizer', action='store_true',
|
804 |
+
help='Use distributed optimizer.')
|
805 |
+
|
806 |
+
return parser
|
807 |
+
|
808 |
+
|
809 |
+
def _add_validation_args(parser):
|
810 |
+
group = parser.add_argument_group(title='validation')
|
811 |
+
|
812 |
+
group.add_argument('--eval-iters', type=int, default=100,
|
813 |
+
help='Number of iterations to run for evaluation'
|
814 |
+
'validation/test for.')
|
815 |
+
group.add_argument('--eval-interval', type=int, default=1000,
|
816 |
+
help='Interval between running evaluation on '
|
817 |
+
'validation set.')
|
818 |
+
|
819 |
+
return parser
|
820 |
+
|
821 |
+
|
822 |
+
def _add_data_args(parser):
|
823 |
+
group = parser.add_argument_group(title='data and dataloader')
|
824 |
+
|
825 |
+
group.add_argument('--data-path', nargs='*', default=None,
|
826 |
+
help='Path to the training dataset. Accepted format:'
|
827 |
+
'1) a single data path, 2) multiple datasets in the'
|
828 |
+
'form: dataset1-weight dataset1-path dataset2-weight '
|
829 |
+
'dataset2-path ...')
|
830 |
+
group.add_argument('--split', type=str, default='969, 30, 1',
|
831 |
+
help='Comma-separated list of proportions for training,'
|
832 |
+
' validation, and test split. For example the split '
|
833 |
+
'`90,5,5` will use 90%% of data for training, 5%% for '
|
834 |
+
'validation and 5%% for test.')
|
835 |
+
group.add_argument('--vocab-file', type=str, default=None,
|
836 |
+
help='Path to the vocab file.')
|
837 |
+
group.add_argument('--merge-file', type=str, default=None,
|
838 |
+
help='Path to the BPE merge file.')
|
839 |
+
group.add_argument('--vocab-extra-ids', type=int, default=0,
|
840 |
+
help='Number of additional vocabulary tokens. '
|
841 |
+
'They are used for span masking in the T5 model')
|
842 |
+
group.add_argument('--seq-length', type=int, default=None,
|
843 |
+
help='Maximum sequence length to process.')
|
844 |
+
group.add_argument('--encoder-seq-length', type=int, default=None,
|
845 |
+
help='Maximum encoder sequence length to process.'
|
846 |
+
'This should be exclusive of --seq-length')
|
847 |
+
group.add_argument('--decoder-seq-length', type=int, default=None,
|
848 |
+
help="Maximum decoder sequence length to process.")
|
849 |
+
group.add_argument('--retriever-seq-length', type=int, default=256,
|
850 |
+
help='Maximum sequence length for the biencoder model '
|
851 |
+
' for retriever')
|
852 |
+
group.add_argument('--sample-rate', type=float, default=1.0,
|
853 |
+
help='sample rate for training data. Supposed to be 0 '
|
854 |
+
' < sample_rate < 1')
|
855 |
+
group.add_argument('--mask-prob', type=float, default=0.15,
|
856 |
+
help='Probability of replacing a token with mask.')
|
857 |
+
group.add_argument('--short-seq-prob', type=float, default=0.1,
|
858 |
+
help='Probability of producing a short sequence.')
|
859 |
+
group.add_argument('--mmap-warmup', action='store_true',
|
860 |
+
help='Warm up mmap files.')
|
861 |
+
group.add_argument('--num-workers', type=int, default=2,
|
862 |
+
help="Dataloader number of workers.")
|
863 |
+
group.add_argument('--tokenizer-type', type=str,
|
864 |
+
default=None,
|
865 |
+
choices=['BertWordPieceLowerCase',
|
866 |
+
'BertWordPieceCase',
|
867 |
+
'GPT2BPETokenizer'],
|
868 |
+
help='What type of tokenizer to use.')
|
869 |
+
group.add_argument('--data-impl', type=str, default='infer',
|
870 |
+
choices=['lazy', 'cached', 'mmap', 'infer'],
|
871 |
+
help='Implementation of indexed datasets.')
|
872 |
+
group.add_argument('--reset-position-ids', action='store_true',
|
873 |
+
help='Reset posistion ids after end-of-document token.')
|
874 |
+
group.add_argument('--reset-attention-mask', action='store_true',
|
875 |
+
help='Reset self attention maske after '
|
876 |
+
'end-of-document token.')
|
877 |
+
group.add_argument('--eod-mask-loss', action='store_true',
|
878 |
+
help='Mask loss for the end of document tokens.')
|
879 |
+
|
880 |
+
return parser
|
881 |
+
|
882 |
+
|
883 |
+
def _add_autoresume_args(parser):
|
884 |
+
group = parser.add_argument_group(title='autoresume')
|
885 |
+
|
886 |
+
group.add_argument('--adlr-autoresume', action='store_true',
|
887 |
+
help='Enable autoresume on adlr cluster.')
|
888 |
+
group.add_argument('--adlr-autoresume-interval', type=int, default=1000,
|
889 |
+
help='Intervals over which check for autoresume'
|
890 |
+
'termination signal')
|
891 |
+
|
892 |
+
return parser
|
893 |
+
|
894 |
+
|
895 |
+
def _add_biencoder_args(parser):
|
896 |
+
group = parser.add_argument_group(title='biencoder')
|
897 |
+
|
898 |
+
# network size
|
899 |
+
group.add_argument('--ict-head-size', type=int, default=None,
|
900 |
+
help='Size of block embeddings to be used in ICT and '
|
901 |
+
'REALM (paper default: 128)')
|
902 |
+
group.add_argument('--biencoder-projection-dim', type=int, default=0,
|
903 |
+
help='Size of projection head used in biencoder (paper'
|
904 |
+
' default: 128)')
|
905 |
+
group.add_argument('--biencoder-shared-query-context-model', action='store_true',
|
906 |
+
help='Whether to share the parameters of the query '
|
907 |
+
'and context models or not')
|
908 |
+
|
909 |
+
# checkpointing
|
910 |
+
group.add_argument('--ict-load', type=str, default=None,
|
911 |
+
help='Directory containing an ICTBertModel checkpoint')
|
912 |
+
group.add_argument('--bert-load', type=str, default=None,
|
913 |
+
help='Directory containing an BertModel checkpoint '
|
914 |
+
'(needed to start ICT and REALM)')
|
915 |
+
|
916 |
+
# data
|
917 |
+
group.add_argument('--titles-data-path', type=str, default=None,
|
918 |
+
help='Path to titles dataset used for ICT')
|
919 |
+
group.add_argument('--query-in-block-prob', type=float, default=0.1,
|
920 |
+
help='Probability of keeping query in block for '
|
921 |
+
'ICT dataset')
|
922 |
+
group.add_argument('--use-one-sent-docs', action='store_true',
|
923 |
+
help='Whether to use one sentence documents in ICT')
|
924 |
+
group.add_argument('--evidence-data-path', type=str, default=None,
|
925 |
+
help='Path to Wikipedia Evidence frm DPR paper')
|
926 |
+
|
927 |
+
# training
|
928 |
+
group.add_argument('--retriever-report-topk-accuracies', nargs='+', type=int,
|
929 |
+
default=[], help="Which top-k accuracies to report "
|
930 |
+
"(e.g. '1 5 20')")
|
931 |
+
group.add_argument('--retriever-score-scaling', action='store_true',
|
932 |
+
help='Whether to scale retriever scores by inverse '
|
933 |
+
'square root of hidden size')
|
934 |
+
|
935 |
+
# faiss index
|
936 |
+
group.add_argument('--block-data-path', type=str, default=None,
|
937 |
+
help='Where to save/load BlockData to/from')
|
938 |
+
group.add_argument('--embedding-path', type=str, default=None,
|
939 |
+
help='Where to save/load Open-Retrieval Embedding'
|
940 |
+
' data to/from')
|
941 |
+
|
942 |
+
# indexer
|
943 |
+
group.add_argument('--indexer-batch-size', type=int, default=128,
|
944 |
+
help='How large of batches to use when doing indexing '
|
945 |
+
'jobs')
|
946 |
+
group.add_argument('--indexer-log-interval', type=int, default=1000,
|
947 |
+
help='After how many batches should the indexer '
|
948 |
+
'report progress')
|
949 |
+
return parser
|
950 |
+
|
951 |
+
|
952 |
+
def _add_vision_args(parser):
|
953 |
+
group = parser.add_argument_group(title="vision")
|
954 |
+
|
955 |
+
# general vision arguements
|
956 |
+
group.add_argument('--num-classes', type=int, default=1000,
|
957 |
+
help='num of classes in vision classificaiton task')
|
958 |
+
group.add_argument('--img-h', type=int, default=224,
|
959 |
+
help='Image height for vision classification task')
|
960 |
+
group.add_argument('--img-w', type=int, default=224,
|
961 |
+
help='Image height for vision classification task')
|
962 |
+
group.add_argument('--num-channels', type=int, default=3,
|
963 |
+
help='Number of channels in input image data')
|
964 |
+
group.add_argument('--patch-dim', type=int, default=16,
|
965 |
+
help='patch dimension')
|
966 |
+
group.add_argument('--classes-fraction', type=float, default=1.0,
|
967 |
+
help='training with fraction of classes.')
|
968 |
+
group.add_argument('--data-per-class-fraction', type=float, default=1.0,
|
969 |
+
help='training with fraction of data per class.')
|
970 |
+
group.add_argument('--no-data-sharding', action='store_false',
|
971 |
+
help='Disable data sharding.',
|
972 |
+
dest='data_sharding')
|
973 |
+
group.add_argument('--head-lr-mult', type=float, default=1.0,
|
974 |
+
help='learning rate multiplier for head during finetuning')
|
975 |
+
|
976 |
+
# pretraining type and backbone selection`
|
977 |
+
group.add_argument('--vision-pretraining', action='store_true',
|
978 |
+
help='flag to indicate vision pretraining')
|
979 |
+
group.add_argument('--vision-pretraining-type', type=str, default='classify',
|
980 |
+
choices=['classify', 'inpaint', 'dino'],
|
981 |
+
help='pretraining objectives')
|
982 |
+
group.add_argument('--vision-backbone-type', type=str, default='vit',
|
983 |
+
choices=['vit', 'mit', 'swin'],
|
984 |
+
help='backbone types types')
|
985 |
+
group.add_argument('--swin-backbone-type', type=str, default='tiny',
|
986 |
+
choices=['tiny', 'base', 'h3'],
|
987 |
+
help='pretraining objectives')
|
988 |
+
|
989 |
+
# inpainting arguments
|
990 |
+
group.add_argument('--mask-type', type=str, default='random',
|
991 |
+
choices=['random', 'row'],
|
992 |
+
help='mask types')
|
993 |
+
group.add_argument('--mask-factor', type=float, default=1.0,
|
994 |
+
help='mask size scaling parameter')
|
995 |
+
|
996 |
+
# dino arguments
|
997 |
+
group.add_argument('--iter-per-epoch', type=int, default=1250,
|
998 |
+
help='iterations per epoch')
|
999 |
+
group.add_argument('--dino-local-img-size', type=int, default=96,
|
1000 |
+
help='Image size for vision classification task')
|
1001 |
+
group.add_argument('--dino-local-crops-number', type=int, default=10,
|
1002 |
+
help='Number of local crops')
|
1003 |
+
group.add_argument('--dino-head-hidden-size', type=int, default=2048,
|
1004 |
+
help='Hidden dimension size in dino head')
|
1005 |
+
group.add_argument('--dino-bottleneck-size', type=int, default=256,
|
1006 |
+
help='Bottle neck dimension in dino head ')
|
1007 |
+
group.add_argument('--dino-freeze-last-layer', type=float, default=1,
|
1008 |
+
help='Freezing last layer weights')
|
1009 |
+
group.add_argument('--dino-norm-last-layer', action='store_true',
|
1010 |
+
help='Disable Norm in last layer.')
|
1011 |
+
group.add_argument('--dino-warmup-teacher-temp', type=float, default=0.04,
|
1012 |
+
help='warump teacher temperature')
|
1013 |
+
group.add_argument('--dino-teacher-temp', type=float, default=0.07,
|
1014 |
+
help='teacher temperature')
|
1015 |
+
group.add_argument('--dino-warmup-teacher-temp-epochs', type=int, default=30,
|
1016 |
+
help='warmup teacher temperaure epochs')
|
1017 |
+
|
1018 |
+
return parser
|
megatron/checkpointing.py
ADDED
@@ -0,0 +1,675 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Input/output checkpointing."""
|
17 |
+
|
18 |
+
import os
|
19 |
+
import random
|
20 |
+
import sys
|
21 |
+
import numpy as np
|
22 |
+
|
23 |
+
import torch
|
24 |
+
|
25 |
+
from megatron import (mpu,
|
26 |
+
update_num_microbatches)
|
27 |
+
from .global_vars import get_args
|
28 |
+
from .utils import (unwrap_model,
|
29 |
+
print_rank_0)
|
30 |
+
|
31 |
+
|
32 |
+
_CHECKPOINT_VERSION = None
|
33 |
+
|
34 |
+
def set_checkpoint_version(value):
|
35 |
+
global _CHECKPOINT_VERSION
|
36 |
+
if _CHECKPOINT_VERSION is not None:
|
37 |
+
assert _CHECKPOINT_VERSION == value, \
|
38 |
+
"checkpoint versions do not match"
|
39 |
+
_CHECKPOINT_VERSION = value
|
40 |
+
|
41 |
+
def get_checkpoint_version():
|
42 |
+
global _CHECKPOINT_VERSION
|
43 |
+
return _CHECKPOINT_VERSION
|
44 |
+
|
45 |
+
def check_checkpoint_args(checkpoint_args):
|
46 |
+
"""Ensure fixed arguments for a model are the same for the input
|
47 |
+
arguments and the one retrieved from checkpoint."""
|
48 |
+
args = get_args()
|
49 |
+
|
50 |
+
def _compare(arg_name, old_arg_name=None):
|
51 |
+
if old_arg_name is not None:
|
52 |
+
checkpoint_value = getattr(checkpoint_args, old_arg_name)
|
53 |
+
else:
|
54 |
+
checkpoint_value = getattr(checkpoint_args, arg_name)
|
55 |
+
args_value = getattr(args, arg_name)
|
56 |
+
error_message = '{} value from checkpoint ({}) is not equal to the ' \
|
57 |
+
'input argument value ({}).'.format(
|
58 |
+
arg_name, checkpoint_value, args_value)
|
59 |
+
assert checkpoint_value == args_value, error_message
|
60 |
+
|
61 |
+
_compare('num_layers')
|
62 |
+
_compare('hidden_size')
|
63 |
+
_compare('num_attention_heads')
|
64 |
+
if args.vocab_file:
|
65 |
+
_compare('max_position_embeddings')
|
66 |
+
_compare('make_vocab_size_divisible_by')
|
67 |
+
_compare('padded_vocab_size')
|
68 |
+
_compare('tokenizer_type')
|
69 |
+
if args.data_parallel_random_init:
|
70 |
+
_compare('data_parallel_random_init')
|
71 |
+
if get_checkpoint_version() < 3.0:
|
72 |
+
_compare('tensor_model_parallel_size',
|
73 |
+
old_arg_name='model_parallel_size')
|
74 |
+
if get_checkpoint_version() >= 3.0:
|
75 |
+
_compare('tensor_model_parallel_size')
|
76 |
+
_compare('pipeline_model_parallel_size')
|
77 |
+
|
78 |
+
def ensure_directory_exists(filename):
|
79 |
+
"""Build filename's path if it does not already exists."""
|
80 |
+
dirname = os.path.dirname(filename)
|
81 |
+
if not os.path.exists(dirname):
|
82 |
+
os.makedirs(dirname)
|
83 |
+
|
84 |
+
|
85 |
+
def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer, release=False,
|
86 |
+
pipeline_parallel=None, tensor_rank=None, pipeline_rank=None):
|
87 |
+
"""Determine the directory name for this rank's checkpoint."""
|
88 |
+
if release:
|
89 |
+
directory = 'release'
|
90 |
+
else:
|
91 |
+
directory = 'iter_{:07d}'.format(iteration)
|
92 |
+
|
93 |
+
# Use both the tensor and pipeline MP rank.
|
94 |
+
if pipeline_parallel is None:
|
95 |
+
pipeline_parallel = (mpu.get_pipeline_model_parallel_world_size() > 1)
|
96 |
+
if tensor_rank is None:
|
97 |
+
tensor_rank = mpu.get_tensor_model_parallel_rank()
|
98 |
+
if pipeline_rank is None:
|
99 |
+
pipeline_rank = mpu.get_pipeline_model_parallel_rank()
|
100 |
+
|
101 |
+
# Use both the tensor and pipeline MP rank. If using the distributed
|
102 |
+
# optimizer, then the optimizer's path must additionally include the
|
103 |
+
# data parallel rank.
|
104 |
+
if not pipeline_parallel:
|
105 |
+
common_path = os.path.join(checkpoints_path, directory,
|
106 |
+
f'mp_rank_{tensor_rank:02d}')
|
107 |
+
else:
|
108 |
+
common_path = os.path.join(checkpoints_path, directory,
|
109 |
+
f'mp_rank_{tensor_rank:02d}_{pipeline_rank:03d}')
|
110 |
+
|
111 |
+
if use_distributed_optimizer:
|
112 |
+
model_name = os.path.join(common_path, "model_rng.pt")
|
113 |
+
optim_name = os.path.join(
|
114 |
+
common_path + "_%03d" % mpu.get_data_parallel_rank(),
|
115 |
+
"optim.pt")
|
116 |
+
else:
|
117 |
+
model_name = optim_name = os.path.join(common_path, "model_optim_rng.pt")
|
118 |
+
return model_name, optim_name
|
119 |
+
|
120 |
+
def find_checkpoint_rank_0(checkpoints_path, iteration, use_distributed_optimizer, release=False):
|
121 |
+
"""Finds the checkpoint for rank 0 without knowing if we are using
|
122 |
+
pipeline parallelism or not.
|
123 |
+
|
124 |
+
Since the checkpoint naming scheme changes if pipeline parallelism
|
125 |
+
is present, we need to look for both naming schemes if we don't
|
126 |
+
know if the checkpoint has pipeline parallelism.
|
127 |
+
|
128 |
+
"""
|
129 |
+
|
130 |
+
# Look for checkpoint with no pipelining
|
131 |
+
filenames = get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer, release,
|
132 |
+
pipeline_parallel=False,
|
133 |
+
tensor_rank=0, pipeline_rank=0)
|
134 |
+
if os.path.isfile(filenames[0]):
|
135 |
+
return filenames
|
136 |
+
|
137 |
+
# Look for checkpoint with pipelining
|
138 |
+
filenames = get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer, release,
|
139 |
+
pipeline_parallel=True,
|
140 |
+
tensor_rank=0, pipeline_rank=0)
|
141 |
+
if os.path.isfile(filenames[0]):
|
142 |
+
return filenames
|
143 |
+
|
144 |
+
return None, None
|
145 |
+
|
146 |
+
def get_checkpoint_tracker_filename(checkpoints_path):
|
147 |
+
|
148 |
+
"""Tracker file rescords the latest chckpoint during
|
149 |
+
training to restart from."""
|
150 |
+
return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt')
|
151 |
+
|
152 |
+
|
153 |
+
def read_metadata(tracker_filename):
|
154 |
+
# Read the tracker file and either set the iteration or
|
155 |
+
# mark it as a release checkpoint.
|
156 |
+
iteration = 0
|
157 |
+
release = False
|
158 |
+
with open(tracker_filename, 'r') as f:
|
159 |
+
metastring = f.read().strip()
|
160 |
+
try:
|
161 |
+
iteration = int(metastring)
|
162 |
+
except ValueError:
|
163 |
+
release = metastring == 'release'
|
164 |
+
if not release:
|
165 |
+
print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format(
|
166 |
+
tracker_filename))
|
167 |
+
sys.exit()
|
168 |
+
assert iteration > 0 or release, 'error parsing metadata file {}'.format(
|
169 |
+
tracker_filename)
|
170 |
+
|
171 |
+
# Get the max iteration retrieved across the ranks.
|
172 |
+
if torch.distributed.is_initialized():
|
173 |
+
iters_cuda = torch.cuda.LongTensor([iteration])
|
174 |
+
torch.distributed.all_reduce(iters_cuda, op=torch.distributed.ReduceOp.MAX)
|
175 |
+
max_iter = iters_cuda[0].item()
|
176 |
+
|
177 |
+
# We should now have all the same iteration.
|
178 |
+
# If not, print a warning and chose the maximum
|
179 |
+
# iteration across all ranks.
|
180 |
+
if iteration != max_iter:
|
181 |
+
print('WARNING: on rank {} found iteration {} in the '
|
182 |
+
'metadata while max iteration across the ranks '
|
183 |
+
'is {}, replacing it with max iteration.'.format(
|
184 |
+
rank, iteration, max_iter), flush=True)
|
185 |
+
else:
|
186 |
+
# When loading a checkpoint outside of training (for example,
|
187 |
+
# when editing it), we might not have torch distributed
|
188 |
+
# initialized, in this case, just assume we have the latest
|
189 |
+
max_iter = iteration
|
190 |
+
return max_iter, release
|
191 |
+
|
192 |
+
|
193 |
+
def get_rng_state():
|
194 |
+
""" collect rng state across data parallel ranks """
|
195 |
+
args = get_args()
|
196 |
+
rng_state = {
|
197 |
+
'random_rng_state': random.getstate(),
|
198 |
+
'np_rng_state': np.random.get_state(),
|
199 |
+
'torch_rng_state': torch.get_rng_state(),
|
200 |
+
'cuda_rng_state': torch.cuda.get_rng_state(),
|
201 |
+
'rng_tracker_states': mpu.get_cuda_rng_tracker().get_states()}
|
202 |
+
|
203 |
+
rng_state_list = None
|
204 |
+
if torch.distributed.is_initialized() and \
|
205 |
+
mpu.get_data_parallel_world_size() > 1 and \
|
206 |
+
args.data_parallel_random_init:
|
207 |
+
rng_state_list = \
|
208 |
+
[None for i in range(mpu.get_data_parallel_world_size())]
|
209 |
+
torch.distributed.all_gather_object(
|
210 |
+
rng_state_list,
|
211 |
+
rng_state,
|
212 |
+
group=mpu.get_data_parallel_group())
|
213 |
+
else:
|
214 |
+
rng_state_list = [rng_state]
|
215 |
+
|
216 |
+
return rng_state_list
|
217 |
+
|
218 |
+
|
219 |
+
def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
|
220 |
+
"""Save a model checkpoint."""
|
221 |
+
args = get_args()
|
222 |
+
|
223 |
+
# Only rank zero of the data parallel writes to the disk.
|
224 |
+
model = unwrap_model(model)
|
225 |
+
|
226 |
+
print_rank_0('saving checkpoint at iteration {:7d} to {}'.format(
|
227 |
+
iteration, args.save))
|
228 |
+
|
229 |
+
# Collect rng state across data parallel ranks.
|
230 |
+
rng_state = get_rng_state()
|
231 |
+
|
232 |
+
# Checkpoint file names.
|
233 |
+
model_checkpoint_name, optim_checkpoint_name = \
|
234 |
+
get_checkpoint_names(args.save, iteration, args.use_distributed_optimizer)
|
235 |
+
|
236 |
+
# Collect args, model, RNG.
|
237 |
+
model_state_dict = {}
|
238 |
+
if not torch.distributed.is_initialized() \
|
239 |
+
or mpu.get_data_parallel_rank() == 0:
|
240 |
+
|
241 |
+
# Arguments, iteration, and model.
|
242 |
+
model_state_dict['args'] = args
|
243 |
+
model_state_dict['checkpoint_version'] = 3.0
|
244 |
+
model_state_dict['iteration'] = iteration
|
245 |
+
if len(model) == 1:
|
246 |
+
model_state_dict['model'] = model[0].state_dict_for_save_checkpoint()
|
247 |
+
else:
|
248 |
+
for i in range(len(model)):
|
249 |
+
mpu.set_virtual_pipeline_model_parallel_rank(i)
|
250 |
+
model_state_dict['model%d' % i] = \
|
251 |
+
model[i].state_dict_for_save_checkpoint()
|
252 |
+
|
253 |
+
# RNG states.
|
254 |
+
if not args.no_save_rng:
|
255 |
+
model_state_dict["rng_state"] = rng_state
|
256 |
+
|
257 |
+
# Collect optimizer state. (Optimizer is saved separately from the model, due
|
258 |
+
# to the conflicting data pattern when using the distributed optimizer.)
|
259 |
+
optim_state_dict = {}
|
260 |
+
if not args.no_save_optim \
|
261 |
+
and (not torch.distributed.is_initialized()
|
262 |
+
or mpu.get_data_parallel_rank() == 0
|
263 |
+
or args.use_distributed_optimizer):
|
264 |
+
|
265 |
+
# Optimizer stuff.
|
266 |
+
if optimizer is not None:
|
267 |
+
optim_state_dict['optimizer'] = optimizer.state_dict()
|
268 |
+
if opt_param_scheduler is not None:
|
269 |
+
optim_state_dict['opt_param_scheduler'] = \
|
270 |
+
opt_param_scheduler.state_dict()
|
271 |
+
|
272 |
+
# Save.
|
273 |
+
if args.use_distributed_optimizer:
|
274 |
+
# Save model separate from optimizer.
|
275 |
+
if model_state_dict:
|
276 |
+
ensure_directory_exists(model_checkpoint_name)
|
277 |
+
torch.save(model_state_dict, model_checkpoint_name)
|
278 |
+
if optim_state_dict:
|
279 |
+
ensure_directory_exists(optim_checkpoint_name)
|
280 |
+
torch.save(optim_state_dict, optim_checkpoint_name)
|
281 |
+
else:
|
282 |
+
# Save model and optimizer together.
|
283 |
+
state_dict = {**model_state_dict, **optim_state_dict}
|
284 |
+
if state_dict: # only saves if populated (i.e., inherits conditions above)
|
285 |
+
ensure_directory_exists(model_checkpoint_name)
|
286 |
+
torch.save(state_dict, model_checkpoint_name)
|
287 |
+
|
288 |
+
# Wait so everyone is done (necessary)
|
289 |
+
if torch.distributed.is_initialized():
|
290 |
+
torch.distributed.barrier()
|
291 |
+
|
292 |
+
print_rank_0(' successfully saved checkpoint at iteration {:7d} to {}'.format(
|
293 |
+
iteration, args.save))
|
294 |
+
|
295 |
+
# And update the latest iteration
|
296 |
+
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
|
297 |
+
tracker_filename = get_checkpoint_tracker_filename(args.save)
|
298 |
+
with open(tracker_filename, 'w') as f:
|
299 |
+
f.write(str(iteration))
|
300 |
+
|
301 |
+
# Wait so everyone is done (not necessary)
|
302 |
+
if torch.distributed.is_initialized():
|
303 |
+
torch.distributed.barrier()
|
304 |
+
|
305 |
+
def _transpose_first_dim(t, num_splits, num_splits_first, model):
|
306 |
+
input_shape = t.size()
|
307 |
+
# We use a self_attention module but the values extracted aren't
|
308 |
+
# specific to self attention so should work for cross attention as well
|
309 |
+
while hasattr(model, 'module'):
|
310 |
+
model = model.module
|
311 |
+
attention_module = model.language_model.encoder.layers[0].self_attention
|
312 |
+
hidden_size_per_attention_head = attention_module.hidden_size_per_attention_head
|
313 |
+
num_attention_heads_per_partition = attention_module.num_attention_heads_per_partition
|
314 |
+
if num_splits_first:
|
315 |
+
"""[num_splits * np * hn, h]
|
316 |
+
-->(view) [num_splits, np, hn, h]
|
317 |
+
-->(tranpose) [np, num_splits, hn, h]
|
318 |
+
-->(view) [np * num_splits * hn, h] """
|
319 |
+
|
320 |
+
intermediate_shape = \
|
321 |
+
(num_splits, num_attention_heads_per_partition,
|
322 |
+
hidden_size_per_attention_head) + input_shape[1:]
|
323 |
+
|
324 |
+
t = t.view(*intermediate_shape)
|
325 |
+
t = t.transpose(0, 1).contiguous()
|
326 |
+
else:
|
327 |
+
"""[np * hn * num_splits, h]
|
328 |
+
-->(view) [np, hn, num_splits, h]
|
329 |
+
-->(tranpose) [np, num_splits, hn, h]
|
330 |
+
-->(view) [np * num_splits * hn, h] """
|
331 |
+
|
332 |
+
intermediate_shape = \
|
333 |
+
(num_attention_heads_per_partition,
|
334 |
+
hidden_size_per_attention_head, num_splits) +\
|
335 |
+
input_shape[1:]
|
336 |
+
|
337 |
+
t = t.view(*intermediate_shape)
|
338 |
+
t = t.transpose(1, 2).contiguous()
|
339 |
+
t = t.view(*input_shape)
|
340 |
+
|
341 |
+
return t
|
342 |
+
|
343 |
+
def fix_query_key_value_ordering(model, checkpoint_version):
|
344 |
+
"""Fix up query/key/value matrix ordering if checkpoint
|
345 |
+
version is smaller than 2.0
|
346 |
+
"""
|
347 |
+
if checkpoint_version < 2.0:
|
348 |
+
if isinstance(model, list):
|
349 |
+
assert len(model)==1
|
350 |
+
model = model[0]
|
351 |
+
for name, param in model.named_parameters():
|
352 |
+
if name.endswith(('.query_key_value.weight', '.query_key_value.bias')):
|
353 |
+
if checkpoint_version == 0:
|
354 |
+
fixed_param = _transpose_first_dim(param.data, 3, True, model)
|
355 |
+
elif checkpoint_version == 1.0:
|
356 |
+
fixed_param = _transpose_first_dim(param.data, 3, False, model)
|
357 |
+
else:
|
358 |
+
print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
|
359 |
+
sys.exit()
|
360 |
+
param.data.copy_(fixed_param)
|
361 |
+
if name.endswith(('.key_value.weight', '.key_value.bias')):
|
362 |
+
if checkpoint_version == 0:
|
363 |
+
fixed_param = _transpose_first_dim(param.data, 2, True, model)
|
364 |
+
elif checkpoint_version == 1.0:
|
365 |
+
fixed_param = _transpose_first_dim(param.data, 2, False, model)
|
366 |
+
else:
|
367 |
+
print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
|
368 |
+
sys.exit()
|
369 |
+
param.data.copy_(fixed_param)
|
370 |
+
print_rank_0(" succesfully fixed query-key-values ordering for"
|
371 |
+
" checkpoint version {}".format(checkpoint_version))
|
372 |
+
|
373 |
+
def _load_base_checkpoint(load_dir, use_distributed_optimizer, rank0=False):
|
374 |
+
""" Load the base state_dict from the given directory
|
375 |
+
|
376 |
+
If rank0 is true, just loads rank 0 checkpoint, ignoring arguments.
|
377 |
+
"""
|
378 |
+
|
379 |
+
|
380 |
+
# Read the tracker file and set the iteration.
|
381 |
+
tracker_filename = get_checkpoint_tracker_filename(load_dir)
|
382 |
+
|
383 |
+
# If no tracker file, return nothing
|
384 |
+
if not os.path.isfile(tracker_filename):
|
385 |
+
if not rank0:
|
386 |
+
print_rank_0('WARNING: could not find the metadata file {} '.format(
|
387 |
+
tracker_filename))
|
388 |
+
print_rank_0(' will not load any checkpoints and will start from '
|
389 |
+
'random')
|
390 |
+
return None, None, False
|
391 |
+
|
392 |
+
# Otherwise, read the tracker file and either set the iteration or
|
393 |
+
# mark it as a release checkpoint.
|
394 |
+
iteration, release = read_metadata(tracker_filename)
|
395 |
+
|
396 |
+
# Checkpoint.
|
397 |
+
if rank0:
|
398 |
+
checkpoint_names = find_checkpoint_rank_0(load_dir, iteration, use_distributed_optimizer,
|
399 |
+
release)
|
400 |
+
else:
|
401 |
+
checkpoint_names = get_checkpoint_names(load_dir, iteration, use_distributed_optimizer,
|
402 |
+
release)
|
403 |
+
if release:
|
404 |
+
print_rank_0(f' loading release checkpoint from {load_dir}')
|
405 |
+
else:
|
406 |
+
print_rank_0(f' loading checkpoint from {load_dir} at iteration {iteration}')
|
407 |
+
|
408 |
+
model_checkpoint_name, optim_checkpoint_name = checkpoint_names
|
409 |
+
|
410 |
+
# Load the checkpoint.
|
411 |
+
try:
|
412 |
+
model_state_dict = torch.load(model_checkpoint_name, map_location='cpu')
|
413 |
+
if use_distributed_optimizer:
|
414 |
+
optim_state_dict = torch.load(optim_checkpoint_name, map_location='cpu')
|
415 |
+
else:
|
416 |
+
optim_state_dict = model_state_dict
|
417 |
+
except ModuleNotFoundError:
|
418 |
+
from megatron.fp16_deprecated import loss_scaler
|
419 |
+
# For backward compatibility.
|
420 |
+
if not rank0:
|
421 |
+
print_rank_0(' > deserializing using the old code structure ...')
|
422 |
+
sys.modules['fp16.loss_scaler'] = sys.modules[
|
423 |
+
'megatron.fp16_deprecated.loss_scaler']
|
424 |
+
sys.modules['megatron.fp16.loss_scaler'] = sys.modules[
|
425 |
+
'megatron.fp16_deprecated.loss_scaler']
|
426 |
+
model_state_dict = torch.load(model_checkpoint_name, map_location='cpu')
|
427 |
+
optim_state_dict = torch.load(optim_checkpoint_name, map_location='cpu')
|
428 |
+
sys.modules.pop('fp16.loss_scaler', None)
|
429 |
+
sys.modules.pop('megatron.fp16.loss_scaler', None)
|
430 |
+
except BaseException as e:
|
431 |
+
print_rank_0('could not load the checkpoint')
|
432 |
+
print_rank_0(e)
|
433 |
+
sys.exit()
|
434 |
+
|
435 |
+
return model_state_dict, optim_state_dict, release
|
436 |
+
|
437 |
+
def load_args_from_checkpoint(args, load_arg='load'):
|
438 |
+
"""Set required arguments from the checkpoint specified in the
|
439 |
+
arguments.
|
440 |
+
|
441 |
+
Will overwrite arguments that have a non-None default value, but
|
442 |
+
will leave any arguments that default to None as set.
|
443 |
+
|
444 |
+
Returns the same args NameSpace with the new values added/updated.
|
445 |
+
|
446 |
+
If no checkpoint is specified in args, or if the checkpoint is
|
447 |
+
there but invalid, the arguments will not be modified
|
448 |
+
|
449 |
+
"""
|
450 |
+
load_dir = getattr(args, load_arg)
|
451 |
+
|
452 |
+
if load_dir is None:
|
453 |
+
print_rank_0('No load directory specified, using provided arguments.')
|
454 |
+
return args
|
455 |
+
|
456 |
+
model_state_dict, optim_state_dict, release = \
|
457 |
+
_load_base_checkpoint(load_dir,
|
458 |
+
use_distributed_optimizer=args.use_distributed_optimizer,
|
459 |
+
rank0=True)
|
460 |
+
|
461 |
+
# For args we only care about model state dict
|
462 |
+
state_dict = model_state_dict
|
463 |
+
|
464 |
+
if not state_dict:
|
465 |
+
print_rank_0('Checkpoint not found to provide arguments, using provided arguments.')
|
466 |
+
return args
|
467 |
+
|
468 |
+
if 'args' not in state_dict:
|
469 |
+
print_rank_0('Checkpoint provided does not have arguments saved, using provided arguments.')
|
470 |
+
return args
|
471 |
+
|
472 |
+
checkpoint_args = state_dict['args']
|
473 |
+
checkpoint_version = state_dict.get('checkpoint_version', 0)
|
474 |
+
args.iteration = state_dict['iteration']
|
475 |
+
|
476 |
+
def _set_arg(arg_name, old_arg_name=None, force=False):
|
477 |
+
if not force and getattr(args, arg_name, None) is not None:
|
478 |
+
return
|
479 |
+
|
480 |
+
if old_arg_name is not None:
|
481 |
+
checkpoint_value = getattr(checkpoint_args, old_arg_name, None)
|
482 |
+
else:
|
483 |
+
checkpoint_value = getattr(checkpoint_args, arg_name, None)
|
484 |
+
|
485 |
+
if checkpoint_value is not None:
|
486 |
+
print_rank_0(f"Setting {arg_name} to {checkpoint_value} from checkpoint")
|
487 |
+
setattr(args, arg_name, checkpoint_value)
|
488 |
+
|
489 |
+
_set_arg('num_layers')
|
490 |
+
_set_arg('hidden_size')
|
491 |
+
_set_arg('ffn_hidden_size')
|
492 |
+
_set_arg('seq_length')
|
493 |
+
_set_arg('num_attention_heads')
|
494 |
+
_set_arg('kv_channels')
|
495 |
+
_set_arg('max_position_embeddings')
|
496 |
+
_set_arg('tokenizer_type')
|
497 |
+
_set_arg('padded_vocab_size')
|
498 |
+
if checkpoint_version < 3.0:
|
499 |
+
_set_arg('tensor_model_parallel_size',
|
500 |
+
'model_parallel_size')
|
501 |
+
else:
|
502 |
+
_set_arg('tensor_model_parallel_size', force=True)
|
503 |
+
_set_arg('pipeline_model_parallel_size', force=True)
|
504 |
+
_set_arg('num_layers_per_virtual_pipeline_stage')
|
505 |
+
return args
|
506 |
+
|
507 |
+
|
508 |
+
def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', strict=True):
|
509 |
+
"""Load a model checkpoint and return the iteration.
|
510 |
+
strict (bool): whether to strictly enforce that the keys in
|
511 |
+
:attr:`state_dict` of the checkpoint match the names of
|
512 |
+
parameters and buffers in model.
|
513 |
+
"""
|
514 |
+
args = get_args()
|
515 |
+
load_dir = getattr(args, load_arg)
|
516 |
+
|
517 |
+
model = unwrap_model(model)
|
518 |
+
|
519 |
+
model_state_dict, optim_state_dict, release = \
|
520 |
+
_load_base_checkpoint(load_dir,
|
521 |
+
use_distributed_optimizer=args.use_distributed_optimizer,
|
522 |
+
rank0=False)
|
523 |
+
|
524 |
+
if model_state_dict is None:
|
525 |
+
return 0
|
526 |
+
|
527 |
+
# set checkpoint version
|
528 |
+
set_checkpoint_version(model_state_dict.get('checkpoint_version', 0))
|
529 |
+
|
530 |
+
# Set iteration.
|
531 |
+
if args.finetune or release:
|
532 |
+
iteration = 0
|
533 |
+
else:
|
534 |
+
try:
|
535 |
+
iteration = model_state_dict['iteration']
|
536 |
+
except KeyError:
|
537 |
+
try: # Backward compatible with older checkpoints
|
538 |
+
iteration = model_state_dict['total_iters']
|
539 |
+
except KeyError:
|
540 |
+
print_rank_0('A metadata file exists but unable to load '
|
541 |
+
'iteration from checkpoint {}, exiting'.format(
|
542 |
+
checkpoint_name))
|
543 |
+
sys.exit()
|
544 |
+
|
545 |
+
# Check arguments.
|
546 |
+
assert args.consumed_train_samples == 0
|
547 |
+
assert args.consumed_valid_samples == 0
|
548 |
+
if 'args' in model_state_dict:
|
549 |
+
checkpoint_args = model_state_dict['args']
|
550 |
+
check_checkpoint_args(checkpoint_args)
|
551 |
+
args.consumed_train_samples = getattr(checkpoint_args,
|
552 |
+
'consumed_train_samples', 0)
|
553 |
+
update_num_microbatches(consumed_samples=args.consumed_train_samples)
|
554 |
+
args.consumed_valid_samples = getattr(checkpoint_args,
|
555 |
+
'consumed_valid_samples', 0)
|
556 |
+
else:
|
557 |
+
print_rank_0('could not find arguments in the checkpoint ...')
|
558 |
+
|
559 |
+
# Model.
|
560 |
+
if len(model) == 1:
|
561 |
+
model[0].load_state_dict(model_state_dict['model'], strict=strict)
|
562 |
+
else:
|
563 |
+
for i in range(len(model)):
|
564 |
+
mpu.set_virtual_pipeline_model_parallel_rank(i)
|
565 |
+
model[i].load_state_dict(model_state_dict['model%d' % i], strict=strict)
|
566 |
+
|
567 |
+
# Fix up query/key/value matrix ordering if needed
|
568 |
+
checkpoint_version = get_checkpoint_version()
|
569 |
+
print_rank_0(f' checkpoint version {checkpoint_version}')
|
570 |
+
fix_query_key_value_ordering(model, checkpoint_version)
|
571 |
+
|
572 |
+
# Optimizer.
|
573 |
+
if not release and not args.finetune and not args.no_load_optim:
|
574 |
+
try:
|
575 |
+
if optimizer is not None:
|
576 |
+
optimizer.load_state_dict(optim_state_dict['optimizer'])
|
577 |
+
if opt_param_scheduler is not None:
|
578 |
+
if 'lr_scheduler' in optim_state_dict: # backward compatbility
|
579 |
+
opt_param_scheduler.load_state_dict(optim_state_dict['lr_scheduler'])
|
580 |
+
else:
|
581 |
+
opt_param_scheduler.load_state_dict(optim_state_dict['opt_param_scheduler'])
|
582 |
+
except KeyError:
|
583 |
+
print_rank_0('Unable to load optimizer from checkpoint {}. '
|
584 |
+
'Specify --no-load-optim or --finetune to prevent '
|
585 |
+
'attempting to load the optimizer state, '
|
586 |
+
'exiting ...'.format(checkpoint_name))
|
587 |
+
sys.exit()
|
588 |
+
|
589 |
+
# rng states.
|
590 |
+
if not release and not args.finetune and not args.no_load_rng:
|
591 |
+
try:
|
592 |
+
if 'rng_state' in model_state_dict:
|
593 |
+
# access rng_state for data parallel rank
|
594 |
+
if args.data_parallel_random_init:
|
595 |
+
|
596 |
+
rng_state = model_state_dict['rng_state'][mpu.get_data_parallel_rank()]
|
597 |
+
else:
|
598 |
+
rng_state = model_state_dict['rng_state'][0]
|
599 |
+
random.setstate(rng_state['random_rng_state'])
|
600 |
+
np.random.set_state(rng_state['np_rng_state'])
|
601 |
+
torch.set_rng_state(rng_state['torch_rng_state'])
|
602 |
+
torch.cuda.set_rng_state(rng_state['cuda_rng_state'])
|
603 |
+
# Check for empty states array
|
604 |
+
if not rng_state['rng_tracker_states']:
|
605 |
+
raise KeyError
|
606 |
+
mpu.get_cuda_rng_tracker().set_states(
|
607 |
+
rng_state['rng_tracker_states'])
|
608 |
+
else: # backward compatability
|
609 |
+
random.setstate(model_state_dict['random_rng_state'])
|
610 |
+
np.random.set_state(model_state_dict['np_rng_state'])
|
611 |
+
torch.set_rng_state(model_state_dict['torch_rng_state'])
|
612 |
+
torch.cuda.set_rng_state(model_state_dict['cuda_rng_state'])
|
613 |
+
# Check for empty states array
|
614 |
+
if not model_state_dict['rng_tracker_states']:
|
615 |
+
raise KeyError
|
616 |
+
mpu.get_cuda_rng_tracker().set_states(
|
617 |
+
model_state_dict['rng_tracker_states'])
|
618 |
+
except KeyError:
|
619 |
+
print_rank_0('Unable to load rng state from checkpoint {}. '
|
620 |
+
'Specify --no-load-rng or --finetune to prevent '
|
621 |
+
'attempting to load the rng state, '
|
622 |
+
'exiting ...'.format(checkpoint_name))
|
623 |
+
sys.exit()
|
624 |
+
|
625 |
+
# Some utilities want to load a checkpoint without distributed being initialized
|
626 |
+
if torch.distributed.is_initialized():
|
627 |
+
torch.distributed.barrier()
|
628 |
+
|
629 |
+
print_rank_0(f' successfully loaded checkpoint from {args.load} '
|
630 |
+
f'at iteration {iteration}')
|
631 |
+
|
632 |
+
return iteration
|
633 |
+
|
634 |
+
|
635 |
+
def load_biencoder_checkpoint(model, only_query_model=False,
|
636 |
+
only_context_model=False, custom_load_path=None):
|
637 |
+
"""
|
638 |
+
selectively load retrieval models for indexing/retrieving
|
639 |
+
from saved checkpoints
|
640 |
+
"""
|
641 |
+
|
642 |
+
args = get_args()
|
643 |
+
|
644 |
+
model = unwrap_model(model)
|
645 |
+
|
646 |
+
load_path = custom_load_path if custom_load_path is not None else args.load
|
647 |
+
|
648 |
+
tracker_filename = get_checkpoint_tracker_filename(load_path)
|
649 |
+
with open(tracker_filename, 'r') as f:
|
650 |
+
iteration = int(f.read().strip())
|
651 |
+
|
652 |
+
checkpoint_name, _ = get_checkpoint_names(load_path, iteration,
|
653 |
+
args.use_distributed_optimizer,
|
654 |
+
release=False)
|
655 |
+
|
656 |
+
if mpu.get_data_parallel_rank() == 0:
|
657 |
+
print('global rank {} is loading checkpoint {}'.format(
|
658 |
+
torch.distributed.get_rank(), checkpoint_name))
|
659 |
+
|
660 |
+
state_dict = torch.load(model_checkpoint_name, map_location='cpu')
|
661 |
+
ret_state_dict = state_dict['model']
|
662 |
+
|
663 |
+
if only_query_model:
|
664 |
+
ret_state_dict.pop('context_model')
|
665 |
+
if only_context_model:
|
666 |
+
ret_state_dict.pop('query_model')
|
667 |
+
|
668 |
+
assert len(model) == 1
|
669 |
+
model[0].load_state_dict(ret_state_dict)
|
670 |
+
torch.distributed.barrier()
|
671 |
+
|
672 |
+
if mpu.get_data_parallel_rank() == 0:
|
673 |
+
print(' successfully loaded {}'.format(checkpoint_name))
|
674 |
+
|
675 |
+
return model
|
megatron/data/Makefile
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CXXFLAGS += -O3 -Wall -shared -std=c++11 -fPIC -fdiagnostics-color
|
2 |
+
CPPFLAGS += $(shell python3 -m pybind11 --includes)
|
3 |
+
LIBNAME = helpers
|
4 |
+
LIBEXT = $(shell python3-config --extension-suffix)
|
5 |
+
|
6 |
+
default: $(LIBNAME)$(LIBEXT)
|
7 |
+
|
8 |
+
%$(LIBEXT): %.cpp
|
9 |
+
$(CXX) $(CXXFLAGS) $(CPPFLAGS) $< -o $@
|
megatron/data/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from . import indexed_dataset
|
megatron/data/autoaugment.py
ADDED
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""AutoAugment data augmentation policy for ImageNet.
|
2 |
+
|
3 |
+
-- Begin license text.
|
4 |
+
|
5 |
+
MIT License
|
6 |
+
|
7 |
+
Copyright (c) 2018 Philip Popien
|
8 |
+
|
9 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
10 |
+
of this software and associated documentation files (the "Software"), to deal
|
11 |
+
in the Software without restriction, including without limitation the rights
|
12 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
13 |
+
copies of the Software, and to permit persons to whom the Software is
|
14 |
+
furnished to do so, subject to the following conditions:
|
15 |
+
|
16 |
+
The above copyright notice and this permission notice shall be included in all
|
17 |
+
copies or substantial portions of the Software.
|
18 |
+
|
19 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
20 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
21 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
22 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
23 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
24 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
25 |
+
SOFTWARE.
|
26 |
+
|
27 |
+
-- End license text.
|
28 |
+
|
29 |
+
Code adapted from https://github.com/DeepVoltaire/AutoAugment.
|
30 |
+
|
31 |
+
This module implements the fixed AutoAugment data augmentation policy for ImageNet provided in
|
32 |
+
Appendix A, Table 9 of reference [1]. It does not include any of the search code for augmentation
|
33 |
+
policies.
|
34 |
+
|
35 |
+
Reference:
|
36 |
+
[1] https://arxiv.org/abs/1805.09501
|
37 |
+
"""
|
38 |
+
|
39 |
+
import random
|
40 |
+
|
41 |
+
import numpy as np
|
42 |
+
from PIL import Image
|
43 |
+
from PIL import ImageEnhance
|
44 |
+
from PIL import ImageOps
|
45 |
+
|
46 |
+
_MAX_LEVEL = 10 # Maximum integer strength of an augmentation, if applicable.
|
47 |
+
|
48 |
+
|
49 |
+
class ImageNetPolicy:
|
50 |
+
"""Definition of an ImageNetPolicy.
|
51 |
+
|
52 |
+
Implements a fixed AutoAugment data augmentation policy targeted at
|
53 |
+
ImageNet training by randomly applying at runtime one of the 25 pre-defined
|
54 |
+
data augmentation sub-policies provided in Reference [1].
|
55 |
+
|
56 |
+
Usage example as a Pytorch Transform:
|
57 |
+
>>> transform=transforms.Compose([transforms.Resize(256),
|
58 |
+
>>> ImageNetPolicy(),
|
59 |
+
>>> transforms.ToTensor()])
|
60 |
+
"""
|
61 |
+
|
62 |
+
def __init__(self, fillcolor=(128, 128, 128)):
|
63 |
+
"""Initialize an ImageNetPolicy.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
fillcolor (tuple): RGB color components of the color to be used for
|
67 |
+
filling when needed (default: (128, 128, 128), which
|
68 |
+
corresponds to gray).
|
69 |
+
"""
|
70 |
+
# Instantiate a list of sub-policies.
|
71 |
+
# Each entry of the list is a SubPolicy which consists of
|
72 |
+
# two augmentation operations,
|
73 |
+
# each of those parametrized as operation, probability, magnitude.
|
74 |
+
# Those two operations are applied sequentially on the image upon call.
|
75 |
+
self.policies = [
|
76 |
+
SubPolicy("posterize", 0.4, 8, "rotate", 0.6, 9, fillcolor),
|
77 |
+
SubPolicy("solarize", 0.6, 5, "autocontrast", 0.6, 5, fillcolor),
|
78 |
+
SubPolicy("equalize", 0.8, 8, "equalize", 0.6, 3, fillcolor),
|
79 |
+
SubPolicy("posterize", 0.6, 7, "posterize", 0.6, 6, fillcolor),
|
80 |
+
SubPolicy("equalize", 0.4, 7, "solarize", 0.2, 4, fillcolor),
|
81 |
+
SubPolicy("equalize", 0.4, 4, "rotate", 0.8, 8, fillcolor),
|
82 |
+
SubPolicy("solarize", 0.6, 3, "equalize", 0.6, 7, fillcolor),
|
83 |
+
SubPolicy("posterize", 0.8, 5, "equalize", 1.0, 2, fillcolor),
|
84 |
+
SubPolicy("rotate", 0.2, 3, "solarize", 0.6, 8, fillcolor),
|
85 |
+
SubPolicy("equalize", 0.6, 8, "posterize", 0.4, 6, fillcolor),
|
86 |
+
SubPolicy("rotate", 0.8, 8, "color", 0.4, 0, fillcolor),
|
87 |
+
SubPolicy("rotate", 0.4, 9, "equalize", 0.6, 2, fillcolor),
|
88 |
+
SubPolicy("equalize", 0.0, 7, "equalize", 0.8, 8, fillcolor),
|
89 |
+
SubPolicy("invert", 0.6, 4, "equalize", 1.0, 8, fillcolor),
|
90 |
+
SubPolicy("color", 0.6, 4, "contrast", 1.0, 8, fillcolor),
|
91 |
+
SubPolicy("rotate", 0.8, 8, "color", 1.0, 2, fillcolor),
|
92 |
+
SubPolicy("color", 0.8, 8, "solarize", 0.8, 7, fillcolor),
|
93 |
+
SubPolicy("sharpness", 0.4, 7, "invert", 0.6, 8, fillcolor),
|
94 |
+
SubPolicy("shearX", 0.6, 5, "equalize", 1.0, 9, fillcolor),
|
95 |
+
SubPolicy("color", 0.4, 0, "equalize", 0.6, 3, fillcolor),
|
96 |
+
SubPolicy("equalize", 0.4, 7, "solarize", 0.2, 4, fillcolor),
|
97 |
+
SubPolicy("solarize", 0.6, 5, "autocontrast", 0.6, 5, fillcolor),
|
98 |
+
SubPolicy("invert", 0.6, 4, "equalize", 1.0, 8, fillcolor),
|
99 |
+
SubPolicy("color", 0.6, 4, "contrast", 1.0, 8, fillcolor),
|
100 |
+
SubPolicy("equalize", 0.8, 8, "equalize", 0.6, 3, fillcolor),
|
101 |
+
]
|
102 |
+
|
103 |
+
def __call__(self, img):
|
104 |
+
"""Define call method for ImageNetPolicy class."""
|
105 |
+
policy_idx = random.randint(0, len(self.policies) - 1)
|
106 |
+
return self.policies[policy_idx](img)
|
107 |
+
|
108 |
+
def __repr__(self):
|
109 |
+
"""Define repr method for ImageNetPolicy class."""
|
110 |
+
return "ImageNetPolicy"
|
111 |
+
|
112 |
+
|
113 |
+
class SubPolicy:
|
114 |
+
"""Definition of a SubPolicy.
|
115 |
+
|
116 |
+
A SubPolicy consists of two augmentation operations,
|
117 |
+
each of those parametrized as operation, probability, magnitude.
|
118 |
+
The two operations are applied sequentially on the image upon call.
|
119 |
+
"""
|
120 |
+
|
121 |
+
def __init__(
|
122 |
+
self,
|
123 |
+
operation1,
|
124 |
+
probability1,
|
125 |
+
magnitude_idx1,
|
126 |
+
operation2,
|
127 |
+
probability2,
|
128 |
+
magnitude_idx2,
|
129 |
+
fillcolor,
|
130 |
+
):
|
131 |
+
"""Initialize a SubPolicy.
|
132 |
+
|
133 |
+
Args:
|
134 |
+
operation1 (str): Key specifying the first augmentation operation.
|
135 |
+
There are fourteen key values altogether (see supported_ops below
|
136 |
+
listing supported operations). probability1 (float): Probability
|
137 |
+
within [0., 1.] of applying the first augmentation operation.
|
138 |
+
magnitude_idx1 (int): Integer specifiying the strength of the first
|
139 |
+
operation as an index further used to derive the magnitude from a
|
140 |
+
range of possible values.
|
141 |
+
operation2 (str): Key specifying the second augmentation operation.
|
142 |
+
probability2 (float): Probability within [0., 1.] of applying the
|
143 |
+
second augmentation operation.
|
144 |
+
magnitude_idx2 (int): Integer specifiying the strength of the
|
145 |
+
second operation as an index further used to derive the magnitude
|
146 |
+
from a range of possible values.
|
147 |
+
fillcolor (tuple): RGB color components of the color to be used for
|
148 |
+
filling.
|
149 |
+
Returns:
|
150 |
+
"""
|
151 |
+
# List of supported operations for operation1 and operation2.
|
152 |
+
supported_ops = [
|
153 |
+
"shearX",
|
154 |
+
"shearY",
|
155 |
+
"translateX",
|
156 |
+
"translateY",
|
157 |
+
"rotate",
|
158 |
+
"color",
|
159 |
+
"posterize",
|
160 |
+
"solarize",
|
161 |
+
"contrast",
|
162 |
+
"sharpness",
|
163 |
+
"brightness",
|
164 |
+
"autocontrast",
|
165 |
+
"equalize",
|
166 |
+
"invert",
|
167 |
+
]
|
168 |
+
assert (operation1 in supported_ops) and (
|
169 |
+
operation2 in supported_ops
|
170 |
+
), "SubPolicy:one of oper1 or oper2 refers to an unsupported operation."
|
171 |
+
|
172 |
+
assert (
|
173 |
+
0.0 <= probability1 <= 1.0 and 0.0 <= probability2 <= 1.0
|
174 |
+
), "SubPolicy: prob1 and prob2 should be within [0., 1.]."
|
175 |
+
|
176 |
+
assert (
|
177 |
+
isinstance(magnitude_idx1, int) and 0 <= magnitude_idx1 <= 10
|
178 |
+
), "SubPolicy: idx1 should be specified as an integer within [0, 10]."
|
179 |
+
|
180 |
+
assert (
|
181 |
+
isinstance(magnitude_idx2, int) and 0 <= magnitude_idx2 <= 10
|
182 |
+
), "SubPolicy: idx2 should be specified as an integer within [0, 10]."
|
183 |
+
|
184 |
+
# Define a dictionary where each key refers to a specific type of
|
185 |
+
# augmentation and the corresponding value is a range of ten possible
|
186 |
+
# magnitude values for that augmentation.
|
187 |
+
num_levels = _MAX_LEVEL + 1
|
188 |
+
ranges = {
|
189 |
+
"shearX": np.linspace(0, 0.3, num_levels),
|
190 |
+
"shearY": np.linspace(0, 0.3, num_levels),
|
191 |
+
"translateX": np.linspace(0, 150 / 331, num_levels),
|
192 |
+
"translateY": np.linspace(0, 150 / 331, num_levels),
|
193 |
+
"rotate": np.linspace(0, 30, num_levels),
|
194 |
+
"color": np.linspace(0.0, 0.9, num_levels),
|
195 |
+
"posterize": np.round(np.linspace(8, 4, num_levels), 0).astype(
|
196 |
+
np.int
|
197 |
+
),
|
198 |
+
"solarize": np.linspace(256, 0, num_levels), # range [0, 256]
|
199 |
+
"contrast": np.linspace(0.0, 0.9, num_levels),
|
200 |
+
"sharpness": np.linspace(0.0, 0.9, num_levels),
|
201 |
+
"brightness": np.linspace(0.0, 0.9, num_levels),
|
202 |
+
"autocontrast": [0]
|
203 |
+
* num_levels, # This augmentation doesn't use magnitude parameter.
|
204 |
+
"equalize": [0]
|
205 |
+
* num_levels, # This augmentation doesn't use magnitude parameter.
|
206 |
+
"invert": [0]
|
207 |
+
* num_levels, # This augmentation doesn't use magnitude parameter.
|
208 |
+
}
|
209 |
+
|
210 |
+
def rotate_with_fill(img, magnitude):
|
211 |
+
"""Define rotation transformation with fill.
|
212 |
+
|
213 |
+
The input image is first rotated, then it is blended together with
|
214 |
+
a gray mask of the same size. Note that fillcolor as defined
|
215 |
+
elsewhere in this module doesn't apply here.
|
216 |
+
|
217 |
+
Args:
|
218 |
+
magnitude (float): rotation angle in degrees.
|
219 |
+
Returns:
|
220 |
+
rotated_filled (PIL Image): rotated image with gray filling for
|
221 |
+
disoccluded areas unveiled by the rotation.
|
222 |
+
"""
|
223 |
+
rotated = img.convert("RGBA").rotate(magnitude)
|
224 |
+
rotated_filled = Image.composite(
|
225 |
+
rotated, Image.new("RGBA", rotated.size, (128,) * 4), rotated
|
226 |
+
)
|
227 |
+
return rotated_filled.convert(img.mode)
|
228 |
+
|
229 |
+
# Define a dictionary of augmentation functions where each key refers
|
230 |
+
# to a specific type of augmentation and the corresponding value defines
|
231 |
+
# the augmentation itself using a lambda function.
|
232 |
+
# pylint: disable=unnecessary-lambda
|
233 |
+
func_dict = {
|
234 |
+
"shearX": lambda img, magnitude: img.transform(
|
235 |
+
img.size,
|
236 |
+
Image.AFFINE,
|
237 |
+
(1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),
|
238 |
+
Image.BICUBIC,
|
239 |
+
fillcolor=fillcolor,
|
240 |
+
),
|
241 |
+
"shearY": lambda img, magnitude: img.transform(
|
242 |
+
img.size,
|
243 |
+
Image.AFFINE,
|
244 |
+
(1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),
|
245 |
+
Image.BICUBIC,
|
246 |
+
fillcolor=fillcolor,
|
247 |
+
),
|
248 |
+
"translateX": lambda img, magnitude: img.transform(
|
249 |
+
img.size,
|
250 |
+
Image.AFFINE,
|
251 |
+
(
|
252 |
+
1,
|
253 |
+
0,
|
254 |
+
magnitude * img.size[0] * random.choice([-1, 1]),
|
255 |
+
0,
|
256 |
+
1,
|
257 |
+
0,
|
258 |
+
),
|
259 |
+
fillcolor=fillcolor,
|
260 |
+
),
|
261 |
+
"translateY": lambda img, magnitude: img.transform(
|
262 |
+
img.size,
|
263 |
+
Image.AFFINE,
|
264 |
+
(
|
265 |
+
1,
|
266 |
+
0,
|
267 |
+
0,
|
268 |
+
0,
|
269 |
+
1,
|
270 |
+
magnitude * img.size[1] * random.choice([-1, 1]),
|
271 |
+
),
|
272 |
+
fillcolor=fillcolor,
|
273 |
+
),
|
274 |
+
"rotate": lambda img, magnitude: rotate_with_fill(img, magnitude),
|
275 |
+
"color": lambda img, magnitude: ImageEnhance.Color(img).enhance(
|
276 |
+
1 + magnitude * random.choice([-1, 1])
|
277 |
+
),
|
278 |
+
"posterize": lambda img, magnitude: ImageOps.posterize(
|
279 |
+
img, magnitude
|
280 |
+
),
|
281 |
+
"solarize": lambda img, magnitude: ImageOps.solarize(
|
282 |
+
img, magnitude
|
283 |
+
),
|
284 |
+
"contrast": lambda img, magnitude: ImageEnhance.Contrast(
|
285 |
+
img
|
286 |
+
).enhance(1 + magnitude * random.choice([-1, 1])),
|
287 |
+
"sharpness": lambda img, magnitude: ImageEnhance.Sharpness(
|
288 |
+
img
|
289 |
+
).enhance(1 + magnitude * random.choice([-1, 1])),
|
290 |
+
"brightness": lambda img, magnitude: ImageEnhance.Brightness(
|
291 |
+
img
|
292 |
+
).enhance(1 + magnitude * random.choice([-1, 1])),
|
293 |
+
"autocontrast": lambda img, magnitude: ImageOps.autocontrast(img),
|
294 |
+
"equalize": lambda img, magnitude: ImageOps.equalize(img),
|
295 |
+
"invert": lambda img, magnitude: ImageOps.invert(img),
|
296 |
+
}
|
297 |
+
|
298 |
+
# Store probability, function and magnitude of the first augmentation
|
299 |
+
# for the sub-policy.
|
300 |
+
self.probability1 = probability1
|
301 |
+
self.operation1 = func_dict[operation1]
|
302 |
+
self.magnitude1 = ranges[operation1][magnitude_idx1]
|
303 |
+
|
304 |
+
# Store probability, function and magnitude of the second augmentation
|
305 |
+
# for the sub-policy.
|
306 |
+
self.probability2 = probability2
|
307 |
+
self.operation2 = func_dict[operation2]
|
308 |
+
self.magnitude2 = ranges[operation2][magnitude_idx2]
|
309 |
+
|
310 |
+
def __call__(self, img):
|
311 |
+
"""Define call method for SubPolicy class."""
|
312 |
+
# Randomly apply operation 1.
|
313 |
+
if random.random() < self.probability1:
|
314 |
+
img = self.operation1(img, self.magnitude1)
|
315 |
+
|
316 |
+
# Randomly apply operation 2.
|
317 |
+
if random.random() < self.probability2:
|
318 |
+
img = self.operation2(img, self.magnitude2)
|
319 |
+
|
320 |
+
return img
|
megatron/data/bert_dataset.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""BERT Style dataset."""
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
import torch
|
20 |
+
|
21 |
+
from megatron import (
|
22 |
+
get_args,
|
23 |
+
get_tokenizer,
|
24 |
+
mpu,
|
25 |
+
print_rank_0
|
26 |
+
)
|
27 |
+
from megatron.data.dataset_utils import (
|
28 |
+
get_samples_mapping,
|
29 |
+
get_a_and_b_segments,
|
30 |
+
truncate_segments,
|
31 |
+
create_tokens_and_tokentypes,
|
32 |
+
create_masked_lm_predictions
|
33 |
+
)
|
34 |
+
|
35 |
+
class DummyBertDataset(torch.utils.data.Dataset):
|
36 |
+
def __init__(self, name, num_samples, max_seq_length):
|
37 |
+
self.name = name
|
38 |
+
self.num_samples = num_samples
|
39 |
+
self.max_seq_length = max_seq_length
|
40 |
+
self.np_rng = np.random.RandomState(seed=0)
|
41 |
+
# self.token_nps = np_rng.randint(1000, 2000, (self.num_samples, 512))
|
42 |
+
# Vocab stuff.
|
43 |
+
tokenizer = get_tokenizer()
|
44 |
+
self.vocab_id_list = list(tokenizer.inv_vocab.keys())
|
45 |
+
self.vocab_id_to_token_dict = tokenizer.inv_vocab
|
46 |
+
self.cls_id = tokenizer.cls
|
47 |
+
self.sep_id = tokenizer.sep
|
48 |
+
self.mask_id = tokenizer.mask
|
49 |
+
self.pad_id = tokenizer.pad
|
50 |
+
|
51 |
+
def __len__(self):
|
52 |
+
return self.num_samples
|
53 |
+
|
54 |
+
def __getitem__(self, idx):
|
55 |
+
tokens = self.np_rng.randint(1000, 2000, self.max_seq_length)
|
56 |
+
masked_position = np.arange(int(tokens.shape[0] * 0.15))
|
57 |
+
tokens = tokens.astype(np.int64)
|
58 |
+
labels = tokens[masked_position]
|
59 |
+
label_np = np.full_like(tokens, -1)
|
60 |
+
label_np[masked_position] = labels
|
61 |
+
tokens[masked_position] = self.mask_id
|
62 |
+
loss_mask_np = np.zeros_like(tokens)
|
63 |
+
loss_mask_np[masked_position] = 1
|
64 |
+
train_sample = {
|
65 |
+
'text': tokens,
|
66 |
+
'types': np.zeros_like(tokens),
|
67 |
+
'labels': label_np,
|
68 |
+
'is_random': 0,
|
69 |
+
'loss_mask': loss_mask_np,
|
70 |
+
'padding_mask': np.ones_like(tokens),
|
71 |
+
'truncated': 0
|
72 |
+
}
|
73 |
+
return train_sample
|
74 |
+
|
75 |
+
class BertDataset(torch.utils.data.Dataset):
|
76 |
+
|
77 |
+
def __init__(self, name, indexed_dataset, data_prefix,
|
78 |
+
num_epochs, max_num_samples, masked_lm_prob,
|
79 |
+
max_seq_length, short_seq_prob, seed, binary_head):
|
80 |
+
|
81 |
+
# Params to store.
|
82 |
+
self.name = name
|
83 |
+
self.seed = seed
|
84 |
+
self.masked_lm_prob = masked_lm_prob
|
85 |
+
self.max_seq_length = max_seq_length
|
86 |
+
self.binary_head = binary_head
|
87 |
+
|
88 |
+
# Dataset.
|
89 |
+
self.indexed_dataset = indexed_dataset
|
90 |
+
|
91 |
+
# Build the samples mapping.
|
92 |
+
self.samples_mapping = get_samples_mapping(self.indexed_dataset,
|
93 |
+
data_prefix,
|
94 |
+
num_epochs,
|
95 |
+
max_num_samples,
|
96 |
+
self.max_seq_length - 3, # account for added tokens
|
97 |
+
short_seq_prob,
|
98 |
+
self.seed,
|
99 |
+
self.name,
|
100 |
+
self.binary_head)
|
101 |
+
|
102 |
+
# Vocab stuff.
|
103 |
+
tokenizer = get_tokenizer()
|
104 |
+
self.vocab_id_list = list(tokenizer.inv_vocab.keys())
|
105 |
+
self.vocab_id_to_token_dict = tokenizer.inv_vocab
|
106 |
+
self.cls_id = tokenizer.cls
|
107 |
+
self.sep_id = tokenizer.sep
|
108 |
+
self.mask_id = tokenizer.mask
|
109 |
+
self.pad_id = tokenizer.pad
|
110 |
+
|
111 |
+
def __len__(self):
|
112 |
+
return self.samples_mapping.shape[0]
|
113 |
+
|
114 |
+
def __getitem__(self, idx):
|
115 |
+
start_idx, end_idx, seq_length = self.samples_mapping[idx]
|
116 |
+
sample = [self.indexed_dataset[i] for i in range(start_idx, end_idx)]
|
117 |
+
# Note that this rng state should be numpy and not python since
|
118 |
+
# python randint is inclusive whereas the numpy one is exclusive.
|
119 |
+
# We % 2**32 since numpy requres the seed to be between 0 and 2**32 - 1
|
120 |
+
np_rng = np.random.RandomState(seed=((self.seed + idx) % 2**32))
|
121 |
+
return build_training_sample(sample, seq_length,
|
122 |
+
self.max_seq_length, # needed for padding
|
123 |
+
self.vocab_id_list,
|
124 |
+
self.vocab_id_to_token_dict,
|
125 |
+
self.cls_id, self.sep_id,
|
126 |
+
self.mask_id, self.pad_id,
|
127 |
+
self.masked_lm_prob, np_rng,
|
128 |
+
self.binary_head)
|
129 |
+
|
130 |
+
|
131 |
+
|
132 |
+
|
133 |
+
def build_training_sample(sample,
|
134 |
+
target_seq_length, max_seq_length,
|
135 |
+
vocab_id_list, vocab_id_to_token_dict,
|
136 |
+
cls_id, sep_id, mask_id, pad_id,
|
137 |
+
masked_lm_prob, np_rng, binary_head):
|
138 |
+
"""Biuld training sample.
|
139 |
+
|
140 |
+
Arguments:
|
141 |
+
sample: A list of sentences in which each sentence is a list token ids.
|
142 |
+
target_seq_length: Desired sequence length.
|
143 |
+
max_seq_length: Maximum length of the sequence. All values are padded to
|
144 |
+
this length.
|
145 |
+
vocab_id_list: List of vocabulary ids. Used to pick a random id.
|
146 |
+
vocab_id_to_token_dict: A dictionary from vocab ids to text tokens.
|
147 |
+
cls_id: Start of example id.
|
148 |
+
sep_id: Separator id.
|
149 |
+
mask_id: Mask token id.
|
150 |
+
pad_id: Padding token id.
|
151 |
+
masked_lm_prob: Probability to mask tokens.
|
152 |
+
np_rng: Random number genenrator. Note that this rng state should be
|
153 |
+
numpy and not python since python randint is inclusive for
|
154 |
+
the opper bound whereas the numpy one is exclusive.
|
155 |
+
"""
|
156 |
+
|
157 |
+
if binary_head:
|
158 |
+
# We assume that we have at least two sentences in the sample
|
159 |
+
assert len(sample) > 1
|
160 |
+
assert target_seq_length <= max_seq_length
|
161 |
+
|
162 |
+
# Divide sample into two segments (A and B).
|
163 |
+
if binary_head:
|
164 |
+
tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample,
|
165 |
+
np_rng)
|
166 |
+
else:
|
167 |
+
tokens_a = []
|
168 |
+
for j in range(len(sample)):
|
169 |
+
tokens_a.extend(sample[j])
|
170 |
+
tokens_b = []
|
171 |
+
is_next_random = False
|
172 |
+
|
173 |
+
# Truncate to `target_sequence_length`.
|
174 |
+
max_num_tokens = target_seq_length
|
175 |
+
truncated = truncate_segments(tokens_a, tokens_b, len(tokens_a),
|
176 |
+
len(tokens_b), max_num_tokens, np_rng)
|
177 |
+
|
178 |
+
# Build tokens and toketypes.
|
179 |
+
tokens, tokentypes = create_tokens_and_tokentypes(tokens_a, tokens_b,
|
180 |
+
cls_id, sep_id)
|
181 |
+
|
182 |
+
# Masking.
|
183 |
+
max_predictions_per_seq = masked_lm_prob * max_num_tokens
|
184 |
+
(tokens, masked_positions, masked_labels, _, _) = create_masked_lm_predictions(
|
185 |
+
tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob,
|
186 |
+
cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng)
|
187 |
+
|
188 |
+
# Padding.
|
189 |
+
tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np \
|
190 |
+
= pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
|
191 |
+
masked_labels, pad_id, max_seq_length)
|
192 |
+
|
193 |
+
train_sample = {
|
194 |
+
'text': tokens_np,
|
195 |
+
'types': tokentypes_np,
|
196 |
+
'labels': labels_np,
|
197 |
+
'is_random': int(is_next_random),
|
198 |
+
'loss_mask': loss_mask_np,
|
199 |
+
'padding_mask': padding_mask_np,
|
200 |
+
'truncated': int(truncated)}
|
201 |
+
return train_sample
|
202 |
+
|
203 |
+
|
204 |
+
def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
|
205 |
+
masked_labels, pad_id, max_seq_length):
|
206 |
+
"""Pad sequences and convert them to numpy."""
|
207 |
+
|
208 |
+
# Some checks.
|
209 |
+
num_tokens = len(tokens)
|
210 |
+
padding_length = max_seq_length - num_tokens
|
211 |
+
assert padding_length >= 0
|
212 |
+
assert len(tokentypes) == num_tokens
|
213 |
+
assert len(masked_positions) == len(masked_labels)
|
214 |
+
|
215 |
+
# Tokens and token types.
|
216 |
+
filler = [pad_id] * padding_length
|
217 |
+
tokens_np = np.array(tokens + filler, dtype=np.int64)
|
218 |
+
tokentypes_np = np.array(tokentypes + filler, dtype=np.int64)
|
219 |
+
|
220 |
+
# Padding mask.
|
221 |
+
padding_mask_np = np.array([1] * num_tokens + [0] * padding_length,
|
222 |
+
dtype=np.int64)
|
223 |
+
|
224 |
+
# Lables and loss mask.
|
225 |
+
labels = [-1] * max_seq_length
|
226 |
+
loss_mask = [0] * max_seq_length
|
227 |
+
for i in range(len(masked_positions)):
|
228 |
+
assert masked_positions[i] < num_tokens
|
229 |
+
labels[masked_positions[i]] = masked_labels[i]
|
230 |
+
loss_mask[masked_positions[i]] = 1
|
231 |
+
labels_np = np.array(labels, dtype=np.int64)
|
232 |
+
loss_mask_np = np.array(loss_mask, dtype=np.int64)
|
233 |
+
|
234 |
+
return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np
|
megatron/data/biencoder_dataset_utils.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from megatron import get_args, get_tokenizer, mpu, print_rank_0
|
8 |
+
from megatron.data.dataset_utils import create_masked_lm_predictions, \
|
9 |
+
pad_and_convert_to_numpy
|
10 |
+
from megatron.data.data_samplers import MegatronPretrainingSampler
|
11 |
+
|
12 |
+
def make_attention_mask(source_block, target_block):
|
13 |
+
"""
|
14 |
+
Returns a 2-dimensional (2-D) attention mask
|
15 |
+
:param source_block: 1-D array
|
16 |
+
:param target_block: 1-D array
|
17 |
+
"""
|
18 |
+
mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1)
|
19 |
+
mask = mask.astype(np.int64)
|
20 |
+
# (source_length, target_length)
|
21 |
+
return mask
|
22 |
+
|
23 |
+
def get_one_epoch_dataloader(dataset, micro_batch_size=None):
|
24 |
+
"""Specifically one epoch to be used in an indexing job."""
|
25 |
+
args = get_args()
|
26 |
+
|
27 |
+
if micro_batch_size is None:
|
28 |
+
micro_batch_size = args.micro_batch_size
|
29 |
+
num_workers = args.num_workers
|
30 |
+
|
31 |
+
# Use megatron's sampler with consumed samples set to 0 as
|
32 |
+
# this is only for evaluation and don't intend to resume half way.
|
33 |
+
# Also, set the drop last to false as don't intend to remove
|
34 |
+
# the last batch
|
35 |
+
batch_sampler = MegatronPretrainingSampler(
|
36 |
+
total_samples=len(dataset),
|
37 |
+
consumed_samples=0,
|
38 |
+
micro_batch_size=args.micro_batch_size,
|
39 |
+
data_parallel_rank=mpu.get_data_parallel_rank(),
|
40 |
+
data_parallel_size=mpu.get_data_parallel_world_size(),
|
41 |
+
drop_last=False)
|
42 |
+
|
43 |
+
return torch.utils.data.DataLoader(dataset,
|
44 |
+
batch_sampler=batch_sampler,
|
45 |
+
num_workers=num_workers,
|
46 |
+
pin_memory=True)
|
47 |
+
|
48 |
+
|
49 |
+
def get_ict_batch(data_iterator):
|
50 |
+
# Items and their type.
|
51 |
+
keys = ['query_tokens', 'query_mask',
|
52 |
+
'context_tokens', 'context_mask', 'block_data']
|
53 |
+
datatype = torch.int64
|
54 |
+
|
55 |
+
# Broadcast data.
|
56 |
+
if data_iterator is None:
|
57 |
+
data = None
|
58 |
+
else:
|
59 |
+
data = next(data_iterator)
|
60 |
+
data_b = mpu.broadcast_data(keys, data, datatype)
|
61 |
+
|
62 |
+
# Unpack.
|
63 |
+
query_tokens = data_b['query_tokens'].long()
|
64 |
+
query_mask = data_b['query_mask'] < 0.5
|
65 |
+
context_tokens = data_b['context_tokens'].long()
|
66 |
+
context_mask = data_b['context_mask'] < 0.5
|
67 |
+
block_indices = data_b['block_data'].long()
|
68 |
+
|
69 |
+
return query_tokens, query_mask,\
|
70 |
+
context_tokens, context_mask, block_indices
|
71 |
+
|
72 |
+
|
73 |
+
def join_str_list(str_list):
|
74 |
+
"""Join a list of strings, handling spaces appropriately"""
|
75 |
+
result = ""
|
76 |
+
for s in str_list:
|
77 |
+
if s.startswith("##"):
|
78 |
+
result += s[2:]
|
79 |
+
else:
|
80 |
+
result += " " + s
|
81 |
+
return result
|
82 |
+
|
83 |
+
|
84 |
+
class BlockSampleData(object):
|
85 |
+
"""A struct for fully describing a fixed-size block of data as used in REALM
|
86 |
+
|
87 |
+
:param start_idx: for first sentence of the block
|
88 |
+
:param end_idx: for last sentence of the block (may be partially truncated in sample construction)
|
89 |
+
:param doc_idx: the index of the document from which the block comes in the original indexed dataset
|
90 |
+
:param block_idx: a unique integer identifier given to every block.
|
91 |
+
"""
|
92 |
+
def __init__(self, start_idx, end_idx, doc_idx, block_idx):
|
93 |
+
self.start_idx = start_idx
|
94 |
+
self.end_idx = end_idx
|
95 |
+
self.doc_idx = doc_idx
|
96 |
+
self.block_idx = block_idx
|
97 |
+
|
98 |
+
def as_array(self):
|
99 |
+
return np.array([self.start_idx, self.end_idx, self.doc_idx, self.block_idx]).astype(np.int64)
|
100 |
+
|
101 |
+
def as_tuple(self):
|
102 |
+
return self.start_idx, self.end_idx, self.doc_idx, self.block_idx
|
103 |
+
|
104 |
+
|
105 |
+
class BlockSamplesMapping(object):
|
106 |
+
def __init__(self, mapping_array):
|
107 |
+
# make sure that the array is compatible with BlockSampleData
|
108 |
+
assert mapping_array.shape[1] == 4
|
109 |
+
self.mapping_array = mapping_array
|
110 |
+
|
111 |
+
def __len__(self):
|
112 |
+
return self.mapping_array.shape[0]
|
113 |
+
|
114 |
+
def __getitem__(self, idx):
|
115 |
+
"""Get the data associated with an indexed sample."""
|
116 |
+
sample_data = BlockSampleData(*self.mapping_array[idx])
|
117 |
+
return sample_data
|
118 |
+
|
119 |
+
|
120 |
+
def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epochs,
|
121 |
+
max_num_samples, max_seq_length, seed, name, use_one_sent_docs=False):
|
122 |
+
"""Get samples mapping for a dataset over fixed size blocks. This function also requires
|
123 |
+
a dataset of the titles for the source documents since their lengths must be taken into account.
|
124 |
+
|
125 |
+
:return: samples_mapping (BlockSamplesMapping)
|
126 |
+
"""
|
127 |
+
|
128 |
+
if not num_epochs:
|
129 |
+
if not max_num_samples:
|
130 |
+
raise ValueError("Need to specify either max_num_samples "
|
131 |
+
"or num_epochs")
|
132 |
+
num_epochs = np.iinfo(np.int32).max - 1
|
133 |
+
if not max_num_samples:
|
134 |
+
max_num_samples = np.iinfo(np.int64).max - 1
|
135 |
+
|
136 |
+
# Filename of the index mapping
|
137 |
+
indexmap_filename = data_prefix
|
138 |
+
indexmap_filename += '_{}_indexmap'.format(name)
|
139 |
+
if num_epochs != (np.iinfo(np.int32).max - 1):
|
140 |
+
indexmap_filename += '_{}ep'.format(num_epochs)
|
141 |
+
if max_num_samples != (np.iinfo(np.int64).max - 1):
|
142 |
+
indexmap_filename += '_{}mns'.format(max_num_samples)
|
143 |
+
indexmap_filename += '_{}msl'.format(max_seq_length)
|
144 |
+
indexmap_filename += '_{}s'.format(seed)
|
145 |
+
if use_one_sent_docs:
|
146 |
+
indexmap_filename += '_1sentok'
|
147 |
+
indexmap_filename += '.npy'
|
148 |
+
|
149 |
+
# Build the indexed mapping if not exist.
|
150 |
+
if mpu.get_data_parallel_rank() == 0 and \
|
151 |
+
not os.path.isfile(indexmap_filename):
|
152 |
+
print(' > WARNING: could not find index map file {}, building '
|
153 |
+
'the indices on rank 0 ...'.format(indexmap_filename))
|
154 |
+
|
155 |
+
# Make sure the types match the helpers input types.
|
156 |
+
assert block_dataset.doc_idx.dtype == np.int64
|
157 |
+
assert block_dataset.sizes.dtype == np.int32
|
158 |
+
|
159 |
+
# Build samples mapping
|
160 |
+
verbose = torch.distributed.get_rank() == 0
|
161 |
+
start_time = time.time()
|
162 |
+
print_rank_0(' > building samples index mapping for {} ...'.format(
|
163 |
+
name))
|
164 |
+
|
165 |
+
from megatron.data import helpers
|
166 |
+
mapping_array = helpers.build_blocks_mapping(
|
167 |
+
block_dataset.doc_idx,
|
168 |
+
block_dataset.sizes,
|
169 |
+
title_dataset.sizes,
|
170 |
+
num_epochs,
|
171 |
+
max_num_samples,
|
172 |
+
max_seq_length - 3, # account for added tokens
|
173 |
+
seed,
|
174 |
+
verbose,
|
175 |
+
use_one_sent_docs)
|
176 |
+
|
177 |
+
|
178 |
+
print_rank_0(' > done building samples index mapping')
|
179 |
+
np.save(indexmap_filename, mapping_array, allow_pickle=True)
|
180 |
+
print_rank_0(' > saved the index mapping in {}'.format(
|
181 |
+
indexmap_filename))
|
182 |
+
# Make sure all the ranks have built the mapping
|
183 |
+
print_rank_0(' > elapsed time to build and save samples mapping '
|
184 |
+
'(seconds): {:4f}'.format(
|
185 |
+
time.time() - start_time))
|
186 |
+
|
187 |
+
# This should be a barrier but nccl barrier assumes
|
188 |
+
# device_index=rank which is not the case for model
|
189 |
+
# parallel case
|
190 |
+
counts = torch.cuda.LongTensor([1])
|
191 |
+
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
|
192 |
+
assert counts[0].item() == torch.distributed.get_world_size(
|
193 |
+
group=mpu.get_data_parallel_group())
|
194 |
+
|
195 |
+
# Load indexed dataset.
|
196 |
+
print_rank_0(' > loading indexed mapping from {}'.format(
|
197 |
+
indexmap_filename))
|
198 |
+
start_time = time.time()
|
199 |
+
|
200 |
+
mapping_array = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r')
|
201 |
+
samples_mapping = BlockSamplesMapping(mapping_array)
|
202 |
+
|
203 |
+
print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
|
204 |
+
time.time() - start_time))
|
205 |
+
print_rank_0(' total number of samples: {}'.format(
|
206 |
+
mapping_array.shape[0]))
|
207 |
+
|
208 |
+
return samples_mapping
|
megatron/data/blendable_dataset.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Blendable dataset."""
|
17 |
+
|
18 |
+
import time
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
import torch
|
22 |
+
|
23 |
+
from megatron import print_rank_0
|
24 |
+
from megatron import mpu
|
25 |
+
|
26 |
+
|
27 |
+
class BlendableDataset(torch.utils.data.Dataset):
|
28 |
+
|
29 |
+
|
30 |
+
def __init__(self, datasets, weights):
|
31 |
+
|
32 |
+
self.datasets = datasets
|
33 |
+
num_datasets = len(datasets)
|
34 |
+
assert num_datasets == len(weights)
|
35 |
+
|
36 |
+
self.size = 0
|
37 |
+
for dataset in self.datasets:
|
38 |
+
self.size += len(dataset)
|
39 |
+
|
40 |
+
# Normalize weights.
|
41 |
+
weights = np.array(weights, dtype=np.float64)
|
42 |
+
sum_weights = np.sum(weights)
|
43 |
+
assert sum_weights > 0.0
|
44 |
+
weights /= sum_weights
|
45 |
+
|
46 |
+
# Build indecies.
|
47 |
+
start_time = time.time()
|
48 |
+
assert num_datasets < 255
|
49 |
+
self.dataset_index = np.zeros(self.size, dtype=np.uint8)
|
50 |
+
self.dataset_sample_index = np.zeros(self.size, dtype=np.int64)
|
51 |
+
|
52 |
+
from megatron.data import helpers
|
53 |
+
helpers.build_blending_indices(self.dataset_index,
|
54 |
+
self.dataset_sample_index,
|
55 |
+
weights, num_datasets, self.size,
|
56 |
+
torch.distributed.get_rank() == 0)
|
57 |
+
print_rank_0('> elapsed time for building blendable dataset indices: '
|
58 |
+
'{:.2f} (sec)'.format(time.time() - start_time))
|
59 |
+
|
60 |
+
|
61 |
+
def __len__(self):
|
62 |
+
return self.size
|
63 |
+
|
64 |
+
|
65 |
+
def __getitem__(self, idx):
|
66 |
+
dataset_idx = self.dataset_index[idx]
|
67 |
+
sample_idx = self.dataset_sample_index[idx]
|
68 |
+
return self.datasets[dataset_idx][sample_idx]
|
megatron/data/data_samplers.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Dataloaders."""
|
17 |
+
|
18 |
+
|
19 |
+
import random
|
20 |
+
import torch
|
21 |
+
import numpy as np
|
22 |
+
from torch.utils.data import Dataset
|
23 |
+
from megatron import get_args
|
24 |
+
from megatron import mpu
|
25 |
+
|
26 |
+
|
27 |
+
def build_pretraining_data_loader(dataset, consumed_samples):
|
28 |
+
"""Buld dataloader given an input dataset."""
|
29 |
+
|
30 |
+
if dataset is None:
|
31 |
+
return None
|
32 |
+
args = get_args()
|
33 |
+
|
34 |
+
# Megatron sampler
|
35 |
+
if args.dataloader_type == 'single':
|
36 |
+
batch_sampler = MegatronPretrainingSampler(
|
37 |
+
total_samples=len(dataset),
|
38 |
+
consumed_samples=consumed_samples,
|
39 |
+
micro_batch_size=args.micro_batch_size,
|
40 |
+
data_parallel_rank=mpu.get_data_parallel_rank(),
|
41 |
+
data_parallel_size=mpu.get_data_parallel_world_size())
|
42 |
+
elif args.dataloader_type == 'cyclic':
|
43 |
+
batch_sampler = MegatronPretrainingRandomSampler(
|
44 |
+
dataset,
|
45 |
+
total_samples=len(dataset),
|
46 |
+
consumed_samples=consumed_samples,
|
47 |
+
micro_batch_size=args.micro_batch_size,
|
48 |
+
data_parallel_rank=mpu.get_data_parallel_rank(),
|
49 |
+
data_parallel_size=mpu.get_data_parallel_world_size(),
|
50 |
+
data_sharding=args.data_sharding)
|
51 |
+
else:
|
52 |
+
raise Exception('{} dataloader type is not supported.'.format(
|
53 |
+
args.dataloader_type))
|
54 |
+
|
55 |
+
# Torch dataloader.
|
56 |
+
return torch.utils.data.DataLoader(dataset,
|
57 |
+
batch_sampler=batch_sampler,
|
58 |
+
num_workers=args.num_workers,
|
59 |
+
pin_memory=True)
|
60 |
+
|
61 |
+
class MegatronPretrainingSampler:
|
62 |
+
|
63 |
+
def __init__(self, total_samples, consumed_samples, micro_batch_size,
|
64 |
+
data_parallel_rank, data_parallel_size, drop_last=True):
|
65 |
+
# Keep a copy of input params for later use.
|
66 |
+
self.total_samples = total_samples
|
67 |
+
self.consumed_samples = consumed_samples
|
68 |
+
self.micro_batch_size = micro_batch_size
|
69 |
+
self.data_parallel_rank = data_parallel_rank
|
70 |
+
self.micro_batch_times_data_parallel_size = \
|
71 |
+
self.micro_batch_size * data_parallel_size
|
72 |
+
self.drop_last = drop_last
|
73 |
+
|
74 |
+
# Sanity checks.
|
75 |
+
assert self.total_samples > 0, \
|
76 |
+
'no sample to consume: {}'.format(self.total_samples)
|
77 |
+
assert self.consumed_samples < self.total_samples, \
|
78 |
+
'no samples left to consume: {}, {}'.format(self.consumed_samples,
|
79 |
+
self.total_samples)
|
80 |
+
assert self.micro_batch_size > 0
|
81 |
+
assert data_parallel_size > 0
|
82 |
+
assert self.data_parallel_rank < data_parallel_size, \
|
83 |
+
'data_parallel_rank should be smaller than data size: {}, ' \
|
84 |
+
'{}'.format(self.data_parallel_rank, data_parallel_size)
|
85 |
+
|
86 |
+
def __len__(self):
|
87 |
+
return self.total_samples
|
88 |
+
|
89 |
+
def get_start_end_idx(self):
|
90 |
+
start_idx = self.data_parallel_rank * self.micro_batch_size
|
91 |
+
end_idx = start_idx + self.micro_batch_size
|
92 |
+
return start_idx, end_idx
|
93 |
+
|
94 |
+
def __iter__(self):
|
95 |
+
batch = []
|
96 |
+
# Last batch will be dropped if drop_last is not set False
|
97 |
+
for idx in range(self.consumed_samples, self.total_samples):
|
98 |
+
batch.append(idx)
|
99 |
+
if len(batch) == self.micro_batch_times_data_parallel_size:
|
100 |
+
start_idx, end_idx = self.get_start_end_idx()
|
101 |
+
yield batch[start_idx:end_idx]
|
102 |
+
batch = []
|
103 |
+
|
104 |
+
# Check the last partial batch and see drop_last is set
|
105 |
+
if len(batch) > 0 and not self.drop_last:
|
106 |
+
start_idx, end_idx = self.get_start_end_idx()
|
107 |
+
yield batch[start_idx:end_idx]
|
108 |
+
|
109 |
+
|
110 |
+
class RandomSeedDataset(Dataset):
|
111 |
+
|
112 |
+
def __init__(self, dataset):
|
113 |
+
args = get_args()
|
114 |
+
self.base_seed = args.seed
|
115 |
+
self.curr_seed = args.seed
|
116 |
+
self.dataset = dataset
|
117 |
+
|
118 |
+
def __len__(self):
|
119 |
+
return len(self.dataset)
|
120 |
+
|
121 |
+
def set_epoch(self, epoch):
|
122 |
+
self.curr_seed = self.base_seed + epoch
|
123 |
+
|
124 |
+
def __getitem__(self, idx):
|
125 |
+
seed = idx + self.curr_seed
|
126 |
+
torch.manual_seed(seed)
|
127 |
+
random.seed(seed)
|
128 |
+
np.random.seed(seed)
|
129 |
+
return self.dataset[idx]
|
130 |
+
|
131 |
+
|
132 |
+
class MegatronPretrainingRandomSampler:
|
133 |
+
|
134 |
+
def __init__(self, dataset, total_samples, consumed_samples, micro_batch_size,
|
135 |
+
data_parallel_rank, data_parallel_size, data_sharding):
|
136 |
+
# Keep a copy of input params for later use.
|
137 |
+
self.dataset = dataset
|
138 |
+
self.total_samples = total_samples
|
139 |
+
self.consumed_samples = consumed_samples
|
140 |
+
self.micro_batch_size = micro_batch_size
|
141 |
+
self.data_parallel_rank = data_parallel_rank
|
142 |
+
self.data_parallel_size = data_parallel_size
|
143 |
+
self.data_sharding = data_sharding
|
144 |
+
self.micro_batch_times_data_parallel_size = \
|
145 |
+
self.micro_batch_size * data_parallel_size
|
146 |
+
self.last_batch_size = \
|
147 |
+
self.total_samples % self.micro_batch_times_data_parallel_size
|
148 |
+
|
149 |
+
# Sanity checks.
|
150 |
+
assert self.total_samples > 0, \
|
151 |
+
'no sample to consume: {}'.format(self.total_samples)
|
152 |
+
assert self.micro_batch_size > 0
|
153 |
+
assert data_parallel_size > 0
|
154 |
+
assert self.data_parallel_rank < data_parallel_size, \
|
155 |
+
'data_parallel_rank should be smaller than data size: {}, ' \
|
156 |
+
'{}'.format(self.data_parallel_rank, data_parallel_size)
|
157 |
+
|
158 |
+
def __len__(self):
|
159 |
+
return self.total_samples
|
160 |
+
|
161 |
+
def __iter__(self):
|
162 |
+
active_total_samples = self.total_samples - self.last_batch_size
|
163 |
+
self.epoch = self.consumed_samples // active_total_samples
|
164 |
+
current_epoch_samples = self.consumed_samples % active_total_samples
|
165 |
+
assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0
|
166 |
+
|
167 |
+
if isinstance(self.dataset, RandomSeedDataset):
|
168 |
+
self.dataset.set_epoch(self.epoch)
|
169 |
+
|
170 |
+
# data sharding and random sampling
|
171 |
+
if self.data_sharding:
|
172 |
+
bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \
|
173 |
+
* self.micro_batch_size
|
174 |
+
bucket_offset = current_epoch_samples // self.data_parallel_size
|
175 |
+
start_idx = self.data_parallel_rank * bucket_size
|
176 |
+
|
177 |
+
g = torch.Generator()
|
178 |
+
g.manual_seed(self.epoch)
|
179 |
+
random_idx = torch.randperm(bucket_size, generator=g).tolist()
|
180 |
+
idx_range = [start_idx + x for x in random_idx[bucket_offset:]]
|
181 |
+
else:
|
182 |
+
full_bucket_size = (self.total_samples // self.micro_batch_size) \
|
183 |
+
* self.micro_batch_size
|
184 |
+
full_bucket_offset = current_epoch_samples
|
185 |
+
g = torch.Generator()
|
186 |
+
g.manual_seed(self.epoch)
|
187 |
+
idx_range_total = \
|
188 |
+
torch.randperm(full_bucket_size, generator=g).tolist()
|
189 |
+
idx_range_active = idx_range_total[full_bucket_offset:]
|
190 |
+
idx_range = idx_range_active[self.data_parallel_rank::self.data_parallel_size]
|
191 |
+
|
192 |
+
batch = []
|
193 |
+
# Last batch if not complete will be dropped.
|
194 |
+
for idx in idx_range:
|
195 |
+
batch.append(idx)
|
196 |
+
if len(batch) == self.micro_batch_size:
|
197 |
+
self.consumed_samples += self.micro_batch_times_data_parallel_size
|
198 |
+
yield batch
|
199 |
+
batch = []
|
megatron/data/dataset_utils.py
ADDED
@@ -0,0 +1,938 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors, and NVIDIA.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
|
17 |
+
# Most of the code here has been copied from:
|
18 |
+
# https://github.com/google-research/albert/blob/master/create_pretraining_data.py
|
19 |
+
# with some modifications.
|
20 |
+
|
21 |
+
import math
|
22 |
+
import os
|
23 |
+
import time
|
24 |
+
import collections
|
25 |
+
|
26 |
+
import numpy as np
|
27 |
+
import torch
|
28 |
+
import random
|
29 |
+
|
30 |
+
from megatron import (
|
31 |
+
get_tokenizer,
|
32 |
+
get_args,
|
33 |
+
mpu,
|
34 |
+
print_rank_0
|
35 |
+
)
|
36 |
+
from megatron.data.blendable_dataset import BlendableDataset
|
37 |
+
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
|
38 |
+
|
39 |
+
DSET_TYPE_BERT = 'standard_bert'
|
40 |
+
DSET_TYPE_ICT = 'ict'
|
41 |
+
DSET_TYPE_T5 = 't5'
|
42 |
+
DSET_TYPE_GLM = 'glm'
|
43 |
+
|
44 |
+
DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT, DSET_TYPE_T5, DSET_TYPE_GLM]
|
45 |
+
|
46 |
+
|
47 |
+
def get_datasets_weights_and_num_samples(data_prefix,
|
48 |
+
train_valid_test_num_samples):
|
49 |
+
|
50 |
+
# The data prefix should be in the format of:
|
51 |
+
# weight-1, data-prefix-1, weight-2, data-prefix-2, ..
|
52 |
+
assert len(data_prefix) % 2 == 0
|
53 |
+
num_datasets = len(data_prefix) // 2
|
54 |
+
weights = [0]*num_datasets
|
55 |
+
prefixes = [0]*num_datasets
|
56 |
+
for i in range(num_datasets):
|
57 |
+
weights[i] = float(data_prefix[2*i])
|
58 |
+
prefixes[i] = (data_prefix[2*i+1]).strip()
|
59 |
+
# Normalize weights
|
60 |
+
weight_sum = 0.0
|
61 |
+
for weight in weights:
|
62 |
+
weight_sum += weight
|
63 |
+
assert weight_sum > 0.0
|
64 |
+
weights = [weight / weight_sum for weight in weights]
|
65 |
+
|
66 |
+
# Add 0.5% (the 1.005 factor) so in case the bleding dataset does
|
67 |
+
# not uniformly distribute the number of samples, we still have
|
68 |
+
# samples left to feed to the network.
|
69 |
+
datasets_train_valid_test_num_samples = []
|
70 |
+
for weight in weights:
|
71 |
+
datasets_train_valid_test_num_samples.append(
|
72 |
+
[int(math.ceil(val * weight * 1.005))
|
73 |
+
for val in train_valid_test_num_samples])
|
74 |
+
|
75 |
+
|
76 |
+
return prefixes, weights, datasets_train_valid_test_num_samples
|
77 |
+
|
78 |
+
|
79 |
+
def compile_helper():
|
80 |
+
"""Compile helper function ar runtime. Make sure this
|
81 |
+
is invoked on a single process."""
|
82 |
+
import os
|
83 |
+
import subprocess
|
84 |
+
path = os.path.abspath(os.path.dirname(__file__))
|
85 |
+
ret = subprocess.run(['make', '-C', path])
|
86 |
+
if ret.returncode != 0:
|
87 |
+
print("Making C++ dataset helpers module failed, exiting.")
|
88 |
+
import sys
|
89 |
+
sys.exit(1)
|
90 |
+
|
91 |
+
|
92 |
+
def get_a_and_b_segments(sample, np_rng):
|
93 |
+
"""Divide sample into a and b segments."""
|
94 |
+
|
95 |
+
# Number of sentences in the sample.
|
96 |
+
n_sentences = len(sample)
|
97 |
+
# Make sure we always have two sentences.
|
98 |
+
assert n_sentences > 1, 'make sure each sample has at least two sentences.'
|
99 |
+
|
100 |
+
# First part:
|
101 |
+
# `a_end` is how many sentences go into the `A`.
|
102 |
+
a_end = 1
|
103 |
+
if n_sentences >= 3:
|
104 |
+
# Note that randin in numpy is exclusive.
|
105 |
+
a_end = np_rng.randint(1, n_sentences)
|
106 |
+
tokens_a = []
|
107 |
+
for j in range(a_end):
|
108 |
+
tokens_a.extend(sample[j])
|
109 |
+
|
110 |
+
# Second part:
|
111 |
+
tokens_b = []
|
112 |
+
for j in range(a_end, n_sentences):
|
113 |
+
tokens_b.extend(sample[j])
|
114 |
+
|
115 |
+
# Random next:
|
116 |
+
is_next_random = False
|
117 |
+
if np_rng.random() < 0.5:
|
118 |
+
is_next_random = True
|
119 |
+
tokens_a, tokens_b = tokens_b, tokens_a
|
120 |
+
|
121 |
+
return tokens_a, tokens_b, is_next_random
|
122 |
+
|
123 |
+
|
124 |
+
def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng):
|
125 |
+
"""Truncates a pair of sequences to a maximum sequence length."""
|
126 |
+
#print(len_a, len_b, max_num_tokens)
|
127 |
+
assert len_a > 0
|
128 |
+
if len_a + len_b <= max_num_tokens:
|
129 |
+
return False
|
130 |
+
while len_a + len_b > max_num_tokens:
|
131 |
+
if len_a > len_b:
|
132 |
+
len_a -= 1
|
133 |
+
tokens = tokens_a
|
134 |
+
else:
|
135 |
+
len_b -= 1
|
136 |
+
tokens = tokens_b
|
137 |
+
if np_rng.random() < 0.5:
|
138 |
+
del tokens[0]
|
139 |
+
else:
|
140 |
+
tokens.pop()
|
141 |
+
return True
|
142 |
+
|
143 |
+
|
144 |
+
def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id):
|
145 |
+
"""Merge segments A and B, add [CLS] and [SEP] and build tokentypes."""
|
146 |
+
|
147 |
+
tokens = []
|
148 |
+
tokentypes = []
|
149 |
+
# [CLS].
|
150 |
+
tokens.append(cls_id)
|
151 |
+
tokentypes.append(0)
|
152 |
+
# Segment A.
|
153 |
+
for token in tokens_a:
|
154 |
+
tokens.append(token)
|
155 |
+
tokentypes.append(0)
|
156 |
+
# [SEP].
|
157 |
+
tokens.append(sep_id)
|
158 |
+
tokentypes.append(0)
|
159 |
+
# Segment B.
|
160 |
+
for token in tokens_b:
|
161 |
+
tokens.append(token)
|
162 |
+
tokentypes.append(1)
|
163 |
+
if tokens_b:
|
164 |
+
# [SEP].
|
165 |
+
tokens.append(sep_id)
|
166 |
+
tokentypes.append(1)
|
167 |
+
|
168 |
+
return tokens, tokentypes
|
169 |
+
|
170 |
+
def create_tokens(tokens_a, tokens_b, cls_id, sep_id):
|
171 |
+
"""Merge segments A and B, add [CLS] and [SEP] and build tokentypes."""
|
172 |
+
|
173 |
+
tokens = []
|
174 |
+
# [CLS].
|
175 |
+
tokens.append(cls_id)
|
176 |
+
# Segment A.
|
177 |
+
for token in tokens_a:
|
178 |
+
tokens.append(token)
|
179 |
+
# [SEP].
|
180 |
+
tokens.append(sep_id)
|
181 |
+
# Segment B.
|
182 |
+
for token in tokens_b:
|
183 |
+
tokens.append(token)
|
184 |
+
if tokens_b:
|
185 |
+
# [SEP].
|
186 |
+
tokens.append(sep_id)
|
187 |
+
|
188 |
+
return tokens
|
189 |
+
|
190 |
+
MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
|
191 |
+
["index", "label"])
|
192 |
+
|
193 |
+
|
194 |
+
def is_start_piece(piece):
|
195 |
+
"""Check if the current word piece is the starting piece (BERT)."""
|
196 |
+
# When a word has been split into
|
197 |
+
# WordPieces, the first token does not have any marker and any subsequence
|
198 |
+
# tokens are prefixed with ##. So whenever we see the ## token, we
|
199 |
+
# append it to the previous set of word indexes.
|
200 |
+
return not piece.startswith("##")
|
201 |
+
|
202 |
+
|
203 |
+
def create_masked_lm_predictions(tokens,
|
204 |
+
vocab_id_list, vocab_id_to_token_dict,
|
205 |
+
masked_lm_prob,
|
206 |
+
cls_id, sep_id, mask_id,
|
207 |
+
max_predictions_per_seq,
|
208 |
+
np_rng,
|
209 |
+
max_ngrams=3,
|
210 |
+
do_whole_word_mask=True,
|
211 |
+
favor_longer_ngram=False,
|
212 |
+
do_permutation=False,
|
213 |
+
geometric_dist=False,
|
214 |
+
masking_style="bert"):
|
215 |
+
"""Creates the predictions for the masked LM objective.
|
216 |
+
Note: Tokens here are vocab ids and not text tokens."""
|
217 |
+
|
218 |
+
cand_indexes = []
|
219 |
+
# Note(mingdachen): We create a list for recording if the piece is
|
220 |
+
# the starting piece of current token, where 1 means true, so that
|
221 |
+
# on-the-fly whole word masking is possible.
|
222 |
+
token_boundary = [0] * len(tokens)
|
223 |
+
|
224 |
+
for (i, token) in enumerate(tokens):
|
225 |
+
if token == cls_id or token == sep_id:
|
226 |
+
token_boundary[i] = 1
|
227 |
+
continue
|
228 |
+
# Whole Word Masking means that if we mask all of the wordpieces
|
229 |
+
# corresponding to an original word.
|
230 |
+
#
|
231 |
+
# Note that Whole Word Masking does *not* change the training code
|
232 |
+
# at all -- we still predict each WordPiece independently, softmaxed
|
233 |
+
# over the entire vocabulary.
|
234 |
+
if (do_whole_word_mask and len(cand_indexes) >= 1 and
|
235 |
+
not is_start_piece(vocab_id_to_token_dict[token])):
|
236 |
+
cand_indexes[-1].append(i)
|
237 |
+
else:
|
238 |
+
cand_indexes.append([i])
|
239 |
+
if is_start_piece(vocab_id_to_token_dict[token]):
|
240 |
+
token_boundary[i] = 1
|
241 |
+
|
242 |
+
output_tokens = list(tokens)
|
243 |
+
|
244 |
+
masked_lm_positions = []
|
245 |
+
masked_lm_labels = []
|
246 |
+
|
247 |
+
if masked_lm_prob == 0:
|
248 |
+
return (output_tokens, masked_lm_positions,
|
249 |
+
masked_lm_labels, token_boundary)
|
250 |
+
|
251 |
+
num_to_predict = min(max_predictions_per_seq,
|
252 |
+
max(1, int(round(len(tokens) * masked_lm_prob))))
|
253 |
+
|
254 |
+
ngrams = np.arange(1, max_ngrams + 1, dtype=np.int64)
|
255 |
+
if not geometric_dist:
|
256 |
+
# Note(mingdachen):
|
257 |
+
# By default, we set the probilities to favor shorter ngram sequences.
|
258 |
+
pvals = 1. / np.arange(1, max_ngrams + 1)
|
259 |
+
pvals /= pvals.sum(keepdims=True)
|
260 |
+
if favor_longer_ngram:
|
261 |
+
pvals = pvals[::-1]
|
262 |
+
|
263 |
+
ngram_indexes = []
|
264 |
+
for idx in range(len(cand_indexes)):
|
265 |
+
ngram_index = []
|
266 |
+
for n in ngrams:
|
267 |
+
ngram_index.append(cand_indexes[idx:idx + n])
|
268 |
+
ngram_indexes.append(ngram_index)
|
269 |
+
|
270 |
+
np_rng.shuffle(ngram_indexes)
|
271 |
+
|
272 |
+
(masked_lms, masked_spans) = ([], [])
|
273 |
+
covered_indexes = set()
|
274 |
+
for cand_index_set in ngram_indexes:
|
275 |
+
if len(masked_lms) >= num_to_predict:
|
276 |
+
break
|
277 |
+
if not cand_index_set:
|
278 |
+
continue
|
279 |
+
# Note(mingdachen):
|
280 |
+
# Skip current piece if they are covered in lm masking or previous ngrams.
|
281 |
+
for index_set in cand_index_set[0]:
|
282 |
+
for index in index_set:
|
283 |
+
if index in covered_indexes:
|
284 |
+
continue
|
285 |
+
|
286 |
+
if not geometric_dist:
|
287 |
+
n = np_rng.choice(ngrams[:len(cand_index_set)],
|
288 |
+
p=pvals[:len(cand_index_set)] /
|
289 |
+
pvals[:len(cand_index_set)].sum(keepdims=True))
|
290 |
+
else:
|
291 |
+
# Sampling "n" from the geometric distribution and clipping it to
|
292 |
+
# the max_ngrams. Using p=0.2 default from the SpanBERT paper
|
293 |
+
# https://arxiv.org/pdf/1907.10529.pdf (Sec 3.1)
|
294 |
+
n = min(np_rng.geometric(0.2), max_ngrams)
|
295 |
+
|
296 |
+
index_set = sum(cand_index_set[n - 1], [])
|
297 |
+
n -= 1
|
298 |
+
# Note(mingdachen):
|
299 |
+
# Repeatedly looking for a candidate that does not exceed the
|
300 |
+
# maximum number of predictions by trying shorter ngrams.
|
301 |
+
while len(masked_lms) + len(index_set) > num_to_predict:
|
302 |
+
if n == 0:
|
303 |
+
break
|
304 |
+
index_set = sum(cand_index_set[n - 1], [])
|
305 |
+
n -= 1
|
306 |
+
# If adding a whole-word mask would exceed the maximum number of
|
307 |
+
# predictions, then just skip this candidate.
|
308 |
+
if len(masked_lms) + len(index_set) > num_to_predict:
|
309 |
+
continue
|
310 |
+
is_any_index_covered = False
|
311 |
+
for index in index_set:
|
312 |
+
if index in covered_indexes:
|
313 |
+
is_any_index_covered = True
|
314 |
+
break
|
315 |
+
if is_any_index_covered:
|
316 |
+
continue
|
317 |
+
for index in index_set:
|
318 |
+
covered_indexes.add(index)
|
319 |
+
masked_token = None
|
320 |
+
if masking_style == "bert":
|
321 |
+
# 80% of the time, replace with [MASK]
|
322 |
+
if np_rng.random() < 0.8:
|
323 |
+
masked_token = mask_id
|
324 |
+
else:
|
325 |
+
# 10% of the time, keep original
|
326 |
+
if np_rng.random() < 0.5:
|
327 |
+
masked_token = tokens[index]
|
328 |
+
# 10% of the time, replace with random word
|
329 |
+
else:
|
330 |
+
masked_token = vocab_id_list[np_rng.randint(0, len(vocab_id_list))]
|
331 |
+
elif masking_style == "t5":
|
332 |
+
masked_token = mask_id
|
333 |
+
else:
|
334 |
+
raise ValueError("invalid value of masking style")
|
335 |
+
|
336 |
+
output_tokens[index] = masked_token
|
337 |
+
masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
|
338 |
+
|
339 |
+
masked_spans.append(MaskedLmInstance(
|
340 |
+
index=index_set,
|
341 |
+
label=[tokens[index] for index in index_set]))
|
342 |
+
|
343 |
+
assert len(masked_lms) <= num_to_predict
|
344 |
+
np_rng.shuffle(ngram_indexes)
|
345 |
+
|
346 |
+
select_indexes = set()
|
347 |
+
if do_permutation:
|
348 |
+
for cand_index_set in ngram_indexes:
|
349 |
+
if len(select_indexes) >= num_to_predict:
|
350 |
+
break
|
351 |
+
if not cand_index_set:
|
352 |
+
continue
|
353 |
+
# Note(mingdachen):
|
354 |
+
# Skip current piece if they are covered in lm masking or previous ngrams.
|
355 |
+
for index_set in cand_index_set[0]:
|
356 |
+
for index in index_set:
|
357 |
+
if index in covered_indexes or index in select_indexes:
|
358 |
+
continue
|
359 |
+
|
360 |
+
n = np.random.choice(ngrams[:len(cand_index_set)],
|
361 |
+
p=pvals[:len(cand_index_set)] /
|
362 |
+
pvals[:len(cand_index_set)].sum(keepdims=True))
|
363 |
+
index_set = sum(cand_index_set[n - 1], [])
|
364 |
+
n -= 1
|
365 |
+
|
366 |
+
while len(select_indexes) + len(index_set) > num_to_predict:
|
367 |
+
if n == 0:
|
368 |
+
break
|
369 |
+
index_set = sum(cand_index_set[n - 1], [])
|
370 |
+
n -= 1
|
371 |
+
# If adding a whole-word mask would exceed the maximum number of
|
372 |
+
# predictions, then just skip this candidate.
|
373 |
+
if len(select_indexes) + len(index_set) > num_to_predict:
|
374 |
+
continue
|
375 |
+
is_any_index_covered = False
|
376 |
+
for index in index_set:
|
377 |
+
if index in covered_indexes or index in select_indexes:
|
378 |
+
is_any_index_covered = True
|
379 |
+
break
|
380 |
+
if is_any_index_covered:
|
381 |
+
continue
|
382 |
+
for index in index_set:
|
383 |
+
select_indexes.add(index)
|
384 |
+
assert len(select_indexes) <= num_to_predict
|
385 |
+
|
386 |
+
select_indexes = sorted(select_indexes)
|
387 |
+
permute_indexes = list(select_indexes)
|
388 |
+
np_rng.shuffle(permute_indexes)
|
389 |
+
orig_token = list(output_tokens)
|
390 |
+
|
391 |
+
for src_i, tgt_i in zip(select_indexes, permute_indexes):
|
392 |
+
output_tokens[src_i] = orig_token[tgt_i]
|
393 |
+
masked_lms.append(MaskedLmInstance(index=src_i, label=orig_token[src_i]))
|
394 |
+
|
395 |
+
masked_lms = sorted(masked_lms, key=lambda x: x.index)
|
396 |
+
# Sort the spans by the index of the first span
|
397 |
+
masked_spans = sorted(masked_spans, key=lambda x: x.index[0])
|
398 |
+
|
399 |
+
for p in masked_lms:
|
400 |
+
masked_lm_positions.append(p.index)
|
401 |
+
masked_lm_labels.append(p.label)
|
402 |
+
return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary, masked_spans)
|
403 |
+
|
404 |
+
|
405 |
+
def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
|
406 |
+
masked_labels, pad_id, max_seq_length):
|
407 |
+
"""Pad sequences and convert them to numpy."""
|
408 |
+
|
409 |
+
# Some checks.
|
410 |
+
num_tokens = len(tokens)
|
411 |
+
padding_length = max_seq_length - num_tokens
|
412 |
+
assert padding_length >= 0
|
413 |
+
assert len(tokentypes) == num_tokens
|
414 |
+
assert len(masked_positions) == len(masked_labels)
|
415 |
+
|
416 |
+
# Tokens and token types.
|
417 |
+
filler = [pad_id] * padding_length
|
418 |
+
tokens_np = np.array(tokens + filler, dtype=np.int64)
|
419 |
+
tokentypes_np = np.array(tokentypes + filler, dtype=np.int64)
|
420 |
+
|
421 |
+
# Padding mask.
|
422 |
+
padding_mask_np = np.array([1] * num_tokens + [0] * padding_length,
|
423 |
+
dtype=np.int64)
|
424 |
+
|
425 |
+
# Lables and loss mask.
|
426 |
+
labels = [-1] * max_seq_length
|
427 |
+
loss_mask = [0] * max_seq_length
|
428 |
+
for i in range(len(masked_positions)):
|
429 |
+
assert masked_positions[i] < num_tokens
|
430 |
+
labels[masked_positions[i]] = masked_labels[i]
|
431 |
+
loss_mask[masked_positions[i]] = 1
|
432 |
+
labels_np = np.array(labels, dtype=np.int64)
|
433 |
+
loss_mask_np = np.array(loss_mask, dtype=np.int64)
|
434 |
+
|
435 |
+
return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np
|
436 |
+
|
437 |
+
|
438 |
+
def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
|
439 |
+
train_valid_test_num_samples,
|
440 |
+
max_seq_length,
|
441 |
+
masked_lm_prob, short_seq_prob, seed,
|
442 |
+
skip_warmup, binary_head=False,
|
443 |
+
max_seq_length_dec=None,
|
444 |
+
dataset_type='standard_bert'):
|
445 |
+
if len(data_prefix) == 1:
|
446 |
+
return _build_train_valid_test_datasets(data_prefix[0],
|
447 |
+
data_impl, splits_string,
|
448 |
+
train_valid_test_num_samples,
|
449 |
+
max_seq_length, masked_lm_prob,
|
450 |
+
short_seq_prob, seed,
|
451 |
+
skip_warmup,
|
452 |
+
binary_head,
|
453 |
+
max_seq_length_dec,
|
454 |
+
dataset_type=dataset_type)
|
455 |
+
# Blending dataset.
|
456 |
+
# Parse the values.
|
457 |
+
output = get_datasets_weights_and_num_samples(data_prefix,
|
458 |
+
train_valid_test_num_samples)
|
459 |
+
prefixes, weights, datasets_train_valid_test_num_samples = output
|
460 |
+
|
461 |
+
# Build individual datasets.
|
462 |
+
train_datasets = []
|
463 |
+
valid_datasets = []
|
464 |
+
test_datasets = []
|
465 |
+
for i in range(len(prefixes)):
|
466 |
+
train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(
|
467 |
+
prefixes[i], data_impl, splits_string,
|
468 |
+
datasets_train_valid_test_num_samples[i],
|
469 |
+
max_seq_length, masked_lm_prob, short_seq_prob,
|
470 |
+
seed, skip_warmup, binary_head, max_seq_length_dec, dataset_type=dataset_type)
|
471 |
+
if train_ds:
|
472 |
+
train_datasets.append(train_ds)
|
473 |
+
if valid_ds:
|
474 |
+
valid_datasets.append(valid_ds)
|
475 |
+
if test_ds:
|
476 |
+
test_datasets.append(test_ds)
|
477 |
+
|
478 |
+
# Blend.
|
479 |
+
blending_train_dataset = None
|
480 |
+
if train_datasets:
|
481 |
+
blending_train_dataset = BlendableDataset(train_datasets, weights)
|
482 |
+
blending_valid_dataset = None
|
483 |
+
if valid_datasets:
|
484 |
+
blending_valid_dataset = BlendableDataset(valid_datasets, weights)
|
485 |
+
blending_test_dataset = None
|
486 |
+
if test_datasets:
|
487 |
+
blending_test_dataset = BlendableDataset(test_datasets, weights)
|
488 |
+
|
489 |
+
return (blending_train_dataset, blending_valid_dataset,
|
490 |
+
blending_test_dataset)
|
491 |
+
|
492 |
+
|
493 |
+
def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
|
494 |
+
train_valid_test_num_samples,
|
495 |
+
max_seq_length,
|
496 |
+
masked_lm_prob, short_seq_prob, seed,
|
497 |
+
skip_warmup, binary_head,
|
498 |
+
max_seq_length_dec,
|
499 |
+
dataset_type='standard_bert'):
|
500 |
+
|
501 |
+
if dataset_type not in DSET_TYPES:
|
502 |
+
raise ValueError("Invalid dataset_type: ", dataset_type)
|
503 |
+
|
504 |
+
# Indexed dataset.
|
505 |
+
indexed_dataset = get_indexed_dataset_(data_prefix,
|
506 |
+
data_impl,
|
507 |
+
skip_warmup)
|
508 |
+
|
509 |
+
if dataset_type == DSET_TYPE_ICT:
|
510 |
+
args = get_args()
|
511 |
+
title_dataset = get_indexed_dataset_(args.titles_data_path,
|
512 |
+
data_impl,
|
513 |
+
skip_warmup)
|
514 |
+
|
515 |
+
# Get start and end indices of train/valid/train into doc-idx
|
516 |
+
# Note that doc-idx is desinged to be num-docs + 1 so we can
|
517 |
+
# easily iterate over it.
|
518 |
+
total_num_of_documents = indexed_dataset.doc_idx.shape[0] - 1
|
519 |
+
splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
|
520 |
+
|
521 |
+
# Print stats about the splits.
|
522 |
+
print_rank_0(' > dataset split:')
|
523 |
+
|
524 |
+
def print_split_stats(name, index):
|
525 |
+
print_rank_0(' {}:'.format(name))
|
526 |
+
print_rank_0(' document indices in [{}, {}) total of {} '
|
527 |
+
'documents'.format(splits[index], splits[index + 1],
|
528 |
+
splits[index + 1] - splits[index]))
|
529 |
+
start_index = indexed_dataset.doc_idx[splits[index]]
|
530 |
+
end_index = indexed_dataset.doc_idx[splits[index + 1]]
|
531 |
+
print_rank_0(' sentence indices in [{}, {}) total of {} '
|
532 |
+
'sentences'.format(start_index, end_index,
|
533 |
+
end_index - start_index))
|
534 |
+
print_split_stats('train', 0)
|
535 |
+
print_split_stats('validation', 1)
|
536 |
+
print_split_stats('test', 2)
|
537 |
+
|
538 |
+
def build_dataset(index, name):
|
539 |
+
from megatron.data.bert_dataset import BertDataset
|
540 |
+
from megatron.data.ict_dataset import ICTDataset
|
541 |
+
from megatron.data.t5_dataset import T5Dataset
|
542 |
+
from megatron.data.glm_dataset import GlmDataset
|
543 |
+
dataset = None
|
544 |
+
if splits[index + 1] > splits[index]:
|
545 |
+
# Get the pointer to the original doc-idx so we can set it later.
|
546 |
+
doc_idx_ptr = indexed_dataset.get_doc_idx()
|
547 |
+
# Slice the doc-idx
|
548 |
+
start_index = splits[index]
|
549 |
+
# Add +1 so we can index into the dataset to get the upper bound.
|
550 |
+
end_index = splits[index + 1] + 1
|
551 |
+
# New doc_idx view.
|
552 |
+
indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index])
|
553 |
+
# Build the dataset accordingly.
|
554 |
+
kwargs = dict(
|
555 |
+
name=name,
|
556 |
+
data_prefix=data_prefix,
|
557 |
+
num_epochs=None,
|
558 |
+
max_num_samples=train_valid_test_num_samples[index],
|
559 |
+
max_seq_length=max_seq_length,
|
560 |
+
seed=seed,
|
561 |
+
)
|
562 |
+
|
563 |
+
if dataset_type == DSET_TYPE_ICT:
|
564 |
+
args = get_args()
|
565 |
+
dataset = ICTDataset(
|
566 |
+
block_dataset=indexed_dataset,
|
567 |
+
title_dataset=title_dataset,
|
568 |
+
query_in_block_prob=args.query_in_block_prob,
|
569 |
+
use_one_sent_docs=args.use_one_sent_docs,
|
570 |
+
binary_head=binary_head,
|
571 |
+
**kwargs
|
572 |
+
)
|
573 |
+
elif dataset_type == DSET_TYPE_T5:
|
574 |
+
dataset = T5Dataset(
|
575 |
+
indexed_dataset=indexed_dataset,
|
576 |
+
masked_lm_prob=masked_lm_prob,
|
577 |
+
max_seq_length_dec=max_seq_length_dec,
|
578 |
+
short_seq_prob=short_seq_prob,
|
579 |
+
**kwargs
|
580 |
+
)
|
581 |
+
elif dataset_type == DSET_TYPE_BERT:
|
582 |
+
dataset = BertDataset(
|
583 |
+
indexed_dataset=indexed_dataset,
|
584 |
+
masked_lm_prob=masked_lm_prob,
|
585 |
+
short_seq_prob=short_seq_prob,
|
586 |
+
binary_head=binary_head,
|
587 |
+
**kwargs
|
588 |
+
)
|
589 |
+
elif dataset_type == DSET_TYPE_GLM:
|
590 |
+
dataset = GlmDataset(
|
591 |
+
indexed_dataset=indexed_dataset,
|
592 |
+
masked_lm_prob=masked_lm_prob,
|
593 |
+
short_seq_prob=short_seq_prob,
|
594 |
+
binary_head=binary_head,
|
595 |
+
**kwargs
|
596 |
+
)
|
597 |
+
else:
|
598 |
+
raise NotImplementedError("Dataset type not fully implemented.")
|
599 |
+
|
600 |
+
# Set the original pointer so dataset remains the main dataset.
|
601 |
+
indexed_dataset.set_doc_idx(doc_idx_ptr)
|
602 |
+
# Checks.
|
603 |
+
assert indexed_dataset.doc_idx[0] == 0
|
604 |
+
assert indexed_dataset.doc_idx.shape[0] == \
|
605 |
+
(total_num_of_documents + 1)
|
606 |
+
return dataset
|
607 |
+
|
608 |
+
train_dataset = build_dataset(0, 'train')
|
609 |
+
valid_dataset = build_dataset(1, 'valid')
|
610 |
+
test_dataset = build_dataset(2, 'test')
|
611 |
+
|
612 |
+
return (train_dataset, valid_dataset, test_dataset)
|
613 |
+
|
614 |
+
|
615 |
+
def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
|
616 |
+
|
617 |
+
print_rank_0(' > building dataset index ...')
|
618 |
+
|
619 |
+
start_time = time.time()
|
620 |
+
indexed_dataset = make_indexed_dataset(data_prefix,
|
621 |
+
data_impl,
|
622 |
+
skip_warmup)
|
623 |
+
assert indexed_dataset.sizes.shape[0] == indexed_dataset.doc_idx[-1]
|
624 |
+
print_rank_0(' > finished creating indexed dataset in {:4f} '
|
625 |
+
'seconds'.format(time.time() - start_time))
|
626 |
+
|
627 |
+
print_rank_0(' > indexed dataset stats:')
|
628 |
+
print_rank_0(' number of documents: {}'.format(
|
629 |
+
indexed_dataset.doc_idx.shape[0] - 1))
|
630 |
+
print_rank_0(' number of sentences: {}'.format(
|
631 |
+
indexed_dataset.sizes.shape[0]))
|
632 |
+
|
633 |
+
return indexed_dataset
|
634 |
+
|
635 |
+
|
636 |
+
def get_train_valid_test_split_(splits_string, size):
|
637 |
+
""" Get dataset splits from comma or '/' separated string list."""
|
638 |
+
|
639 |
+
splits = []
|
640 |
+
if splits_string.find(',') != -1:
|
641 |
+
splits = [float(s) for s in splits_string.split(',')]
|
642 |
+
elif splits_string.find('/') != -1:
|
643 |
+
splits = [float(s) for s in splits_string.split('/')]
|
644 |
+
else:
|
645 |
+
splits = [float(splits_string)]
|
646 |
+
while len(splits) < 3:
|
647 |
+
splits.append(0.)
|
648 |
+
splits = splits[:3]
|
649 |
+
splits_sum = sum(splits)
|
650 |
+
assert splits_sum > 0.0
|
651 |
+
splits = [split / splits_sum for split in splits]
|
652 |
+
splits_index = [0]
|
653 |
+
for index, split in enumerate(splits):
|
654 |
+
splits_index.append(splits_index[index] +
|
655 |
+
int(round(split * float(size))))
|
656 |
+
diff = splits_index[-1] - size
|
657 |
+
for index in range(1, len(splits_index)):
|
658 |
+
splits_index[index] -= diff
|
659 |
+
assert len(splits_index) == 4
|
660 |
+
assert splits_index[-1] == size
|
661 |
+
return splits_index
|
662 |
+
|
663 |
+
def get_samples_mapping(indexed_dataset,
|
664 |
+
data_prefix,
|
665 |
+
num_epochs,
|
666 |
+
max_num_samples,
|
667 |
+
max_seq_length,
|
668 |
+
short_seq_prob,
|
669 |
+
seed,
|
670 |
+
name,
|
671 |
+
binary_head):
|
672 |
+
"""Get a list that maps a sample index to a starting sentence index, end sentence index, and length"""
|
673 |
+
|
674 |
+
if not num_epochs:
|
675 |
+
if not max_num_samples:
|
676 |
+
raise ValueError("Need to specify either max_num_samples "
|
677 |
+
"or num_epochs")
|
678 |
+
num_epochs = np.iinfo(np.int32).max - 1
|
679 |
+
if not max_num_samples:
|
680 |
+
max_num_samples = np.iinfo(np.int64).max - 1
|
681 |
+
|
682 |
+
# Filename of the index mapping
|
683 |
+
indexmap_filename = data_prefix
|
684 |
+
indexmap_filename += '_{}_indexmap'.format(name)
|
685 |
+
if num_epochs != (np.iinfo(np.int32).max - 1):
|
686 |
+
indexmap_filename += '_{}ep'.format(num_epochs)
|
687 |
+
if max_num_samples != (np.iinfo(np.int64).max - 1):
|
688 |
+
indexmap_filename += '_{}mns'.format(max_num_samples)
|
689 |
+
indexmap_filename += '_{}msl'.format(max_seq_length)
|
690 |
+
indexmap_filename += '_{:0.2f}ssp'.format(short_seq_prob)
|
691 |
+
indexmap_filename += '_{}s'.format(seed)
|
692 |
+
indexmap_filename += '.npy'
|
693 |
+
|
694 |
+
# Build the indexed mapping if not exist.
|
695 |
+
if torch.distributed.get_rank() == 0 and \
|
696 |
+
not os.path.isfile(indexmap_filename):
|
697 |
+
print(' > WARNING: could not find index map file {}, building '
|
698 |
+
'the indices on rank 0 ...'.format(indexmap_filename))
|
699 |
+
|
700 |
+
# Make sure the types match the helpers input types.
|
701 |
+
assert indexed_dataset.doc_idx.dtype == np.int64
|
702 |
+
assert indexed_dataset.sizes.dtype == np.int32
|
703 |
+
|
704 |
+
# Build samples mapping
|
705 |
+
verbose = torch.distributed.get_rank() == 0
|
706 |
+
start_time = time.time()
|
707 |
+
print_rank_0(' > building samples index mapping for {} ...'.format(
|
708 |
+
name))
|
709 |
+
# First compile and then import.
|
710 |
+
from megatron.data import helpers
|
711 |
+
samples_mapping = helpers.build_mapping(
|
712 |
+
indexed_dataset.doc_idx,
|
713 |
+
indexed_dataset.sizes,
|
714 |
+
num_epochs,
|
715 |
+
max_num_samples,
|
716 |
+
max_seq_length,
|
717 |
+
short_seq_prob,
|
718 |
+
seed,
|
719 |
+
verbose,
|
720 |
+
2 if binary_head else 1)
|
721 |
+
print_rank_0(' > done building samples index maping')
|
722 |
+
np.save(indexmap_filename, samples_mapping, allow_pickle=True)
|
723 |
+
print_rank_0(' > saved the index mapping in {}'.format(
|
724 |
+
indexmap_filename))
|
725 |
+
# Make sure all the ranks have built the mapping
|
726 |
+
print_rank_0(' > elasped time to build and save samples mapping '
|
727 |
+
'(seconds): {:4f}'.format(
|
728 |
+
time.time() - start_time))
|
729 |
+
# This should be a barrier but nccl barrier assumes
|
730 |
+
# device_index=rank which is not the case for model
|
731 |
+
# parallel case
|
732 |
+
counts = torch.cuda.LongTensor([1])
|
733 |
+
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
|
734 |
+
torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group())
|
735 |
+
assert counts[0].item() == (
|
736 |
+
torch.distributed.get_world_size() //
|
737 |
+
torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group()))
|
738 |
+
|
739 |
+
# Load indexed dataset.
|
740 |
+
print_rank_0(' > loading indexed mapping from {}'.format(
|
741 |
+
indexmap_filename))
|
742 |
+
start_time = time.time()
|
743 |
+
samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r')
|
744 |
+
print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
|
745 |
+
time.time() - start_time))
|
746 |
+
print_rank_0(' total number of samples: {}'.format(
|
747 |
+
samples_mapping.shape[0]))
|
748 |
+
|
749 |
+
return samples_mapping
|
750 |
+
|
751 |
+
|
752 |
+
class MaskEncoder(object):
|
753 |
+
def __init__(self):
|
754 |
+
tokenizer = get_tokenizer()
|
755 |
+
self.vocab_size = tokenizer.vocab_size
|
756 |
+
self.vocab_id_list = list(tokenizer.inv_vocab.keys())
|
757 |
+
self.vocab_id_to_token_dict = tokenizer.inv_vocab
|
758 |
+
self.cls_id = tokenizer.cls
|
759 |
+
self.sep_id = tokenizer.sep
|
760 |
+
self.mask_id = tokenizer.mask
|
761 |
+
self.pad_id = tokenizer.pad
|
762 |
+
|
763 |
+
import jieba_fast
|
764 |
+
self.zh_tokenizer = jieba_fast.lcut
|
765 |
+
self.random_ratio = 0
|
766 |
+
|
767 |
+
|
768 |
+
def word_starts(self, source):
|
769 |
+
raw_tokens = [self.vocab_id_to_token_dict[i] for i in source.tolist()]
|
770 |
+
words = [raw_tokens[0]] + self.zh_tokenizer(''.join(raw_tokens[1:-1]), HMM=True) + [raw_tokens[-1]]
|
771 |
+
|
772 |
+
def _is_chinese_char(c):
|
773 |
+
"""Checks whether CP is the codepoint of a CJK character."""
|
774 |
+
# This defines a "chinese character" as anything in the CJK Unicode block:
|
775 |
+
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
776 |
+
#
|
777 |
+
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
778 |
+
# despite its name. The modern Korean Hangul alphabet is a different block,
|
779 |
+
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
780 |
+
# space-separated words, so they are not treated specially and handled
|
781 |
+
# like the all of the other languages.
|
782 |
+
if len(c) > 1:
|
783 |
+
return all([_is_chinese_char(c_i) for c_i in c])
|
784 |
+
cp = ord(c)
|
785 |
+
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
|
786 |
+
(cp >= 0x3400 and cp <= 0x4DBF) or #
|
787 |
+
(cp >= 0x20000 and cp <= 0x2A6DF) or #
|
788 |
+
(cp >= 0x2A700 and cp <= 0x2B73F) or #
|
789 |
+
(cp >= 0x2B740 and cp <= 0x2B81F) or #
|
790 |
+
(cp >= 0x2B820 and cp <= 0x2CEAF) or
|
791 |
+
(cp >= 0xF900 and cp <= 0xFAFF) or #
|
792 |
+
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
|
793 |
+
return True
|
794 |
+
|
795 |
+
return False
|
796 |
+
|
797 |
+
def align_linear(atokens, btokens):
|
798 |
+
a2c = []
|
799 |
+
c2b = []
|
800 |
+
a2b = []
|
801 |
+
length = 0
|
802 |
+
for tok in atokens:
|
803 |
+
a2c.append([length + i for i in range(len(tok))])
|
804 |
+
length += len(tok)
|
805 |
+
for i, tok in enumerate(btokens):
|
806 |
+
c2b.extend([i for _ in range(len(tok))])
|
807 |
+
|
808 |
+
for i, amap in enumerate(a2c):
|
809 |
+
bmap = [c2b[ci] for ci in amap]
|
810 |
+
a2b.append(list(set(bmap)))
|
811 |
+
return a2b
|
812 |
+
|
813 |
+
raw_to_word_align = align_linear(raw_tokens, words)
|
814 |
+
is_word_start = torch.zeros(source.size())
|
815 |
+
word_starts = []
|
816 |
+
skip_cur_word = True
|
817 |
+
for i in range(1, len(raw_to_word_align)):
|
818 |
+
if raw_to_word_align[i-1] == raw_to_word_align[i]:
|
819 |
+
# not a word start, as they align to the same word
|
820 |
+
if not skip_cur_word and not _is_chinese_char(raw_tokens[i]):
|
821 |
+
word_starts.pop(-1)
|
822 |
+
skip_cur_word = True
|
823 |
+
continue
|
824 |
+
else:
|
825 |
+
is_word_start[i] = 1
|
826 |
+
if _is_chinese_char(raw_tokens[i]):
|
827 |
+
word_starts.append(i)
|
828 |
+
skip_cur_word = False
|
829 |
+
is_word_start[0] = 0
|
830 |
+
is_word_start[-1] = 0
|
831 |
+
word_starts = torch.tensor(word_starts).long().view(-1, 1)
|
832 |
+
return is_word_start, word_starts
|
833 |
+
|
834 |
+
def add_whole_word_mask(self, source, p, replace_length=1):
|
835 |
+
is_word_start, word_starts = self.word_starts(source)
|
836 |
+
num_to_mask_word = int(math.ceil(word_starts.size(0) * p))
|
837 |
+
num_to_mask_char = int(math.ceil(word_starts.size(0) * p * 0.1))
|
838 |
+
num_to_mask = num_to_mask_word + num_to_mask_char
|
839 |
+
if num_to_mask > word_starts.size(0):
|
840 |
+
word_starts = is_word_start.nonzero(as_tuple=False)
|
841 |
+
num_inserts = 0
|
842 |
+
if num_to_mask == 0:
|
843 |
+
return source
|
844 |
+
|
845 |
+
lengths = torch.ones((num_to_mask,)).long()
|
846 |
+
assert is_word_start[-1] == 0
|
847 |
+
indices = word_starts[
|
848 |
+
torch.randperm(word_starts.size(0))[:num_to_mask]
|
849 |
+
].squeeze(1)
|
850 |
+
if len(indices) < num_to_mask:
|
851 |
+
num_to_mask = len(indices)
|
852 |
+
|
853 |
+
mask_random = torch.FloatTensor(num_to_mask).uniform_() < self.random_ratio
|
854 |
+
source_length = source.size(0)
|
855 |
+
assert source_length - 1 not in indices
|
856 |
+
to_keep = torch.ones(source_length, dtype=torch.bool)
|
857 |
+
is_word_start[
|
858 |
+
-1
|
859 |
+
] = 255 # acts as a long length, so spans don't go over the end of doc
|
860 |
+
if replace_length == 0:
|
861 |
+
to_keep[indices] = 0
|
862 |
+
else:
|
863 |
+
# keep index, but replace it with [MASK]
|
864 |
+
# print(source.size(), word_starts.size(), indices.size(), mask_random.size())
|
865 |
+
# try:
|
866 |
+
source[indices] = self.mask_id
|
867 |
+
source[indices[mask_random]] = torch.randint(
|
868 |
+
1, self.vocab_size, size=(mask_random.sum(),)
|
869 |
+
)
|
870 |
+
# except:
|
871 |
+
# print(source)
|
872 |
+
# print(indices)
|
873 |
+
# print(mask_random)
|
874 |
+
# print()
|
875 |
+
# sorted_indices = torch.sort(indices)[0]
|
876 |
+
# continue_mask_pos = ((sorted_indices + 1)[:-1] == sorted_indices[1:])
|
877 |
+
# continue_mask_indices = sorted_indices[1:][continue_mask_pos]
|
878 |
+
# to_keep[continue_mask_indices] = 0
|
879 |
+
|
880 |
+
# for char indices, we already masked, the following loop handles word mask
|
881 |
+
indices = indices[:num_to_mask_word]
|
882 |
+
mask_random = mask_random[:num_to_mask_word]
|
883 |
+
while indices.size(0) > 0:
|
884 |
+
uncompleted = is_word_start[indices + 1] == 0
|
885 |
+
indices = indices[uncompleted] + 1
|
886 |
+
mask_random = mask_random[uncompleted]
|
887 |
+
if replace_length != -1:
|
888 |
+
# delete token
|
889 |
+
to_keep[indices] = 0
|
890 |
+
else:
|
891 |
+
# keep index, but replace it with [MASK]
|
892 |
+
source[indices] = self.mask_id
|
893 |
+
source[indices[mask_random]] = torch.randint(
|
894 |
+
1, self.vocab_size, size=(mask_random.sum(),)
|
895 |
+
)
|
896 |
+
|
897 |
+
assert source_length - 1 not in indices
|
898 |
+
source = source[to_keep]
|
899 |
+
|
900 |
+
return source
|
901 |
+
|
902 |
+
def shif_chinese_word(self, tokens, tokens_bf_mask):
|
903 |
+
assert len(tokens) == len(tokens_bf_mask), 'length must be equal in this function'
|
904 |
+
buff_list = []
|
905 |
+
buff_list_index = []
|
906 |
+
for i in range(len(tokens)):
|
907 |
+
if tokens[i] == tokens_bf_mask[i]:
|
908 |
+
if len(buff_list) == 0:
|
909 |
+
continue
|
910 |
+
else:
|
911 |
+
if len(buff_list) != 1:
|
912 |
+
random.shuffle(buff_list)
|
913 |
+
tokens[buff_list_index[0] : buff_list_index[-1]+1] = buff_list
|
914 |
+
buff_list = []
|
915 |
+
buff_list_index = []
|
916 |
+
else:
|
917 |
+
buff_list.append(tokens_bf_mask[i])
|
918 |
+
buff_list_index.append(i)
|
919 |
+
|
920 |
+
return tokens
|
921 |
+
|
922 |
+
def mass_style_mask(self, tokens):
|
923 |
+
tokens = tokens[:]
|
924 |
+
p = random.uniform(0.3, 0.5)
|
925 |
+
num_to_mask = int(len(tokens) * p)
|
926 |
+
start_index = int((1 - p) / 2 * len(tokens))
|
927 |
+
tokens[start_index : start_index + num_to_mask] = [self.mask_id] * num_to_mask
|
928 |
+
|
929 |
+
return tokens
|
930 |
+
|
931 |
+
def delete_chinese_word(self, tokens, tokens_bf_mask):
|
932 |
+
return_tokens = []
|
933 |
+
assert len(tokens) == len(tokens_bf_mask), 'length must be equal in this function'
|
934 |
+
for i in range(len(tokens)):
|
935 |
+
if tokens[i] == tokens_bf_mask[i]:
|
936 |
+
return_tokens.append(tokens[i])
|
937 |
+
|
938 |
+
return return_tokens
|
megatron/data/glm_dataset.py
ADDED
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""BERT Style dataset."""
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
import torch
|
20 |
+
|
21 |
+
from megatron import (
|
22 |
+
get_args,
|
23 |
+
get_tokenizer,
|
24 |
+
mpu,
|
25 |
+
print_rank_0
|
26 |
+
)
|
27 |
+
from megatron.data.dataset_utils import (
|
28 |
+
get_samples_mapping,
|
29 |
+
get_a_and_b_segments,
|
30 |
+
truncate_segments,
|
31 |
+
create_tokens_and_tokentypes,
|
32 |
+
create_tokens,
|
33 |
+
create_masked_lm_predictions,
|
34 |
+
MaskEncoder
|
35 |
+
)
|
36 |
+
|
37 |
+
class DummyBertDataset(torch.utils.data.Dataset):
|
38 |
+
def __init__(self, name, num_samples, max_seq_length):
|
39 |
+
self.name = name
|
40 |
+
self.num_samples = num_samples
|
41 |
+
self.max_seq_length = max_seq_length
|
42 |
+
self.np_rng = np.random.RandomState(seed=0)
|
43 |
+
# self.token_nps = np_rng.randint(1000, 2000, (self.num_samples, 512))
|
44 |
+
# Vocab stuff.
|
45 |
+
tokenizer = get_tokenizer()
|
46 |
+
self.vocab_id_list = list(tokenizer.inv_vocab.keys())
|
47 |
+
self.vocab_id_to_token_dict = tokenizer.inv_vocab
|
48 |
+
self.cls_id = tokenizer.cls
|
49 |
+
self.sep_id = tokenizer.sep
|
50 |
+
self.mask_id = tokenizer.mask
|
51 |
+
self.pad_id = tokenizer.pad
|
52 |
+
|
53 |
+
def __len__(self):
|
54 |
+
return self.num_samples
|
55 |
+
|
56 |
+
def __getitem__(self, idx):
|
57 |
+
tokens = self.np_rng.randint(1000, 2000, self.max_seq_length)
|
58 |
+
masked_position = np.arange(int(tokens.shape[0] * 0.15))
|
59 |
+
tokens = tokens.astype(np.int64)
|
60 |
+
labels = tokens[masked_position]
|
61 |
+
label_np = np.full_like(tokens, -1)
|
62 |
+
label_np[masked_position] = labels
|
63 |
+
tokens[masked_position] = self.mask_id
|
64 |
+
loss_mask_np = np.zeros_like(tokens)
|
65 |
+
loss_mask_np[masked_position] = 1
|
66 |
+
train_sample = {
|
67 |
+
'text': tokens,
|
68 |
+
'types': np.zeros_like(tokens),
|
69 |
+
'labels': label_np,
|
70 |
+
'is_random': 0,
|
71 |
+
'loss_mask': loss_mask_np,
|
72 |
+
'padding_mask': np.ones_like(tokens),
|
73 |
+
'truncated': 0
|
74 |
+
}
|
75 |
+
return train_sample
|
76 |
+
|
77 |
+
class GlmDataset(torch.utils.data.Dataset):
|
78 |
+
|
79 |
+
def __init__(self, name, indexed_dataset, data_prefix,
|
80 |
+
num_epochs, max_num_samples, masked_lm_prob,
|
81 |
+
max_seq_length, short_seq_prob, seed, binary_head):
|
82 |
+
|
83 |
+
# Params to store.
|
84 |
+
self.name = name
|
85 |
+
self.seed = seed
|
86 |
+
self.masked_lm_prob = masked_lm_prob
|
87 |
+
self.max_seq_length = max_seq_length
|
88 |
+
self.binary_head = binary_head
|
89 |
+
|
90 |
+
# Dataset.
|
91 |
+
self.indexed_dataset = indexed_dataset
|
92 |
+
|
93 |
+
# Build the samples mapping.
|
94 |
+
self.samples_mapping = get_samples_mapping(self.indexed_dataset,
|
95 |
+
data_prefix,
|
96 |
+
num_epochs,
|
97 |
+
max_num_samples,
|
98 |
+
self.max_seq_length - 3, # account for added tokens
|
99 |
+
short_seq_prob,
|
100 |
+
self.seed,
|
101 |
+
self.name,
|
102 |
+
self.binary_head)
|
103 |
+
|
104 |
+
# Vocab stuff.
|
105 |
+
tokenizer = get_tokenizer()
|
106 |
+
self.vocab_id_list = list(tokenizer.inv_vocab.keys())
|
107 |
+
self.vocab_id_to_token_dict = tokenizer.inv_vocab
|
108 |
+
self.cls_id = tokenizer.cls
|
109 |
+
self.sep_id = tokenizer.sep
|
110 |
+
self.mask_id = tokenizer.mask
|
111 |
+
self.pad_id = tokenizer.pad
|
112 |
+
|
113 |
+
def __len__(self):
|
114 |
+
return self.samples_mapping.shape[0]
|
115 |
+
|
116 |
+
def __getitem__(self, idx):
|
117 |
+
start_idx, end_idx, seq_length = self.samples_mapping[idx]
|
118 |
+
sample = [self.indexed_dataset[i] for i in range(start_idx, end_idx)]
|
119 |
+
# Note that this rng state should be numpy and not python since
|
120 |
+
# python randint is inclusive whereas the numpy one is exclusive.
|
121 |
+
# We % 2**32 since numpy requres the seed to be between 0 and 2**32 - 1
|
122 |
+
np_rng = np.random.RandomState(seed=((self.seed + idx) % 2**32))
|
123 |
+
return build_training_sample(sample, seq_length,
|
124 |
+
self.max_seq_length, # needed for padding
|
125 |
+
self.vocab_id_list,
|
126 |
+
self.vocab_id_to_token_dict,
|
127 |
+
self.cls_id, self.sep_id,
|
128 |
+
self.mask_id, self.pad_id,
|
129 |
+
self.masked_lm_prob, np_rng,
|
130 |
+
self.binary_head)
|
131 |
+
|
132 |
+
def sent_level_task(binary_head, sample, target_seq_length, max_seq_length, np_rng):
|
133 |
+
if binary_head:
|
134 |
+
# We assume that we have at least two sentences in the sample
|
135 |
+
assert len(sample) > 1
|
136 |
+
assert target_seq_length <= max_seq_length
|
137 |
+
# Divide sample into two segments (A and B).
|
138 |
+
if binary_head:
|
139 |
+
tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample, np_rng)
|
140 |
+
else:
|
141 |
+
tokens_a = []
|
142 |
+
for j in range(len(sample)):
|
143 |
+
tokens_a.extend(sample[j])
|
144 |
+
tokens_b = []
|
145 |
+
is_next_random = False
|
146 |
+
# Truncate to `target_sequence_length`.
|
147 |
+
|
148 |
+
max_num_tokens = target_seq_length
|
149 |
+
truncated = truncate_segments(tokens_a, tokens_b, len(tokens_a),
|
150 |
+
len(tokens_b), max_num_tokens, np_rng)
|
151 |
+
return is_next_random, truncated, max_num_tokens, tokens_a, tokens_b
|
152 |
+
|
153 |
+
def generate_decoder_input_and_output(tokens, pad_id, sep_id):
|
154 |
+
"""
|
155 |
+
decoder input [SEP] [CSL] A B C D
|
156 |
+
decoder output [CLS] A B C D E
|
157 |
+
"""
|
158 |
+
|
159 |
+
decoder_output = tokens[:]
|
160 |
+
decoder_input = [0] * len(decoder_output)
|
161 |
+
decoder_input[0] = sep_id # match the preprocessing in fairseq
|
162 |
+
# decoder_input[0] = sep_id # match the preprocessing in fairseq
|
163 |
+
decoder_input[1:] = decoder_output[:-1]
|
164 |
+
|
165 |
+
"""
|
166 |
+
decoder input [CSL] A B C D [SEP]
|
167 |
+
decoder output A B C D [SEP] [PAD]
|
168 |
+
"""
|
169 |
+
|
170 |
+
# decoder_input = tokens[:]
|
171 |
+
# decoder_output = [0] * len(decoder_input)
|
172 |
+
# decoder_output[:-1] = decoder_input[1:]
|
173 |
+
# decoder_output[-1] = pad_id
|
174 |
+
|
175 |
+
return decoder_input, decoder_output
|
176 |
+
|
177 |
+
|
178 |
+
|
179 |
+
def build_training_sample(sample,
|
180 |
+
target_seq_length, max_seq_length,
|
181 |
+
vocab_id_list, vocab_id_to_token_dict,
|
182 |
+
cls_id, sep_id, mask_id, pad_id,
|
183 |
+
masked_lm_prob, np_rng, binary_head):
|
184 |
+
|
185 |
+
"""
|
186 |
+
sent-level task
|
187 |
+
"""
|
188 |
+
is_next_random, truncated, max_num_tokens, tokens_a, tokens_b = sent_level_task(
|
189 |
+
binary_head, sample, target_seq_length, max_seq_length, np_rng)
|
190 |
+
tokens_bf_mask = create_tokens(tokens_a, tokens_b, cls_id, sep_id)
|
191 |
+
if is_next_random:
|
192 |
+
raw_tokens = create_tokens(tokens_b, tokens_a, cls_id, sep_id)
|
193 |
+
else:
|
194 |
+
raw_tokens = tokens_bf_mask[:]
|
195 |
+
|
196 |
+
"""
|
197 |
+
decoder-input and output
|
198 |
+
"""
|
199 |
+
decoder_input, decoder_output = generate_decoder_input_and_output(raw_tokens, pad_id, sep_id)
|
200 |
+
|
201 |
+
# importance part
|
202 |
+
|
203 |
+
encoder_loss_flag = 0
|
204 |
+
decoder_loss_flag = 0
|
205 |
+
sent_loss_flag = 1
|
206 |
+
encoder_rng = torch.rand(1).item()
|
207 |
+
me = MaskEncoder()
|
208 |
+
if encoder_rng < 1.1:
|
209 |
+
# only train with encoder and decoder
|
210 |
+
# Masking.
|
211 |
+
if 0:
|
212 |
+
max_predictions_per_seq = masked_lm_prob * max_num_tokens
|
213 |
+
(tokens, _, _, _, _) = create_masked_lm_predictions(
|
214 |
+
tokens_bf_mask, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob,
|
215 |
+
cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng, masking_style="t5")
|
216 |
+
if 1 :
|
217 |
+
tokens = torch.LongTensor(tokens_bf_mask)
|
218 |
+
tokens = me.add_whole_word_mask(tokens, 0.15, -1)
|
219 |
+
tokens = tokens.tolist()
|
220 |
+
shift_rng = torch.rand(1).item()
|
221 |
+
if shift_rng < 0.0:
|
222 |
+
tokens = me.shif_chinese_word(tokens, tokens_bf_mask)
|
223 |
+
encoder_loss_flag = 1
|
224 |
+
decoder_loss_flag = 1
|
225 |
+
else:
|
226 |
+
# train only with decoder
|
227 |
+
tokens = torch.LongTensor(tokens_bf_mask)
|
228 |
+
decoder_rng = torch.rand(1).item()
|
229 |
+
if decoder_rng < 0.4:
|
230 |
+
# WWM mask 30% word
|
231 |
+
tokens = me.add_whole_word_mask(tokens, 0.3, -1)
|
232 |
+
tokens = tokens.tolist()
|
233 |
+
if decoder_rng >= 0.4 and decoder_rng < 0.6:
|
234 |
+
# MASS mask style
|
235 |
+
tokens = me.mass_style_mask(tokens_bf_mask)
|
236 |
+
if decoder_rng > 0.6:
|
237 |
+
# delete tokens
|
238 |
+
tokens = me.add_whole_word_mask(tokens, 0.3, -1)
|
239 |
+
tokens = tokens.tolist()
|
240 |
+
tokens = me.delete_chinese_word(tokens, tokens_bf_mask)
|
241 |
+
tmp_tt = get_tokenizer()
|
242 |
+
# print("encoder ori input", "".join(tmp_tt.tokenizer.convert_ids_to_tokens(tokens_bf_mask)))
|
243 |
+
# print("encoder input ", "".join(tmp_tt.tokenizer.convert_ids_to_tokens(tokens)))
|
244 |
+
# print("------\n\n")
|
245 |
+
|
246 |
+
|
247 |
+
decoder_loss_flag = 1
|
248 |
+
|
249 |
+
# tmp_tt = get_tokenizer()
|
250 |
+
# print("encoder ori input", "".join(tmp_tt.tokenizer.convert_ids_to_tokens(tokens_bf_mask)))
|
251 |
+
# print("encoder input ", "".join(tmp_tt.tokenizer.convert_ids_to_tokens(tokens)))
|
252 |
+
# print("decoder input", "".join(tmp_tt.tokenizer.convert_ids_to_tokens(decoder_input)))
|
253 |
+
# print("decoder output", "".join(tmp_tt.tokenizer.convert_ids_to_tokens(decoder_output)))
|
254 |
+
|
255 |
+
tokentypes = []
|
256 |
+
encoder_labels = []
|
257 |
+
encoder_labels_mask = []
|
258 |
+
padding_mask = []
|
259 |
+
apppend_type_id = 0
|
260 |
+
|
261 |
+
if len(tokens) == len(tokens_bf_mask):
|
262 |
+
# encoder and decoder can train togather
|
263 |
+
for index in range(len(tokens)):
|
264 |
+
padding_mask.append(1)
|
265 |
+
# generate tokens type
|
266 |
+
if tokens[index] == sep_id:
|
267 |
+
apppend_type_id = 1
|
268 |
+
tokentypes.append(apppend_type_id)
|
269 |
+
|
270 |
+
if tokens[index] == tokens_bf_mask[index]:
|
271 |
+
encoder_labels.append(-1)
|
272 |
+
encoder_labels_mask.append(0)
|
273 |
+
else:
|
274 |
+
encoder_labels.append(tokens_bf_mask[index])
|
275 |
+
encoder_labels_mask.append(1)
|
276 |
+
else:
|
277 |
+
# only train decoder
|
278 |
+
for index in range(len(tokens)):
|
279 |
+
padding_mask.append(1)
|
280 |
+
if tokens[index] == sep_id:
|
281 |
+
apppend_type_id = 1
|
282 |
+
tokentypes.append(apppend_type_id)
|
283 |
+
encoder_labels.append(-1)
|
284 |
+
encoder_labels_mask.append(0)
|
285 |
+
|
286 |
+
tokens_np = pad_and_convert_to_numpy_light(tokens, max_seq_length, pad_id)
|
287 |
+
tokentypes_np = pad_and_convert_to_numpy_light(tokentypes, max_seq_length, pad_id)
|
288 |
+
padding_mask_np = pad_and_convert_to_numpy_light(padding_mask, max_seq_length, pad_id)
|
289 |
+
encoder_labels_np = pad_and_convert_to_numpy_light(encoder_labels, max_seq_length, -1)
|
290 |
+
encoder_labels_mask_np = pad_and_convert_to_numpy_light(encoder_labels_mask, max_seq_length, pad_id)
|
291 |
+
decoder_input_np = pad_and_convert_to_numpy_light(decoder_input, max_seq_length, pad_id)
|
292 |
+
decoder_output_np = pad_and_convert_to_numpy_light(decoder_output, max_seq_length, pad_id)
|
293 |
+
|
294 |
+
# print(tokens_np)
|
295 |
+
# print(encoder_labels_np)
|
296 |
+
# print(padding_mask_np)
|
297 |
+
# print(encoder_labels_mask_np)
|
298 |
+
|
299 |
+
# generate tokentypes
|
300 |
+
train_sample = {
|
301 |
+
'text': tokens_np, # encoder_input
|
302 |
+
'types': tokentypes_np, # token_type
|
303 |
+
'is_random': int(is_next_random), #sop_labels
|
304 |
+
'truncated': int(truncated), # if truncated
|
305 |
+
'labels': encoder_labels_np, #encoder_labels
|
306 |
+
'loss_mask': encoder_labels_mask_np, # mlm_mask
|
307 |
+
'padding_mask': padding_mask_np, # padding_mask
|
308 |
+
'decoder_input': decoder_input_np, # decoder_input
|
309 |
+
'decoder_output': decoder_output_np, #decoder_output
|
310 |
+
'encoder_loss_flag': int(encoder_loss_flag),
|
311 |
+
'decoder_loss_flag': int(decoder_loss_flag),
|
312 |
+
'sent_loss_flag': int(sent_loss_flag),
|
313 |
+
}
|
314 |
+
|
315 |
+
# print(tokens_np.shape)
|
316 |
+
# print(tokens_np)
|
317 |
+
|
318 |
+
# print(tokentypes_np.shape)
|
319 |
+
# print(tokentypes_np)
|
320 |
+
|
321 |
+
# print(encoder_labels_np.shape)
|
322 |
+
# print(encoder_labels_np)
|
323 |
+
|
324 |
+
# print(encoder_labels_mask_np.shape)
|
325 |
+
# print(encoder_labels_mask_np)
|
326 |
+
|
327 |
+
# print(padding_mask_np.shape)
|
328 |
+
# print(padding_mask_np)
|
329 |
+
|
330 |
+
# print(decoder_input_np.shape)
|
331 |
+
# print(decoder_input_np)
|
332 |
+
|
333 |
+
# print(decoder_output_np.shape)
|
334 |
+
# print(decoder_output_np)
|
335 |
+
|
336 |
+
# print("=====\n\n\n")
|
337 |
+
# import sys;sys.exit(0)
|
338 |
+
return train_sample
|
339 |
+
|
340 |
+
def pad_and_convert_to_numpy_light(tokens, max_seq_length, pad_id):
|
341 |
+
padding_length = max_seq_length - len(tokens)
|
342 |
+
assert padding_length >= 0
|
343 |
+
filler = [pad_id] * padding_length
|
344 |
+
tokens_np = np.array(tokens + filler, dtype=np.int64)
|
345 |
+
return tokens_np
|
346 |
+
|
347 |
+
def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
|
348 |
+
masked_labels, pad_id, max_seq_length):
|
349 |
+
"""Pad sequences and convert them to numpy."""
|
350 |
+
|
351 |
+
# Some checks.
|
352 |
+
num_tokens = len(tokens)
|
353 |
+
padding_length = max_seq_length - num_tokens
|
354 |
+
assert padding_length >= 0
|
355 |
+
assert len(tokentypes) == num_tokens
|
356 |
+
assert len(masked_positions) == len(masked_labels)
|
357 |
+
|
358 |
+
# Tokens and token types.
|
359 |
+
filler = [pad_id] * padding_length
|
360 |
+
tokens_np = np.array(tokens + filler, dtype=np.int64)
|
361 |
+
tokentypes_np = np.array(tokentypes + filler, dtype=np.int64)
|
362 |
+
|
363 |
+
# Padding mask.
|
364 |
+
padding_mask_np = np.array([1] * num_tokens + [0] * padding_length,
|
365 |
+
dtype=np.int64)
|
366 |
+
|
367 |
+
# Lables and loss mask.
|
368 |
+
labels = [-1] * max_seq_length
|
369 |
+
loss_mask = [0] * max_seq_length
|
370 |
+
for i in range(len(masked_positions)):
|
371 |
+
assert masked_positions[i] < num_tokens
|
372 |
+
labels[masked_positions[i]] = masked_labels[i]
|
373 |
+
loss_mask[masked_positions[i]] = 1
|
374 |
+
labels_np = np.array(labels, dtype=np.int64)
|
375 |
+
loss_mask_np = np.array(loss_mask, dtype=np.int64)
|
376 |
+
|
377 |
+
return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np
|
megatron/data/gpt_dataset.py
ADDED
@@ -0,0 +1,430 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""GPT style dataset."""
|
17 |
+
|
18 |
+
import os
|
19 |
+
import time
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
import torch
|
23 |
+
|
24 |
+
from megatron import mpu, print_rank_0
|
25 |
+
from megatron.data.blendable_dataset import BlendableDataset
|
26 |
+
from megatron.data.dataset_utils import get_datasets_weights_and_num_samples
|
27 |
+
from megatron.data.dataset_utils import get_train_valid_test_split_
|
28 |
+
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
|
29 |
+
|
30 |
+
|
31 |
+
def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
|
32 |
+
train_valid_test_num_samples,
|
33 |
+
seq_length, seed, skip_warmup):
|
34 |
+
"""Build train, valid, and test datasets."""
|
35 |
+
|
36 |
+
# Single dataset.
|
37 |
+
if len(data_prefix) == 1:
|
38 |
+
return _build_train_valid_test_datasets(data_prefix[0],
|
39 |
+
data_impl, splits_string,
|
40 |
+
train_valid_test_num_samples,
|
41 |
+
seq_length, seed, skip_warmup)
|
42 |
+
|
43 |
+
# Blending dataset.
|
44 |
+
# Parse the values.
|
45 |
+
output = get_datasets_weights_and_num_samples(data_prefix,
|
46 |
+
train_valid_test_num_samples)
|
47 |
+
prefixes, weights, datasets_train_valid_test_num_samples = output
|
48 |
+
|
49 |
+
# Build individual datasets.
|
50 |
+
train_datasets = []
|
51 |
+
valid_datasets = []
|
52 |
+
test_datasets = []
|
53 |
+
for i in range(len(prefixes)):
|
54 |
+
train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(
|
55 |
+
prefixes[i], data_impl, splits_string,
|
56 |
+
datasets_train_valid_test_num_samples[i],
|
57 |
+
seq_length, seed, skip_warmup)
|
58 |
+
if train_ds:
|
59 |
+
train_datasets.append(train_ds)
|
60 |
+
if valid_ds:
|
61 |
+
valid_datasets.append(valid_ds)
|
62 |
+
if test_ds:
|
63 |
+
test_datasets.append(test_ds)
|
64 |
+
|
65 |
+
# Blend.
|
66 |
+
blending_train_dataset = None
|
67 |
+
if train_datasets:
|
68 |
+
blending_train_dataset = BlendableDataset(train_datasets, weights)
|
69 |
+
blending_valid_dataset = None
|
70 |
+
if valid_datasets:
|
71 |
+
blending_valid_dataset = BlendableDataset(valid_datasets, weights)
|
72 |
+
blending_test_dataset = None
|
73 |
+
if test_datasets:
|
74 |
+
blending_test_dataset = BlendableDataset(test_datasets, weights)
|
75 |
+
|
76 |
+
return (blending_train_dataset, blending_valid_dataset,
|
77 |
+
blending_test_dataset)
|
78 |
+
|
79 |
+
|
80 |
+
def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
|
81 |
+
train_valid_test_num_samples,
|
82 |
+
seq_length, seed, skip_warmup):
|
83 |
+
"""Build train, valid, and test datasets."""
|
84 |
+
|
85 |
+
# Indexed dataset.
|
86 |
+
indexed_dataset = get_indexed_dataset_(data_prefix,
|
87 |
+
data_impl,
|
88 |
+
skip_warmup)
|
89 |
+
|
90 |
+
total_num_of_documents = indexed_dataset.sizes.shape[0]
|
91 |
+
splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
|
92 |
+
|
93 |
+
# Print stats about the splits.
|
94 |
+
print_rank_0(' > dataset split:')
|
95 |
+
|
96 |
+
def print_split_stats(name, index):
|
97 |
+
print_rank_0(' {}:'.format(name))
|
98 |
+
print_rank_0(' document indices in [{}, {}) total of {} '
|
99 |
+
'documents'.format(splits[index], splits[index + 1],
|
100 |
+
splits[index + 1] - splits[index]))
|
101 |
+
print_split_stats('train', 0)
|
102 |
+
print_split_stats('validation', 1)
|
103 |
+
print_split_stats('test', 2)
|
104 |
+
|
105 |
+
def build_dataset(index, name):
|
106 |
+
dataset = None
|
107 |
+
if splits[index + 1] > splits[index]:
|
108 |
+
documents = np.arange(start=splits[index], stop=splits[index + 1],
|
109 |
+
step=1, dtype=np.int32)
|
110 |
+
dataset = GPTDataset(name, data_prefix,
|
111 |
+
documents, indexed_dataset,
|
112 |
+
train_valid_test_num_samples[index],
|
113 |
+
seq_length, seed)
|
114 |
+
return dataset
|
115 |
+
|
116 |
+
train_dataset = build_dataset(0, 'train')
|
117 |
+
valid_dataset = build_dataset(1, 'valid')
|
118 |
+
test_dataset = build_dataset(2, 'test')
|
119 |
+
|
120 |
+
return (train_dataset, valid_dataset, test_dataset)
|
121 |
+
|
122 |
+
|
123 |
+
def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
|
124 |
+
"""Build indexed dataset."""
|
125 |
+
print_rank_0(' > building dataset index ...')
|
126 |
+
|
127 |
+
start_time = time.time()
|
128 |
+
indexed_dataset = make_indexed_dataset(data_prefix,
|
129 |
+
data_impl,
|
130 |
+
skip_warmup)
|
131 |
+
print_rank_0(' > finished creating indexed dataset in {:4f} '
|
132 |
+
'seconds'.format(time.time() - start_time))
|
133 |
+
print_rank_0(' number of documents: {}'.format(
|
134 |
+
indexed_dataset.sizes.shape[0]))
|
135 |
+
|
136 |
+
return indexed_dataset
|
137 |
+
|
138 |
+
|
139 |
+
class GPTDataset(torch.utils.data.Dataset):
|
140 |
+
|
141 |
+
def __init__(self, name, data_prefix, documents, indexed_dataset,
|
142 |
+
num_samples, seq_length, seed):
|
143 |
+
|
144 |
+
self.name = name
|
145 |
+
self.indexed_dataset = indexed_dataset
|
146 |
+
|
147 |
+
# Checks
|
148 |
+
assert np.min(documents) >= 0
|
149 |
+
assert np.max(documents) < indexed_dataset.sizes.shape[0]
|
150 |
+
|
151 |
+
# Build index mappings.
|
152 |
+
self.doc_idx, self.sample_idx, self.shuffle_idx = _build_index_mappings(
|
153 |
+
self.name, data_prefix, documents, self.indexed_dataset.sizes,
|
154 |
+
num_samples, seq_length, seed)
|
155 |
+
|
156 |
+
def __len__(self):
|
157 |
+
# -1 is due to data structure used to retieve the index:
|
158 |
+
# sample i --> [sample_idx[i], sample_idx[i+1])
|
159 |
+
return self.sample_idx.shape[0] - 1
|
160 |
+
|
161 |
+
def __getitem__(self, idx):
|
162 |
+
# Get the shuffled index.
|
163 |
+
idx = self.shuffle_idx[idx]
|
164 |
+
# Start and end documents and offsets.
|
165 |
+
doc_index_f = self.sample_idx[idx][0]
|
166 |
+
doc_index_l = self.sample_idx[idx + 1][0]
|
167 |
+
offset_f = self.sample_idx[idx][1]
|
168 |
+
offset_l = self.sample_idx[idx + 1][1]
|
169 |
+
# If we are within the same document, just extract the chunk.
|
170 |
+
if doc_index_f == doc_index_l:
|
171 |
+
sample = self.indexed_dataset.get(self.doc_idx[doc_index_f],
|
172 |
+
offset=offset_f,
|
173 |
+
length=offset_l - offset_f + 1)
|
174 |
+
else:
|
175 |
+
# Otherwise, get the rest of the initial document.
|
176 |
+
sample_list = [self.indexed_dataset.get(self.doc_idx[doc_index_f],
|
177 |
+
offset=offset_f)]
|
178 |
+
# Loop over all in between documents and add the entire document.
|
179 |
+
for i in range(doc_index_f + 1, doc_index_l):
|
180 |
+
sample_list.append(self.indexed_dataset.get(self.doc_idx[i]))
|
181 |
+
# And finally add the relevant portion of last document.
|
182 |
+
sample_list.append(self.indexed_dataset.get(
|
183 |
+
self.doc_idx[doc_index_l],
|
184 |
+
length=offset_l + 1))
|
185 |
+
sample = np.concatenate(sample_list)
|
186 |
+
|
187 |
+
return {'text': np.array(sample, dtype=np.int64)}
|
188 |
+
|
189 |
+
|
190 |
+
def _build_index_mappings(name, data_prefix, documents, sizes,
|
191 |
+
num_samples, seq_length, seed):
|
192 |
+
"""Build doc-idx, sample-idx, and shuffle-idx.
|
193 |
+
doc-idx: is an array (ordered) of documents to be used in training.
|
194 |
+
sample-idx: is the start document index and document offset for each
|
195 |
+
training sample.
|
196 |
+
shuffle-idx: maps the sample index into a random index into sample-idx.
|
197 |
+
"""
|
198 |
+
# Number of tokens in each epoch and number of required epochs.
|
199 |
+
tokens_per_epoch = _num_tokens(documents, sizes)
|
200 |
+
num_epochs = _num_epochs(tokens_per_epoch, seq_length, num_samples)
|
201 |
+
# rng state
|
202 |
+
np_rng = np.random.RandomState(seed=seed)
|
203 |
+
|
204 |
+
# Filename of the index mappings.
|
205 |
+
_filename = data_prefix
|
206 |
+
_filename += '_{}_indexmap'.format(name)
|
207 |
+
_filename += '_{}ns'.format(num_samples)
|
208 |
+
_filename += '_{}sl'.format(seq_length)
|
209 |
+
_filename += '_{}s'.format(seed)
|
210 |
+
doc_idx_filename = _filename + '_doc_idx.npy'
|
211 |
+
sample_idx_filename = _filename + '_sample_idx.npy'
|
212 |
+
shuffle_idx_filename = _filename + '_shuffle_idx.npy'
|
213 |
+
|
214 |
+
# Build the indexed mapping if not exist.
|
215 |
+
if torch.distributed.get_rank() == 0:
|
216 |
+
if (not os.path.isfile(doc_idx_filename)) or \
|
217 |
+
(not os.path.isfile(sample_idx_filename)) or \
|
218 |
+
(not os.path.isfile(shuffle_idx_filename)):
|
219 |
+
|
220 |
+
print_rank_0(' > WARNING: could not find index map files, building '
|
221 |
+
'the indices on rank 0 ...')
|
222 |
+
|
223 |
+
# For the last epoch, decide whether include the entire epoch
|
224 |
+
# in the global shuffle or not.
|
225 |
+
|
226 |
+
# If we need only one epoch, then separating last epoch does
|
227 |
+
# not mean anything.
|
228 |
+
if num_epochs == 1:
|
229 |
+
separate_last_epoch = False
|
230 |
+
print(' > only one epoch required, setting '
|
231 |
+
'separate_last_epoch to False', flush=True)
|
232 |
+
|
233 |
+
else:
|
234 |
+
# Get the number of samples for the last epoch
|
235 |
+
num_samples_from_epochs_minus_one = (
|
236 |
+
(num_epochs - 1) * tokens_per_epoch - 1) // seq_length
|
237 |
+
last_epoch_num_samples = num_samples - \
|
238 |
+
num_samples_from_epochs_minus_one
|
239 |
+
assert last_epoch_num_samples >= 0, \
|
240 |
+
'last epoch number of samples should be non-negative.'
|
241 |
+
num_samples_per_epoch = (tokens_per_epoch - 1) // seq_length
|
242 |
+
assert last_epoch_num_samples < (num_samples_per_epoch + 1), \
|
243 |
+
'last epoch number of samples exceeded max value.'
|
244 |
+
# If we have less than 80% of the samples for the last epoch,
|
245 |
+
# seperate out the epoch and treat it differently.
|
246 |
+
# Note: the 80% number is just based on common sense and can
|
247 |
+
# be adjusted if needed.
|
248 |
+
separate_last_epoch = (last_epoch_num_samples <
|
249 |
+
int(0.80 * num_samples_per_epoch))
|
250 |
+
if separate_last_epoch:
|
251 |
+
string = ' > last epoch number of samples ({}) is smaller '\
|
252 |
+
'than 80% of number of samples per epoch ({}), '\
|
253 |
+
'setting separate_last_epoch to True'
|
254 |
+
else:
|
255 |
+
string = ' > last epoch number of samples ({}) is larger '\
|
256 |
+
'than 80% of number of samples per epoch ({}), '\
|
257 |
+
'setting separate_last_epoch to False'
|
258 |
+
print(string.format(last_epoch_num_samples,
|
259 |
+
num_samples_per_epoch), flush=True)
|
260 |
+
|
261 |
+
# doc-idx.
|
262 |
+
start_time = time.time()
|
263 |
+
doc_idx = _build_doc_idx(documents, num_epochs, np_rng,
|
264 |
+
separate_last_epoch)
|
265 |
+
np.save(doc_idx_filename, doc_idx, allow_pickle=True)
|
266 |
+
print_rank_0(' > elasped time to build and save doc-idx mapping '
|
267 |
+
'(seconds): {:4f}'.format(time.time() - start_time))
|
268 |
+
# sample-idx.
|
269 |
+
start_time = time.time()
|
270 |
+
# Use C++ implementation for speed.
|
271 |
+
# First compile and then import.
|
272 |
+
from megatron.data import helpers
|
273 |
+
assert doc_idx.dtype == np.int32
|
274 |
+
assert sizes.dtype == np.int32
|
275 |
+
sample_idx = helpers.build_sample_idx(sizes, doc_idx, seq_length,
|
276 |
+
num_epochs, tokens_per_epoch)
|
277 |
+
# sample_idx = _build_sample_idx(sizes, doc_idx, seq_length,
|
278 |
+
# num_epochs, tokens_per_epoch)
|
279 |
+
np.save(sample_idx_filename, sample_idx, allow_pickle=True)
|
280 |
+
print_rank_0(' > elasped time to build and save sample-idx mapping '
|
281 |
+
'(seconds): {:4f}'.format(time.time() - start_time))
|
282 |
+
# shuffle-idx.
|
283 |
+
start_time = time.time()
|
284 |
+
# -1 is due to data structure used to retieve the index:
|
285 |
+
# sample i --> [sample_idx[i], sample_idx[i+1])
|
286 |
+
if separate_last_epoch:
|
287 |
+
num_samples_ = num_samples_from_epochs_minus_one
|
288 |
+
else:
|
289 |
+
num_samples_ = sample_idx.shape[0] - 1
|
290 |
+
shuffle_idx = _build_shuffle_idx(num_samples_,
|
291 |
+
sample_idx.shape[0] - 1, np_rng)
|
292 |
+
np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True)
|
293 |
+
print_rank_0(' > elasped time to build and save shuffle-idx mapping'
|
294 |
+
' (seconds): {:4f}'.format(time.time() - start_time))
|
295 |
+
|
296 |
+
# This should be a barrier but nccl barrier assumes
|
297 |
+
# device_index=rank which is not the case for model
|
298 |
+
# parallel case
|
299 |
+
counts = torch.cuda.LongTensor([1])
|
300 |
+
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
|
301 |
+
torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group())
|
302 |
+
assert counts[0].item() == (
|
303 |
+
torch.distributed.get_world_size() //
|
304 |
+
torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group()))
|
305 |
+
|
306 |
+
# Load mappings.
|
307 |
+
start_time = time.time()
|
308 |
+
print_rank_0(' > loading doc-idx mapping from {}'.format(
|
309 |
+
doc_idx_filename))
|
310 |
+
doc_idx = np.load(doc_idx_filename, allow_pickle=True, mmap_mode='r')
|
311 |
+
print_rank_0(' > loading sample-idx mapping from {}'.format(
|
312 |
+
sample_idx_filename))
|
313 |
+
sample_idx = np.load(sample_idx_filename, allow_pickle=True, mmap_mode='r')
|
314 |
+
print_rank_0(' > loading shuffle-idx mapping from {}'.format(
|
315 |
+
shuffle_idx_filename))
|
316 |
+
shuffle_idx = np.load(shuffle_idx_filename, allow_pickle=True, mmap_mode='r')
|
317 |
+
print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
|
318 |
+
time.time() - start_time))
|
319 |
+
print_rank_0(' total number of samples: {}'.format(
|
320 |
+
sample_idx.shape[0]))
|
321 |
+
print_rank_0(' total number of epochs: {}'.format(num_epochs))
|
322 |
+
|
323 |
+
return doc_idx, sample_idx, shuffle_idx
|
324 |
+
|
325 |
+
|
326 |
+
def _num_tokens(documents, sizes):
|
327 |
+
"""Total number of tokens in the dataset."""
|
328 |
+
return np.sum(sizes[documents])
|
329 |
+
|
330 |
+
|
331 |
+
def _num_epochs(tokens_per_epoch, seq_length, num_samples):
|
332 |
+
"""Based on number of samples and sequence lenght, calculate how many
|
333 |
+
epochs will be needed."""
|
334 |
+
num_epochs = 0
|
335 |
+
total_tokens = 0
|
336 |
+
while True:
|
337 |
+
num_epochs += 1
|
338 |
+
total_tokens += tokens_per_epoch
|
339 |
+
# -1 is because we need to retrieve seq_length + 1 token each time
|
340 |
+
# but the last token will overlap with the first token of the next
|
341 |
+
# sample except for the last sample.
|
342 |
+
if ((total_tokens - 1) // seq_length) >= num_samples:
|
343 |
+
return num_epochs
|
344 |
+
|
345 |
+
|
346 |
+
def _build_doc_idx(documents, num_epochs, np_rng, separate_last_epoch):
|
347 |
+
"""Build an array with length = number-of-epochs * number-of-dcuments.
|
348 |
+
Each index is mapped to a corresponding document."""
|
349 |
+
if not separate_last_epoch or num_epochs == 1:
|
350 |
+
doc_idx = np.mgrid[0:num_epochs, 0:len(documents)][1]
|
351 |
+
doc_idx[:] = documents
|
352 |
+
doc_idx = doc_idx.reshape(-1)
|
353 |
+
doc_idx = doc_idx.astype(np.int32)
|
354 |
+
np_rng.shuffle(doc_idx)
|
355 |
+
return doc_idx
|
356 |
+
|
357 |
+
doc_idx_first = _build_doc_idx(documents, num_epochs-1, np_rng, False)
|
358 |
+
doc_idx_last = _build_doc_idx(documents, 1, np_rng, False)
|
359 |
+
return np.concatenate((doc_idx_first, doc_idx_last))
|
360 |
+
|
361 |
+
|
362 |
+
def _build_sample_idx(sizes, doc_idx, seq_length,
|
363 |
+
num_epochs, tokens_per_epoch):
|
364 |
+
"""Sample index mapping is a 2D array with sizes
|
365 |
+
[number-of-samples + 1, 2] where [..., 0] contains
|
366 |
+
the index into `doc_idx` and [..., 1] is the
|
367 |
+
starting offset in that document."""
|
368 |
+
|
369 |
+
# Total number of samples. For -1 see comments in `_num_epochs`.
|
370 |
+
num_samples = (num_epochs * tokens_per_epoch - 1) // seq_length
|
371 |
+
sample_idx = np.zeros([num_samples + 1, 2], dtype=np.int32)
|
372 |
+
|
373 |
+
# Index into sample_idx.
|
374 |
+
sample_index = 0
|
375 |
+
# Index into doc_idx.
|
376 |
+
doc_idx_index = 0
|
377 |
+
# Begining offset for each document.
|
378 |
+
doc_offset = 0
|
379 |
+
# Start with first document and no offset.
|
380 |
+
sample_idx[sample_index][0] = doc_idx_index
|
381 |
+
sample_idx[sample_index][1] = doc_offset
|
382 |
+
sample_index += 1
|
383 |
+
while sample_index <= num_samples:
|
384 |
+
# Start with a fresh sequence.
|
385 |
+
remaining_seq_length = seq_length + 1
|
386 |
+
while remaining_seq_length != 0:
|
387 |
+
# Get the document length.
|
388 |
+
doc_id = doc_idx[doc_idx_index]
|
389 |
+
doc_length = sizes[doc_id] - doc_offset
|
390 |
+
# And add it to the current sequence.
|
391 |
+
remaining_seq_length -= doc_length
|
392 |
+
# If we have more than a full sequence, adjust offset and set
|
393 |
+
# remaining length to zero so we return from the while loop.
|
394 |
+
# Note that -1 here is for the same reason we have -1 in
|
395 |
+
# `_num_epochs` calculations.
|
396 |
+
if remaining_seq_length <= 0:
|
397 |
+
doc_offset += (remaining_seq_length + doc_length - 1)
|
398 |
+
remaining_seq_length = 0
|
399 |
+
else:
|
400 |
+
# Otherwise, start from the begining of the next document.
|
401 |
+
doc_idx_index += 1
|
402 |
+
doc_offset = 0
|
403 |
+
# Record the sequence.
|
404 |
+
sample_idx[sample_index][0] = doc_idx_index
|
405 |
+
sample_idx[sample_index][1] = doc_offset
|
406 |
+
sample_index += 1
|
407 |
+
|
408 |
+
return sample_idx
|
409 |
+
|
410 |
+
|
411 |
+
def _build_shuffle_idx(num_samples, total_size, np_rng):
|
412 |
+
"""Build the range [0, size) and shuffle."""
|
413 |
+
print(' > building shuffle index with split [0, {}) and [{}, {}) '
|
414 |
+
'...'.format(num_samples, num_samples, total_size), flush=True)
|
415 |
+
|
416 |
+
dtype_ = np.uint32
|
417 |
+
if total_size >= (np.iinfo(np.uint32).max - 1):
|
418 |
+
dtype_ = np.int64
|
419 |
+
|
420 |
+
shuffle_idx_first = np.arange(start=0, stop=num_samples,
|
421 |
+
step=1, dtype=dtype_)
|
422 |
+
np_rng.shuffle(shuffle_idx_first)
|
423 |
+
if num_samples == total_size:
|
424 |
+
return shuffle_idx_first
|
425 |
+
|
426 |
+
shuffle_idx_last = np.arange(start=num_samples, stop=total_size,
|
427 |
+
step=1, dtype=dtype_)
|
428 |
+
np_rng.shuffle(shuffle_idx_last)
|
429 |
+
|
430 |
+
return np.concatenate((shuffle_idx_first, shuffle_idx_last))
|
megatron/data/helpers.cpp
ADDED
@@ -0,0 +1,717 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
coding=utf-8
|
3 |
+
Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
|
5 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
you may not use this file except in compliance with the License.
|
7 |
+
You may obtain a copy of the License at
|
8 |
+
|
9 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
|
11 |
+
Unless required by applicable law or agreed to in writing, software
|
12 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
See the License for the specific language governing permissions and
|
15 |
+
limitations under the License.
|
16 |
+
*/
|
17 |
+
|
18 |
+
|
19 |
+
/* Helper methods for fast index mapping builds */
|
20 |
+
|
21 |
+
#include <algorithm>
|
22 |
+
#include <iostream>
|
23 |
+
#include <limits>
|
24 |
+
#include <math.h>
|
25 |
+
#include <stdexcept>
|
26 |
+
#include <pybind11/pybind11.h>
|
27 |
+
#include <pybind11/numpy.h>
|
28 |
+
#include <random>
|
29 |
+
|
30 |
+
namespace py = pybind11;
|
31 |
+
using namespace std;
|
32 |
+
|
33 |
+
const int32_t LONG_SENTENCE_LEN = 512;
|
34 |
+
|
35 |
+
|
36 |
+
void build_blending_indices(py::array_t<uint8_t>& dataset_index,
|
37 |
+
py::array_t<int64_t>& dataset_sample_index,
|
38 |
+
const py::array_t<double>& weights,
|
39 |
+
const int32_t num_datasets,
|
40 |
+
const int64_t size, const bool verbose) {
|
41 |
+
/* Given multiple datasets and a weighting array, build samples
|
42 |
+
such that it follows those wieghts.*/
|
43 |
+
|
44 |
+
if (verbose) {
|
45 |
+
std::cout << "> building indices for blendable datasets ..." << std::endl;
|
46 |
+
}
|
47 |
+
|
48 |
+
// Get the pointer access without the checks.
|
49 |
+
auto dataset_index_ptr = dataset_index.mutable_unchecked<1>();
|
50 |
+
auto dataset_sample_index_ptr = dataset_sample_index.mutable_unchecked<1>();
|
51 |
+
auto weights_ptr = weights.unchecked<1>();
|
52 |
+
|
53 |
+
// Initialize buffer for number of samples used for each dataset.
|
54 |
+
int64_t current_samples[num_datasets];
|
55 |
+
for(int64_t i = 0; i < num_datasets; ++i) {
|
56 |
+
current_samples[i] = 0;
|
57 |
+
}
|
58 |
+
|
59 |
+
// For each sample:
|
60 |
+
for(int64_t sample_idx = 0; sample_idx < size; ++sample_idx) {
|
61 |
+
|
62 |
+
// Determine where the max error in sampling is happening.
|
63 |
+
auto sample_idx_double = std::max(static_cast<double>(sample_idx), 1.0);
|
64 |
+
int64_t max_error_index = 0;
|
65 |
+
double max_error = weights_ptr[0] * sample_idx_double -
|
66 |
+
static_cast<double>(current_samples[0]);
|
67 |
+
for (int64_t dataset_idx = 1; dataset_idx < num_datasets; ++dataset_idx) {
|
68 |
+
double error = weights_ptr[dataset_idx] * sample_idx_double -
|
69 |
+
static_cast<double>(current_samples[dataset_idx]);
|
70 |
+
if (error > max_error) {
|
71 |
+
max_error = error;
|
72 |
+
max_error_index = dataset_idx;
|
73 |
+
}
|
74 |
+
}
|
75 |
+
|
76 |
+
// Populate the indices.
|
77 |
+
dataset_index_ptr[sample_idx] = static_cast<uint8_t>(max_error_index);
|
78 |
+
dataset_sample_index_ptr[sample_idx] = current_samples[max_error_index];
|
79 |
+
|
80 |
+
// Update the total samples.
|
81 |
+
current_samples[max_error_index] += 1;
|
82 |
+
|
83 |
+
}
|
84 |
+
|
85 |
+
// print info
|
86 |
+
if (verbose) {
|
87 |
+
std::cout << " > sample ratios:" << std::endl;
|
88 |
+
for (int64_t dataset_idx = 0; dataset_idx < num_datasets; ++dataset_idx) {
|
89 |
+
auto ratio = static_cast<double>(current_samples[dataset_idx]) /
|
90 |
+
static_cast<double>(size);
|
91 |
+
std::cout << " dataset " << dataset_idx << ", input: " <<
|
92 |
+
weights_ptr[dataset_idx] << ", achieved: " << ratio << std::endl;
|
93 |
+
}
|
94 |
+
}
|
95 |
+
|
96 |
+
}
|
97 |
+
|
98 |
+
|
99 |
+
py::array build_sample_idx(const py::array_t<int32_t>& sizes_,
|
100 |
+
const py::array_t<int32_t>& doc_idx_,
|
101 |
+
const int32_t seq_length,
|
102 |
+
const int32_t num_epochs,
|
103 |
+
const int64_t tokens_per_epoch) {
|
104 |
+
/* Sample index (sample_idx) is used for gpt2 like dataset for which
|
105 |
+
the documents are flattened and the samples are built based on this
|
106 |
+
1-D flatten array. It is a 2D array with sizes [number-of-samples + 1, 2]
|
107 |
+
where [..., 0] contains the index into `doc_idx` and [..., 1] is the
|
108 |
+
starting offset in that document.*/
|
109 |
+
|
110 |
+
// Consistency checks.
|
111 |
+
assert(seq_length > 1);
|
112 |
+
assert(num_epochs > 0);
|
113 |
+
assert(tokens_per_epoch > 1);
|
114 |
+
|
115 |
+
// Remove bound checks.
|
116 |
+
auto sizes = sizes_.unchecked<1>();
|
117 |
+
auto doc_idx = doc_idx_.unchecked<1>();
|
118 |
+
|
119 |
+
// Mapping and it's length (1D).
|
120 |
+
int64_t num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length;
|
121 |
+
int32_t* sample_idx = new int32_t[2*(num_samples+1)];
|
122 |
+
|
123 |
+
cout << " using:" << endl << std::flush;
|
124 |
+
cout << " number of documents: " <<
|
125 |
+
doc_idx_.shape(0) / num_epochs << endl << std::flush;
|
126 |
+
cout << " number of epochs: " << num_epochs <<
|
127 |
+
endl << std::flush;
|
128 |
+
cout << " sequence length: " << seq_length <<
|
129 |
+
endl << std::flush;
|
130 |
+
cout << " total number of samples: " << num_samples <<
|
131 |
+
endl << std::flush;
|
132 |
+
|
133 |
+
// Index into sample_idx.
|
134 |
+
int64_t sample_index = 0;
|
135 |
+
// Index into doc_idx.
|
136 |
+
int64_t doc_idx_index = 0;
|
137 |
+
// Begining offset for each document.
|
138 |
+
int32_t doc_offset = 0;
|
139 |
+
// Start with first document and no offset.
|
140 |
+
sample_idx[2 * sample_index] = doc_idx_index;
|
141 |
+
sample_idx[2 * sample_index + 1] = doc_offset;
|
142 |
+
++sample_index;
|
143 |
+
|
144 |
+
while (sample_index <= num_samples) {
|
145 |
+
// Start with a fresh sequence.
|
146 |
+
int32_t remaining_seq_length = seq_length + 1;
|
147 |
+
while (remaining_seq_length != 0) {
|
148 |
+
// Get the document length.
|
149 |
+
auto doc_id = doc_idx[doc_idx_index];
|
150 |
+
auto doc_length = sizes[doc_id] - doc_offset;
|
151 |
+
// And add it to the current sequence.
|
152 |
+
remaining_seq_length -= doc_length;
|
153 |
+
// If we have more than a full sequence, adjust offset and set
|
154 |
+
// remaining length to zero so we return from the while loop.
|
155 |
+
// Note that -1 here is for the same reason we have -1 in
|
156 |
+
// `_num_epochs` calculations.
|
157 |
+
if (remaining_seq_length <= 0) {
|
158 |
+
doc_offset += (remaining_seq_length + doc_length - 1);
|
159 |
+
remaining_seq_length = 0;
|
160 |
+
} else {
|
161 |
+
// Otherwise, start from the begining of the next document.
|
162 |
+
++doc_idx_index;
|
163 |
+
doc_offset = 0;
|
164 |
+
}
|
165 |
+
}
|
166 |
+
// Record the sequence.
|
167 |
+
sample_idx[2 * sample_index] = doc_idx_index;
|
168 |
+
sample_idx[2 * sample_index + 1] = doc_offset;
|
169 |
+
++sample_index;
|
170 |
+
}
|
171 |
+
|
172 |
+
// Method to deallocate memory.
|
173 |
+
py::capsule free_when_done(sample_idx, [](void *mem_) {
|
174 |
+
int32_t *mem = reinterpret_cast<int32_t*>(mem_);
|
175 |
+
delete[] mem;
|
176 |
+
});
|
177 |
+
|
178 |
+
// Return the numpy array.
|
179 |
+
const auto byte_size = sizeof(int32_t);
|
180 |
+
return py::array(std::vector<int64_t>{num_samples+1, 2}, // shape
|
181 |
+
{2*byte_size, byte_size}, // C-style contiguous strides
|
182 |
+
sample_idx, // the data pointer
|
183 |
+
free_when_done); // numpy array references
|
184 |
+
|
185 |
+
}
|
186 |
+
|
187 |
+
|
188 |
+
inline int32_t get_target_sample_len(const int32_t short_seq_ratio,
|
189 |
+
const int32_t max_length,
|
190 |
+
std::mt19937& rand32_gen) {
|
191 |
+
/* Training sample length. */
|
192 |
+
if (short_seq_ratio == 0) {
|
193 |
+
return max_length;
|
194 |
+
}
|
195 |
+
const auto random_number = rand32_gen();
|
196 |
+
if ((random_number % short_seq_ratio) == 0) {
|
197 |
+
return 2 + random_number % (max_length - 1);
|
198 |
+
}
|
199 |
+
return max_length;
|
200 |
+
}
|
201 |
+
|
202 |
+
|
203 |
+
template<typename DocIdx>
|
204 |
+
py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
|
205 |
+
const py::array_t<int32_t>& sizes_,
|
206 |
+
const int32_t num_epochs,
|
207 |
+
const uint64_t max_num_samples,
|
208 |
+
const int32_t max_seq_length,
|
209 |
+
const double short_seq_prob,
|
210 |
+
const int32_t seed,
|
211 |
+
const bool verbose,
|
212 |
+
const int32_t min_num_sent) {
|
213 |
+
/* Build a mapping of (start-index, end-index, sequence-length) where
|
214 |
+
start and end index are the indices of the sentences in the sample
|
215 |
+
and sequence-length is the target sequence length.
|
216 |
+
*/
|
217 |
+
|
218 |
+
// Consistency checks.
|
219 |
+
assert(num_epochs > 0);
|
220 |
+
assert(max_seq_length > 1);
|
221 |
+
assert(short_seq_prob >= 0.0);
|
222 |
+
assert(short_seq_prob <= 1.0);
|
223 |
+
assert(seed > 0);
|
224 |
+
|
225 |
+
// Remove bound checks.
|
226 |
+
auto docs = docs_.unchecked<1>();
|
227 |
+
auto sizes = sizes_.unchecked<1>();
|
228 |
+
|
229 |
+
// For efficiency, convert probability to ratio. Note: rand() generates int.
|
230 |
+
int32_t short_seq_ratio = 0;
|
231 |
+
if (short_seq_prob > 0) {
|
232 |
+
short_seq_ratio = static_cast<int32_t>(round(1.0 / short_seq_prob));
|
233 |
+
}
|
234 |
+
|
235 |
+
if (verbose) {
|
236 |
+
const auto sent_start_index = docs[0];
|
237 |
+
const auto sent_end_index = docs[docs_.shape(0) - 1];
|
238 |
+
const auto num_sentences = sent_end_index - sent_start_index;
|
239 |
+
cout << " using:" << endl << std::flush;
|
240 |
+
cout << " number of documents: " << docs_.shape(0) - 1 <<
|
241 |
+
endl << std::flush;
|
242 |
+
cout << " sentences range: [" << sent_start_index <<
|
243 |
+
", " << sent_end_index << ")" << endl << std::flush;
|
244 |
+
cout << " total number of sentences: " << num_sentences <<
|
245 |
+
endl << std::flush;
|
246 |
+
cout << " number of epochs: " << num_epochs <<
|
247 |
+
endl << std::flush;
|
248 |
+
cout << " maximum number of samples: " << max_num_samples <<
|
249 |
+
endl << std::flush;
|
250 |
+
cout << " maximum sequence length: " << max_seq_length <<
|
251 |
+
endl << std::flush;
|
252 |
+
cout << " short sequence probability: " << short_seq_prob <<
|
253 |
+
endl << std::flush;
|
254 |
+
cout << " short sequence ration (1/prob): " << short_seq_ratio <<
|
255 |
+
endl << std::flush;
|
256 |
+
cout << " seed: " << seed << endl <<
|
257 |
+
std::flush;
|
258 |
+
}
|
259 |
+
|
260 |
+
// Mapping and it's length (1D).
|
261 |
+
int64_t num_samples = -1;
|
262 |
+
DocIdx* maps = NULL;
|
263 |
+
|
264 |
+
// Perform two iterations, in the first iteration get the size
|
265 |
+
// and allocate memory and in the second iteration populate the map.
|
266 |
+
bool second = false;
|
267 |
+
for (int32_t iteration=0; iteration<2; ++iteration) {
|
268 |
+
|
269 |
+
// Set the seed so both iterations produce the same results.
|
270 |
+
std::mt19937 rand32_gen(seed);
|
271 |
+
|
272 |
+
// Set the flag on second iteration.
|
273 |
+
second = (iteration == 1);
|
274 |
+
|
275 |
+
// Counters:
|
276 |
+
uint64_t empty_docs = 0;
|
277 |
+
uint64_t one_sent_docs = 0;
|
278 |
+
uint64_t long_sent_docs = 0;
|
279 |
+
|
280 |
+
// Current map index.
|
281 |
+
uint64_t map_index = 0;
|
282 |
+
|
283 |
+
// For each epoch:
|
284 |
+
for (int32_t epoch=0; epoch<num_epochs; ++epoch) {
|
285 |
+
if (map_index >= max_num_samples) {
|
286 |
+
if (verbose && (!second)) {
|
287 |
+
cout << " reached " << max_num_samples << " samples after "
|
288 |
+
<< epoch << " epochs ..." << endl << std::flush;
|
289 |
+
}
|
290 |
+
break;
|
291 |
+
}
|
292 |
+
// For each document:
|
293 |
+
for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) {
|
294 |
+
|
295 |
+
// Document sentences are in [sent_index_first, sent_index_last)
|
296 |
+
const auto sent_index_first = docs[doc];
|
297 |
+
const auto sent_index_last = docs[doc + 1];
|
298 |
+
|
299 |
+
// At the begining of the document previous index is the
|
300 |
+
// start index.
|
301 |
+
auto prev_start_index = sent_index_first;
|
302 |
+
|
303 |
+
// Remaining documents.
|
304 |
+
auto num_remain_sent = sent_index_last - sent_index_first;
|
305 |
+
|
306 |
+
// Some bookkeeping
|
307 |
+
if ((epoch == 0) && (!second)) {
|
308 |
+
if (num_remain_sent == 0) {
|
309 |
+
++empty_docs;
|
310 |
+
}
|
311 |
+
if (num_remain_sent == 1) {
|
312 |
+
++one_sent_docs;
|
313 |
+
}
|
314 |
+
}
|
315 |
+
|
316 |
+
// Detect documents with long sentences.
|
317 |
+
bool contains_long_sentence = false;
|
318 |
+
if (num_remain_sent > 1) {
|
319 |
+
for (auto sent_index=sent_index_first;
|
320 |
+
sent_index < sent_index_last; ++sent_index) {
|
321 |
+
if (sizes[sent_index] > LONG_SENTENCE_LEN){
|
322 |
+
if ((epoch == 0) && (!second)) {
|
323 |
+
++long_sent_docs;
|
324 |
+
}
|
325 |
+
contains_long_sentence = true;
|
326 |
+
break;
|
327 |
+
}
|
328 |
+
}
|
329 |
+
}
|
330 |
+
|
331 |
+
// If we have more than two sentences.
|
332 |
+
if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) {
|
333 |
+
|
334 |
+
// Set values.
|
335 |
+
auto seq_len = int32_t{0};
|
336 |
+
auto num_sent = int32_t{0};
|
337 |
+
auto target_seq_len = get_target_sample_len(short_seq_ratio,
|
338 |
+
max_seq_length,
|
339 |
+
rand32_gen);
|
340 |
+
|
341 |
+
// Loop through sentences.
|
342 |
+
for (auto sent_index=sent_index_first;
|
343 |
+
sent_index < sent_index_last; ++sent_index) {
|
344 |
+
|
345 |
+
// Add the size and number of sentences.
|
346 |
+
seq_len += sizes[sent_index];
|
347 |
+
++num_sent;
|
348 |
+
--num_remain_sent;
|
349 |
+
|
350 |
+
// If we have reached the target length.
|
351 |
+
// and if not only one sentence is left in the document.
|
352 |
+
// and if we have at least two sentneces.
|
353 |
+
// and if we have reached end of the document.
|
354 |
+
if (((seq_len >= target_seq_len) &&
|
355 |
+
(num_remain_sent > 1) &&
|
356 |
+
(num_sent >= min_num_sent) ) || (num_remain_sent == 0)) {
|
357 |
+
|
358 |
+
// Check for overflow.
|
359 |
+
if ((3 * map_index + 2) >
|
360 |
+
std::numeric_limits<int64_t>::max()) {
|
361 |
+
cout << "number of samples exceeded maximum "
|
362 |
+
<< "allowed by type int64: "
|
363 |
+
<< std::numeric_limits<int64_t>::max()
|
364 |
+
<< endl;
|
365 |
+
throw std::overflow_error("Number of samples");
|
366 |
+
}
|
367 |
+
|
368 |
+
// Populate the map.
|
369 |
+
if (second) {
|
370 |
+
const auto map_index_0 = 3 * map_index;
|
371 |
+
maps[map_index_0] = static_cast<DocIdx>(prev_start_index);
|
372 |
+
maps[map_index_0 + 1] = static_cast<DocIdx>(sent_index + 1);
|
373 |
+
maps[map_index_0 + 2] = static_cast<DocIdx>(target_seq_len);
|
374 |
+
}
|
375 |
+
|
376 |
+
// Update indices / counters.
|
377 |
+
++map_index;
|
378 |
+
prev_start_index = sent_index + 1;
|
379 |
+
target_seq_len = get_target_sample_len(short_seq_ratio,
|
380 |
+
max_seq_length,
|
381 |
+
rand32_gen);
|
382 |
+
seq_len = 0;
|
383 |
+
num_sent = 0;
|
384 |
+
}
|
385 |
+
|
386 |
+
} // for (auto sent_index=sent_index_first; ...
|
387 |
+
} // if (num_remain_sent > 1) {
|
388 |
+
} // for (int doc=0; doc < num_docs; ++doc) {
|
389 |
+
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
|
390 |
+
|
391 |
+
if (!second) {
|
392 |
+
if (verbose) {
|
393 |
+
cout << " number of empty documents: " << empty_docs <<
|
394 |
+
endl << std::flush;
|
395 |
+
cout << " number of documents with one sentence: " <<
|
396 |
+
one_sent_docs << endl << std::flush;
|
397 |
+
cout << " number of documents with long sentences: " <<
|
398 |
+
long_sent_docs << endl << std::flush;
|
399 |
+
cout << " will create mapping for " << map_index <<
|
400 |
+
" samples" << endl << std::flush;
|
401 |
+
}
|
402 |
+
assert(maps == NULL);
|
403 |
+
assert(num_samples < 0);
|
404 |
+
maps = new DocIdx[3*map_index];
|
405 |
+
num_samples = static_cast<int64_t>(map_index);
|
406 |
+
}
|
407 |
+
|
408 |
+
} // for (int iteration=0; iteration < 2; ++iteration) {
|
409 |
+
|
410 |
+
// Shuffle.
|
411 |
+
// We need a 64 bit random number generator as we might have more
|
412 |
+
// than 2 billion samples.
|
413 |
+
std::mt19937_64 rand64_gen(seed + 1);
|
414 |
+
for (auto i=(num_samples - 1); i > 0; --i) {
|
415 |
+
const auto j = static_cast<int64_t>(rand64_gen() % (i + 1));
|
416 |
+
const auto i0 = 3 * i;
|
417 |
+
const auto j0 = 3 * j;
|
418 |
+
// Swap values.
|
419 |
+
swap(maps[i0], maps[j0]);
|
420 |
+
swap(maps[i0 + 1], maps[j0 + 1]);
|
421 |
+
swap(maps[i0 + 2], maps[j0 + 2]);
|
422 |
+
}
|
423 |
+
|
424 |
+
// Method to deallocate memory.
|
425 |
+
py::capsule free_when_done(maps, [](void *mem_) {
|
426 |
+
DocIdx *mem = reinterpret_cast<DocIdx*>(mem_);
|
427 |
+
delete[] mem;
|
428 |
+
});
|
429 |
+
|
430 |
+
// Return the numpy array.
|
431 |
+
const auto byte_size = sizeof(DocIdx);
|
432 |
+
return py::array(std::vector<int64_t>{num_samples, 3}, // shape
|
433 |
+
{3*byte_size, byte_size}, // C-style contiguous strides
|
434 |
+
maps, // the data pointer
|
435 |
+
free_when_done); // numpy array references
|
436 |
+
|
437 |
+
}
|
438 |
+
|
439 |
+
|
440 |
+
py::array build_mapping(const py::array_t<int64_t>& docs_,
|
441 |
+
const py::array_t<int>& sizes_,
|
442 |
+
const int num_epochs,
|
443 |
+
const uint64_t max_num_samples,
|
444 |
+
const int max_seq_length,
|
445 |
+
const double short_seq_prob,
|
446 |
+
const int seed,
|
447 |
+
const bool verbose,
|
448 |
+
const int32_t min_num_sent) {
|
449 |
+
|
450 |
+
if (sizes_.size() > std::numeric_limits<uint32_t>::max()) {
|
451 |
+
if (verbose) {
|
452 |
+
cout << " using uint64 for data mapping..." << endl << std::flush;
|
453 |
+
}
|
454 |
+
return build_mapping_impl<uint64_t>(docs_, sizes_, num_epochs,
|
455 |
+
max_num_samples, max_seq_length,
|
456 |
+
short_seq_prob, seed, verbose,
|
457 |
+
min_num_sent);
|
458 |
+
} else {
|
459 |
+
if (verbose) {
|
460 |
+
cout << " using uint32 for data mapping..." << endl << std::flush;
|
461 |
+
}
|
462 |
+
return build_mapping_impl<uint32_t>(docs_, sizes_, num_epochs,
|
463 |
+
max_num_samples, max_seq_length,
|
464 |
+
short_seq_prob, seed, verbose,
|
465 |
+
min_num_sent);
|
466 |
+
}
|
467 |
+
}
|
468 |
+
|
469 |
+
template<typename DocIdx>
|
470 |
+
py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
|
471 |
+
const py::array_t<int32_t>& sizes_,
|
472 |
+
const py::array_t<int32_t>& titles_sizes_,
|
473 |
+
const int32_t num_epochs,
|
474 |
+
const uint64_t max_num_samples,
|
475 |
+
const int32_t max_seq_length,
|
476 |
+
const int32_t seed,
|
477 |
+
const bool verbose,
|
478 |
+
const bool use_one_sent_blocks) {
|
479 |
+
/* Build a mapping of (start-index, end-index, sequence-length) where
|
480 |
+
start and end index are the indices of the sentences in the sample
|
481 |
+
and sequence-length is the target sequence length.
|
482 |
+
*/
|
483 |
+
|
484 |
+
// Consistency checks.
|
485 |
+
assert(num_epochs > 0);
|
486 |
+
assert(max_seq_length > 1);
|
487 |
+
assert(seed > 0);
|
488 |
+
|
489 |
+
// Remove bound checks.
|
490 |
+
auto docs = docs_.unchecked<1>();
|
491 |
+
auto sizes = sizes_.unchecked<1>();
|
492 |
+
auto titles_sizes = titles_sizes_.unchecked<1>();
|
493 |
+
|
494 |
+
if (verbose) {
|
495 |
+
const auto sent_start_index = docs[0];
|
496 |
+
const auto sent_end_index = docs[docs_.shape(0) - 1];
|
497 |
+
const auto num_sentences = sent_end_index - sent_start_index;
|
498 |
+
cout << " using:" << endl << std::flush;
|
499 |
+
cout << " number of documents: " << docs_.shape(0) - 1 <<
|
500 |
+
endl << std::flush;
|
501 |
+
cout << " sentences range: [" << sent_start_index <<
|
502 |
+
", " << sent_end_index << ")" << endl << std::flush;
|
503 |
+
cout << " total number of sentences: " << num_sentences <<
|
504 |
+
endl << std::flush;
|
505 |
+
cout << " number of epochs: " << num_epochs <<
|
506 |
+
endl << std::flush;
|
507 |
+
cout << " maximum number of samples: " << max_num_samples <<
|
508 |
+
endl << std::flush;
|
509 |
+
cout << " maximum sequence length: " << max_seq_length <<
|
510 |
+
endl << std::flush;
|
511 |
+
cout << " seed: " << seed << endl <<
|
512 |
+
std::flush;
|
513 |
+
}
|
514 |
+
|
515 |
+
// Mapping and its length (1D).
|
516 |
+
int64_t num_samples = -1;
|
517 |
+
DocIdx* maps = NULL;
|
518 |
+
|
519 |
+
// Acceptable number of sentences per block.
|
520 |
+
int min_num_sent = 2;
|
521 |
+
if (use_one_sent_blocks) {
|
522 |
+
min_num_sent = 1;
|
523 |
+
}
|
524 |
+
|
525 |
+
// Perform two iterations, in the first iteration get the size
|
526 |
+
// and allocate memory and in the second iteration populate the map.
|
527 |
+
bool second = false;
|
528 |
+
for (int32_t iteration=0; iteration<2; ++iteration) {
|
529 |
+
|
530 |
+
// Set the flag on second iteration.
|
531 |
+
second = (iteration == 1);
|
532 |
+
|
533 |
+
// Current map index.
|
534 |
+
uint64_t map_index = 0;
|
535 |
+
|
536 |
+
uint64_t empty_docs = 0;
|
537 |
+
uint64_t one_sent_docs = 0;
|
538 |
+
uint64_t long_sent_docs = 0;
|
539 |
+
// For each epoch:
|
540 |
+
for (int32_t epoch=0; epoch<num_epochs; ++epoch) {
|
541 |
+
// assign every block a unique id
|
542 |
+
int32_t block_id = 0;
|
543 |
+
|
544 |
+
if (map_index >= max_num_samples) {
|
545 |
+
if (verbose && (!second)) {
|
546 |
+
cout << " reached " << max_num_samples << " samples after "
|
547 |
+
<< epoch << " epochs ..." << endl << std::flush;
|
548 |
+
}
|
549 |
+
break;
|
550 |
+
}
|
551 |
+
// For each document:
|
552 |
+
for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) {
|
553 |
+
|
554 |
+
// Document sentences are in [sent_index_first, sent_index_last)
|
555 |
+
const auto sent_index_first = docs[doc];
|
556 |
+
const auto sent_index_last = docs[doc + 1];
|
557 |
+
const auto target_seq_len = max_seq_length - titles_sizes[doc];
|
558 |
+
|
559 |
+
// At the begining of the document previous index is the
|
560 |
+
// start index.
|
561 |
+
auto prev_start_index = sent_index_first;
|
562 |
+
|
563 |
+
// Remaining documents.
|
564 |
+
auto num_remain_sent = sent_index_last - sent_index_first;
|
565 |
+
|
566 |
+
// Some bookkeeping
|
567 |
+
if ((epoch == 0) && (!second)) {
|
568 |
+
if (num_remain_sent == 0) {
|
569 |
+
++empty_docs;
|
570 |
+
}
|
571 |
+
if (num_remain_sent == 1) {
|
572 |
+
++one_sent_docs;
|
573 |
+
}
|
574 |
+
}
|
575 |
+
// Detect documents with long sentences.
|
576 |
+
bool contains_long_sentence = false;
|
577 |
+
if (num_remain_sent >= min_num_sent) {
|
578 |
+
for (auto sent_index=sent_index_first;
|
579 |
+
sent_index < sent_index_last; ++sent_index) {
|
580 |
+
if (sizes[sent_index] > LONG_SENTENCE_LEN){
|
581 |
+
if ((epoch == 0) && (!second)) {
|
582 |
+
++long_sent_docs;
|
583 |
+
}
|
584 |
+
contains_long_sentence = true;
|
585 |
+
break;
|
586 |
+
}
|
587 |
+
}
|
588 |
+
}
|
589 |
+
// If we have enough sentences and no long sentences.
|
590 |
+
if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) {
|
591 |
+
|
592 |
+
// Set values.
|
593 |
+
auto seq_len = int32_t{0};
|
594 |
+
auto num_sent = int32_t{0};
|
595 |
+
|
596 |
+
// Loop through sentences.
|
597 |
+
for (auto sent_index=sent_index_first;
|
598 |
+
sent_index < sent_index_last; ++sent_index) {
|
599 |
+
|
600 |
+
// Add the size and number of sentences.
|
601 |
+
seq_len += sizes[sent_index];
|
602 |
+
++num_sent;
|
603 |
+
--num_remain_sent;
|
604 |
+
|
605 |
+
// If we have reached the target length.
|
606 |
+
// and there are an acceptable number of sentences left
|
607 |
+
// and if we have at least the minimum number of sentences.
|
608 |
+
// or if we have reached end of the document.
|
609 |
+
if (((seq_len >= target_seq_len) &&
|
610 |
+
(num_remain_sent >= min_num_sent) &&
|
611 |
+
(num_sent >= min_num_sent) ) || (num_remain_sent == 0)) {
|
612 |
+
|
613 |
+
// Populate the map.
|
614 |
+
if (second) {
|
615 |
+
const auto map_index_0 = 4 * map_index;
|
616 |
+
// Each sample has 4 items: the starting sentence index, ending sentence index,
|
617 |
+
// the index of the document from which the block comes (used for fetching titles)
|
618 |
+
// and the unique id of the block (used for creating block indexes)
|
619 |
+
|
620 |
+
maps[map_index_0] = static_cast<DocIdx>(prev_start_index);
|
621 |
+
maps[map_index_0 + 1] = static_cast<DocIdx>(sent_index + 1);
|
622 |
+
maps[map_index_0 + 2] = static_cast<DocIdx>(doc);
|
623 |
+
maps[map_index_0 + 3] = static_cast<DocIdx>(block_id);
|
624 |
+
}
|
625 |
+
|
626 |
+
// Update indices / counters.
|
627 |
+
++map_index;
|
628 |
+
++block_id;
|
629 |
+
prev_start_index = sent_index + 1;
|
630 |
+
seq_len = 0;
|
631 |
+
num_sent = 0;
|
632 |
+
}
|
633 |
+
} // for (auto sent_index=sent_index_first; ...
|
634 |
+
} // if (num_remain_sent > 1) {
|
635 |
+
} // for (int doc=0; doc < num_docs; ++doc) {
|
636 |
+
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
|
637 |
+
|
638 |
+
if (!second) {
|
639 |
+
if (verbose) {
|
640 |
+
cout << " number of empty documents: " << empty_docs <<
|
641 |
+
endl << std::flush;
|
642 |
+
cout << " number of documents with one sentence: " <<
|
643 |
+
one_sent_docs << endl << std::flush;
|
644 |
+
cout << " number of documents with long sentences: " <<
|
645 |
+
long_sent_docs << endl << std::flush;
|
646 |
+
cout << " will create mapping for " << map_index <<
|
647 |
+
" samples" << endl << std::flush;
|
648 |
+
}
|
649 |
+
assert(maps == NULL);
|
650 |
+
assert(num_samples < 0);
|
651 |
+
maps = new DocIdx[4*map_index];
|
652 |
+
num_samples = static_cast<int64_t>(map_index);
|
653 |
+
}
|
654 |
+
|
655 |
+
} // for (int iteration=0; iteration < 2; ++iteration) {
|
656 |
+
|
657 |
+
// Shuffle.
|
658 |
+
// We need a 64 bit random number generator as we might have more
|
659 |
+
// than 2 billion samples.
|
660 |
+
std::mt19937_64 rand64_gen(seed + 1);
|
661 |
+
for (auto i=(num_samples - 1); i > 0; --i) {
|
662 |
+
const auto j = static_cast<int64_t>(rand64_gen() % (i + 1));
|
663 |
+
const auto i0 = 4 * i;
|
664 |
+
const auto j0 = 4 * j;
|
665 |
+
// Swap values.
|
666 |
+
swap(maps[i0], maps[j0]);
|
667 |
+
swap(maps[i0 + 1], maps[j0 + 1]);
|
668 |
+
swap(maps[i0 + 2], maps[j0 + 2]);
|
669 |
+
swap(maps[i0 + 3], maps[j0 + 3]);
|
670 |
+
}
|
671 |
+
|
672 |
+
// Method to deallocate memory.
|
673 |
+
py::capsule free_when_done(maps, [](void *mem_) {
|
674 |
+
DocIdx *mem = reinterpret_cast<DocIdx*>(mem_);
|
675 |
+
delete[] mem;
|
676 |
+
});
|
677 |
+
|
678 |
+
// Return the numpy array.
|
679 |
+
const auto byte_size = sizeof(DocIdx);
|
680 |
+
return py::array(std::vector<int64_t>{num_samples, 4}, // shape
|
681 |
+
{4*byte_size, byte_size}, // C-style contiguous strides
|
682 |
+
maps, // the data pointer
|
683 |
+
free_when_done); // numpy array references
|
684 |
+
|
685 |
+
}
|
686 |
+
|
687 |
+
py::array build_blocks_mapping(const py::array_t<int64_t>& docs_,
|
688 |
+
const py::array_t<int>& sizes_,
|
689 |
+
const py::array_t<int>& titles_sizes_,
|
690 |
+
const int num_epochs,
|
691 |
+
const uint64_t max_num_samples,
|
692 |
+
const int max_seq_length,
|
693 |
+
const int seed,
|
694 |
+
const bool verbose,
|
695 |
+
const bool use_one_sent_blocks) {
|
696 |
+
|
697 |
+
if (sizes_.size() > std::numeric_limits<uint32_t>::max()) {
|
698 |
+
if (verbose) {
|
699 |
+
cout << " using uint64 for data mapping..." << endl << std::flush;
|
700 |
+
}
|
701 |
+
return build_blocks_mapping_impl<uint64_t>(docs_, sizes_, titles_sizes_,
|
702 |
+
num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks);
|
703 |
+
} else {
|
704 |
+
if (verbose) {
|
705 |
+
cout << " using uint32 for data mapping..." << endl << std::flush;
|
706 |
+
}
|
707 |
+
return build_blocks_mapping_impl<uint32_t>(docs_, sizes_, titles_sizes_,
|
708 |
+
num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks);
|
709 |
+
}
|
710 |
+
}
|
711 |
+
|
712 |
+
PYBIND11_MODULE(helpers, m) {
|
713 |
+
m.def("build_mapping", &build_mapping);
|
714 |
+
m.def("build_blocks_mapping", &build_blocks_mapping);
|
715 |
+
m.def("build_sample_idx", &build_sample_idx);
|
716 |
+
m.def("build_blending_indices", &build_blending_indices);
|
717 |
+
}
|
megatron/data/helpers.cpython-38-x86_64-linux-gnu.so
ADDED
Binary file (192 kB). View file
|
|
megatron/data/helpers.cpython-39-x86_64-linux-gnu.so
ADDED
Binary file (212 kB). View file
|
|
megatron/data/ict_dataset.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
import random
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
|
7 |
+
from megatron import get_tokenizer
|
8 |
+
from megatron import get_args
|
9 |
+
from megatron.data.dataset_utils import get_indexed_dataset_
|
10 |
+
from megatron.data.realm_dataset_utils import get_block_samples_mapping
|
11 |
+
|
12 |
+
def make_attention_mask(source_block, target_block):
|
13 |
+
"""
|
14 |
+
Returns a 2-dimensional (2-D) attention mask
|
15 |
+
:param source_block: 1-D array
|
16 |
+
:param target_block: 1-D array
|
17 |
+
"""
|
18 |
+
mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1)
|
19 |
+
mask = mask.astype(np.int64)
|
20 |
+
# (source_length, target_length)
|
21 |
+
return mask
|
22 |
+
|
23 |
+
def get_ict_dataset(use_titles=True, query_in_block_prob=1):
|
24 |
+
"""Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block())
|
25 |
+
rather than for training, since it is only built with a single epoch sample mapping.
|
26 |
+
"""
|
27 |
+
args = get_args()
|
28 |
+
block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True)
|
29 |
+
titles_dataset = get_indexed_dataset_(args.titles_data_path, 'mmap', True)
|
30 |
+
|
31 |
+
kwargs = dict(
|
32 |
+
name='full',
|
33 |
+
block_dataset=block_dataset,
|
34 |
+
title_dataset=titles_dataset,
|
35 |
+
data_prefix=args.data_path,
|
36 |
+
num_epochs=1,
|
37 |
+
max_num_samples=None,
|
38 |
+
max_seq_length=args.seq_length,
|
39 |
+
seed=1,
|
40 |
+
query_in_block_prob=query_in_block_prob,
|
41 |
+
use_titles=use_titles,
|
42 |
+
use_one_sent_docs=args.use_one_sent_docs
|
43 |
+
)
|
44 |
+
dataset = ICTDataset(**kwargs)
|
45 |
+
return dataset
|
46 |
+
|
47 |
+
|
48 |
+
class ICTDataset(Dataset):
|
49 |
+
"""Dataset containing sentences and their blocks for an inverse cloze task."""
|
50 |
+
def __init__(self, name, block_dataset, title_dataset, data_prefix,
|
51 |
+
num_epochs, max_num_samples, max_seq_length, query_in_block_prob,
|
52 |
+
seed, use_titles=True, use_one_sent_docs=False, binary_head=False):
|
53 |
+
self.name = name
|
54 |
+
self.seed = seed
|
55 |
+
self.max_seq_length = max_seq_length
|
56 |
+
self.query_in_block_prob = query_in_block_prob
|
57 |
+
self.block_dataset = block_dataset
|
58 |
+
self.title_dataset = title_dataset
|
59 |
+
self.rng = random.Random(self.seed)
|
60 |
+
self.use_titles = use_titles
|
61 |
+
self.use_one_sent_docs = use_one_sent_docs
|
62 |
+
|
63 |
+
self.samples_mapping = get_block_samples_mapping(
|
64 |
+
block_dataset, title_dataset, data_prefix, num_epochs,
|
65 |
+
max_num_samples, max_seq_length, seed, name, use_one_sent_docs)
|
66 |
+
self.tokenizer = get_tokenizer()
|
67 |
+
self.vocab_id_list = list(self.tokenizer.inv_vocab.keys())
|
68 |
+
self.vocab_id_to_token_list = self.tokenizer.inv_vocab
|
69 |
+
self.cls_id = self.tokenizer.cls
|
70 |
+
self.sep_id = self.tokenizer.sep
|
71 |
+
self.mask_id = self.tokenizer.mask
|
72 |
+
self.pad_id = self.tokenizer.pad
|
73 |
+
|
74 |
+
def __len__(self):
|
75 |
+
return len(self.samples_mapping)
|
76 |
+
|
77 |
+
def __getitem__(self, idx):
|
78 |
+
"""Get an ICT example of a pseudo-query and the block of text from which it was extracted"""
|
79 |
+
sample_data = self.samples_mapping[idx]
|
80 |
+
start_idx, end_idx, doc_idx, block_idx = sample_data.as_tuple()
|
81 |
+
|
82 |
+
if self.use_titles:
|
83 |
+
title = self.title_dataset[int(doc_idx)]
|
84 |
+
title_pad_offset = 3 + len(title)
|
85 |
+
else:
|
86 |
+
title = None
|
87 |
+
title_pad_offset = 2
|
88 |
+
block = [self.block_dataset[i] for i in range(start_idx, end_idx)]
|
89 |
+
assert len(block) > 1 or self.use_one_sent_docs or self.query_in_block_prob == 1
|
90 |
+
|
91 |
+
# randint() is inclusive for Python rng
|
92 |
+
rand_sent_idx = self.rng.randint(0, len(block) - 1)
|
93 |
+
|
94 |
+
# keep the query in the context query_in_block_prob fraction of the time.
|
95 |
+
if self.rng.random() < self.query_in_block_prob:
|
96 |
+
query = block[rand_sent_idx].copy()
|
97 |
+
else:
|
98 |
+
query = block.pop(rand_sent_idx)
|
99 |
+
|
100 |
+
# still need to truncate because blocks are concluded when
|
101 |
+
# the sentence lengths have exceeded max_seq_length.
|
102 |
+
query = query[:self.max_seq_length - 2]
|
103 |
+
block = list(itertools.chain(*block))[:self.max_seq_length - title_pad_offset]
|
104 |
+
|
105 |
+
query_tokens, query_pad_mask = self.concat_and_pad_tokens(query)
|
106 |
+
context_tokens, context_pad_mask = self.concat_and_pad_tokens(block, title)
|
107 |
+
|
108 |
+
query_mask = make_attention_mask(query_tokens, query_tokens)
|
109 |
+
context_mask = make_attention_mask(context_tokens, context_tokens)
|
110 |
+
|
111 |
+
block_data = sample_data.as_array()
|
112 |
+
|
113 |
+
sample = {
|
114 |
+
'query_tokens': query_tokens,
|
115 |
+
'query_mask': query_mask,
|
116 |
+
'query_pad_mask': query_pad_mask,
|
117 |
+
'context_tokens': context_tokens,
|
118 |
+
'context_mask': context_mask,
|
119 |
+
'context_pad_mask': context_pad_mask,
|
120 |
+
'block_data': block_data,
|
121 |
+
}
|
122 |
+
|
123 |
+
return sample
|
124 |
+
|
125 |
+
def get_block(self, start_idx, end_idx, doc_idx):
|
126 |
+
"""Get the IDs for an evidence block plus the title of the corresponding document"""
|
127 |
+
block = [self.block_dataset[i] for i in range(start_idx, end_idx)]
|
128 |
+
title = self.title_dataset[int(doc_idx)]
|
129 |
+
|
130 |
+
block = list(itertools.chain(*block))[:self.max_seq_length - (3 + len(title))]
|
131 |
+
block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
|
132 |
+
|
133 |
+
return block_tokens, block_pad_mask
|
134 |
+
|
135 |
+
def get_null_block(self):
|
136 |
+
"""Get empty block and title - used in REALM pretraining"""
|
137 |
+
block, title = [], []
|
138 |
+
block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
|
139 |
+
|
140 |
+
return block_tokens, block_pad_mask
|
141 |
+
|
142 |
+
def concat_and_pad_tokens(self, tokens, title=None):
|
143 |
+
"""Concat with special tokens and pad sequence to self.max_seq_length"""
|
144 |
+
tokens = list(tokens)
|
145 |
+
if title is None:
|
146 |
+
tokens = [self.cls_id] + tokens + [self.sep_id]
|
147 |
+
else:
|
148 |
+
title = list(title)
|
149 |
+
tokens = [self.cls_id] + title + [self.sep_id] + tokens + [self.sep_id]
|
150 |
+
assert len(tokens) <= self.max_seq_length
|
151 |
+
|
152 |
+
num_pad = self.max_seq_length - len(tokens)
|
153 |
+
pad_mask = [1] * len(tokens) + [0] * num_pad
|
154 |
+
tokens += [self.pad_id] * num_pad
|
155 |
+
|
156 |
+
return np.array(tokens), np.array(pad_mask)
|
megatron/data/image_folder.py
ADDED
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# BSD 3-Clause License
|
2 |
+
#
|
3 |
+
# Copyright (c) Soumith Chintala 2016,
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# Redistribution and use in source and binary forms, with or without
|
7 |
+
# modification, are permitted provided that the following conditions are met:
|
8 |
+
#
|
9 |
+
# * Redistributions of source code must retain the above copyright notice, this
|
10 |
+
# list of conditions and the following disclaimer.
|
11 |
+
#
|
12 |
+
# * Redistributions in binary form must reproduce the above copyright notice,
|
13 |
+
# this list of conditions and the following disclaimer in the documentation
|
14 |
+
# and/or other materials provided with the distribution.
|
15 |
+
#
|
16 |
+
# * Neither the name of the copyright holder nor the names of its
|
17 |
+
# contributors may be used to endorse or promote products derived from
|
18 |
+
# this software without specific prior written permission.
|
19 |
+
|
20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
30 |
+
|
31 |
+
# code taken from
|
32 |
+
# https://github.com/pytorch/vision/blob/main/torchvision/datasets/folder.py
|
33 |
+
# added support for classes_fraction and data_per_class_fraction
|
34 |
+
|
35 |
+
from torchvision.datasets import VisionDataset
|
36 |
+
from PIL import Image
|
37 |
+
|
38 |
+
import os
|
39 |
+
import os.path
|
40 |
+
from typing import Any, Callable, cast, Dict, List, Optional, Tuple
|
41 |
+
import numpy as np
|
42 |
+
|
43 |
+
def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool:
|
44 |
+
"""Checks if a file is an allowed extension.
|
45 |
+
Args:
|
46 |
+
filename (string): path to a file
|
47 |
+
extensions (tuple of strings): extensions to consider (lowercase)
|
48 |
+
Returns:
|
49 |
+
bool: True if the filename ends with one of given extensions
|
50 |
+
"""
|
51 |
+
return filename.lower().endswith(extensions)
|
52 |
+
|
53 |
+
|
54 |
+
def is_image_file(filename: str) -> bool:
|
55 |
+
"""Checks if a file is an allowed image extension.
|
56 |
+
Args:
|
57 |
+
filename (string): path to a file
|
58 |
+
Returns:
|
59 |
+
bool: True if the filename ends with a known image extension
|
60 |
+
"""
|
61 |
+
return has_file_allowed_extension(filename, IMG_EXTENSIONS)
|
62 |
+
|
63 |
+
|
64 |
+
def make_dataset(
|
65 |
+
directory: str,
|
66 |
+
class_to_idx: Dict[str, int],
|
67 |
+
data_per_class_fraction: float,
|
68 |
+
extensions: Optional[Tuple[str, ...]] = None,
|
69 |
+
is_valid_file: Optional[Callable[[str], bool]] = None,
|
70 |
+
) -> List[Tuple[str, int]]:
|
71 |
+
"""Generates a list of samples of a form (path_to_sample, class).
|
72 |
+
Args:
|
73 |
+
directory (str): root dataset directory
|
74 |
+
class_to_idx (Dict[str, int]): dictionary mapping class name to class index
|
75 |
+
extensions (optional): A list of allowed extensions.
|
76 |
+
Either extensions or is_valid_file should be passed. Defaults to None.
|
77 |
+
is_valid_file (optional): A function that takes path of a file
|
78 |
+
and checks if the file is a valid file
|
79 |
+
(used to check of corrupt files) both extensions and
|
80 |
+
is_valid_file should not be passed. Defaults to None.
|
81 |
+
Raises:
|
82 |
+
ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None.
|
83 |
+
Returns:
|
84 |
+
List[Tuple[str, int]]: samples of a form (path_to_sample, class)
|
85 |
+
"""
|
86 |
+
instances = []
|
87 |
+
directory = os.path.expanduser(directory)
|
88 |
+
both_none = extensions is None and is_valid_file is None
|
89 |
+
both_something = extensions is not None and is_valid_file is not None
|
90 |
+
if both_none or both_something:
|
91 |
+
raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
|
92 |
+
if extensions is not None:
|
93 |
+
def is_valid_file(x: str) -> bool:
|
94 |
+
return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions))
|
95 |
+
is_valid_file = cast(Callable[[str], bool], is_valid_file)
|
96 |
+
for target_class in sorted(class_to_idx.keys()):
|
97 |
+
class_index = class_to_idx[target_class]
|
98 |
+
target_dir = os.path.join(directory, target_class)
|
99 |
+
if not os.path.isdir(target_dir):
|
100 |
+
continue
|
101 |
+
local_instances = []
|
102 |
+
for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
|
103 |
+
for fname in sorted(fnames):
|
104 |
+
path = os.path.join(root, fname)
|
105 |
+
if is_valid_file(path):
|
106 |
+
item = path, class_index
|
107 |
+
local_instances.append(item)
|
108 |
+
|
109 |
+
instances.extend(local_instances[0:int(len(local_instances) * data_per_class_fraction)])
|
110 |
+
|
111 |
+
return instances
|
112 |
+
|
113 |
+
|
114 |
+
class DatasetFolder(VisionDataset):
|
115 |
+
"""A generic data loader where the samples are arranged in this way: ::
|
116 |
+
root/class_x/xxx.ext
|
117 |
+
root/class_x/xxy.ext
|
118 |
+
root/class_x/[...]/xxz.ext
|
119 |
+
root/class_y/123.ext
|
120 |
+
root/class_y/nsdf3.ext
|
121 |
+
root/class_y/[...]/asd932_.ext
|
122 |
+
Args:
|
123 |
+
root (string): Root directory path.
|
124 |
+
loader (callable): A function to load a sample given its path.
|
125 |
+
extensions (tuple[string]): A list of allowed extensions.
|
126 |
+
both extensions and is_valid_file should not be passed.
|
127 |
+
transform (callable, optional): A function/transform that takes in
|
128 |
+
a sample and returns a transformed version.
|
129 |
+
E.g, ``transforms.RandomCrop`` for images.
|
130 |
+
target_transform (callable, optional): A function/transform that takes
|
131 |
+
in the target and transforms it.
|
132 |
+
is_valid_file (callable, optional): A function that takes path of a file
|
133 |
+
and check if the file is a valid file (used to check of corrupt files)
|
134 |
+
both extensions and is_valid_file should not be passed.
|
135 |
+
Attributes:
|
136 |
+
classes (list): List of the class names sorted alphabetically.
|
137 |
+
class_to_idx (dict): Dict with items (class_name, class_index).
|
138 |
+
samples (list): List of (sample path, class_index) tuples
|
139 |
+
targets (list): The class_index value for each image in the dataset
|
140 |
+
"""
|
141 |
+
|
142 |
+
def __init__(
|
143 |
+
self,
|
144 |
+
root: str,
|
145 |
+
loader: Callable[[str], Any],
|
146 |
+
extensions: Optional[Tuple[str, ...]] = None,
|
147 |
+
transform: Optional[Callable] = None,
|
148 |
+
target_transform: Optional[Callable] = None,
|
149 |
+
classes_fraction=1.0,
|
150 |
+
data_per_class_fraction=1.0,
|
151 |
+
is_valid_file: Optional[Callable[[str], bool]] = None,
|
152 |
+
) -> None:
|
153 |
+
super(DatasetFolder, self).__init__(root, transform=transform,
|
154 |
+
target_transform=target_transform)
|
155 |
+
self.classes_fraction = classes_fraction
|
156 |
+
self.data_per_class_fraction = data_per_class_fraction
|
157 |
+
classes, class_to_idx = self._find_classes(self.root)
|
158 |
+
samples = self.make_dataset(self.root,
|
159 |
+
class_to_idx,
|
160 |
+
self.data_per_class_fraction,
|
161 |
+
extensions,
|
162 |
+
is_valid_file)
|
163 |
+
if len(samples) == 0:
|
164 |
+
msg = "Found 0 files in subfolders of: {}\n".format(self.root)
|
165 |
+
if extensions is not None:
|
166 |
+
msg += "Supported extensions are: {}".format(",".join(extensions))
|
167 |
+
raise RuntimeError(msg)
|
168 |
+
|
169 |
+
self.loader = loader
|
170 |
+
self.extensions = extensions
|
171 |
+
self.total = len(samples)
|
172 |
+
self.classes = classes
|
173 |
+
self.class_to_idx = class_to_idx
|
174 |
+
self.samples = samples
|
175 |
+
self.targets = [s[1] for s in samples]
|
176 |
+
|
177 |
+
@staticmethod
|
178 |
+
def make_dataset(
|
179 |
+
directory: str,
|
180 |
+
class_to_idx: Dict[str, int],
|
181 |
+
data_per_class_fraction: float,
|
182 |
+
extensions: Optional[Tuple[str, ...]] = None,
|
183 |
+
is_valid_file: Optional[Callable[[str], bool]] = None,
|
184 |
+
) -> List[Tuple[str, int]]:
|
185 |
+
return make_dataset(directory,
|
186 |
+
class_to_idx,
|
187 |
+
data_per_class_fraction,
|
188 |
+
extensions=extensions,
|
189 |
+
is_valid_file=is_valid_file)
|
190 |
+
|
191 |
+
def _find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]:
|
192 |
+
"""
|
193 |
+
Finds the class folders in a dataset.
|
194 |
+
Args:
|
195 |
+
dir (string): Root directory path.
|
196 |
+
Returns:
|
197 |
+
tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
|
198 |
+
Ensures:
|
199 |
+
No class is a subdirectory of another.
|
200 |
+
"""
|
201 |
+
all_classes = [d.name for d in os.scandir(dir) if d.is_dir()]
|
202 |
+
classes = all_classes[0:int(len(all_classes) * self.classes_fraction)]
|
203 |
+
classes.sort()
|
204 |
+
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
|
205 |
+
return classes, class_to_idx
|
206 |
+
|
207 |
+
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
208 |
+
"""
|
209 |
+
Args:
|
210 |
+
index (int): Index
|
211 |
+
Returns:
|
212 |
+
tuple: (sample, target) where target is class_index of the target class.
|
213 |
+
"""
|
214 |
+
curr_index = index
|
215 |
+
for x in range(self.total):
|
216 |
+
try:
|
217 |
+
path, target = self.samples[curr_index]
|
218 |
+
sample = self.loader(path)
|
219 |
+
break
|
220 |
+
except Exception as e:
|
221 |
+
curr_index = np.random.randint(0, self.total)
|
222 |
+
|
223 |
+
if self.transform is not None:
|
224 |
+
sample = self.transform(sample)
|
225 |
+
if self.target_transform is not None:
|
226 |
+
target = self.target_transform(target)
|
227 |
+
|
228 |
+
return sample, target
|
229 |
+
|
230 |
+
def __len__(self) -> int:
|
231 |
+
return len(self.samples)
|
232 |
+
|
233 |
+
|
234 |
+
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
|
235 |
+
|
236 |
+
|
237 |
+
def pil_loader(path: str) -> Image.Image:
|
238 |
+
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
|
239 |
+
with open(path, 'rb') as f:
|
240 |
+
img = Image.open(f)
|
241 |
+
return img.convert('RGB')
|
242 |
+
|
243 |
+
|
244 |
+
# TODO: specify the return type
|
245 |
+
def accimage_loader(path: str) -> Any:
|
246 |
+
import accimage
|
247 |
+
try:
|
248 |
+
return accimage.Image(path)
|
249 |
+
except IOError:
|
250 |
+
# Potentially a decoding problem, fall back to PIL.Image
|
251 |
+
return pil_loader(path)
|
252 |
+
|
253 |
+
|
254 |
+
def default_loader(path: str) -> Any:
|
255 |
+
from torchvision import get_image_backend
|
256 |
+
if get_image_backend() == 'accimage':
|
257 |
+
return accimage_loader(path)
|
258 |
+
else:
|
259 |
+
return pil_loader(path)
|
260 |
+
|
261 |
+
|
262 |
+
class ImageFolder(DatasetFolder):
|
263 |
+
"""A generic data loader where the images are arranged in this way: ::
|
264 |
+
root/dog/xxx.png
|
265 |
+
root/dog/xxy.png
|
266 |
+
root/dog/[...]/xxz.png
|
267 |
+
root/cat/123.png
|
268 |
+
root/cat/nsdf3.png
|
269 |
+
root/cat/[...]/asd932_.png
|
270 |
+
Args:
|
271 |
+
root (string): Root directory path.
|
272 |
+
transform (callable, optional): A function/transform that takes in an PIL image
|
273 |
+
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
274 |
+
target_transform (callable, optional): A function/transform that takes in the
|
275 |
+
target and transforms it.
|
276 |
+
loader (callable, optional): A function to load an image given its path.
|
277 |
+
is_valid_file (callable, optional): A function that takes path of an Image file
|
278 |
+
and check if the file is a valid file (used to check of corrupt files)
|
279 |
+
Attributes:
|
280 |
+
classes (list): List of the class names sorted alphabetically.
|
281 |
+
class_to_idx (dict): Dict with items (class_name, class_index).
|
282 |
+
imgs (list): List of (image path, class_index) tuples
|
283 |
+
"""
|
284 |
+
|
285 |
+
def __init__(
|
286 |
+
self,
|
287 |
+
root: str,
|
288 |
+
transform: Optional[Callable] = None,
|
289 |
+
target_transform: Optional[Callable] = None,
|
290 |
+
classes_fraction=1.0,
|
291 |
+
data_per_class_fraction=1.0,
|
292 |
+
loader: Callable[[str], Any] = default_loader,
|
293 |
+
is_valid_file: Optional[Callable[[str], bool]] = None,
|
294 |
+
):
|
295 |
+
super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None,
|
296 |
+
transform=transform,
|
297 |
+
target_transform=target_transform,
|
298 |
+
classes_fraction=classes_fraction,
|
299 |
+
data_per_class_fraction=data_per_class_fraction,
|
300 |
+
is_valid_file=is_valid_file)
|
301 |
+
self.imgs = self.samples
|
302 |
+
|
megatron/data/indexed_dataset.py
ADDED
@@ -0,0 +1,576 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
|
7 |
+
# copied from fairseq/fairseq/data/indexed_dataset.py
|
8 |
+
# Removed IndexedRawTextDataset since it relied on Fairseq dictionary
|
9 |
+
# other slight modifications to remove fairseq dependencies
|
10 |
+
# Added document index to index file and made it accessible.
|
11 |
+
# An empty sentence no longer separates documents.
|
12 |
+
|
13 |
+
from functools import lru_cache
|
14 |
+
import os
|
15 |
+
import shutil
|
16 |
+
import struct
|
17 |
+
from itertools import accumulate
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
from megatron import print_rank_0
|
22 |
+
|
23 |
+
|
24 |
+
def __best_fitting_dtype(vocab_size=None):
|
25 |
+
if vocab_size is not None and vocab_size < 65500:
|
26 |
+
return np.uint16
|
27 |
+
else:
|
28 |
+
return np.int32
|
29 |
+
|
30 |
+
|
31 |
+
def get_available_dataset_impl():
|
32 |
+
return ['lazy', 'cached', 'mmap']
|
33 |
+
|
34 |
+
|
35 |
+
def infer_dataset_impl(path):
|
36 |
+
if IndexedDataset.exists(path):
|
37 |
+
with open(index_file_path(path), 'rb') as f:
|
38 |
+
magic = f.read(8)
|
39 |
+
if magic == IndexedDataset._HDR_MAGIC:
|
40 |
+
return 'cached'
|
41 |
+
elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]:
|
42 |
+
return 'mmap'
|
43 |
+
else:
|
44 |
+
return None
|
45 |
+
else:
|
46 |
+
print(f"Dataset does not exist: {path}")
|
47 |
+
print("Path should be a basename that both .idx and .bin can be appended to get full filenames.")
|
48 |
+
return None
|
49 |
+
|
50 |
+
|
51 |
+
def make_builder(out_file, impl, vocab_size=None):
|
52 |
+
if impl == 'mmap':
|
53 |
+
return MMapIndexedDatasetBuilder(out_file, dtype=__best_fitting_dtype(vocab_size))
|
54 |
+
else:
|
55 |
+
return IndexedDatasetBuilder(out_file)
|
56 |
+
|
57 |
+
|
58 |
+
def make_dataset(path, impl, skip_warmup=False):
|
59 |
+
if not IndexedDataset.exists(path):
|
60 |
+
print(f"Dataset does not exist: {path}")
|
61 |
+
print("Path should be a basename that both .idx and .bin can be appended to get full filenames.")
|
62 |
+
return None
|
63 |
+
if impl == 'infer':
|
64 |
+
impl = infer_dataset_impl(path)
|
65 |
+
if impl == 'lazy' and IndexedDataset.exists(path):
|
66 |
+
return IndexedDataset(path)
|
67 |
+
elif impl == 'cached' and IndexedDataset.exists(path):
|
68 |
+
return IndexedCachedDataset(path)
|
69 |
+
elif impl == 'mmap' and MMapIndexedDataset.exists(path):
|
70 |
+
return MMapIndexedDataset(path, skip_warmup)
|
71 |
+
print(f"Unknown dataset implementation: {impl}")
|
72 |
+
return None
|
73 |
+
|
74 |
+
|
75 |
+
def dataset_exists(path, impl):
|
76 |
+
if impl == 'mmap':
|
77 |
+
return MMapIndexedDataset.exists(path)
|
78 |
+
else:
|
79 |
+
return IndexedDataset.exists(path)
|
80 |
+
|
81 |
+
|
82 |
+
def read_longs(f, n):
|
83 |
+
a = np.empty(n, dtype=np.int64)
|
84 |
+
f.readinto(a)
|
85 |
+
return a
|
86 |
+
|
87 |
+
|
88 |
+
def write_longs(f, a):
|
89 |
+
f.write(np.array(a, dtype=np.int64))
|
90 |
+
|
91 |
+
|
92 |
+
dtypes = {
|
93 |
+
1: np.uint8,
|
94 |
+
2: np.int8,
|
95 |
+
3: np.int16,
|
96 |
+
4: np.int32,
|
97 |
+
5: np.int64,
|
98 |
+
6: np.float,
|
99 |
+
7: np.double,
|
100 |
+
8: np.uint16
|
101 |
+
}
|
102 |
+
|
103 |
+
|
104 |
+
def code(dtype):
|
105 |
+
for k in dtypes.keys():
|
106 |
+
if dtypes[k] == dtype:
|
107 |
+
return k
|
108 |
+
raise ValueError(dtype)
|
109 |
+
|
110 |
+
|
111 |
+
def index_file_path(prefix_path):
|
112 |
+
return prefix_path + '.idx'
|
113 |
+
|
114 |
+
|
115 |
+
def data_file_path(prefix_path):
|
116 |
+
return prefix_path + '.bin'
|
117 |
+
|
118 |
+
|
119 |
+
def create_doc_idx(sizes):
|
120 |
+
doc_idx = [0]
|
121 |
+
for i, s in enumerate(sizes):
|
122 |
+
if s == 0:
|
123 |
+
doc_idx.append(i + 1)
|
124 |
+
return doc_idx
|
125 |
+
|
126 |
+
|
127 |
+
class IndexedDataset(torch.utils.data.Dataset):
|
128 |
+
"""Loader for IndexedDataset"""
|
129 |
+
_HDR_MAGIC = b'TNTIDX\x00\x00'
|
130 |
+
|
131 |
+
def __init__(self, path):
|
132 |
+
super().__init__()
|
133 |
+
self.path = path
|
134 |
+
self.data_file = None
|
135 |
+
self.read_index(path)
|
136 |
+
|
137 |
+
def read_index(self, path):
|
138 |
+
with open(index_file_path(path), 'rb') as f:
|
139 |
+
magic = f.read(8)
|
140 |
+
assert magic == self._HDR_MAGIC, (
|
141 |
+
'Index file doesn\'t match expected format. '
|
142 |
+
'Make sure that --dataset-impl is configured properly.'
|
143 |
+
)
|
144 |
+
version = f.read(8)
|
145 |
+
assert struct.unpack('<Q', version) == (1,)
|
146 |
+
code, self.element_size = struct.unpack('<QQ', f.read(16))
|
147 |
+
self.dtype = dtypes[code]
|
148 |
+
self._len, self.s = struct.unpack('<QQ', f.read(16))
|
149 |
+
self.doc_count = struct.unpack('<Q', f.read(8))
|
150 |
+
self.dim_offsets = read_longs(f, self._len + 1)
|
151 |
+
self.data_offsets = read_longs(f, self._len + 1)
|
152 |
+
self.sizes = read_longs(f, self.s)
|
153 |
+
self.doc_idx = read_longs(f, self.doc_count)
|
154 |
+
|
155 |
+
def read_data(self, path):
|
156 |
+
self.data_file = open(data_file_path(path), 'rb', buffering=0)
|
157 |
+
|
158 |
+
def check_index(self, i):
|
159 |
+
if i < 0 or i >= self._len:
|
160 |
+
raise IndexError('index out of range')
|
161 |
+
|
162 |
+
def __del__(self):
|
163 |
+
if self.data_file:
|
164 |
+
self.data_file.close()
|
165 |
+
|
166 |
+
# @lru_cache(maxsize=8)
|
167 |
+
def __getitem__(self, idx):
|
168 |
+
if not self.data_file:
|
169 |
+
self.read_data(self.path)
|
170 |
+
if isinstance(idx, int):
|
171 |
+
i = idx
|
172 |
+
self.check_index(i)
|
173 |
+
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
|
174 |
+
a = np.empty(tensor_size, dtype=self.dtype)
|
175 |
+
self.data_file.seek(self.data_offsets[i] * self.element_size)
|
176 |
+
self.data_file.readinto(a)
|
177 |
+
return a
|
178 |
+
elif isinstance(idx, slice):
|
179 |
+
start, stop, step = idx.indices(len(self))
|
180 |
+
if step != 1:
|
181 |
+
raise ValueError("Slices into indexed_dataset must be contiguous")
|
182 |
+
sizes = self.sizes[self.dim_offsets[start]:self.dim_offsets[stop]]
|
183 |
+
size = sum(sizes)
|
184 |
+
a = np.empty(size, dtype=self.dtype)
|
185 |
+
self.data_file.seek(self.data_offsets[start] * self.element_size)
|
186 |
+
self.data_file.readinto(a)
|
187 |
+
offsets = list(accumulate(sizes))
|
188 |
+
sents = np.split(a, offsets[:-1])
|
189 |
+
return sents
|
190 |
+
|
191 |
+
def __len__(self):
|
192 |
+
return self._len
|
193 |
+
|
194 |
+
def num_tokens(self, index):
|
195 |
+
return self.sizes[index]
|
196 |
+
|
197 |
+
def size(self, index):
|
198 |
+
return self.sizes[index]
|
199 |
+
|
200 |
+
@staticmethod
|
201 |
+
def exists(path):
|
202 |
+
return (
|
203 |
+
os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
|
204 |
+
)
|
205 |
+
|
206 |
+
@property
|
207 |
+
def supports_prefetch(self):
|
208 |
+
return False # avoid prefetching to save memory
|
209 |
+
|
210 |
+
|
211 |
+
class IndexedCachedDataset(IndexedDataset):
|
212 |
+
|
213 |
+
def __init__(self, path):
|
214 |
+
super().__init__(path)
|
215 |
+
self.cache = None
|
216 |
+
self.cache_index = {}
|
217 |
+
|
218 |
+
@property
|
219 |
+
def supports_prefetch(self):
|
220 |
+
return True
|
221 |
+
|
222 |
+
def prefetch(self, indices):
|
223 |
+
if all(i in self.cache_index for i in indices):
|
224 |
+
return
|
225 |
+
if not self.data_file:
|
226 |
+
self.read_data(self.path)
|
227 |
+
indices = sorted(set(indices))
|
228 |
+
total_size = 0
|
229 |
+
for i in indices:
|
230 |
+
total_size += self.data_offsets[i + 1] - self.data_offsets[i]
|
231 |
+
self.cache = np.empty(total_size, dtype=self.dtype)
|
232 |
+
ptx = 0
|
233 |
+
self.cache_index.clear()
|
234 |
+
for i in indices:
|
235 |
+
self.cache_index[i] = ptx
|
236 |
+
size = self.data_offsets[i + 1] - self.data_offsets[i]
|
237 |
+
a = self.cache[ptx: ptx + size]
|
238 |
+
self.data_file.seek(self.data_offsets[i] * self.element_size)
|
239 |
+
self.data_file.readinto(a)
|
240 |
+
ptx += size
|
241 |
+
if self.data_file:
|
242 |
+
# close and delete data file after prefetch so we can pickle
|
243 |
+
self.data_file.close()
|
244 |
+
self.data_file = None
|
245 |
+
|
246 |
+
# @lru_cache(maxsize=8)
|
247 |
+
def __getitem__(self, idx):
|
248 |
+
if isinstance(idx, int):
|
249 |
+
i = idx
|
250 |
+
self.check_index(i)
|
251 |
+
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
|
252 |
+
a = np.empty(tensor_size, dtype=self.dtype)
|
253 |
+
ptx = self.cache_index[i]
|
254 |
+
np.copyto(a, self.cache[ptx: ptx + a.size])
|
255 |
+
return a
|
256 |
+
elif isinstance(idx, slice):
|
257 |
+
# Hack just to make this work, can optimizer later if necessary
|
258 |
+
sents = []
|
259 |
+
for i in range(*idx.indices(len(self))):
|
260 |
+
sents.append(self[i])
|
261 |
+
return sents
|
262 |
+
|
263 |
+
|
264 |
+
class IndexedDatasetBuilder(object):
|
265 |
+
element_sizes = {
|
266 |
+
np.uint8: 1,
|
267 |
+
np.int8: 1,
|
268 |
+
np.int16: 2,
|
269 |
+
np.int32: 4,
|
270 |
+
np.int64: 8,
|
271 |
+
np.float: 4,
|
272 |
+
np.double: 8
|
273 |
+
}
|
274 |
+
|
275 |
+
def __init__(self, out_file, dtype=np.int32):
|
276 |
+
self.out_file = open(out_file, 'wb')
|
277 |
+
self.dtype = dtype
|
278 |
+
self.data_offsets = [0]
|
279 |
+
self.dim_offsets = [0]
|
280 |
+
self.sizes = []
|
281 |
+
self.element_size = self.element_sizes[self.dtype]
|
282 |
+
self.doc_idx = [0]
|
283 |
+
|
284 |
+
def add_item(self, tensor):
|
285 |
+
bytes = self.out_file.write(np.array(tensor.numpy(), dtype=self.dtype))
|
286 |
+
self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size)
|
287 |
+
for s in tensor.size():
|
288 |
+
self.sizes.append(s)
|
289 |
+
self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size()))
|
290 |
+
|
291 |
+
def end_document(self):
|
292 |
+
self.doc_idx.append(len(self.sizes))
|
293 |
+
|
294 |
+
def merge_file_(self, another_file):
|
295 |
+
index = IndexedDataset(another_file)
|
296 |
+
assert index.dtype == self.dtype
|
297 |
+
|
298 |
+
doc_offset = len(self.sizes)
|
299 |
+
|
300 |
+
begin = self.data_offsets[-1]
|
301 |
+
for data_offset in index.data_offsets[1:]:
|
302 |
+
self.data_offsets.append(begin + data_offset)
|
303 |
+
self.sizes.extend(index.sizes)
|
304 |
+
|
305 |
+
begin = self.dim_offsets[-1]
|
306 |
+
for dim_offset in index.dim_offsets[1:]:
|
307 |
+
self.dim_offsets.append(begin + dim_offset)
|
308 |
+
|
309 |
+
self.doc_idx.extend((doc_offset + index.doc_idx)[1:])
|
310 |
+
|
311 |
+
with open(data_file_path(another_file), 'rb') as f:
|
312 |
+
while True:
|
313 |
+
data = f.read(1024)
|
314 |
+
if data:
|
315 |
+
self.out_file.write(data)
|
316 |
+
else:
|
317 |
+
break
|
318 |
+
|
319 |
+
def finalize(self, index_file):
|
320 |
+
self.out_file.close()
|
321 |
+
index = open(index_file, 'wb')
|
322 |
+
index.write(b'TNTIDX\x00\x00')
|
323 |
+
index.write(struct.pack('<Q', 1))
|
324 |
+
index.write(struct.pack('<QQ', code(self.dtype), self.element_size))
|
325 |
+
index.write(struct.pack('<QQ', len(self.data_offsets) - 1, len(self.sizes)))
|
326 |
+
index.write(struct.pack('<Q', len(self.doc_idx)))
|
327 |
+
write_longs(index, self.dim_offsets)
|
328 |
+
write_longs(index, self.data_offsets)
|
329 |
+
write_longs(index, self.sizes)
|
330 |
+
write_longs(index, self.doc_idx)
|
331 |
+
index.close()
|
332 |
+
|
333 |
+
|
334 |
+
def _warmup_mmap_file(path):
|
335 |
+
with open(path, 'rb') as stream:
|
336 |
+
while stream.read(100 * 1024 * 1024):
|
337 |
+
pass
|
338 |
+
|
339 |
+
|
340 |
+
class MMapIndexedDataset(torch.utils.data.Dataset):
|
341 |
+
class Index(object):
|
342 |
+
_HDR_MAGIC = b'MMIDIDX\x00\x00'
|
343 |
+
|
344 |
+
@classmethod
|
345 |
+
def writer(cls, path, dtype):
|
346 |
+
class _Writer(object):
|
347 |
+
def __enter__(self):
|
348 |
+
self._file = open(path, 'wb')
|
349 |
+
|
350 |
+
self._file.write(cls._HDR_MAGIC)
|
351 |
+
self._file.write(struct.pack('<Q', 1))
|
352 |
+
self._file.write(struct.pack('<B', code(dtype)))
|
353 |
+
|
354 |
+
return self
|
355 |
+
|
356 |
+
@staticmethod
|
357 |
+
def _get_pointers(sizes):
|
358 |
+
dtype_size = dtype().itemsize
|
359 |
+
address = 0
|
360 |
+
pointers = []
|
361 |
+
|
362 |
+
for size in sizes:
|
363 |
+
pointers.append(address)
|
364 |
+
address += size * dtype_size
|
365 |
+
|
366 |
+
return pointers
|
367 |
+
|
368 |
+
def write(self, sizes, doc_idx):
|
369 |
+
pointers = self._get_pointers(sizes)
|
370 |
+
|
371 |
+
self._file.write(struct.pack('<Q', len(sizes)))
|
372 |
+
self._file.write(struct.pack('<Q', len(doc_idx)))
|
373 |
+
|
374 |
+
sizes = np.array(sizes, dtype=np.int32)
|
375 |
+
self._file.write(sizes.tobytes(order='C'))
|
376 |
+
del sizes
|
377 |
+
|
378 |
+
pointers = np.array(pointers, dtype=np.int64)
|
379 |
+
self._file.write(pointers.tobytes(order='C'))
|
380 |
+
del pointers
|
381 |
+
|
382 |
+
doc_idx = np.array(doc_idx, dtype=np.int64)
|
383 |
+
self._file.write(doc_idx.tobytes(order='C'))
|
384 |
+
|
385 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
386 |
+
self._file.close()
|
387 |
+
|
388 |
+
return _Writer()
|
389 |
+
|
390 |
+
def __init__(self, path, skip_warmup=False):
|
391 |
+
with open(path, 'rb') as stream:
|
392 |
+
magic_test = stream.read(9)
|
393 |
+
assert self._HDR_MAGIC == magic_test, (
|
394 |
+
'Index file doesn\'t match expected format. '
|
395 |
+
'Make sure that --dataset-impl is configured properly.'
|
396 |
+
)
|
397 |
+
version = struct.unpack('<Q', stream.read(8))
|
398 |
+
assert (1,) == version
|
399 |
+
|
400 |
+
dtype_code, = struct.unpack('<B', stream.read(1))
|
401 |
+
self._dtype = dtypes[dtype_code]
|
402 |
+
self._dtype_size = self._dtype().itemsize
|
403 |
+
|
404 |
+
self._len = struct.unpack('<Q', stream.read(8))[0]
|
405 |
+
self._doc_count = struct.unpack('<Q', stream.read(8))[0]
|
406 |
+
offset = stream.tell()
|
407 |
+
|
408 |
+
if not skip_warmup:
|
409 |
+
print_rank_0(" warming up index mmap file...")
|
410 |
+
_warmup_mmap_file(path)
|
411 |
+
|
412 |
+
self._bin_buffer_mmap = np.memmap(path, mode='r', order='C')
|
413 |
+
self._bin_buffer = memoryview(self._bin_buffer_mmap)
|
414 |
+
print_rank_0(" reading sizes...")
|
415 |
+
self._sizes = np.frombuffer(
|
416 |
+
self._bin_buffer,
|
417 |
+
dtype=np.int32,
|
418 |
+
count=self._len,
|
419 |
+
offset=offset)
|
420 |
+
print_rank_0(" reading pointers...")
|
421 |
+
self._pointers = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._len,
|
422 |
+
offset=offset + self._sizes.nbytes)
|
423 |
+
print_rank_0(" reading document index...")
|
424 |
+
self._doc_idx = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._doc_count,
|
425 |
+
offset=offset + self._sizes.nbytes + self._pointers.nbytes)
|
426 |
+
|
427 |
+
def __del__(self):
|
428 |
+
self._bin_buffer_mmap._mmap.close()
|
429 |
+
del self._bin_buffer_mmap
|
430 |
+
|
431 |
+
@property
|
432 |
+
def dtype(self):
|
433 |
+
return self._dtype
|
434 |
+
|
435 |
+
@property
|
436 |
+
def sizes(self):
|
437 |
+
return self._sizes
|
438 |
+
|
439 |
+
@property
|
440 |
+
def doc_idx(self):
|
441 |
+
return self._doc_idx
|
442 |
+
|
443 |
+
@lru_cache(maxsize=8)
|
444 |
+
def __getitem__(self, i):
|
445 |
+
return self._pointers[i], self._sizes[i]
|
446 |
+
|
447 |
+
def __len__(self):
|
448 |
+
return self._len
|
449 |
+
|
450 |
+
def __init__(self, path, skip_warmup=False):
|
451 |
+
super().__init__()
|
452 |
+
|
453 |
+
self._path = None
|
454 |
+
self._index = None
|
455 |
+
self._bin_buffer = None
|
456 |
+
|
457 |
+
self._do_init(path, skip_warmup)
|
458 |
+
|
459 |
+
def __getstate__(self):
|
460 |
+
return self._path
|
461 |
+
|
462 |
+
def __setstate__(self, state):
|
463 |
+
self._do_init(state)
|
464 |
+
|
465 |
+
def _do_init(self, path, skip_warmup):
|
466 |
+
self._path = path
|
467 |
+
self._index = self.Index(index_file_path(self._path), skip_warmup)
|
468 |
+
|
469 |
+
if not skip_warmup:
|
470 |
+
print_rank_0(" warming up data mmap file...")
|
471 |
+
_warmup_mmap_file(data_file_path(self._path))
|
472 |
+
print_rank_0(" creating numpy buffer of mmap...")
|
473 |
+
self._bin_buffer_mmap = np.memmap(data_file_path(self._path), mode='r', order='C')
|
474 |
+
print_rank_0(" creating memory view of numpy buffer...")
|
475 |
+
self._bin_buffer = memoryview(self._bin_buffer_mmap)
|
476 |
+
|
477 |
+
def __del__(self):
|
478 |
+
self._bin_buffer_mmap._mmap.close()
|
479 |
+
del self._bin_buffer_mmap
|
480 |
+
del self._index
|
481 |
+
|
482 |
+
def __len__(self):
|
483 |
+
return len(self._index)
|
484 |
+
|
485 |
+
# @lru_cache(maxsize=8)
|
486 |
+
def __getitem__(self, idx):
|
487 |
+
if isinstance(idx, int):
|
488 |
+
ptr, size = self._index[idx]
|
489 |
+
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
|
490 |
+
count=size, offset=ptr)
|
491 |
+
return np_array
|
492 |
+
elif isinstance(idx, slice):
|
493 |
+
start, stop, step = idx.indices(len(self))
|
494 |
+
if step != 1:
|
495 |
+
raise ValueError("Slices into indexed_dataset must be contiguous")
|
496 |
+
ptr = self._index._pointers[start]
|
497 |
+
sizes = self._index._sizes[idx]
|
498 |
+
offsets = list(accumulate(sizes))
|
499 |
+
total_size = sum(sizes)
|
500 |
+
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
|
501 |
+
count=total_size, offset=ptr)
|
502 |
+
sents = np.split(np_array, offsets[:-1])
|
503 |
+
return sents
|
504 |
+
|
505 |
+
def get(self, idx, offset=0, length=None):
|
506 |
+
""" Retrieves a single item from the dataset with the option to only
|
507 |
+
return a portion of the item.
|
508 |
+
|
509 |
+
get(idx) is the same as [idx] but get() does not support slicing.
|
510 |
+
"""
|
511 |
+
ptr, size = self._index[idx]
|
512 |
+
if length is None:
|
513 |
+
length = size - offset
|
514 |
+
ptr += offset * np.dtype(self._index.dtype).itemsize
|
515 |
+
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
|
516 |
+
count=length, offset=ptr)
|
517 |
+
return np_array
|
518 |
+
|
519 |
+
@property
|
520 |
+
def sizes(self):
|
521 |
+
return self._index.sizes
|
522 |
+
|
523 |
+
@property
|
524 |
+
def doc_idx(self):
|
525 |
+
return self._index.doc_idx
|
526 |
+
|
527 |
+
def get_doc_idx(self):
|
528 |
+
return self._index._doc_idx
|
529 |
+
|
530 |
+
def set_doc_idx(self, doc_idx_):
|
531 |
+
self._index._doc_idx = doc_idx_
|
532 |
+
|
533 |
+
@property
|
534 |
+
def supports_prefetch(self):
|
535 |
+
return False
|
536 |
+
|
537 |
+
@staticmethod
|
538 |
+
def exists(path):
|
539 |
+
return (
|
540 |
+
os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
|
541 |
+
)
|
542 |
+
|
543 |
+
|
544 |
+
class MMapIndexedDatasetBuilder(object):
|
545 |
+
def __init__(self, out_file, dtype=np.int64):
|
546 |
+
self._data_file = open(out_file, 'wb')
|
547 |
+
self._dtype = dtype
|
548 |
+
self._sizes = []
|
549 |
+
self._doc_idx = [0]
|
550 |
+
|
551 |
+
def add_item(self, tensor):
|
552 |
+
np_array = np.array(tensor.numpy(), dtype=self._dtype)
|
553 |
+
self._data_file.write(np_array.tobytes(order='C'))
|
554 |
+
self._sizes.append(np_array.size)
|
555 |
+
|
556 |
+
def end_document(self):
|
557 |
+
self._doc_idx.append(len(self._sizes))
|
558 |
+
|
559 |
+
def merge_file_(self, another_file):
|
560 |
+
# Concatenate index
|
561 |
+
index = MMapIndexedDataset.Index(index_file_path(another_file))
|
562 |
+
assert index.dtype == self._dtype
|
563 |
+
|
564 |
+
offset = len(self._sizes)
|
565 |
+
self._sizes.extend(index.sizes)
|
566 |
+
self._doc_idx.extend((offset + index.doc_idx)[1:])
|
567 |
+
|
568 |
+
# Concatenate data
|
569 |
+
with open(data_file_path(another_file), 'rb') as f:
|
570 |
+
shutil.copyfileobj(f, self._data_file)
|
571 |
+
|
572 |
+
def finalize(self, index_file):
|
573 |
+
self._data_file.close()
|
574 |
+
|
575 |
+
with MMapIndexedDataset.Index.writer(index_file, self._dtype) as index:
|
576 |
+
index.write(self._sizes, self._doc_idx)
|
megatron/data/orqa_wiki_dataset.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Wikipedia dataset from DPR code for ORQA."""
|
17 |
+
|
18 |
+
from abc import ABC
|
19 |
+
import csv
|
20 |
+
import numpy as np
|
21 |
+
import random
|
22 |
+
import torch
|
23 |
+
from torch.utils.data import Dataset
|
24 |
+
|
25 |
+
from megatron import print_rank_0, get_args, get_tokenizer, mpu
|
26 |
+
from megatron.data.biencoder_dataset_utils import make_attention_mask
|
27 |
+
|
28 |
+
def get_open_retrieval_wiki_dataset():
|
29 |
+
args = get_args()
|
30 |
+
tokenizer = get_tokenizer()
|
31 |
+
|
32 |
+
dataset = OpenRetrievalEvidenceDataset('2018 Wikipedia from DPR codebase',
|
33 |
+
'evidence',
|
34 |
+
args.evidence_data_path,
|
35 |
+
tokenizer,
|
36 |
+
args.retriever_seq_length)
|
37 |
+
return dataset
|
38 |
+
|
39 |
+
|
40 |
+
def get_open_retrieval_batch(data_iterator):
|
41 |
+
# Items and their type.
|
42 |
+
keys = ['row_id', 'context', 'context_mask', 'context_types',
|
43 |
+
'context_pad_mask']
|
44 |
+
datatype = torch.int64
|
45 |
+
|
46 |
+
# Broadcast data.
|
47 |
+
data = None if data_iterator is None else next(data_iterator)
|
48 |
+
data_b = mpu.broadcast_data(keys, data, datatype)
|
49 |
+
|
50 |
+
# Unpack.
|
51 |
+
row_id = data_b['row_id'].long()
|
52 |
+
context = data_b['context'].long()
|
53 |
+
|
54 |
+
# TODO: make the context mask a binary one
|
55 |
+
context_mask = (data_b['context_mask'] < 0.5)
|
56 |
+
|
57 |
+
context_types = data_b['context_types'].long()
|
58 |
+
context_pad_mask = data_b['context_pad_mask'].long()
|
59 |
+
|
60 |
+
return row_id, context, context_mask, context_types, context_pad_mask
|
61 |
+
|
62 |
+
|
63 |
+
def build_tokens_types_paddings_from_text(row, tokenizer, max_seq_length):
|
64 |
+
"""Build token types and paddings, trim if needed, and pad if needed."""
|
65 |
+
|
66 |
+
title_ids = tokenizer.tokenize(row['title'])
|
67 |
+
context_ids = tokenizer.tokenize(row['text'])
|
68 |
+
|
69 |
+
# Appending the title of the context at front
|
70 |
+
extended_context_ids = title_ids + [tokenizer.sep_id] + context_ids
|
71 |
+
|
72 |
+
context_ids, context_types, context_pad_mask = \
|
73 |
+
build_tokens_types_paddings_from_ids(extended_context_ids,
|
74 |
+
max_seq_length, tokenizer.cls, tokenizer.sep, tokenizer.pad)
|
75 |
+
|
76 |
+
return context_ids, context_types, context_pad_mask
|
77 |
+
|
78 |
+
|
79 |
+
# noinspection DuplicatedCode
|
80 |
+
def build_tokens_types_paddings_from_ids(text_ids, max_seq_length,
|
81 |
+
cls_id, sep_id, pad_id):
|
82 |
+
"""Build token types and paddings, trim if needed, and pad if needed."""
|
83 |
+
enc_ids = []
|
84 |
+
tokentypes_enc = []
|
85 |
+
|
86 |
+
# [CLS].
|
87 |
+
enc_ids.append(cls_id)
|
88 |
+
tokentypes_enc.append(0)
|
89 |
+
|
90 |
+
# A.
|
91 |
+
len_src = len(text_ids)
|
92 |
+
enc_ids.extend(text_ids)
|
93 |
+
tokentypes_enc.extend([0] * len_src)
|
94 |
+
|
95 |
+
# Cap the size.
|
96 |
+
if len(enc_ids) > max_seq_length - 1:
|
97 |
+
enc_ids = enc_ids[0: max_seq_length - 1]
|
98 |
+
tokentypes_enc = tokentypes_enc[0: max_seq_length - 1]
|
99 |
+
|
100 |
+
# [SEP].
|
101 |
+
enc_ids.append(sep_id)
|
102 |
+
tokentypes_enc.append(0)
|
103 |
+
|
104 |
+
num_tokens_enc = len(enc_ids)
|
105 |
+
# Padding.
|
106 |
+
padding_length = max_seq_length - len(enc_ids)
|
107 |
+
if padding_length > 0:
|
108 |
+
enc_ids.extend([pad_id] * padding_length)
|
109 |
+
tokentypes_enc.extend([pad_id] * padding_length)
|
110 |
+
|
111 |
+
pad_mask = ([1] * num_tokens_enc) + ([0] * padding_length)
|
112 |
+
pad_mask = np.array(pad_mask, dtype=np.int64)
|
113 |
+
|
114 |
+
return enc_ids, tokentypes_enc, pad_mask
|
115 |
+
|
116 |
+
|
117 |
+
def build_sample(row_id, context_ids, context_types, context_pad_mask):
|
118 |
+
"""Convert to numpy and return a sample consumed by the batch producer."""
|
119 |
+
|
120 |
+
context_ids = np.array(context_ids, dtype=np.int64)
|
121 |
+
context_types = np.array(context_types, dtype=np.int64)
|
122 |
+
context_mask = make_attention_mask(context_ids, context_ids)
|
123 |
+
|
124 |
+
sample = ({
|
125 |
+
'row_id': row_id,
|
126 |
+
'context': context_ids,
|
127 |
+
'context_mask': context_mask,
|
128 |
+
'context_types': context_types,
|
129 |
+
'context_pad_mask': context_pad_mask
|
130 |
+
})
|
131 |
+
return sample
|
132 |
+
|
133 |
+
|
134 |
+
class OpenRetrievalEvidenceDataset(ABC, Dataset):
|
135 |
+
"""Open Retrieval Evidence dataset class."""
|
136 |
+
|
137 |
+
def __init__(self, task_name, dataset_name, datapath, tokenizer,
|
138 |
+
max_seq_length):
|
139 |
+
# Store inputs.
|
140 |
+
self.task_name = task_name
|
141 |
+
self.dataset_name = dataset_name
|
142 |
+
self.tokenizer = tokenizer
|
143 |
+
self.max_seq_length = max_seq_length
|
144 |
+
print_rank_0(' > building {} dataset for {}:'.format(self.task_name,
|
145 |
+
self.dataset_name))
|
146 |
+
# Process the files.
|
147 |
+
print_rank_0(datapath)
|
148 |
+
self.samples, self.id2text = self.process_samples_from_single_path(
|
149 |
+
datapath)
|
150 |
+
|
151 |
+
args = get_args()
|
152 |
+
if args.sample_rate < 1: # subsample
|
153 |
+
k = int(len(self.samples) * args.sample_rate)
|
154 |
+
self.samples = random.sample(self.samples, k)
|
155 |
+
|
156 |
+
print_rank_0(' >> total number of samples: {}'.format(
|
157 |
+
len(self.samples)))
|
158 |
+
|
159 |
+
def __len__(self):
|
160 |
+
return len(self.samples)
|
161 |
+
|
162 |
+
def __getitem__(self, idx):
|
163 |
+
row = self.samples[idx]
|
164 |
+
|
165 |
+
context_ids, context_types, context_pad_mask = \
|
166 |
+
build_tokens_types_paddings_from_text(row, self.tokenizer,
|
167 |
+
self.max_seq_length)
|
168 |
+
|
169 |
+
sample = build_sample(row['doc_id'],
|
170 |
+
context_ids,
|
171 |
+
context_types,
|
172 |
+
context_pad_mask)
|
173 |
+
return sample
|
174 |
+
|
175 |
+
@staticmethod
|
176 |
+
def process_samples_from_single_path(filename):
|
177 |
+
print_rank_0(' > Processing {} ...'.format(filename))
|
178 |
+
total = 0
|
179 |
+
|
180 |
+
rows = []
|
181 |
+
id2text = {}
|
182 |
+
|
183 |
+
with open(filename) as tsvfile:
|
184 |
+
reader = csv.reader(tsvfile, delimiter='\t')
|
185 |
+
next(reader, None) # skip the headers
|
186 |
+
for row in reader:
|
187 |
+
# file format: doc_id, doc_text, title
|
188 |
+
doc_id = int(row[0])
|
189 |
+
text = row[1]
|
190 |
+
title = row[2]
|
191 |
+
|
192 |
+
rows.append({'doc_id': doc_id,
|
193 |
+
'text': text,
|
194 |
+
'title': title})
|
195 |
+
|
196 |
+
assert doc_id not in id2text
|
197 |
+
id2text[doc_id] = (text, title)
|
198 |
+
|
199 |
+
total += 1
|
200 |
+
if total % 100000 == 0:
|
201 |
+
print_rank_0(' > processed {} rows so far ...'.format(
|
202 |
+
total))
|
203 |
+
|
204 |
+
print_rank_0(' >> processed {} samples.'.format(len(rows)))
|
205 |
+
return rows, id2text
|
megatron/data/realm_dataset_utils.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from megatron import mpu, print_rank_0
|
8 |
+
from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_convert_to_numpy
|
9 |
+
from megatron import get_args, get_tokenizer, print_rank_0, mpu
|
10 |
+
|
11 |
+
|
12 |
+
def get_one_epoch_dataloader(dataset, micro_batch_size=None):
|
13 |
+
"""Specifically one epoch to be used in an indexing job."""
|
14 |
+
args = get_args()
|
15 |
+
|
16 |
+
world_size = mpu.get_data_parallel_world_size()
|
17 |
+
rank = mpu.get_data_parallel_rank()
|
18 |
+
if micro_batch_size is None:
|
19 |
+
micro_batch_size = args.micro_batch_size
|
20 |
+
global_batch_size = micro_batch_size * world_size
|
21 |
+
num_workers = args.num_workers
|
22 |
+
|
23 |
+
sampler = torch.utils.data.SequentialSampler(dataset)
|
24 |
+
# importantly, drop_last must be False to get all the data.
|
25 |
+
assert False, 'DistributedBatchSampler deprecated, change the implementation'
|
26 |
+
from megatron.data.samplers import DistributedBatchSampler
|
27 |
+
batch_sampler = DistributedBatchSampler(sampler,
|
28 |
+
batch_size=global_batch_size,
|
29 |
+
drop_last=False,
|
30 |
+
rank=rank,
|
31 |
+
world_size=world_size)
|
32 |
+
|
33 |
+
return torch.utils.data.DataLoader(dataset,
|
34 |
+
batch_sampler=batch_sampler,
|
35 |
+
num_workers=num_workers,
|
36 |
+
pin_memory=True)
|
37 |
+
|
38 |
+
|
39 |
+
def get_ict_batch(data_iterator):
|
40 |
+
# Items and their type.
|
41 |
+
keys = ['query_tokens', 'query_pad_mask',
|
42 |
+
'block_tokens', 'block_pad_mask', 'block_data']
|
43 |
+
datatype = torch.int64
|
44 |
+
|
45 |
+
# Broadcast data.
|
46 |
+
if data_iterator is None:
|
47 |
+
data = None
|
48 |
+
else:
|
49 |
+
data = next(data_iterator)
|
50 |
+
data_b = mpu.broadcast_data(keys, data, datatype)
|
51 |
+
|
52 |
+
# Unpack.
|
53 |
+
query_tokens = data_b['query_tokens'].long()
|
54 |
+
query_pad_mask = data_b['query_pad_mask'].long()
|
55 |
+
block_tokens = data_b['block_tokens'].long()
|
56 |
+
block_pad_mask = data_b['block_pad_mask'].long()
|
57 |
+
block_indices = data_b['block_data'].long()
|
58 |
+
|
59 |
+
return query_tokens, query_pad_mask,\
|
60 |
+
block_tokens, block_pad_mask, block_indices
|
61 |
+
|
62 |
+
|
63 |
+
def join_str_list(str_list):
|
64 |
+
"""Join a list of strings, handling spaces appropriately"""
|
65 |
+
result = ""
|
66 |
+
for s in str_list:
|
67 |
+
if s.startswith("##"):
|
68 |
+
result += s[2:]
|
69 |
+
else:
|
70 |
+
result += " " + s
|
71 |
+
return result
|
72 |
+
|
73 |
+
|
74 |
+
class BlockSampleData(object):
|
75 |
+
"""A struct for fully describing a fixed-size block of data as used in REALM
|
76 |
+
|
77 |
+
:param start_idx: for first sentence of the block
|
78 |
+
:param end_idx: for last sentence of the block (may be partially truncated in sample construction)
|
79 |
+
:param doc_idx: the index of the document from which the block comes in the original indexed dataset
|
80 |
+
:param block_idx: a unique integer identifier given to every block.
|
81 |
+
"""
|
82 |
+
def __init__(self, start_idx, end_idx, doc_idx, block_idx):
|
83 |
+
self.start_idx = start_idx
|
84 |
+
self.end_idx = end_idx
|
85 |
+
self.doc_idx = doc_idx
|
86 |
+
self.block_idx = block_idx
|
87 |
+
|
88 |
+
def as_array(self):
|
89 |
+
return np.array([self.start_idx, self.end_idx, self.doc_idx, self.block_idx]).astype(np.int64)
|
90 |
+
|
91 |
+
def as_tuple(self):
|
92 |
+
return self.start_idx, self.end_idx, self.doc_idx, self.block_idx
|
93 |
+
|
94 |
+
|
95 |
+
class BlockSamplesMapping(object):
|
96 |
+
def __init__(self, mapping_array):
|
97 |
+
# make sure that the array is compatible with BlockSampleData
|
98 |
+
assert mapping_array.shape[1] == 4
|
99 |
+
self.mapping_array = mapping_array
|
100 |
+
|
101 |
+
def __len__(self):
|
102 |
+
return self.mapping_array.shape[0]
|
103 |
+
|
104 |
+
def __getitem__(self, idx):
|
105 |
+
"""Get the data associated with an indexed sample."""
|
106 |
+
sample_data = BlockSampleData(*self.mapping_array[idx])
|
107 |
+
return sample_data
|
108 |
+
|
109 |
+
|
110 |
+
def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epochs,
|
111 |
+
max_num_samples, max_seq_length, seed, name, use_one_sent_docs=False):
|
112 |
+
"""Get samples mapping for a dataset over fixed size blocks. This function also requires
|
113 |
+
a dataset of the titles for the source documents since their lengths must be taken into account.
|
114 |
+
|
115 |
+
:return: samples_mapping (BlockSamplesMapping)
|
116 |
+
"""
|
117 |
+
|
118 |
+
if not num_epochs:
|
119 |
+
if not max_num_samples:
|
120 |
+
raise ValueError("Need to specify either max_num_samples "
|
121 |
+
"or num_epochs")
|
122 |
+
num_epochs = np.iinfo(np.int32).max - 1
|
123 |
+
if not max_num_samples:
|
124 |
+
max_num_samples = np.iinfo(np.int64).max - 1
|
125 |
+
|
126 |
+
# Filename of the index mapping
|
127 |
+
indexmap_filename = data_prefix
|
128 |
+
indexmap_filename += '_{}_indexmap'.format(name)
|
129 |
+
if num_epochs != (np.iinfo(np.int32).max - 1):
|
130 |
+
indexmap_filename += '_{}ep'.format(num_epochs)
|
131 |
+
if max_num_samples != (np.iinfo(np.int64).max - 1):
|
132 |
+
indexmap_filename += '_{}mns'.format(max_num_samples)
|
133 |
+
indexmap_filename += '_{}msl'.format(max_seq_length)
|
134 |
+
indexmap_filename += '_{}s'.format(seed)
|
135 |
+
if use_one_sent_docs:
|
136 |
+
indexmap_filename += '_1sentok'
|
137 |
+
indexmap_filename += '.npy'
|
138 |
+
|
139 |
+
# Build the indexed mapping if not exist.
|
140 |
+
if mpu.get_data_parallel_rank() == 0 and \
|
141 |
+
not os.path.isfile(indexmap_filename):
|
142 |
+
print(' > WARNING: could not find index map file {}, building '
|
143 |
+
'the indices on rank 0 ...'.format(indexmap_filename))
|
144 |
+
|
145 |
+
# Make sure the types match the helpers input types.
|
146 |
+
assert block_dataset.doc_idx.dtype == np.int64
|
147 |
+
assert block_dataset.sizes.dtype == np.int32
|
148 |
+
|
149 |
+
# Build samples mapping
|
150 |
+
verbose = torch.distributed.get_rank() == 0
|
151 |
+
start_time = time.time()
|
152 |
+
print_rank_0(' > building samples index mapping for {} ...'.format(
|
153 |
+
name))
|
154 |
+
|
155 |
+
from megatron.data import helpers
|
156 |
+
mapping_array = helpers.build_blocks_mapping(
|
157 |
+
block_dataset.doc_idx,
|
158 |
+
block_dataset.sizes,
|
159 |
+
title_dataset.sizes,
|
160 |
+
num_epochs,
|
161 |
+
max_num_samples,
|
162 |
+
max_seq_length - 3, # account for added tokens
|
163 |
+
seed,
|
164 |
+
verbose,
|
165 |
+
use_one_sent_docs)
|
166 |
+
|
167 |
+
|
168 |
+
print_rank_0(' > done building samples index mapping')
|
169 |
+
np.save(indexmap_filename, mapping_array, allow_pickle=True)
|
170 |
+
print_rank_0(' > saved the index mapping in {}'.format(
|
171 |
+
indexmap_filename))
|
172 |
+
# Make sure all the ranks have built the mapping
|
173 |
+
print_rank_0(' > elapsed time to build and save samples mapping '
|
174 |
+
'(seconds): {:4f}'.format(
|
175 |
+
time.time() - start_time))
|
176 |
+
|
177 |
+
# This should be a barrier but nccl barrier assumes
|
178 |
+
# device_index=rank which is not the case for model
|
179 |
+
# parallel case
|
180 |
+
counts = torch.cuda.LongTensor([1])
|
181 |
+
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
|
182 |
+
assert counts[0].item() == torch.distributed.get_world_size(
|
183 |
+
group=mpu.get_data_parallel_group())
|
184 |
+
|
185 |
+
# Load indexed dataset.
|
186 |
+
print_rank_0(' > loading indexed mapping from {}'.format(
|
187 |
+
indexmap_filename))
|
188 |
+
start_time = time.time()
|
189 |
+
|
190 |
+
mapping_array = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r')
|
191 |
+
samples_mapping = BlockSamplesMapping(mapping_array)
|
192 |
+
|
193 |
+
print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
|
194 |
+
time.time() - start_time))
|
195 |
+
print_rank_0(' total number of samples: {}'.format(
|
196 |
+
mapping_array.shape[0]))
|
197 |
+
|
198 |
+
return samples_mapping
|
megatron/data/realm_index.py
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
import os
|
3 |
+
import pickle
|
4 |
+
import shutil
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from megatron import get_args
|
10 |
+
from megatron import mpu
|
11 |
+
|
12 |
+
|
13 |
+
def detach(tensor):
|
14 |
+
return tensor.detach().cpu().numpy()
|
15 |
+
|
16 |
+
|
17 |
+
class OpenRetreivalDataStore(object):
|
18 |
+
"""
|
19 |
+
Serializable data structure for holding data for blocks --
|
20 |
+
embeddings and necessary metadata for Retriever
|
21 |
+
"""
|
22 |
+
def __init__(self, embedding_path=None, load_from_path=True, rank=None):
|
23 |
+
self.embed_data = dict()
|
24 |
+
if embedding_path is None:
|
25 |
+
args = get_args()
|
26 |
+
embedding_path = args.embedding_path
|
27 |
+
rank = args.rank
|
28 |
+
self.embedding_path = embedding_path
|
29 |
+
self.rank = rank
|
30 |
+
|
31 |
+
if load_from_path:
|
32 |
+
self.load_from_file()
|
33 |
+
|
34 |
+
block_data_name = os.path.splitext(self.embedding_path)[0]
|
35 |
+
self.temp_dir_name = block_data_name + '_tmp'
|
36 |
+
|
37 |
+
def state(self):
|
38 |
+
return {
|
39 |
+
'embed_data': self.embed_data,
|
40 |
+
}
|
41 |
+
|
42 |
+
def clear(self):
|
43 |
+
"""
|
44 |
+
Clear the embedding data structures to save memory.
|
45 |
+
The metadata ends up getting used, and is also much smaller in
|
46 |
+
dimensionality so it isn't really worth clearing.
|
47 |
+
"""
|
48 |
+
self.embed_data = dict()
|
49 |
+
|
50 |
+
def load_from_file(self):
|
51 |
+
"""Populate members from instance saved to file"""
|
52 |
+
|
53 |
+
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
|
54 |
+
print("\n> Unpickling BlockData", flush=True)
|
55 |
+
state_dict = pickle.load(open(self.embedding_path, 'rb'))
|
56 |
+
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
|
57 |
+
print(">> Finished unpickling BlockData\n", flush=True)
|
58 |
+
|
59 |
+
self.embed_data = state_dict['embed_data']
|
60 |
+
|
61 |
+
def add_block_data(self, row_id, block_embeds, allow_overwrite=False):
|
62 |
+
"""
|
63 |
+
Add data for set of blocks
|
64 |
+
:param row_id: 1D array of unique int ids for the blocks
|
65 |
+
:param block_embeds: 2D array of embeddings of the blocks
|
66 |
+
In the case of retriever this will be [start_idx, end_idx, doc_idx]
|
67 |
+
"""
|
68 |
+
for idx, embed in zip(row_id, block_embeds):
|
69 |
+
if not allow_overwrite and idx in self.embed_data:
|
70 |
+
raise ValueError("Unexpectedly tried to overwrite block data")
|
71 |
+
|
72 |
+
self.embed_data[idx] = np.float16(embed)
|
73 |
+
|
74 |
+
def save_shard(self):
|
75 |
+
"""
|
76 |
+
Save the block data that was created this in this process
|
77 |
+
"""
|
78 |
+
if not os.path.isdir(self.temp_dir_name):
|
79 |
+
os.makedirs(self.temp_dir_name, exist_ok=True)
|
80 |
+
|
81 |
+
# save the data for each shard
|
82 |
+
with open('{}/{}.pkl'.format(self.temp_dir_name, self.rank), 'wb') \
|
83 |
+
as writer:
|
84 |
+
pickle.dump(self.state(), writer)
|
85 |
+
|
86 |
+
def merge_shards_and_save(self):
|
87 |
+
#Combine all the shards made using save_shard
|
88 |
+
shard_names = os.listdir(self.temp_dir_name)
|
89 |
+
seen_own_shard = False
|
90 |
+
|
91 |
+
for fname in os.listdir(self.temp_dir_name):
|
92 |
+
shard_rank = int(os.path.splitext(fname)[0])
|
93 |
+
if shard_rank == self.rank:
|
94 |
+
seen_own_shard = True
|
95 |
+
continue
|
96 |
+
|
97 |
+
with open('{}/{}'.format(self.temp_dir_name, fname), 'rb') as f:
|
98 |
+
data = pickle.load(f)
|
99 |
+
old_size = len(self.embed_data)
|
100 |
+
shard_size = len(data['embed_data'])
|
101 |
+
|
102 |
+
# add the shard's data and check to make sure there
|
103 |
+
# is no overlap
|
104 |
+
self.embed_data.update(data['embed_data'])
|
105 |
+
assert len(self.embed_data) == old_size + shard_size
|
106 |
+
|
107 |
+
assert seen_own_shard
|
108 |
+
|
109 |
+
# save the consolidated shards and remove temporary directory
|
110 |
+
with open(self.embedding_path, 'wb') as final_file:
|
111 |
+
pickle.dump(self.state(), final_file)
|
112 |
+
shutil.rmtree(self.temp_dir_name, ignore_errors=True)
|
113 |
+
|
114 |
+
print("Finished merging {} shards for a total of {} embeds".format(
|
115 |
+
len(shard_names), len(self.embed_data)), flush=True)
|
116 |
+
|
117 |
+
|
118 |
+
class FaissMIPSIndex(object):
|
119 |
+
"""
|
120 |
+
Wrapper object for a BlockData which similarity search via FAISS under the hood
|
121 |
+
"""
|
122 |
+
def __init__(self, embed_size, embed_data=None, use_gpu=False):
|
123 |
+
self.embed_size = embed_size
|
124 |
+
self.embed_data = embed_data
|
125 |
+
self.use_gpu = use_gpu
|
126 |
+
|
127 |
+
self.mips_index = None
|
128 |
+
self._set_mips_index()
|
129 |
+
|
130 |
+
def _set_mips_index(self):
|
131 |
+
"""
|
132 |
+
Create a Faiss Flat index with inner product as the metric
|
133 |
+
to search against
|
134 |
+
"""
|
135 |
+
try:
|
136 |
+
import faiss
|
137 |
+
except ImportError:
|
138 |
+
raise Exception("Error: Please install faiss to use FaissMIPSIndex")
|
139 |
+
|
140 |
+
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
|
141 |
+
print("\n> Building index", flush=True)
|
142 |
+
|
143 |
+
cpu_index = faiss.IndexFlatIP(self.embed_size)
|
144 |
+
|
145 |
+
if self.use_gpu:
|
146 |
+
# create resources and config for GpuIndex
|
147 |
+
config = faiss.GpuMultipleClonerOptions()
|
148 |
+
config.shard = True
|
149 |
+
config.useFloat16 = True
|
150 |
+
gpu_index = faiss.index_cpu_to_all_gpus(cpu_index, co=config)
|
151 |
+
self.mips_index = faiss.IndexIDMap(gpu_index)
|
152 |
+
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
|
153 |
+
print(">> Initialized index on GPU", flush=True)
|
154 |
+
else:
|
155 |
+
# CPU index supports IDs so wrap with IDMap
|
156 |
+
self.mips_index = faiss.IndexIDMap(cpu_index)
|
157 |
+
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
|
158 |
+
print(">> Initialized index on CPU", flush=True)
|
159 |
+
|
160 |
+
# if we were constructed with a BlockData, then automatically load it
|
161 |
+
# when the FAISS structure is built
|
162 |
+
if self.embed_data is not None:
|
163 |
+
self.add_embed_data(self.embed_data)
|
164 |
+
|
165 |
+
def reset_index(self):
|
166 |
+
"""Delete existing index and create a new"""
|
167 |
+
del self.mips_index
|
168 |
+
|
169 |
+
# reset the block data so that _set_block_index will reload it as well
|
170 |
+
if self.embed_data is not None:
|
171 |
+
embed_data_path = self.embed_data.embedding_path
|
172 |
+
del self.embed_data
|
173 |
+
self.embed_data = OpenRetreivalDataStore(embed_data_path)
|
174 |
+
|
175 |
+
self._set_mips_index()
|
176 |
+
|
177 |
+
def update_index(self):
|
178 |
+
"""Delete existing index and create a new"""
|
179 |
+
del self.mips_index
|
180 |
+
|
181 |
+
# reset the block data so that _set_mips_index will reload it as well
|
182 |
+
if self.embed_data is not None:
|
183 |
+
self.embed_data.load_from_file()
|
184 |
+
self._set_mips_index()
|
185 |
+
|
186 |
+
def add_embed_data(self, all_embed_data):
|
187 |
+
"""Add the embedding of each block to the underlying FAISS index"""
|
188 |
+
|
189 |
+
# this assumes the embed_data is a dict : {int: np.array<float>}
|
190 |
+
block_indices, block_embeds = zip(*all_embed_data.embed_data.items())
|
191 |
+
|
192 |
+
# the embeddings have to be entered in as float32 even though the math
|
193 |
+
# internally is done with float16.
|
194 |
+
embeds_arr = np.float32(np.array(block_embeds))
|
195 |
+
indices_arr = np.array(block_indices)
|
196 |
+
|
197 |
+
# we no longer need the embedding data since it's in the index now
|
198 |
+
all_embed_data.clear()
|
199 |
+
|
200 |
+
self.mips_index.add_with_ids(embeds_arr, indices_arr)
|
201 |
+
|
202 |
+
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
|
203 |
+
print(">>> Finished adding block data to index", flush=True)
|
204 |
+
|
205 |
+
def search_mips_index(self, query_embeds, top_k, reconstruct=True):
|
206 |
+
"""
|
207 |
+
Get the top-k blocks by the index distance metric.
|
208 |
+
|
209 |
+
:param reconstruct: if True: return a [num_queries x k x embed_dim]
|
210 |
+
array of blocks
|
211 |
+
if False: return [num_queries x k] array of
|
212 |
+
distances, and another for indices
|
213 |
+
"""
|
214 |
+
query_embeds = np.float32(detach(query_embeds))
|
215 |
+
|
216 |
+
if reconstruct:
|
217 |
+
# get the vectors themselves
|
218 |
+
top_k_block_embeds = self.mips_index.search_and_reconstruct(\
|
219 |
+
query_embeds, top_k)
|
220 |
+
return top_k_block_embeds
|
221 |
+
else:
|
222 |
+
# get distances and indices of closest vectors
|
223 |
+
distances, block_indices = self.mips_index.search(query_embeds, top_k)
|
224 |
+
return distances, block_indices
|
megatron/data/t5_dataset.py
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""T5 Style dataset."""
|
17 |
+
|
18 |
+
import collections
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
import torch
|
22 |
+
|
23 |
+
from megatron import get_tokenizer
|
24 |
+
from megatron.data.dataset_utils import (
|
25 |
+
create_masked_lm_predictions,
|
26 |
+
get_samples_mapping
|
27 |
+
)
|
28 |
+
|
29 |
+
class T5Dataset(torch.utils.data.Dataset):
|
30 |
+
|
31 |
+
def __init__(self, name, indexed_dataset, data_prefix,
|
32 |
+
num_epochs, max_num_samples, masked_lm_prob,
|
33 |
+
max_seq_length, max_seq_length_dec,
|
34 |
+
short_seq_prob, seed):
|
35 |
+
|
36 |
+
# Params to store.
|
37 |
+
self.name = name
|
38 |
+
self.seed = seed
|
39 |
+
self.masked_lm_prob = masked_lm_prob
|
40 |
+
self.max_seq_length = max_seq_length
|
41 |
+
self.max_seq_length_dec = max_seq_length_dec
|
42 |
+
|
43 |
+
# Dataset.
|
44 |
+
self.indexed_dataset = indexed_dataset
|
45 |
+
|
46 |
+
# Build the samples mapping.
|
47 |
+
self.samples_mapping = get_samples_mapping(self.indexed_dataset,
|
48 |
+
data_prefix,
|
49 |
+
num_epochs,
|
50 |
+
max_num_samples,
|
51 |
+
self.max_seq_length - 2, # account for added tokens
|
52 |
+
short_seq_prob,
|
53 |
+
self.seed,
|
54 |
+
self.name,
|
55 |
+
False)
|
56 |
+
|
57 |
+
# Vocab stuff.
|
58 |
+
tokenizer = get_tokenizer()
|
59 |
+
self.vocab_id_list = list(tokenizer.inv_vocab.keys())
|
60 |
+
self.vocab_id_to_token_dict = tokenizer.inv_vocab
|
61 |
+
self.cls_id = tokenizer.cls
|
62 |
+
self.sep_id = tokenizer.sep
|
63 |
+
self.mask_id = tokenizer.mask
|
64 |
+
self.pad_id = tokenizer.pad
|
65 |
+
self.bos_id = tokenizer.bos_token_id
|
66 |
+
self.eos_id = tokenizer.eos_token_id
|
67 |
+
self.sentinel_tokens = tokenizer.additional_special_tokens_ids
|
68 |
+
assert len(self.sentinel_tokens) > 0, "Provide the argument --vocab-extra-ids 100 to the script"
|
69 |
+
|
70 |
+
def __len__(self):
|
71 |
+
return self.samples_mapping.shape[0]
|
72 |
+
|
73 |
+
def __getitem__(self, idx):
|
74 |
+
|
75 |
+
start_index, end_index, seq_length = self.samples_mapping[idx]
|
76 |
+
sample = []
|
77 |
+
for index in range(start_index, end_index):
|
78 |
+
sample.append(self.indexed_dataset[index])
|
79 |
+
# Note that this rng state should be numpy and not python since
|
80 |
+
# python randint is inclusive whereas the numpy one is exclusive.
|
81 |
+
np_rng = np.random.RandomState(seed=(self.seed + idx))
|
82 |
+
return build_training_sample(sample, seq_length,
|
83 |
+
self.max_seq_length, # needed for padding
|
84 |
+
self.max_seq_length_dec,
|
85 |
+
self.vocab_id_list,
|
86 |
+
self.vocab_id_to_token_dict,
|
87 |
+
self.cls_id, self.sep_id,
|
88 |
+
self.mask_id, self.pad_id,
|
89 |
+
self.masked_lm_prob, np_rng,
|
90 |
+
self.bos_id, self.eos_id,
|
91 |
+
self.sentinel_tokens)
|
92 |
+
|
93 |
+
|
94 |
+
def build_training_sample(sample, target_seq_length,
|
95 |
+
max_seq_length, max_seq_length_dec,
|
96 |
+
vocab_id_list, vocab_id_to_token_dict,
|
97 |
+
cls_id, sep_id, mask_id, pad_id,
|
98 |
+
masked_lm_prob, np_rng, bos_id=None,
|
99 |
+
eos_id=None, sentinel_tokens=None):
|
100 |
+
"""Build training sample.
|
101 |
+
|
102 |
+
Arguments:
|
103 |
+
sample: A list of sentences in which each sentence is a list token ids.
|
104 |
+
target_seq_length: Desired sequence length.
|
105 |
+
max_seq_length: Maximum length of the sequence. All values are padded to
|
106 |
+
this length.
|
107 |
+
vocab_id_list: List of vocabulary ids. Used to pick a random id.
|
108 |
+
vocab_id_to_token_dict: A dictionary from vocab ids to text tokens.
|
109 |
+
cls_id: Start of example id.
|
110 |
+
sep_id: Separator id.
|
111 |
+
mask_id: Mask token id.
|
112 |
+
pad_id: Padding token id.
|
113 |
+
masked_lm_prob: Probability to mask tokens.
|
114 |
+
np_rng: Random number genenrator. Note that this rng state should be
|
115 |
+
numpy and not python since python randint is inclusive for
|
116 |
+
the opper bound whereas the numpy one is exclusive.
|
117 |
+
bos_id: start of decoder example id
|
118 |
+
eos_id: end of generation id
|
119 |
+
sentinel_tokens: unique value to be substituted for every replaced span
|
120 |
+
"""
|
121 |
+
|
122 |
+
assert target_seq_length <= max_seq_length
|
123 |
+
|
124 |
+
# flatten sentences into one list
|
125 |
+
tokens = [token for sentence in sample for token in sentence]
|
126 |
+
|
127 |
+
# Truncate to `target_sequence_length`.
|
128 |
+
max_num_tokens = target_seq_length
|
129 |
+
truncated = len(tokens) > max_num_tokens
|
130 |
+
tokens = tokens[:max_num_tokens]
|
131 |
+
|
132 |
+
# Masking.
|
133 |
+
max_predictions_per_seq = masked_lm_prob * max_num_tokens
|
134 |
+
(tokens, masked_positions, masked_labels, _, masked_spans) = create_masked_lm_predictions(
|
135 |
+
tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob,
|
136 |
+
cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng,
|
137 |
+
max_ngrams=10, geometric_dist=True, masking_style="t5")
|
138 |
+
|
139 |
+
# Padding.
|
140 |
+
tokens_enc, tokens_dec_in, labels, enc_mask, \
|
141 |
+
dec_mask, enc_dec_mask, loss_mask \
|
142 |
+
= pad_and_convert_to_numpy(tokens, masked_positions,
|
143 |
+
masked_labels, pad_id, max_seq_length,
|
144 |
+
max_seq_length_dec, masked_spans,
|
145 |
+
bos_id, eos_id, sentinel_tokens)
|
146 |
+
|
147 |
+
train_sample = {
|
148 |
+
'text_enc': tokens_enc,
|
149 |
+
'text_dec': tokens_dec_in,
|
150 |
+
'labels': labels,
|
151 |
+
'loss_mask': loss_mask,
|
152 |
+
'truncated': int(truncated),
|
153 |
+
'enc_mask': enc_mask,
|
154 |
+
'dec_mask': dec_mask,
|
155 |
+
'enc_dec_mask': enc_dec_mask,
|
156 |
+
}
|
157 |
+
return train_sample
|
158 |
+
|
159 |
+
|
160 |
+
def pad_and_convert_to_numpy(tokens, masked_positions,
|
161 |
+
masked_labels, pad_id,
|
162 |
+
max_seq_length, max_seq_length_dec,
|
163 |
+
masked_spans=None, bos_id=None,
|
164 |
+
eos_id=None, sentinel_tokens=None):
|
165 |
+
"""Pad sequences and convert them to numpy."""
|
166 |
+
|
167 |
+
sentinel_tokens = collections.deque(sentinel_tokens)
|
168 |
+
t5_input = []
|
169 |
+
(t5_decoder_in, t5_decoder_out) = ([bos_id], [])
|
170 |
+
(start_index, end_index) = (0, None)
|
171 |
+
for span in masked_spans:
|
172 |
+
flag = sentinel_tokens.popleft()
|
173 |
+
|
174 |
+
# Append the same tokens in decoder input and output
|
175 |
+
t5_decoder_in.append(flag)
|
176 |
+
t5_decoder_in.extend(span.label)
|
177 |
+
t5_decoder_out.append(flag)
|
178 |
+
t5_decoder_out.extend(span.label)
|
179 |
+
|
180 |
+
end_index = span.index[0]
|
181 |
+
t5_input.extend(tokens[start_index: end_index])
|
182 |
+
t5_input.append(flag)
|
183 |
+
|
184 |
+
# the next start index is the token after the last span token
|
185 |
+
start_index = span.index[-1] + 1
|
186 |
+
|
187 |
+
# Add <eos> token to the t5_decoder_out
|
188 |
+
t5_decoder_out.append(eos_id)
|
189 |
+
|
190 |
+
# Add the remaining tokens to the t5 input
|
191 |
+
t5_input.extend(tokens[start_index:])
|
192 |
+
|
193 |
+
# assert (len(t5_input) - len(masked_spans)) + \
|
194 |
+
# (len(t5_decoder_in) - (len(masked_spans) + 1)) == len(tokens)
|
195 |
+
|
196 |
+
# Some checks.
|
197 |
+
|
198 |
+
# Encoder-side padding mask.
|
199 |
+
num_tokens = len(t5_input)
|
200 |
+
padding_length = max_seq_length - num_tokens
|
201 |
+
assert padding_length >= 0
|
202 |
+
assert len(masked_positions) == len(masked_labels)
|
203 |
+
|
204 |
+
# Tokens..
|
205 |
+
filler = [pad_id] * padding_length
|
206 |
+
tokens_enc = np.array(t5_input + filler, dtype=np.int64)
|
207 |
+
|
208 |
+
# Decoder-side padding mask.
|
209 |
+
num_tokens_dec = len(t5_decoder_in)
|
210 |
+
padding_length_dec = max_seq_length_dec - num_tokens_dec
|
211 |
+
assert padding_length_dec >= 0
|
212 |
+
filler_dec = [pad_id] * padding_length_dec
|
213 |
+
tokens_dec_in = np.array(t5_decoder_in + filler_dec, dtype=np.int64)
|
214 |
+
|
215 |
+
# Create attention masks
|
216 |
+
enc_mask = make_attention_mask(tokens_enc, tokens_enc)
|
217 |
+
enc_dec_mask = make_attention_mask(tokens_dec_in, tokens_enc)
|
218 |
+
dec_mask = make_attention_mask(tokens_dec_in, tokens_dec_in)
|
219 |
+
dec_mask = dec_mask * make_history_mask(tokens_dec_in)
|
220 |
+
|
221 |
+
# Labels mask.
|
222 |
+
labels = t5_decoder_out + ([-1] * padding_length_dec)
|
223 |
+
labels = np.array(labels, dtype=np.int64)
|
224 |
+
|
225 |
+
# Loss mask
|
226 |
+
loss_mask = ([1] * num_tokens_dec) + ([0] * padding_length_dec)
|
227 |
+
loss_mask = np.array(loss_mask, dtype=np.int64)
|
228 |
+
|
229 |
+
return tokens_enc, tokens_dec_in, labels, enc_mask, \
|
230 |
+
dec_mask, enc_dec_mask, loss_mask
|
231 |
+
|
232 |
+
|
233 |
+
def make_attention_mask(source_block, target_block):
|
234 |
+
"""
|
235 |
+
Returns a 2-dimensional (2-D) attention mask
|
236 |
+
:param source_block: 1-D array
|
237 |
+
:param target_block: 1-D array
|
238 |
+
"""
|
239 |
+
mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1)
|
240 |
+
mask = mask.astype(np.int64)
|
241 |
+
# (source_length, target_length)
|
242 |
+
return mask
|
243 |
+
|
244 |
+
|
245 |
+
def make_attention_mask_3d(source_block, target_block):
|
246 |
+
"""
|
247 |
+
Returns a 3-dimensional (3-D) attention mask
|
248 |
+
:param source_block: 1-D array
|
249 |
+
:param target_block: 1-D array
|
250 |
+
"""
|
251 |
+
mask = (target_block[:, None, :] >= 1) * (source_block[:, :, None] >= 1)
|
252 |
+
# (batch, source_length, target_length)
|
253 |
+
# mask = mask.astype(np.int64)
|
254 |
+
return mask
|
255 |
+
|
256 |
+
|
257 |
+
def make_history_mask(block):
|
258 |
+
length = block.shape[0]
|
259 |
+
arange = np.arange(length)
|
260 |
+
history_mask = (arange[None, ] <= arange[:, None])
|
261 |
+
history_mask = history_mask.astype(np.int64)
|
262 |
+
return history_mask
|
263 |
+
|
264 |
+
|
265 |
+
def make_history_mask_3d(block):
|
266 |
+
batch, length = block.shape
|
267 |
+
arange = torch.arange(length, device=block.device)
|
268 |
+
history_mask = (arange[None, ] <= arange[:, None])[None, ]
|
269 |
+
history_mask = history_mask.expand(batch, length, length)
|
270 |
+
return history_mask
|
megatron/data/test/test_indexed_dataset.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This file isn't really a formal automated test, it's just a place to
|
2 |
+
# put some code used during development and manual testing of
|
3 |
+
# indexed_dataset.
|
4 |
+
|
5 |
+
from megatron.data import indexed_dataset
|
6 |
+
from megatron.tokenizer import build_tokenizer
|
7 |
+
import argparse
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
|
11 |
+
import torch
|
12 |
+
|
13 |
+
script_dir = os.path.dirname(os.path.realpath(__file__))
|
14 |
+
sys.path.append(os.path.join(script_dir, "../../../"))
|
15 |
+
|
16 |
+
|
17 |
+
def test_indexed_dataset(args):
|
18 |
+
ds = indexed_dataset.make_dataset(args.data, args.dataset_impl)
|
19 |
+
tokenizer = build_tokenizer(args)
|
20 |
+
print(len(ds.doc_idx))
|
21 |
+
print(len(ds))
|
22 |
+
print(ds.doc_idx[-1])
|
23 |
+
if ds.supports_prefetch:
|
24 |
+
# just prefetch the whole thing in test (so assume it is small)
|
25 |
+
ds.prefetch(range(len(ds)))
|
26 |
+
if args.count > len(ds.doc_idx) - 1:
|
27 |
+
args.count = len(ds.doc_idx) - 1
|
28 |
+
|
29 |
+
for i in range(args.count):
|
30 |
+
start = ds.doc_idx[i]
|
31 |
+
end = ds.doc_idx[i + 1]
|
32 |
+
ids = ds[start:end]
|
33 |
+
print(f"Document {i}:")
|
34 |
+
print("--------------")
|
35 |
+
for s in ids:
|
36 |
+
assert len(s) > 0
|
37 |
+
l = s.data.tolist()
|
38 |
+
text = tokenizer.detokenize(l)
|
39 |
+
print(text)
|
40 |
+
print("---")
|
41 |
+
|
42 |
+
|
43 |
+
def test_indexed_dataset_get(args):
|
44 |
+
ds = indexed_dataset.make_dataset(args.data, args.dataset_impl)
|
45 |
+
tokenizer = build_tokenizer(args)
|
46 |
+
size = ds.sizes[0]
|
47 |
+
print(f"size: {size}")
|
48 |
+
full = ds.get(0)
|
49 |
+
print(full)
|
50 |
+
# print(tokenizer.detokenize(full.data.tolist()))
|
51 |
+
print("---")
|
52 |
+
end = ds.get(0, offset=size - 10)
|
53 |
+
print(end)
|
54 |
+
# print(tokenizer.detokenize(end.data.tolist()))
|
55 |
+
|
56 |
+
start = ds.get(0, length=10)
|
57 |
+
print(start)
|
58 |
+
# print(tokenizer.detokenize(start.data.tolist()))
|
59 |
+
|
60 |
+
part = ds.get(0, offset=2, length=8)
|
61 |
+
print(part)
|
62 |
+
# print(tokenizer.detokenize(part.data.tolist()))
|
63 |
+
|
64 |
+
# def test_albert_dataset(args):
|
65 |
+
# # tokenizer = FullBertTokenizer(args.vocab, do_lower_case=True)
|
66 |
+
# # idataset = indexed_dataset.make_dataset(args.data, args.dataset_impl)
|
67 |
+
# # ds = AlbertDataset(idataset, tokenizer)
|
68 |
+
# ds = AlbertDataset.from_paths(args.vocab, args.data, args.dataset_impl,
|
69 |
+
# args.epochs, args.max_num_samples,
|
70 |
+
# args.masked_lm_prob, args.seq_length,
|
71 |
+
# args.short_seq_prob, args.seed)
|
72 |
+
# truncated = 0
|
73 |
+
# total = 0
|
74 |
+
# for i, s in enumerate(ds):
|
75 |
+
# ids = s['text']
|
76 |
+
# tokens = ds.tokenizer.convert_ids_to_tokens(ids)
|
77 |
+
# print(tokens)
|
78 |
+
# if i >= args.count-1:
|
79 |
+
# exit()
|
80 |
+
|
81 |
+
|
82 |
+
def main():
|
83 |
+
parser = argparse.ArgumentParser()
|
84 |
+
parser.add_argument('--data', type=str, help='prefix to data files')
|
85 |
+
parser.add_argument('--dataset-impl', type=str, default='infer',
|
86 |
+
choices=['lazy', 'cached', 'mmap', 'infer'])
|
87 |
+
parser.add_argument('--count', type=int, default=10,
|
88 |
+
help='Number of samples/documents to print')
|
89 |
+
|
90 |
+
group = parser.add_argument_group(title='tokenizer')
|
91 |
+
group.add_argument('--tokenizer-type', type=str, required=True,
|
92 |
+
choices=['BertWordPieceLowerCase',
|
93 |
+
'GPT2BPETokenizer'],
|
94 |
+
help='What type of tokenizer to use.')
|
95 |
+
group.add_argument('--vocab-file', type=str, default=None,
|
96 |
+
help='Path to the vocab file')
|
97 |
+
group.add_argument('--merge-file', type=str, default=None,
|
98 |
+
help='Path to the BPE merge file (if necessary).')
|
99 |
+
|
100 |
+
parser.add_argument('--epochs', type=int, default=5,
|
101 |
+
help='Number of epochs to plan for')
|
102 |
+
parser.add_argument('--max-num-samples', type=int, default=None,
|
103 |
+
help='Maximum number of samples to plan for')
|
104 |
+
parser.add_argument('--masked-lm-prob', type=float, default=0.15,
|
105 |
+
help='probability of masking tokens')
|
106 |
+
parser.add_argument('--seq-length', type=int, default=512,
|
107 |
+
help='maximum sequence length')
|
108 |
+
parser.add_argument('--short-seq-prob', type=float, default=0.1,
|
109 |
+
help='probability of creating a short sequence')
|
110 |
+
parser.add_argument('--seed', type=int, default=1234,
|
111 |
+
help='random seed')
|
112 |
+
args = parser.parse_args()
|
113 |
+
args.rank = 0
|
114 |
+
args.make_vocab_size_divisible_by = 128
|
115 |
+
args.tensor_model_parallel_size = 1
|
116 |
+
|
117 |
+
if args.dataset_impl == "infer":
|
118 |
+
args.dataset_impl = indexed_dataset.infer_dataset_impl(args.data)
|
119 |
+
|
120 |
+
# test_albert_dataset(args)
|
121 |
+
test_indexed_dataset_get(args)
|
122 |
+
|
123 |
+
|
124 |
+
if __name__ == "__main__":
|
125 |
+
main()
|
megatron/data/test/test_preprocess_data.sh
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
IMPL=cached
|
4 |
+
python ../preprocess_data.py \
|
5 |
+
--input test_samples.json \
|
6 |
+
--vocab vocab.txt \
|
7 |
+
--dataset-impl ${IMPL} \
|
8 |
+
--output-prefix test_samples_${IMPL} \
|
9 |
+
--workers 1 \
|
10 |
+
--log-interval 2
|
megatron/data/vit_dataset.py
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
import os
|
16 |
+
import random
|
17 |
+
import numpy as np
|
18 |
+
import torch
|
19 |
+
import torchvision.transforms as T
|
20 |
+
from torchvision import datasets
|
21 |
+
from megatron import get_args
|
22 |
+
from megatron.data.image_folder import ImageFolder
|
23 |
+
from megatron.data.autoaugment import ImageNetPolicy
|
24 |
+
from megatron.data.data_samplers import RandomSeedDataset
|
25 |
+
from PIL import Image, ImageFilter, ImageOps
|
26 |
+
|
27 |
+
|
28 |
+
class GaussianBlur(object):
|
29 |
+
"""
|
30 |
+
Apply Gaussian Blur to the PIL image.
|
31 |
+
"""
|
32 |
+
def __init__(self, p=0.5, radius_min=0.1, radius_max=2.):
|
33 |
+
self.prob = p
|
34 |
+
self.radius_min = radius_min
|
35 |
+
self.radius_max = radius_max
|
36 |
+
|
37 |
+
def __call__(self, img):
|
38 |
+
do_it = random.random() <= self.prob
|
39 |
+
if not do_it:
|
40 |
+
return img
|
41 |
+
|
42 |
+
return img.filter(
|
43 |
+
ImageFilter.GaussianBlur(
|
44 |
+
radius=random.uniform(self.radius_min, self.radius_max)
|
45 |
+
)
|
46 |
+
)
|
47 |
+
|
48 |
+
|
49 |
+
class Solarization(object):
|
50 |
+
"""
|
51 |
+
Apply Solarization to the PIL image.
|
52 |
+
"""
|
53 |
+
def __init__(self, p):
|
54 |
+
self.p = p
|
55 |
+
|
56 |
+
def __call__(self, img):
|
57 |
+
if random.random() < self.p:
|
58 |
+
return ImageOps.solarize(img)
|
59 |
+
else:
|
60 |
+
return img
|
61 |
+
|
62 |
+
|
63 |
+
class ClassificationTransform():
|
64 |
+
def __init__(self, image_size, train=True):
|
65 |
+
args = get_args()
|
66 |
+
assert args.fp16 or args.bf16
|
67 |
+
self.data_type = torch.half if args.fp16 else torch.bfloat16
|
68 |
+
if train:
|
69 |
+
self.transform = T.Compose([
|
70 |
+
T.RandomResizedCrop(image_size),
|
71 |
+
T.RandomHorizontalFlip(),
|
72 |
+
T.ColorJitter(0.4, 0.4, 0.4, 0.1),
|
73 |
+
ImageNetPolicy(),
|
74 |
+
T.ToTensor(),
|
75 |
+
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
76 |
+
T.ConvertImageDtype(self.data_type)
|
77 |
+
])
|
78 |
+
else:
|
79 |
+
self.transform = T.Compose([
|
80 |
+
T.Resize(image_size),
|
81 |
+
T.CenterCrop(image_size),
|
82 |
+
T.ToTensor(),
|
83 |
+
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
84 |
+
T.ConvertImageDtype(self.data_type)
|
85 |
+
])
|
86 |
+
|
87 |
+
def __call__(self, input):
|
88 |
+
output = self.transform(input)
|
89 |
+
return output
|
90 |
+
|
91 |
+
|
92 |
+
class InpaintingTransform():
|
93 |
+
def __init__(self, image_size, train=True):
|
94 |
+
|
95 |
+
args = get_args()
|
96 |
+
self.mask_factor = args.mask_factor
|
97 |
+
self.mask_type = args.mask_type
|
98 |
+
self.image_size = image_size
|
99 |
+
self.patch_size = args.patch_dim
|
100 |
+
self.mask_size = int(self.mask_factor*(image_size[0]/self.patch_size)*(image_size[1]/self.patch_size))
|
101 |
+
self.train = train
|
102 |
+
assert args.fp16 or args.bf16
|
103 |
+
self.data_type = torch.half if args.fp16 else torch.bfloat16
|
104 |
+
|
105 |
+
if self.train:
|
106 |
+
self.transform = T.Compose([
|
107 |
+
T.RandomResizedCrop(self.image_size),
|
108 |
+
T.RandomHorizontalFlip(),
|
109 |
+
T.ColorJitter(0.4, 0.4, 0.4, 0.1),
|
110 |
+
ImageNetPolicy(),
|
111 |
+
T.ToTensor(),
|
112 |
+
T.ConvertImageDtype(self.data_type)
|
113 |
+
])
|
114 |
+
else:
|
115 |
+
self.transform = T.Compose([
|
116 |
+
T.Resize(self.image_size, interpolation=2),
|
117 |
+
T.CenterCrop(self.image_size),
|
118 |
+
T.ToTensor(),
|
119 |
+
T.ConvertImageDtype(self.data_type)
|
120 |
+
])
|
121 |
+
|
122 |
+
def gen_mask(self, image_size, mask_size, mask_type, patch_size):
|
123 |
+
# output: mask as a list with indices for missing patches
|
124 |
+
action_list = [[0, 1], [0, -1], [1, 0], [-1, 0]]
|
125 |
+
assert image_size[0] == image_size[1]
|
126 |
+
img_size_patch = image_size[0] // patch_size
|
127 |
+
|
128 |
+
# drop masked patches
|
129 |
+
mask = torch.zeros((image_size[0], image_size[1]), dtype=torch.float)
|
130 |
+
|
131 |
+
if mask_type == 'random':
|
132 |
+
x = torch.randint(0, img_size_patch, ())
|
133 |
+
y = torch.randint(0, img_size_patch, ())
|
134 |
+
for i in range(mask_size):
|
135 |
+
r = torch.randint(0, len(action_list), ())
|
136 |
+
x = torch.clamp(x + action_list[r][0], min=0, max=img_size_patch - 1)
|
137 |
+
y = torch.clamp(y + action_list[r][1], min=0, max=img_size_patch - 1)
|
138 |
+
x_offset = x * patch_size
|
139 |
+
y_offset = y * patch_size
|
140 |
+
mask[x_offset:x_offset+patch_size, y_offset:y_offset+patch_size] = 1
|
141 |
+
else:
|
142 |
+
assert mask_type == 'row'
|
143 |
+
count = 0
|
144 |
+
for x in reversed(range(img_size_patch)):
|
145 |
+
for y in reversed(range(img_size_patch)):
|
146 |
+
if (count < mask_size):
|
147 |
+
count += 1
|
148 |
+
x_offset = x * patch_size
|
149 |
+
y_offset = y * patch_size
|
150 |
+
mask[x_offset:x_offset+patch_size, y_offset:y_offset+patch_size] = 1
|
151 |
+
return mask
|
152 |
+
|
153 |
+
def __call__(self, input):
|
154 |
+
trans_input = self.transform(input)
|
155 |
+
mask = self.gen_mask(self.image_size, self.mask_size,
|
156 |
+
self.mask_type, self.patch_size)
|
157 |
+
mask = mask.unsqueeze(dim=0)
|
158 |
+
return trans_input, mask
|
159 |
+
|
160 |
+
|
161 |
+
class DinoTransform(object):
|
162 |
+
def __init__(self, image_size, train=True):
|
163 |
+
args = get_args()
|
164 |
+
self.data_type = torch.half if args.fp16 else torch.bfloat16
|
165 |
+
|
166 |
+
flip_and_color_jitter = T.Compose([
|
167 |
+
T.RandomHorizontalFlip(p=0.5),
|
168 |
+
T.RandomApply(
|
169 |
+
[T.ColorJitter(brightness=0.4, contrast=0.4,
|
170 |
+
saturation=0.2, hue=0.1)],
|
171 |
+
p=0.8
|
172 |
+
),
|
173 |
+
T.RandomGrayscale(p=0.2),
|
174 |
+
])
|
175 |
+
|
176 |
+
if args.fp16 or args.bf16:
|
177 |
+
normalize = T.Compose([
|
178 |
+
T.ToTensor(),
|
179 |
+
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
180 |
+
T.ConvertImageDtype(self.data_type)
|
181 |
+
])
|
182 |
+
else:
|
183 |
+
normalize = T.Compose([
|
184 |
+
T.ToTensor(),
|
185 |
+
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
186 |
+
])
|
187 |
+
|
188 |
+
# first global crop
|
189 |
+
scale_const = 0.4
|
190 |
+
self.global_transform1 = T.Compose([
|
191 |
+
T.RandomResizedCrop(image_size,
|
192 |
+
scale=(scale_const, 1),
|
193 |
+
interpolation=Image.BICUBIC),
|
194 |
+
flip_and_color_jitter,
|
195 |
+
GaussianBlur(1.0),
|
196 |
+
normalize
|
197 |
+
])
|
198 |
+
# second global crop
|
199 |
+
self.global_transform2 = T.Compose([
|
200 |
+
T.RandomResizedCrop(image_size,
|
201 |
+
scale=(scale_const, 1),
|
202 |
+
interpolation=Image.BICUBIC),
|
203 |
+
flip_and_color_jitter,
|
204 |
+
GaussianBlur(0.1),
|
205 |
+
Solarization(0.2),
|
206 |
+
normalize
|
207 |
+
])
|
208 |
+
# transformation for the local small crops
|
209 |
+
self.local_crops_number = args.dino_local_crops_number
|
210 |
+
self.local_transform = T.Compose([
|
211 |
+
T.RandomResizedCrop(args.dino_local_img_size,
|
212 |
+
scale=(0.05, scale_const),
|
213 |
+
interpolation=Image.BICUBIC),
|
214 |
+
flip_and_color_jitter,
|
215 |
+
GaussianBlur(p=0.5),
|
216 |
+
normalize
|
217 |
+
])
|
218 |
+
|
219 |
+
def __call__(self, image):
|
220 |
+
crops = []
|
221 |
+
crops.append(self.global_transform1(image))
|
222 |
+
crops.append(self.global_transform2(image))
|
223 |
+
for _ in range(self.local_crops_number):
|
224 |
+
crops.append(self.local_transform(image))
|
225 |
+
return crops
|
226 |
+
|
227 |
+
|
228 |
+
def build_train_valid_datasets(data_path, image_size=224):
|
229 |
+
args = get_args()
|
230 |
+
|
231 |
+
if args.vision_pretraining_type == 'classify':
|
232 |
+
train_transform = ClassificationTransform(image_size)
|
233 |
+
val_transform = ClassificationTransform(image_size, train=False)
|
234 |
+
elif args.vision_pretraining_type == 'inpaint':
|
235 |
+
train_transform = InpaintingTransform(image_size, train=False)
|
236 |
+
val_transform = InpaintingTransform(image_size, train=False)
|
237 |
+
elif args.vision_pretraining_type == 'dino':
|
238 |
+
train_transform = DinoTransform(image_size, train=True)
|
239 |
+
val_transform = ClassificationTransform(image_size, train=False)
|
240 |
+
else:
|
241 |
+
raise Exception('{} vit pretraining type is not supported.'.format(
|
242 |
+
args.vit_pretraining_type))
|
243 |
+
|
244 |
+
# training dataset
|
245 |
+
train_data_path = data_path[0] if len(data_path) <= 2 else data_path[2]
|
246 |
+
train_data = ImageFolder(
|
247 |
+
root=train_data_path,
|
248 |
+
transform=train_transform,
|
249 |
+
classes_fraction=args.classes_fraction,
|
250 |
+
data_per_class_fraction=args.data_per_class_fraction
|
251 |
+
)
|
252 |
+
train_data = RandomSeedDataset(train_data)
|
253 |
+
|
254 |
+
# validation dataset
|
255 |
+
val_data_path = data_path[1]
|
256 |
+
val_data = ImageFolder(
|
257 |
+
root=val_data_path,
|
258 |
+
transform=val_transform
|
259 |
+
)
|
260 |
+
val_data = RandomSeedDataset(val_data)
|
261 |
+
|
262 |
+
return train_data, val_data
|
megatron/dist_signal_handler.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import signal
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
def get_world_size():
|
7 |
+
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
8 |
+
world_size = torch.distributed.get_world_size()
|
9 |
+
else:
|
10 |
+
world_size = 1
|
11 |
+
return world_size
|
12 |
+
|
13 |
+
|
14 |
+
def get_device(local_rank=None):
|
15 |
+
backend = torch.distributed.get_backend()
|
16 |
+
if backend == 'nccl':
|
17 |
+
if local_rank is None:
|
18 |
+
device = torch.device('cuda')
|
19 |
+
else:
|
20 |
+
device = torch.device(f'cuda:{local_rank}')
|
21 |
+
elif backend == 'gloo':
|
22 |
+
device = torch.device('cpu')
|
23 |
+
else:
|
24 |
+
raise RuntimeError
|
25 |
+
return device
|
26 |
+
|
27 |
+
|
28 |
+
def all_gather_item(item, dtype, group=None, async_op=False, local_rank=None):
|
29 |
+
if not torch.distributed.is_available() or \
|
30 |
+
not torch.distributed.is_initialized():
|
31 |
+
return [item]
|
32 |
+
|
33 |
+
device = get_device(local_rank)
|
34 |
+
|
35 |
+
if group is not None:
|
36 |
+
group_size = group.size()
|
37 |
+
else:
|
38 |
+
group_size = get_world_size()
|
39 |
+
|
40 |
+
tensor = torch.tensor([item], device=device, dtype=dtype)
|
41 |
+
output_tensors = [
|
42 |
+
torch.zeros(1, dtype=tensor.dtype, device=tensor.device)
|
43 |
+
for _ in range(group_size)
|
44 |
+
]
|
45 |
+
torch.distributed.all_gather(output_tensors, tensor, group, async_op)
|
46 |
+
output = [elem.item() for elem in output_tensors]
|
47 |
+
return output
|
48 |
+
|
49 |
+
|
50 |
+
class DistributedSignalHandler:
|
51 |
+
def __init__(self, sig=signal.SIGTERM):
|
52 |
+
self.sig = sig
|
53 |
+
|
54 |
+
def signals_received(self):
|
55 |
+
all_received = all_gather_item(
|
56 |
+
self._signal_received, dtype=torch.int32
|
57 |
+
)
|
58 |
+
return all_received
|
59 |
+
|
60 |
+
def __enter__(self):
|
61 |
+
self._signal_received = False
|
62 |
+
self.released = False
|
63 |
+
self.original_handler = signal.getsignal(self.sig)
|
64 |
+
|
65 |
+
def handler(signum, frame):
|
66 |
+
self._signal_received = True
|
67 |
+
|
68 |
+
signal.signal(self.sig, handler)
|
69 |
+
|
70 |
+
return self
|
71 |
+
|
72 |
+
def __exit__(self, type, value, tb):
|
73 |
+
self.release()
|
74 |
+
|
75 |
+
def release(self):
|
76 |
+
if self.released:
|
77 |
+
return False
|
78 |
+
|
79 |
+
signal.signal(self.sig, self.original_handler)
|
80 |
+
self.released = True
|
81 |
+
return True
|
megatron/fp16_deprecated/loss_scaler.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""For backward compatibility, we need the class definitions to deserialize."""
|
17 |
+
|
18 |
+
class LossScaler:
|
19 |
+
def __init__(self, scale=1):
|
20 |
+
self.cur_scale = scale
|
21 |
+
|
22 |
+
class DynamicLossScaler:
|
23 |
+
def __init__(self,
|
24 |
+
init_scale=2**32,
|
25 |
+
scale_factor=2.,
|
26 |
+
scale_window=1000,
|
27 |
+
min_scale=1,
|
28 |
+
delayed_shift=1,
|
29 |
+
consecutive_hysteresis=False):
|
30 |
+
self.cur_scale = init_scale
|
31 |
+
self.cur_iter = 0
|
32 |
+
self.last_overflow_iter = -1
|
33 |
+
self.scale_factor = scale_factor
|
34 |
+
self.scale_window = scale_window
|
35 |
+
self.min_scale = min_scale
|
36 |
+
self.delayed_shift = delayed_shift
|
37 |
+
self.cur_hysteresis = delayed_shift
|
38 |
+
self.consecutive_hysteresis = consecutive_hysteresis
|
39 |
+
|
megatron/fused_kernels/__init__.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import os
|
17 |
+
import pathlib
|
18 |
+
import subprocess
|
19 |
+
|
20 |
+
from torch.utils import cpp_extension
|
21 |
+
|
22 |
+
# Setting this param to a list has a problem of generating different
|
23 |
+
# compilation commands (with diferent order of architectures) and
|
24 |
+
# leading to recompilation of fused kernels. Set it to empty string
|
25 |
+
# to avoid recompilation and assign arch flags explicity in
|
26 |
+
# extra_cuda_cflags below
|
27 |
+
os.environ["TORCH_CUDA_ARCH_LIST"] = ""
|
28 |
+
|
29 |
+
|
30 |
+
def load(args):
|
31 |
+
|
32 |
+
# Check if cuda 11 is installed for compute capability 8.0
|
33 |
+
cc_flag = []
|
34 |
+
_, bare_metal_major, _ = _get_cuda_bare_metal_version(
|
35 |
+
cpp_extension.CUDA_HOME)
|
36 |
+
if int(bare_metal_major) >= 11:
|
37 |
+
cc_flag.append('-gencode')
|
38 |
+
cc_flag.append('arch=compute_80,code=sm_80')
|
39 |
+
|
40 |
+
# Build path
|
41 |
+
srcpath = pathlib.Path(__file__).parent.absolute()
|
42 |
+
buildpath = srcpath / 'build'
|
43 |
+
_create_build_dir(buildpath)
|
44 |
+
|
45 |
+
# Helper function to build the kernels.
|
46 |
+
def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
|
47 |
+
return cpp_extension.load(
|
48 |
+
name=name,
|
49 |
+
sources=sources,
|
50 |
+
build_directory=buildpath,
|
51 |
+
extra_cflags=['-O3',],
|
52 |
+
extra_cuda_cflags=['-O3',
|
53 |
+
'-gencode', 'arch=compute_70,code=sm_70',
|
54 |
+
'--use_fast_math'] + extra_cuda_flags + cc_flag,
|
55 |
+
verbose=(args.rank == 0)
|
56 |
+
)
|
57 |
+
|
58 |
+
# ==============
|
59 |
+
# Fused softmax.
|
60 |
+
# ==============
|
61 |
+
|
62 |
+
if args.masked_softmax_fusion:
|
63 |
+
extra_cuda_flags = ['-U__CUDA_NO_HALF_OPERATORS__',
|
64 |
+
'-U__CUDA_NO_HALF_CONVERSIONS__',
|
65 |
+
'--expt-relaxed-constexpr',
|
66 |
+
'--expt-extended-lambda']
|
67 |
+
|
68 |
+
# Upper triangular softmax.
|
69 |
+
sources=[srcpath / 'scaled_upper_triang_masked_softmax.cpp',
|
70 |
+
srcpath / 'scaled_upper_triang_masked_softmax_cuda.cu']
|
71 |
+
scaled_upper_triang_masked_softmax_cuda = _cpp_extention_load_helper(
|
72 |
+
"scaled_upper_triang_masked_softmax_cuda",
|
73 |
+
sources, extra_cuda_flags)
|
74 |
+
|
75 |
+
# Masked softmax.
|
76 |
+
sources=[srcpath / 'scaled_masked_softmax.cpp',
|
77 |
+
srcpath / 'scaled_masked_softmax_cuda.cu']
|
78 |
+
scaled_masked_softmax_cuda = _cpp_extention_load_helper(
|
79 |
+
"scaled_masked_softmax_cuda", sources, extra_cuda_flags)
|
80 |
+
|
81 |
+
# Softmax
|
82 |
+
sources=[srcpath / 'scaled_softmax.cpp',
|
83 |
+
srcpath / 'scaled_softmax_cuda.cu']
|
84 |
+
scaled_softmax_cuda = _cpp_extention_load_helper(
|
85 |
+
"scaled_softmax_cuda", sources, extra_cuda_flags)
|
86 |
+
|
87 |
+
# =================================
|
88 |
+
# Mixed precision fused layer norm.
|
89 |
+
# =================================
|
90 |
+
|
91 |
+
extra_cuda_flags = ['-maxrregcount=50']
|
92 |
+
sources=[srcpath / 'layer_norm_cuda.cpp',
|
93 |
+
srcpath / 'layer_norm_cuda_kernel.cu']
|
94 |
+
fused_mix_prec_layer_norm_cuda = _cpp_extention_load_helper(
|
95 |
+
"fused_mix_prec_layer_norm_cuda", sources, extra_cuda_flags)
|
96 |
+
|
97 |
+
# =================================
|
98 |
+
# Fused gradient accumulation to weight gradient computation of linear layer
|
99 |
+
# =================================
|
100 |
+
|
101 |
+
if args.gradient_accumulation_fusion:
|
102 |
+
sources=[srcpath / 'fused_weight_gradient_dense.cpp',
|
103 |
+
srcpath / 'fused_weight_gradient_dense.cu']
|
104 |
+
fused_dense_cuda = _cpp_extention_load_helper(
|
105 |
+
"fused_dense_cuda", sources, [])
|
106 |
+
|
107 |
+
|
108 |
+
def _get_cuda_bare_metal_version(cuda_dir):
|
109 |
+
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"],
|
110 |
+
universal_newlines=True)
|
111 |
+
output = raw_output.split()
|
112 |
+
release_idx = output.index("release") + 1
|
113 |
+
release = output[release_idx].split(".")
|
114 |
+
bare_metal_major = release[0]
|
115 |
+
bare_metal_minor = release[1][0]
|
116 |
+
|
117 |
+
return raw_output, bare_metal_major, bare_metal_minor
|
118 |
+
|
119 |
+
|
120 |
+
def _create_build_dir(buildpath):
|
121 |
+
try:
|
122 |
+
os.mkdir(buildpath)
|
123 |
+
except OSError:
|
124 |
+
if not os.path.isdir(buildpath):
|
125 |
+
print(f"Creation of the build directory {buildpath} failed")
|
megatron/fused_kernels/build/.ninja_deps
ADDED
Binary file (128 kB). View file
|
|
megatron/fused_kernels/build/.ninja_log
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ninja log v5
|
2 |
+
0 10410 1666322917541195629 scaled_upper_triang_masked_softmax.o 3f7d6908014e2c4b
|
3 |
+
0 44996 1666322952117194853 scaled_upper_triang_masked_softmax_cuda.cuda.o 63e21f5b42e40bd0
|
4 |
+
44999 45365 1666322952493194844 scaled_upper_triang_masked_softmax_cuda.so f84b117fa81963e5
|
5 |
+
1 10116 1666322962741194614 scaled_masked_softmax.o 9dfad674b44501f2
|
6 |
+
1 45749 1666322998365193814 scaled_masked_softmax_cuda.cuda.o aec0e3e7e0fe0af5
|
7 |
+
45749 46118 1666322998741193805 scaled_masked_softmax_cuda.so a24b9ad6a01f9db5
|
8 |
+
0 10316 1666323009189193571 scaled_softmax.o ba78446f9188fae0
|
9 |
+
0 44538 1666323043401192802 scaled_softmax_cuda.cuda.o 77ea362b8652c7e9
|
10 |
+
44538 44904 1666323043777192794 scaled_softmax_cuda.so d8fa0ebfd78e8bd8
|
11 |
+
0 10918 1666323054829192545 layer_norm_cuda.o 7dc5869ac5593422
|
12 |
+
0 11891 1666323055797192524 layer_norm_cuda_kernel.cuda.o 13d0d213fbbb62de
|
13 |
+
11891 12255 1666323056165192515 fused_mix_prec_layer_norm_cuda.so a21986e1b00b3401
|
14 |
+
0 10072 1666682710301113263 scaled_upper_triang_masked_softmax.o c11d897e7800befb
|
15 |
+
0 46206 1666682746425112452 scaled_upper_triang_masked_softmax_cuda.cuda.o bc610e36d8dfd435
|
16 |
+
46206 46587 1666682746813112443 scaled_upper_triang_masked_softmax_cuda.so f84b117fa81963e5
|
17 |
+
0 9858 1666682756829112218 scaled_masked_softmax.o fedebad209ed2d21
|
18 |
+
0 46362 1666682793321111399 scaled_masked_softmax_cuda.cuda.o 51814239e7caea9a
|
19 |
+
46362 46747 1666682793717111390 scaled_masked_softmax_cuda.so a24b9ad6a01f9db5
|
20 |
+
0 9870 1666682803741111164 scaled_softmax.o 1d9e3231fe352c0b
|
21 |
+
0 46512 1666682840373110342 scaled_softmax_cuda.cuda.o f9b5a976cff0a5ef
|
22 |
+
46513 46900 1666682840769110333 scaled_softmax_cuda.so d8fa0ebfd78e8bd8
|
23 |
+
0 10615 1666682851533110091 layer_norm_cuda.o 3cacb26d8faa2b99
|
24 |
+
0 11849 1666682852761110063 layer_norm_cuda_kernel.cuda.o 319de99ce0920143
|
25 |
+
11849 12230 1666682853145110055 fused_mix_prec_layer_norm_cuda.so a21986e1b00b3401
|
26 |
+
0 12428 1666750089507718000 scaled_upper_triang_masked_softmax.o 8e61e453c7b77ff5
|
27 |
+
0 46226 1666750123300534000 scaled_upper_triang_masked_softmax_cuda.cuda.o 193ac2a539f3f292
|
28 |
+
46228 47144 1666750124224556000 scaled_upper_triang_masked_softmax_cuda.so 1faf6b7eefee1ef4
|
29 |
+
0 11486 1666750135832836000 scaled_masked_softmax.o b6378099b518f069
|
30 |
+
0 47221 1666750171561699000 scaled_masked_softmax_cuda.cuda.o 4faef3fa30fe1e1d
|
31 |
+
47223 48124 1666750172469721000 scaled_masked_softmax_cuda.so d6611febaa933d3d
|
32 |
+
0 11564 1666750184410010000 scaled_softmax.o a90db6c821074406
|
33 |
+
0 46461 1666750219302852000 scaled_softmax_cuda.cuda.o bf0ec8bfec64157c
|
34 |
+
46464 47488 1666750220334877000 scaled_softmax_cuda.so e7199387ed26e64e
|
35 |
+
0 12007 1666750232439170000 layer_norm_cuda.o 6a55ca87d1a8c0b2
|
36 |
+
0 12020 1666750232447170000 layer_norm_cuda_kernel.cuda.o 655c31ba3cbc10c2
|
37 |
+
12022 13866 1666750234299214000 fused_mix_prec_layer_norm_cuda.so 55e2400ed170f0da
|
38 |
+
0 10929 1666856346511840836 scaled_upper_triang_masked_softmax.o 5a6ff0631bbc2735
|
39 |
+
0 43626 1666856379199840197 scaled_upper_triang_masked_softmax_cuda.cuda.o c36b22f9d9fc2117
|
40 |
+
43626 44026 1666856379607840189 scaled_upper_triang_masked_softmax_cuda.so f84b117fa81963e5
|
41 |
+
0 10190 1666856389955839987 scaled_masked_softmax.o e032e1ed28e01e30
|
42 |
+
0 44582 1666856424331839315 scaled_masked_softmax_cuda.cuda.o 2c8db7df38489475
|
43 |
+
44582 44961 1666856424723839308 scaled_masked_softmax_cuda.so a24b9ad6a01f9db5
|
44 |
+
0 10142 1666856435015839107 scaled_softmax.o 446947b66b18fa33
|
45 |
+
0 44480 1666856469343838436 scaled_softmax_cuda.cuda.o 4fb07733497ecc29
|
46 |
+
44480 44879 1666856469751838428 scaled_softmax_cuda.so d8fa0ebfd78e8bd8
|
47 |
+
0 10396 1666856480295838222 layer_norm_cuda.o 73da6101a07a24a7
|
48 |
+
1 11899 1666856481791838193 layer_norm_cuda_kernel.cuda.o 9ec8eab79e592ff4
|
49 |
+
11899 12298 1666856482195838185 fused_mix_prec_layer_norm_cuda.so a21986e1b00b3401
|
50 |
+
1 12100 1666925285098117232 scaled_upper_triang_masked_softmax.o f73d4cea858af8b1
|
51 |
+
2 45529 1666925318557711559 scaled_upper_triang_masked_softmax_cuda.cuda.o 2fa8c20456ca471c
|
52 |
+
45531 47926 1666925320961971387 scaled_upper_triang_masked_softmax_cuda.so 1faf6b7eefee1ef4
|
53 |
+
1 12395 1666925333323307273 scaled_masked_softmax.o ae4878f941317da4
|
54 |
+
1 46831 1666925368027057696 scaled_masked_softmax_cuda.cuda.o 94fda8fac85d606d
|
55 |
+
46833 48525 1666925369731241867 scaled_masked_softmax_cuda.so d6611febaa933d3d
|
56 |
+
1 12276 1666925382212590722 scaled_softmax.o 7a00c61166684714
|
57 |
+
1 47263 1666925417188370543 scaled_softmax_cuda.cuda.o dfe840df0dc3178d
|
58 |
+
47265 50987 1666925420916773471 scaled_softmax_cuda.so e7199387ed26e64e
|
59 |
+
1 13036 1666925434150203603 layer_norm_cuda_kernel.cuda.o 128560bba544b6cb
|
60 |
+
1 14561 1666925435354333732 layer_norm_cuda.o e02c8859a84e70db
|
61 |
+
14569 15455 1666925436574465592 fused_mix_prec_layer_norm_cuda.so 55e2400ed170f0da
|
62 |
+
1 62343 1666962134003529000 scaled_upper_triang_masked_softmax_cuda.cuda.o abba0fca57f22344
|
63 |
+
62363 63833 1666962135511529000 scaled_upper_triang_masked_softmax_cuda.so 1faf6b7eefee1ef4
|
64 |
+
0 12215 1667460191194232000 scaled_upper_triang_masked_softmax.o 96d879ae2bf7b993
|
65 |
+
1 49211 1667460228190231000 scaled_upper_triang_masked_softmax_cuda.cuda.o bc4b370b3c3d5c9e
|
66 |
+
49213 50896 1667460229886231000 scaled_upper_triang_masked_softmax_cuda.so 1faf6b7eefee1ef4
|
67 |
+
0 13297 1667460243334231000 scaled_masked_softmax.o a6809b1177c7ca02
|
68 |
+
1 50055 1667460280082230000 scaled_masked_softmax_cuda.cuda.o dfbe364852f092fc
|
69 |
+
50057 65422 1667460295430230000 scaled_masked_softmax_cuda.so d6611febaa933d3d
|
70 |
+
1 12055 1667460307682230000 scaled_softmax.o cd4f40829964c3cb
|
71 |
+
1 48489 1667460344126229000 scaled_softmax_cuda.cuda.o a6917d3b3ea80f97
|
72 |
+
48526 49856 1667460345502229000 scaled_softmax_cuda.so e7199387ed26e64e
|
73 |
+
0 13966 1667460359626229000 layer_norm_cuda.o e644ccb47b3615c
|
74 |
+
1 15506 1667460361158229000 layer_norm_cuda_kernel.cuda.o 2d32cb24bea852c7
|
75 |
+
15509 40152 1667460385810228000 fused_mix_prec_layer_norm_cuda.so 55e2400ed170f0da
|
76 |
+
0 11006 1676876089228927885 scaled_upper_triang_masked_softmax.o 9c8f7b7399ab2d1f
|
77 |
+
0 42047 1676876120260456898 scaled_upper_triang_masked_softmax_cuda.cuda.o 12331cadf47cb899
|
78 |
+
42047 42264 1676876120488482829 scaled_upper_triang_masked_softmax_cuda.so 1faf6b7eefee1ef4
|
79 |
+
1 10823 1676876131453729835 scaled_masked_softmax.o 6e7c1e4df8bc11a8
|
80 |
+
1 47902 1676876168521945360 scaled_masked_softmax_cuda.cuda.o c86371ad5d3e19ff
|
81 |
+
47902 48087 1676876168713967197 scaled_masked_softmax_cuda.so d6611febaa933d3d
|
82 |
+
0 13926 1676876182787567696 scaled_softmax.o eeaf8300d7ba52f3
|
83 |
+
0 42044 1676876210894764143 scaled_softmax_cuda.cuda.o 379540e7c9ee343a
|
84 |
+
42044 42242 1676876211098787345 scaled_softmax_cuda.so e7199387ed26e64e
|
85 |
+
0 11125 1676876222348066652 layer_norm_cuda_kernel.cuda.o 5fa2a5b112be408c
|
86 |
+
0 11375 1676876222600095314 layer_norm_cuda.o e369f8c0d20bc213
|
87 |
+
11375 11572 1676876222800118061 fused_mix_prec_layer_norm_cuda.so 55e2400ed170f0da
|
88 |
+
2 12764 1682404121987049000 scaled_upper_triang_masked_softmax.o 20d2ceb970d5c522
|
89 |
+
3 48533 1682404157745832000 scaled_upper_triang_masked_softmax_cuda.cuda.o e3bb3890927e826b
|
90 |
+
48592 49557 1682404158781913000 scaled_upper_triang_masked_softmax_cuda.so 1faf6b7eefee1ef4
|
91 |
+
2 13484 1682404172586987000 scaled_masked_softmax.o 8b2e2a7ca1fd841b
|
92 |
+
2 48141 1682404207237685000 scaled_masked_softmax_cuda.cuda.o 6215ee08e9bfd383
|
93 |
+
48143 49151 1682404208257765000 scaled_masked_softmax_cuda.so d6611febaa933d3d
|
94 |
+
2 12243 1682404220754738000 scaled_softmax.o 2842472d594d0d1d
|
95 |
+
2 51180 1682404259661769000 scaled_softmax_cuda.cuda.o 78419d022eea9ad2
|
96 |
+
51184 52124 1682404260637845000 scaled_softmax_cuda.so e7199387ed26e64e
|
97 |
+
2 12762 1682404273698863000 layer_norm_cuda_kernel.cuda.o dec411c038eb6254
|
98 |
+
1 13361 1682404274238905000 layer_norm_cuda.o 14db0d087e6f7321
|
99 |
+
13415 14196 1682404275138975000 fused_mix_prec_layer_norm_cuda.so 55e2400ed170f0da
|
megatron/fused_kernels/build/build.ninja
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ninja_required_version = 1.3
|
2 |
+
cxx = c++
|
3 |
+
nvcc = /usr/local/cuda/bin/nvcc
|
4 |
+
|
5 |
+
cflags = -DTORCH_EXTENSION_NAME=fused_mix_prec_layer_norm_cuda -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1013\" -isystem /opt/conda/lib/python3.8/site-packages/torch/include -isystem /opt/conda/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -isystem /opt/conda/lib/python3.8/site-packages/torch/include/TH -isystem /opt/conda/lib/python3.8/site-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /opt/conda/include/python3.8 -D_GLIBCXX_USE_CXX11_ABI=1 -fPIC -std=c++14 -O3
|
6 |
+
post_cflags =
|
7 |
+
cuda_cflags = -DTORCH_EXTENSION_NAME=fused_mix_prec_layer_norm_cuda -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1013\" -isystem /opt/conda/lib/python3.8/site-packages/torch/include -isystem /opt/conda/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -isystem /opt/conda/lib/python3.8/site-packages/torch/include/TH -isystem /opt/conda/lib/python3.8/site-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /opt/conda/include/python3.8 -D_GLIBCXX_USE_CXX11_ABI=1 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_80,code=sm_80 --compiler-options '-fPIC' -O3 -gencode arch=compute_70,code=sm_70 --use_fast_math -maxrregcount=50 -gencode arch=compute_80,code=sm_80 -std=c++14
|
8 |
+
cuda_post_cflags =
|
9 |
+
ldflags = -shared -L/opt/conda/lib/python3.8/site-packages/torch/lib -lc10 -lc10_cuda -ltorch_cpu -ltorch_cuda -ltorch -ltorch_python -L/usr/local/cuda/lib64 -lcudart
|
10 |
+
|
11 |
+
rule compile
|
12 |
+
command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags
|
13 |
+
depfile = $out.d
|
14 |
+
deps = gcc
|
15 |
+
|
16 |
+
rule cuda_compile
|
17 |
+
command = $nvcc $cuda_cflags -c $in -o $out $cuda_post_cflags
|
18 |
+
|
19 |
+
rule link
|
20 |
+
command = $cxx $in $ldflags -o $out
|
21 |
+
|
22 |
+
build layer_norm_cuda.o: compile /root/ouyangxuan/project/big_model_finetune/Megatrion-LM-clear/megatron/fused_kernels/layer_norm_cuda.cpp
|
23 |
+
build layer_norm_cuda_kernel.cuda.o: cuda_compile /root/ouyangxuan/project/big_model_finetune/Megatrion-LM-clear/megatron/fused_kernels/layer_norm_cuda_kernel.cu
|
24 |
+
|
25 |
+
build fused_mix_prec_layer_norm_cuda.so: link layer_norm_cuda.o layer_norm_cuda_kernel.cuda.o
|
26 |
+
|
27 |
+
default fused_mix_prec_layer_norm_cuda.so
|
28 |
+
|
megatron/fused_kernels/build/fused_mix_prec_layer_norm_cuda.so
ADDED
Binary file (700 kB). View file
|
|
megatron/fused_kernels/build/layer_norm_cuda.o
ADDED
Binary file (293 kB). View file
|
|
megatron/fused_kernels/build/layer_norm_cuda_kernel.cuda.o
ADDED
Binary file (545 kB). View file
|
|
megatron/fused_kernels/build/scaled_masked_softmax.o
ADDED
Binary file (239 kB). View file
|
|
megatron/fused_kernels/build/scaled_masked_softmax_cuda.cuda.o
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:23117ca78a4427a192781d5bff9b2ddc34711ec3c57f6bd5c7c4b7d3b634e429
|
3 |
+
size 1196624
|
megatron/fused_kernels/build/scaled_masked_softmax_cuda.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:409dbcba44e37f1e791e8ce1cb82f4a6884c6eb68aa281311a675462e91762ea
|
3 |
+
size 1283032
|
megatron/fused_kernels/build/scaled_softmax.o
ADDED
Binary file (229 kB). View file
|
|
megatron/fused_kernels/build/scaled_softmax_cuda.cuda.o
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3c2e292b61e4060b6a2c263bd95dcc08c4ab9c29ae4cf74df98bbd2ad4b566ee
|
3 |
+
size 1084024
|
megatron/fused_kernels/build/scaled_softmax_cuda.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5c0b73baf4b0a10ccf124c4e20aabefc122e2f8b0b0887dccfb7eafe3cd5e39c
|
3 |
+
size 1170600
|
megatron/fused_kernels/build/scaled_upper_triang_masked_softmax.o
ADDED
Binary file (230 kB). View file
|
|
megatron/fused_kernels/build/scaled_upper_triang_masked_softmax_cuda.cuda.o
ADDED
Binary file (944 kB). View file
|
|
megatron/fused_kernels/build/scaled_upper_triang_masked_softmax_cuda.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0e1b57adfdbb4254303f89ad4c8f786f9f7b4516c8fa2e95339ae5177a69e4a5
|
3 |
+
size 1032720
|
megatron/fused_kernels/compat.h
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* coding=utf-8
|
2 |
+
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
|
17 |
+
/*This code is copied fron NVIDIA apex:
|
18 |
+
* https://github.com/NVIDIA/apex
|
19 |
+
* with minor changes. */
|
20 |
+
|
21 |
+
|
22 |
+
|
23 |
+
#ifndef TORCH_CHECK
|
24 |
+
#define TORCH_CHECK AT_CHECK
|
25 |
+
#endif
|
26 |
+
|
27 |
+
#ifdef VERSION_GE_1_3
|
28 |
+
#define DATA_PTR data_ptr
|
29 |
+
#else
|
30 |
+
#define DATA_PTR data
|
31 |
+
#endif
|
megatron/fused_kernels/fused_weight_gradient_dense.cpp
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/torch.h>
|
2 |
+
#include <torch/extension.h>
|
3 |
+
|
4 |
+
#include <vector>
|
5 |
+
#include <stdio.h>
|
6 |
+
|
7 |
+
#include "type_shim.h"
|
8 |
+
|
9 |
+
|
10 |
+
template <typename T>
|
11 |
+
int wgrad_gemm_accum_fp32_cuda(T *input, T *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim);
|
12 |
+
|
13 |
+
void wgrad_gemm_accum_fp32(const at::Tensor input, const at::Tensor d_output, at::Tensor d_weight) {
|
14 |
+
at::Tensor input_2d, d_output_2d;
|
15 |
+
// input tensor: collapse to the first dim
|
16 |
+
auto in_sizes = input.sizes();
|
17 |
+
if (input.dim() > 2) {
|
18 |
+
input_2d = input.view({-1, in_sizes[in_sizes.size() - 1]});
|
19 |
+
} else {
|
20 |
+
input_2d = input;
|
21 |
+
}
|
22 |
+
// d_output tensor: collapse to the first dim
|
23 |
+
auto d_out_sizes = d_output.sizes();
|
24 |
+
if (d_output.dim() > 2) {
|
25 |
+
d_output_2d = d_output.view({-1, d_out_sizes[d_out_sizes.size() - 1]});
|
26 |
+
} else {
|
27 |
+
d_output_2d = d_output;
|
28 |
+
}
|
29 |
+
|
30 |
+
int hidden_dim = input_2d.size(0);
|
31 |
+
int in_dim = input_2d.size(1);
|
32 |
+
int out_dim = d_weight.size(0);
|
33 |
+
|
34 |
+
DISPATCH_HALF_BFLOAT_AND_FLOAT(input_2d.scalar_type(), "wgrad_gemm_accum_fp32",
|
35 |
+
int result = wgrad_gemm_accum_fp32_cuda<scalar_t>(
|
36 |
+
input_2d.data_ptr<scalar_t>(),
|
37 |
+
d_output_2d.data_ptr<scalar_t>(),
|
38 |
+
d_weight.data_ptr<float>(),
|
39 |
+
in_dim,
|
40 |
+
hidden_dim,
|
41 |
+
out_dim);
|
42 |
+
);
|
43 |
+
}
|
44 |
+
|
45 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
46 |
+
m.def("wgrad_gemm_accum_fp32", &wgrad_gemm_accum_fp32, "wgrad gemm accum in fp32");
|
47 |
+
}
|
megatron/fused_kernels/fused_weight_gradient_dense.cu
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <ATen/ATen.h>
|
2 |
+
#include <ATen/cuda/CUDAContext.h>
|
3 |
+
#include <assert.h>
|
4 |
+
#include <stdio.h>
|
5 |
+
#include <stdlib.h>
|
6 |
+
#include <string.h>
|
7 |
+
#include <torch/torch.h>
|
8 |
+
|
9 |
+
/* Includes, cuda */
|
10 |
+
#include <cublas_v2.h>
|
11 |
+
#include <cuda_runtime.h>
|
12 |
+
|
13 |
+
|
14 |
+
// BF16 Tensor core wrapper around cublas GEMMEx
|
15 |
+
cublasStatus_t gemmex_wrapper(
|
16 |
+
cublasHandle_t handle,
|
17 |
+
cublasOperation_t transa,
|
18 |
+
cublasOperation_t transb,
|
19 |
+
int m,
|
20 |
+
int n,
|
21 |
+
int k,
|
22 |
+
const float* alpha,
|
23 |
+
at::BFloat16* A,
|
24 |
+
int lda,
|
25 |
+
at::BFloat16* B,
|
26 |
+
int ldb,
|
27 |
+
const float* beta,
|
28 |
+
float* C,
|
29 |
+
int ldc) {
|
30 |
+
return cublasGemmEx(
|
31 |
+
handle,
|
32 |
+
transa,
|
33 |
+
transb,
|
34 |
+
m,
|
35 |
+
n,
|
36 |
+
k,
|
37 |
+
alpha,
|
38 |
+
A,
|
39 |
+
CUDA_R_16BF,
|
40 |
+
lda,
|
41 |
+
B,
|
42 |
+
CUDA_R_16BF,
|
43 |
+
ldb,
|
44 |
+
beta,
|
45 |
+
C,
|
46 |
+
CUDA_R_32F,
|
47 |
+
ldc,
|
48 |
+
CUDA_R_32F,
|
49 |
+
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
|
50 |
+
}
|
51 |
+
|
52 |
+
// FP16 Tensor core wrapper around cublas GEMMEx
|
53 |
+
cublasStatus_t gemmex_wrapper(
|
54 |
+
cublasHandle_t handle,
|
55 |
+
cublasOperation_t transa,
|
56 |
+
cublasOperation_t transb,
|
57 |
+
int m,
|
58 |
+
int n,
|
59 |
+
int k,
|
60 |
+
const float* alpha,
|
61 |
+
at::Half* A,
|
62 |
+
int lda,
|
63 |
+
at::Half* B,
|
64 |
+
int ldb,
|
65 |
+
const float* beta,
|
66 |
+
float* C,
|
67 |
+
int ldc) {
|
68 |
+
return cublasGemmEx(
|
69 |
+
handle,
|
70 |
+
transa,
|
71 |
+
transb,
|
72 |
+
m,
|
73 |
+
n,
|
74 |
+
k,
|
75 |
+
alpha,
|
76 |
+
A,
|
77 |
+
CUDA_R_16F,
|
78 |
+
lda,
|
79 |
+
B,
|
80 |
+
CUDA_R_16F,
|
81 |
+
ldb,
|
82 |
+
beta,
|
83 |
+
C,
|
84 |
+
CUDA_R_32F,
|
85 |
+
ldc,
|
86 |
+
CUDA_R_32F,
|
87 |
+
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
|
88 |
+
}
|
89 |
+
|
90 |
+
// FP32 Tensor core wrapper around cublas GEMMEx
|
91 |
+
cublasStatus_t gemmex_wrapper(
|
92 |
+
cublasHandle_t handle,
|
93 |
+
cublasOperation_t transa,
|
94 |
+
cublasOperation_t transb,
|
95 |
+
int m,
|
96 |
+
int n,
|
97 |
+
int k,
|
98 |
+
const float* alpha,
|
99 |
+
float* A,
|
100 |
+
int lda,
|
101 |
+
float* B,
|
102 |
+
int ldb,
|
103 |
+
const float* beta,
|
104 |
+
float* C,
|
105 |
+
int ldc) {
|
106 |
+
return cublasGemmEx(
|
107 |
+
handle,
|
108 |
+
transa,
|
109 |
+
transb,
|
110 |
+
m,
|
111 |
+
n,
|
112 |
+
k,
|
113 |
+
alpha,
|
114 |
+
A,
|
115 |
+
CUDA_R_32F,
|
116 |
+
lda,
|
117 |
+
B,
|
118 |
+
CUDA_R_32F,
|
119 |
+
ldb,
|
120 |
+
beta,
|
121 |
+
C,
|
122 |
+
CUDA_R_32F,
|
123 |
+
ldc,
|
124 |
+
CUDA_R_32F,
|
125 |
+
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
|
126 |
+
}
|
127 |
+
|
128 |
+
template <typename T>
|
129 |
+
int wgrad_gemm_accum_fp32_cuda(T *input, T *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim) {
|
130 |
+
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
|
131 |
+
cudaStream_t stream;
|
132 |
+
cublasGetStream(handle, &stream);
|
133 |
+
const float alpha = 1.0;
|
134 |
+
const float beta = 1.0;
|
135 |
+
int status = 1;
|
136 |
+
|
137 |
+
status = gemmex_wrapper(
|
138 |
+
handle,
|
139 |
+
CUBLAS_OP_N,
|
140 |
+
CUBLAS_OP_T,
|
141 |
+
in_dim,
|
142 |
+
out_dim,
|
143 |
+
hidden_dim,
|
144 |
+
&alpha,
|
145 |
+
input,
|
146 |
+
in_dim,
|
147 |
+
d_output,
|
148 |
+
out_dim,
|
149 |
+
&beta,
|
150 |
+
d_weight,
|
151 |
+
in_dim);
|
152 |
+
return status;
|
153 |
+
}
|
154 |
+
|
155 |
+
template int wgrad_gemm_accum_fp32_cuda<at::Half>(at::Half *input, at::Half *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim);
|
156 |
+
template int wgrad_gemm_accum_fp32_cuda<at::BFloat16>(at::BFloat16 *input, at::BFloat16 *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim);
|
157 |
+
template int wgrad_gemm_accum_fp32_cuda<float>(float *input, float *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim);
|
megatron/fused_kernels/layer_norm_cuda.cpp
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* coding=utf-8
|
2 |
+
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
|
17 |
+
/*This code is copied fron NVIDIA apex:
|
18 |
+
* https://github.com/NVIDIA/apex
|
19 |
+
* with minor changes. */
|
20 |
+
|
21 |
+
#include <torch/extension.h>
|
22 |
+
#include <vector>
|
23 |
+
#include <cassert>
|
24 |
+
#include "compat.h"
|
25 |
+
|
26 |
+
namespace {
|
27 |
+
|
28 |
+
void compute_n1_n2(
|
29 |
+
at::Tensor input,
|
30 |
+
at::IntArrayRef normalized_shape,
|
31 |
+
int& n1,
|
32 |
+
int& n2) {
|
33 |
+
int idiff = input.ndimension() - normalized_shape.size();
|
34 |
+
n2 = 1;
|
35 |
+
for (int i = 0; i < (int)normalized_shape.size(); ++i) {
|
36 |
+
assert( input.sizes()[i+idiff] == normalized_shape[i] );
|
37 |
+
n2 *= normalized_shape[i];
|
38 |
+
}
|
39 |
+
n1 = 1;
|
40 |
+
for (int i = 0; i < idiff; ++i) {
|
41 |
+
n1 *= input.sizes()[i];
|
42 |
+
}
|
43 |
+
}
|
44 |
+
|
45 |
+
void check_args(
|
46 |
+
at::IntArrayRef normalized_shape,
|
47 |
+
at::Tensor gamma,
|
48 |
+
at::Tensor beta
|
49 |
+
)
|
50 |
+
{
|
51 |
+
TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape));
|
52 |
+
TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape));
|
53 |
+
}
|
54 |
+
|
55 |
+
void check_args(
|
56 |
+
at::Tensor input,
|
57 |
+
at::IntArrayRef normalized_shape,
|
58 |
+
int& n1,
|
59 |
+
int& n2
|
60 |
+
)
|
61 |
+
{
|
62 |
+
int64_t normalized_ndim = normalized_shape.size();
|
63 |
+
|
64 |
+
if (normalized_ndim < 1) {
|
65 |
+
std::stringstream ss;
|
66 |
+
ss << "Expected normalized_shape to be at least 1-dimensional, i.e., "
|
67 |
+
<< "containing at least one element, but got normalized_shape="
|
68 |
+
<< normalized_shape;
|
69 |
+
throw std::runtime_error(ss.str());
|
70 |
+
}
|
71 |
+
|
72 |
+
auto input_shape = input.sizes();
|
73 |
+
auto input_ndim = input.dim();
|
74 |
+
|
75 |
+
if (input_ndim < normalized_ndim ||
|
76 |
+
!input_shape.slice(input_ndim - normalized_ndim).equals(normalized_shape)) {
|
77 |
+
std::stringstream ss;
|
78 |
+
ss << "Given normalized_shape=" << normalized_shape
|
79 |
+
<< ", expected input with shape [*";
|
80 |
+
for (auto size : normalized_shape) {
|
81 |
+
ss << ", " << size;
|
82 |
+
}
|
83 |
+
ss << "], but got input of size" << input_shape;
|
84 |
+
throw std::runtime_error(ss.str());
|
85 |
+
}
|
86 |
+
|
87 |
+
compute_n1_n2(input,normalized_shape,n1,n2);
|
88 |
+
}
|
89 |
+
|
90 |
+
|
91 |
+
void check_args(
|
92 |
+
at::Tensor input,
|
93 |
+
at::IntArrayRef normalized_shape,
|
94 |
+
at::Tensor gamma,
|
95 |
+
at::Tensor beta,
|
96 |
+
int& n1,
|
97 |
+
int& n2
|
98 |
+
)
|
99 |
+
{
|
100 |
+
check_args(input,normalized_shape,n1,n2);
|
101 |
+
check_args(normalized_shape,gamma,beta);
|
102 |
+
}
|
103 |
+
}
|
104 |
+
|
105 |
+
void cuda_layer_norm(
|
106 |
+
at::Tensor* output,
|
107 |
+
at::Tensor* mean,
|
108 |
+
at::Tensor* invvar,
|
109 |
+
at::Tensor* input,
|
110 |
+
int n1,
|
111 |
+
int n2,
|
112 |
+
at::IntArrayRef normalized_shape,
|
113 |
+
at::Tensor* gamma,
|
114 |
+
at::Tensor* beta,
|
115 |
+
double epsilon);
|
116 |
+
|
117 |
+
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
|
118 |
+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
119 |
+
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
120 |
+
|
121 |
+
std::vector<at::Tensor> layer_norm_affine(
|
122 |
+
at::Tensor input,
|
123 |
+
at::IntArrayRef normalized_shape,
|
124 |
+
at::Tensor gamma,
|
125 |
+
at::Tensor beta,
|
126 |
+
double epsilon) {
|
127 |
+
|
128 |
+
CHECK_INPUT(input);
|
129 |
+
CHECK_INPUT(gamma);
|
130 |
+
CHECK_INPUT(beta);
|
131 |
+
int n1, n2;
|
132 |
+
check_args(input, normalized_shape, gamma, beta, n1, n2);
|
133 |
+
|
134 |
+
at::Tensor output = at::empty_like(
|
135 |
+
input, gamma.options().dtype(gamma.scalar_type()));
|
136 |
+
at::Tensor mean = at::empty(
|
137 |
+
{n1}, input.options().dtype(at::ScalarType::Float));
|
138 |
+
at::Tensor invvar = at::empty_like(mean);
|
139 |
+
|
140 |
+
cuda_layer_norm(&output, &mean, &invvar, &input, n1, n2,
|
141 |
+
normalized_shape, &gamma, &beta, epsilon);
|
142 |
+
|
143 |
+
return {output, mean, invvar};
|
144 |
+
|
145 |
+
}
|
146 |
+
|
147 |
+
|
148 |
+
void cuda_layer_norm_gradient(
|
149 |
+
at::Tensor* dout,
|
150 |
+
at::Tensor* mean,
|
151 |
+
at::Tensor* invvar,
|
152 |
+
at::Tensor* input,
|
153 |
+
int n1,
|
154 |
+
int n2,
|
155 |
+
at::IntArrayRef normalized_shape,
|
156 |
+
at::Tensor* gamma,
|
157 |
+
at::Tensor* beta,
|
158 |
+
double epsilon,
|
159 |
+
at::Tensor* grad_input,
|
160 |
+
at::Tensor* grad_gamma,
|
161 |
+
at::Tensor* grad_beta
|
162 |
+
);
|
163 |
+
|
164 |
+
std::vector<at::Tensor> layer_norm_gradient_affine(
|
165 |
+
at::Tensor dout,
|
166 |
+
at::Tensor mean,
|
167 |
+
at::Tensor invvar,
|
168 |
+
at::Tensor input,
|
169 |
+
at::IntArrayRef normalized_shape,
|
170 |
+
at::Tensor gamma,
|
171 |
+
at::Tensor beta,
|
172 |
+
double epsilon) {
|
173 |
+
|
174 |
+
CHECK_INPUT(dout);
|
175 |
+
CHECK_INPUT(mean);
|
176 |
+
CHECK_INPUT(invvar);
|
177 |
+
CHECK_INPUT(input);
|
178 |
+
CHECK_INPUT(gamma);
|
179 |
+
CHECK_INPUT(beta);
|
180 |
+
int n1, n2;
|
181 |
+
check_args(input, normalized_shape, gamma, beta, n1, n2);
|
182 |
+
|
183 |
+
at::Tensor grad_input = at::empty_like(input);
|
184 |
+
at::Tensor grad_gamma = at::empty_like(gamma);
|
185 |
+
at::Tensor grad_beta = at::empty_like(beta);
|
186 |
+
|
187 |
+
cuda_layer_norm_gradient(&dout, &mean, &invvar, &input, n1, n2,
|
188 |
+
normalized_shape, &gamma, &beta, epsilon,
|
189 |
+
&grad_input, &grad_gamma, &grad_beta);
|
190 |
+
|
191 |
+
return {grad_input, grad_gamma, grad_beta};
|
192 |
+
|
193 |
+
}
|
194 |
+
|
195 |
+
|
196 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
197 |
+
m.def("forward_affine", &layer_norm_affine,
|
198 |
+
"LayerNorm forward (CUDA)");
|
199 |
+
m.def("backward_affine", &layer_norm_gradient_affine,
|
200 |
+
"LayerNorm backward (CUDA)");
|
201 |
+
}
|
megatron/fused_kernels/layer_norm_cuda_kernel.cu
ADDED
@@ -0,0 +1,832 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* coding=utf-8
|
2 |
+
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
|
17 |
+
/*This code is copied fron NVIDIA apex:
|
18 |
+
* https://github.com/NVIDIA/apex
|
19 |
+
* with minor changes. */
|
20 |
+
|
21 |
+
#include "ATen/ATen.h"
|
22 |
+
#include "ATen/AccumulateType.h"
|
23 |
+
#include "ATen/cuda/CUDAContext.h"
|
24 |
+
#include "ATen/cuda/DeviceUtils.cuh"
|
25 |
+
|
26 |
+
#include <cuda.h>
|
27 |
+
#include <cuda_runtime.h>
|
28 |
+
|
29 |
+
#include "type_shim.h"
|
30 |
+
|
31 |
+
template<typename U> __device__
|
32 |
+
void cuWelfordOnlineSum(
|
33 |
+
const U curr,
|
34 |
+
U& mu,
|
35 |
+
U& sigma2,
|
36 |
+
U& count)
|
37 |
+
{
|
38 |
+
count = count + U(1);
|
39 |
+
U delta = curr - mu;
|
40 |
+
U lmean = mu + delta / count;
|
41 |
+
mu = lmean;
|
42 |
+
U delta2 = curr - lmean;
|
43 |
+
sigma2 = sigma2 + delta * delta2;
|
44 |
+
}
|
45 |
+
|
46 |
+
template<typename U> __device__
|
47 |
+
void cuChanOnlineSum(
|
48 |
+
const U muB,
|
49 |
+
const U sigma2B,
|
50 |
+
const U countB,
|
51 |
+
U& mu,
|
52 |
+
U& sigma2,
|
53 |
+
U& count)
|
54 |
+
{
|
55 |
+
U delta = muB - mu;
|
56 |
+
U nA = count;
|
57 |
+
U nB = countB;
|
58 |
+
count = count + countB;
|
59 |
+
U nX = count;
|
60 |
+
if (nX > U(0)) {
|
61 |
+
nA = nA / nX;
|
62 |
+
nB = nB / nX;
|
63 |
+
mu = nA*mu + nB*muB;
|
64 |
+
sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX;
|
65 |
+
} else {
|
66 |
+
mu = U(0);
|
67 |
+
sigma2 = U(0);
|
68 |
+
}
|
69 |
+
}
|
70 |
+
|
71 |
+
template<typename T, typename U> __device__
|
72 |
+
void cuWelfordMuSigma2(
|
73 |
+
const T* __restrict__ vals,
|
74 |
+
const int n1,
|
75 |
+
const int n2,
|
76 |
+
const int i1,
|
77 |
+
U& mu,
|
78 |
+
U& sigma2,
|
79 |
+
U* buf)
|
80 |
+
{
|
81 |
+
// Assumptions:
|
82 |
+
// 1) blockDim.x == warpSize
|
83 |
+
// 2) Tensor is contiguous
|
84 |
+
// 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
|
85 |
+
//
|
86 |
+
// compute variance and mean over n2
|
87 |
+
U count = U(0);
|
88 |
+
mu= U(0);
|
89 |
+
sigma2 = U(0);
|
90 |
+
if (i1 < n1) {
|
91 |
+
// one warp normalizes one n1 index,
|
92 |
+
// synchronization is implicit
|
93 |
+
// initialize with standard Welford algorithm
|
94 |
+
const int numx = blockDim.x * blockDim.y;
|
95 |
+
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
|
96 |
+
const T* lvals = vals + i1*n2;
|
97 |
+
int l = 4*thrx;
|
98 |
+
for (; l+3 < n2; l+=4*numx) {
|
99 |
+
for (int k = 0; k < 4; ++k) {
|
100 |
+
U curr = static_cast<U>(lvals[l+k]);
|
101 |
+
cuWelfordOnlineSum<U>(curr,mu,sigma2,count);
|
102 |
+
}
|
103 |
+
}
|
104 |
+
for (; l < n2; ++l) {
|
105 |
+
U curr = static_cast<U>(lvals[l]);
|
106 |
+
cuWelfordOnlineSum<U>(curr,mu,sigma2,count);
|
107 |
+
}
|
108 |
+
// intra-warp reductions
|
109 |
+
for (int l = 0; l <= 4; ++l) {
|
110 |
+
int srcLaneB = (threadIdx.x+(1<<l))&31;
|
111 |
+
U muB = WARP_SHFL(mu, srcLaneB);
|
112 |
+
U countB = WARP_SHFL(count, srcLaneB);
|
113 |
+
U sigma2B = WARP_SHFL(sigma2, srcLaneB);
|
114 |
+
cuChanOnlineSum<U>(muB,sigma2B,countB,mu,sigma2,count);
|
115 |
+
}
|
116 |
+
// threadIdx.x == 0 has correct values for each warp
|
117 |
+
// inter-warp reductions
|
118 |
+
if (blockDim.y > 1) {
|
119 |
+
U* ubuf = (U*)buf;
|
120 |
+
U* ibuf = (U*)(ubuf + blockDim.y);
|
121 |
+
for (int offset = blockDim.y/2; offset > 0; offset /= 2) {
|
122 |
+
// upper half of warps write to shared
|
123 |
+
if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) {
|
124 |
+
const int wrt_y = threadIdx.y - offset;
|
125 |
+
ubuf[2*wrt_y] = mu;
|
126 |
+
ubuf[2*wrt_y+1] = sigma2;
|
127 |
+
ibuf[wrt_y] = count;
|
128 |
+
}
|
129 |
+
__syncthreads();
|
130 |
+
// lower half merges
|
131 |
+
if (threadIdx.x == 0 && threadIdx.y < offset) {
|
132 |
+
U muB = ubuf[2*threadIdx.y];
|
133 |
+
U sigma2B = ubuf[2*threadIdx.y+1];
|
134 |
+
U countB = ibuf[threadIdx.y];
|
135 |
+
cuChanOnlineSum<U>(muB,sigma2B,countB,mu,sigma2,count);
|
136 |
+
}
|
137 |
+
__syncthreads();
|
138 |
+
}
|
139 |
+
// threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
|
140 |
+
if (threadIdx.x == 0 && threadIdx.y == 0) {
|
141 |
+
ubuf[0] = mu;
|
142 |
+
ubuf[1] = sigma2;
|
143 |
+
}
|
144 |
+
__syncthreads();
|
145 |
+
mu = ubuf[0];
|
146 |
+
sigma2 = ubuf[1]/U(n2);
|
147 |
+
// don't care about final value of count, we know count == n2
|
148 |
+
} else {
|
149 |
+
mu = WARP_SHFL(mu, 0);
|
150 |
+
sigma2 = WARP_SHFL(sigma2/U(n2), 0);
|
151 |
+
}
|
152 |
+
}
|
153 |
+
}
|
154 |
+
|
155 |
+
template<> __device__
|
156 |
+
void cuWelfordMuSigma2(
|
157 |
+
const at::Half* __restrict__ vals,
|
158 |
+
const int n1,
|
159 |
+
const int n2,
|
160 |
+
const int i1,
|
161 |
+
float& mu,
|
162 |
+
float& sigma2,
|
163 |
+
float* buf)
|
164 |
+
{
|
165 |
+
// Assumptions:
|
166 |
+
// 1) blockDim.x == warpSize
|
167 |
+
// 2) Tensor is contiguous
|
168 |
+
// 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
|
169 |
+
//
|
170 |
+
// compute variance and mean over n2
|
171 |
+
float count = 0.0f;
|
172 |
+
mu= float(0);
|
173 |
+
sigma2 = float(0);
|
174 |
+
if (i1 < n1) {
|
175 |
+
// one warp normalizes one n1 index,
|
176 |
+
// synchronization is implicit
|
177 |
+
// initialize with standard Welford algorithm
|
178 |
+
const int numx = blockDim.x * blockDim.y;
|
179 |
+
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
|
180 |
+
const at::Half* lvals = vals + i1*n2;
|
181 |
+
int l = 8*thrx;
|
182 |
+
if ((((size_t)lvals)&3) != 0) {
|
183 |
+
// 16 bit alignment
|
184 |
+
// first thread consumes first point
|
185 |
+
if (thrx == 0) {
|
186 |
+
float curr = static_cast<float>(lvals[0]);
|
187 |
+
cuWelfordOnlineSum(curr,mu,sigma2,count);
|
188 |
+
}
|
189 |
+
++l;
|
190 |
+
}
|
191 |
+
// at this point, lvals[l] are 32 bit aligned for all threads.
|
192 |
+
for (; l+7 < n2; l+=8*numx) {
|
193 |
+
for (int k = 0; k < 8; k+=2) {
|
194 |
+
float2 curr = __half22float2(*((__half2*)(lvals+l+k)));
|
195 |
+
cuWelfordOnlineSum(curr.x,mu,sigma2,count);
|
196 |
+
cuWelfordOnlineSum(curr.y,mu,sigma2,count);
|
197 |
+
}
|
198 |
+
}
|
199 |
+
for (; l < n2; ++l) {
|
200 |
+
float curr = static_cast<float>(lvals[l]);
|
201 |
+
cuWelfordOnlineSum(curr,mu,sigma2,count);
|
202 |
+
}
|
203 |
+
// intra-warp reductions
|
204 |
+
for (int l = 0; l <= 4; ++l) {
|
205 |
+
int srcLaneB = (threadIdx.x+(1<<l))&31;
|
206 |
+
float muB = WARP_SHFL(mu, srcLaneB);
|
207 |
+
float countB = WARP_SHFL(count, srcLaneB);
|
208 |
+
float sigma2B = WARP_SHFL(sigma2, srcLaneB);
|
209 |
+
cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count);
|
210 |
+
}
|
211 |
+
// threadIdx.x == 0 has correct values for each warp
|
212 |
+
// inter-warp reductions
|
213 |
+
if (blockDim.y > 1) {
|
214 |
+
float* ubuf = (float*)buf;
|
215 |
+
float* ibuf = (float*)(ubuf + blockDim.y);
|
216 |
+
for (int offset = blockDim.y/2; offset > 0; offset /= 2) {
|
217 |
+
// upper half of warps write to shared
|
218 |
+
if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) {
|
219 |
+
const int wrt_y = threadIdx.y - offset;
|
220 |
+
ubuf[2*wrt_y] = mu;
|
221 |
+
ubuf[2*wrt_y+1] = sigma2;
|
222 |
+
ibuf[wrt_y] = count;
|
223 |
+
}
|
224 |
+
__syncthreads();
|
225 |
+
// lower half merges
|
226 |
+
if (threadIdx.x == 0 && threadIdx.y < offset) {
|
227 |
+
float muB = ubuf[2*threadIdx.y];
|
228 |
+
float sigma2B = ubuf[2*threadIdx.y+1];
|
229 |
+
float countB = ibuf[threadIdx.y];
|
230 |
+
cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count);
|
231 |
+
}
|
232 |
+
__syncthreads();
|
233 |
+
}
|
234 |
+
// threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
|
235 |
+
if (threadIdx.x == 0 && threadIdx.y == 0) {
|
236 |
+
ubuf[0] = mu;
|
237 |
+
ubuf[1] = sigma2;
|
238 |
+
}
|
239 |
+
__syncthreads();
|
240 |
+
mu = ubuf[0];
|
241 |
+
sigma2 = ubuf[1]/float(n2);
|
242 |
+
// don't care about final value of count, we know count == n2
|
243 |
+
} else {
|
244 |
+
mu = WARP_SHFL(mu, 0);
|
245 |
+
sigma2 = WARP_SHFL(sigma2/float(n2), 0);
|
246 |
+
}
|
247 |
+
}
|
248 |
+
}
|
249 |
+
|
250 |
+
template<typename U> U rsqrt(U v) {
|
251 |
+
return U(1) / sqrt(v);
|
252 |
+
}
|
253 |
+
template<> float rsqrt(float v) {
|
254 |
+
return rsqrtf(v);
|
255 |
+
}
|
256 |
+
template<> double rsqrt(double v) {
|
257 |
+
return rsqrt(v);
|
258 |
+
}
|
259 |
+
|
260 |
+
namespace {
|
261 |
+
// This is the un-specialized struct. Note that we prevent instantiation of this
|
262 |
+
// struct by putting an undefined symbol in the function body so it won't compile.
|
263 |
+
// template <typename T>
|
264 |
+
// struct SharedMemory
|
265 |
+
// {
|
266 |
+
// // Ensure that we won't compile any un-specialized types
|
267 |
+
// __device__ T *getPointer()
|
268 |
+
// {
|
269 |
+
// extern __device__ void error(void);
|
270 |
+
// error();
|
271 |
+
// return NULL;
|
272 |
+
// }
|
273 |
+
// };
|
274 |
+
// https://github.com/NVIDIA/apex/issues/246
|
275 |
+
template <typename T>
|
276 |
+
struct SharedMemory;
|
277 |
+
|
278 |
+
template <>
|
279 |
+
struct SharedMemory <float>
|
280 |
+
{
|
281 |
+
__device__ float *getPointer()
|
282 |
+
{
|
283 |
+
extern __shared__ float s_float[];
|
284 |
+
return s_float;
|
285 |
+
}
|
286 |
+
};
|
287 |
+
|
288 |
+
}
|
289 |
+
|
290 |
+
template<typename T, typename U, typename V> __global__
|
291 |
+
void cuApplyLayerNorm(
|
292 |
+
V* __restrict__ output_vals,
|
293 |
+
U* __restrict__ mean,
|
294 |
+
U* __restrict__ invvar,
|
295 |
+
const T* __restrict__ vals,
|
296 |
+
const int n1,
|
297 |
+
const int n2,
|
298 |
+
const U epsilon,
|
299 |
+
const V* __restrict__ gamma,
|
300 |
+
const V* __restrict__ beta
|
301 |
+
)
|
302 |
+
{
|
303 |
+
// Assumptions:
|
304 |
+
// 1) blockDim.x == warpSize
|
305 |
+
// 2) Tensors are contiguous
|
306 |
+
//
|
307 |
+
for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
|
308 |
+
SharedMemory<U> shared;
|
309 |
+
U* buf = shared.getPointer();
|
310 |
+
U mu,sigma2;
|
311 |
+
cuWelfordMuSigma2(vals,n1,n2,i1,mu,sigma2,buf);
|
312 |
+
const T* lvals = vals + i1*n2;
|
313 |
+
V* ovals = output_vals + i1*n2;
|
314 |
+
U c_invvar = rsqrt(sigma2 + epsilon);
|
315 |
+
const int numx = blockDim.x * blockDim.y;
|
316 |
+
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
|
317 |
+
if (gamma != NULL && beta != NULL) {
|
318 |
+
for (int i = thrx; i < n2; i+=numx) {
|
319 |
+
U curr = static_cast<U>(lvals[i]);
|
320 |
+
ovals[i] = gamma[i] * static_cast<V>(c_invvar * (curr - mu)) + beta[i];
|
321 |
+
}
|
322 |
+
} else {
|
323 |
+
for (int i = thrx; i < n2; i+=numx) {
|
324 |
+
U curr = static_cast<U>(lvals[i]);
|
325 |
+
ovals[i] = static_cast<V>(c_invvar * (curr - mu));
|
326 |
+
}
|
327 |
+
}
|
328 |
+
if (threadIdx.x == 0 && threadIdx.y == 0) {
|
329 |
+
mean[i1] = mu;
|
330 |
+
invvar[i1] = c_invvar;
|
331 |
+
}
|
332 |
+
__syncthreads();
|
333 |
+
}
|
334 |
+
}
|
335 |
+
|
336 |
+
template<typename T, typename U, typename V> __device__
|
337 |
+
void cuLoadWriteStridedInputs(
|
338 |
+
const int i1_block,
|
339 |
+
const int thr_load_row_off,
|
340 |
+
const int thr_load_col_off,
|
341 |
+
const int i2_off,
|
342 |
+
const int row_stride,
|
343 |
+
U* warp_buf1,
|
344 |
+
U* warp_buf2,
|
345 |
+
const T* input,
|
346 |
+
const V* dout,
|
347 |
+
const int i1_end,
|
348 |
+
const int n2,
|
349 |
+
const U* __restrict__ mean,
|
350 |
+
const U* __restrict__ invvar
|
351 |
+
)
|
352 |
+
{
|
353 |
+
int i1 = i1_block+thr_load_row_off;
|
354 |
+
if (i1 < i1_end) {
|
355 |
+
U curr_mean = mean[i1];
|
356 |
+
U curr_invvar = invvar[i1];
|
357 |
+
for (int k = 0; k < blockDim.y; ++k) {
|
358 |
+
int i2 = i2_off + k;
|
359 |
+
int load_idx = i1*n2+i2;
|
360 |
+
int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
|
361 |
+
if (i2<n2) {
|
362 |
+
U curr_input = static_cast<U>(input[load_idx]);
|
363 |
+
U curr_dout = static_cast<U>(dout[load_idx]);
|
364 |
+
warp_buf1[write_idx] = curr_dout;
|
365 |
+
warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar;
|
366 |
+
} else {
|
367 |
+
warp_buf1[write_idx] = U(0);
|
368 |
+
warp_buf2[write_idx] = U(0);
|
369 |
+
}
|
370 |
+
}
|
371 |
+
} else {
|
372 |
+
for (int k = 0; k < blockDim.y; ++k) {
|
373 |
+
int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
|
374 |
+
warp_buf1[write_idx] = U(0);
|
375 |
+
warp_buf2[write_idx] = U(0);
|
376 |
+
}
|
377 |
+
}
|
378 |
+
}
|
379 |
+
|
380 |
+
template<typename T, typename U, typename V> __device__
|
381 |
+
void cuLoadAddStridedInputs(
|
382 |
+
const int i1_block,
|
383 |
+
const int thr_load_row_off,
|
384 |
+
const int thr_load_col_off,
|
385 |
+
const int i2_off,
|
386 |
+
const int row_stride,
|
387 |
+
U* warp_buf1,
|
388 |
+
U* warp_buf2,
|
389 |
+
const T* input,
|
390 |
+
const V* dout,
|
391 |
+
const int i1_end,
|
392 |
+
const int n2,
|
393 |
+
const U* __restrict__ mean,
|
394 |
+
const U* __restrict__ invvar
|
395 |
+
)
|
396 |
+
{
|
397 |
+
int i1 = i1_block+thr_load_row_off;
|
398 |
+
if (i1 < i1_end) {
|
399 |
+
U curr_mean = mean[i1];
|
400 |
+
U curr_invvar = invvar[i1];
|
401 |
+
for (int k = 0; k < blockDim.y; ++k) {
|
402 |
+
int i2 = i2_off + k;
|
403 |
+
int load_idx = i1*n2+i2;
|
404 |
+
int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
|
405 |
+
if (i2<n2) {
|
406 |
+
U curr_input = static_cast<U>(input[load_idx]);
|
407 |
+
U curr_dout = static_cast<U>(dout[load_idx]);
|
408 |
+
warp_buf1[write_idx] += curr_dout;
|
409 |
+
warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar;
|
410 |
+
}
|
411 |
+
}
|
412 |
+
}
|
413 |
+
}
|
414 |
+
|
415 |
+
template<typename T, typename U, typename V> __global__
|
416 |
+
void cuComputePartGradGammaBeta(
|
417 |
+
const V* __restrict__ dout,
|
418 |
+
const T* __restrict__ input,
|
419 |
+
const int n1,
|
420 |
+
const int n2,
|
421 |
+
const U* __restrict__ mean,
|
422 |
+
const U* __restrict__ invvar,
|
423 |
+
U epsilon,
|
424 |
+
U* part_grad_gamma,
|
425 |
+
U* part_grad_beta)
|
426 |
+
{
|
427 |
+
const int numsegs_n1 = (n1+blockDim.y*blockDim.y-1) / (blockDim.y*blockDim.y);
|
428 |
+
const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y;
|
429 |
+
const int i1_beg = blockIdx.y * segs_per_block * blockDim.y*blockDim.y;
|
430 |
+
const int i1_beg_plus_one = (blockIdx.y+1) * segs_per_block * blockDim.y*blockDim.y;
|
431 |
+
const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1;
|
432 |
+
const int row_stride = blockDim.x+1;
|
433 |
+
const int thr_load_col_off = (threadIdx.x*blockDim.y)&(blockDim.x-1);
|
434 |
+
const int thr_load_row_off = (threadIdx.x*blockDim.y)/blockDim.x + threadIdx.y*blockDim.y;
|
435 |
+
const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off;
|
436 |
+
SharedMemory<U> shared;
|
437 |
+
U* buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y * blockDim.y + (blockDim.y - 1)*(blockDim.x/blockDim.y) elements
|
438 |
+
U* warp_buf1 = (U*)buf;
|
439 |
+
U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride;
|
440 |
+
// compute partial sums from strided inputs
|
441 |
+
// do this to increase number of loads in flight
|
442 |
+
cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar);
|
443 |
+
for (int i1_block = i1_beg+blockDim.y*blockDim.y; i1_block < i1_end; i1_block+=blockDim.y*blockDim.y) {
|
444 |
+
cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar);
|
445 |
+
}
|
446 |
+
__syncthreads();
|
447 |
+
// inter-warp reductions
|
448 |
+
// sum within each warp
|
449 |
+
U acc1 = U(0);
|
450 |
+
U acc2 = U(0);
|
451 |
+
for (int k = 0; k < blockDim.y; ++k) {
|
452 |
+
int row1 = threadIdx.y + k*blockDim.y;
|
453 |
+
int idx1 = row1*row_stride + threadIdx.x;
|
454 |
+
acc1 += warp_buf1[idx1];
|
455 |
+
acc2 += warp_buf2[idx1];
|
456 |
+
}
|
457 |
+
warp_buf1[threadIdx.y*row_stride+threadIdx.x] = acc1;
|
458 |
+
warp_buf2[threadIdx.y*row_stride+threadIdx.x] = acc2;
|
459 |
+
__syncthreads();
|
460 |
+
// sum all warps
|
461 |
+
for (int offset = blockDim.y/2; offset > 1; offset /= 2) {
|
462 |
+
if (threadIdx.y < offset) {
|
463 |
+
int row1 = threadIdx.y;
|
464 |
+
int row2 = threadIdx.y + offset;
|
465 |
+
int idx1 = row1*row_stride + threadIdx.x;
|
466 |
+
int idx2 = row2*row_stride + threadIdx.x;
|
467 |
+
warp_buf1[idx1] += warp_buf1[idx2];
|
468 |
+
warp_buf2[idx1] += warp_buf2[idx2];
|
469 |
+
}
|
470 |
+
__syncthreads();
|
471 |
+
}
|
472 |
+
int i2 = blockIdx.x * blockDim.x + threadIdx.x;
|
473 |
+
if (threadIdx.y == 0 && i2 < n2) {
|
474 |
+
int row1 = threadIdx.y;
|
475 |
+
int row2 = threadIdx.y + 1;
|
476 |
+
int idx1 = row1*row_stride + threadIdx.x;
|
477 |
+
int idx2 = row2*row_stride + threadIdx.x;
|
478 |
+
part_grad_beta[blockIdx.y*n2+i2] = warp_buf1[idx1] + warp_buf1[idx2];
|
479 |
+
part_grad_gamma[blockIdx.y*n2+i2] = warp_buf2[idx1] + warp_buf2[idx2];
|
480 |
+
}
|
481 |
+
}
|
482 |
+
|
483 |
+
template<typename U, typename V> __global__
|
484 |
+
void cuComputeGradGammaBeta(
|
485 |
+
const U* part_grad_gamma,
|
486 |
+
const U* part_grad_beta,
|
487 |
+
const int part_size,
|
488 |
+
const int n1,
|
489 |
+
const int n2,
|
490 |
+
V* grad_gamma,
|
491 |
+
V* grad_beta)
|
492 |
+
{
|
493 |
+
// sum partial gradients for gamma and beta
|
494 |
+
SharedMemory<U> shared;
|
495 |
+
U* buf = shared.getPointer();
|
496 |
+
int i2 = blockIdx.x * blockDim.x + threadIdx.x;
|
497 |
+
if (i2 < n2) {
|
498 |
+
// each warp does sequential reductions until reduced part_size is num_warps
|
499 |
+
int num_warp_reductions = part_size / blockDim.y;
|
500 |
+
U sum_gamma = U(0);
|
501 |
+
U sum_beta = U(0);
|
502 |
+
const U* part_grad_gamma_ptr = part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2;
|
503 |
+
const U* part_grad_beta_ptr = part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2;
|
504 |
+
for (int warp_offset = 0; warp_offset < num_warp_reductions; ++warp_offset) {
|
505 |
+
sum_gamma += part_grad_gamma_ptr[warp_offset*n2];
|
506 |
+
sum_beta += part_grad_beta_ptr[warp_offset*n2];
|
507 |
+
}
|
508 |
+
// inter-warp reductions
|
509 |
+
const int nbsize3 = blockDim.x * blockDim.y / 2;
|
510 |
+
for (int offset = blockDim.y/2; offset >= 1; offset /= 2) {
|
511 |
+
// top half write to shared memory
|
512 |
+
if (threadIdx.y >= offset && threadIdx.y < 2*offset) {
|
513 |
+
const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
|
514 |
+
buf[write_idx] = sum_gamma;
|
515 |
+
buf[write_idx+nbsize3] = sum_beta;
|
516 |
+
}
|
517 |
+
__syncthreads();
|
518 |
+
// bottom half sums
|
519 |
+
if (threadIdx.y < offset) {
|
520 |
+
const int read_idx = threadIdx.y * blockDim.x + threadIdx.x;
|
521 |
+
sum_gamma += buf[read_idx];
|
522 |
+
sum_beta += buf[read_idx+nbsize3];
|
523 |
+
}
|
524 |
+
__syncthreads();
|
525 |
+
}
|
526 |
+
// write out fully summed gradients
|
527 |
+
if (threadIdx.y == 0) {
|
528 |
+
grad_gamma[i2] = sum_gamma;
|
529 |
+
grad_beta[i2] = sum_beta;
|
530 |
+
}
|
531 |
+
}
|
532 |
+
}
|
533 |
+
|
534 |
+
template<typename T, typename U, typename V> __global__
|
535 |
+
void cuComputeGradInput(
|
536 |
+
const V* __restrict__ dout,
|
537 |
+
const T* __restrict__ input,
|
538 |
+
const int n1,
|
539 |
+
const int n2,
|
540 |
+
const U* __restrict__ mean,
|
541 |
+
const U* __restrict__ invvar,
|
542 |
+
U epsilon,
|
543 |
+
const V* gamma,
|
544 |
+
T* grad_input)
|
545 |
+
{
|
546 |
+
for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
|
547 |
+
U sum_loss1 = U(0);
|
548 |
+
U sum_loss2 = U(0);
|
549 |
+
const U c_mean = mean[i1];
|
550 |
+
const U c_invvar = invvar[i1];
|
551 |
+
const T* k_input = input + i1*n2;
|
552 |
+
const V* k_dout = dout + i1*n2;
|
553 |
+
const int numx = blockDim.x * blockDim.y;
|
554 |
+
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
|
555 |
+
if (gamma != NULL) {
|
556 |
+
int l = 4*thrx;
|
557 |
+
for (; l+3 < n2; l+=4*numx) {
|
558 |
+
for (int k = 0; k < 4; ++k) {
|
559 |
+
const U c_h = static_cast<U>(k_input[l+k]);
|
560 |
+
const U c_loss = static_cast<U>(k_dout[l+k]);
|
561 |
+
sum_loss1 += c_loss * gamma[l+k];
|
562 |
+
sum_loss2 += c_loss * gamma[l+k] * (c_h - c_mean) * c_invvar;
|
563 |
+
}
|
564 |
+
}
|
565 |
+
for (; l < n2; ++l) {
|
566 |
+
const U c_h = static_cast<U>(k_input[l]);
|
567 |
+
const U c_loss = static_cast<U>(k_dout[l]);
|
568 |
+
sum_loss1 += c_loss * gamma[l];
|
569 |
+
sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar;
|
570 |
+
}
|
571 |
+
} else {
|
572 |
+
int l = 4*thrx;
|
573 |
+
for (; l+3 < n2; l+=4*numx) {
|
574 |
+
for (int k = 0; k < 4; ++k) {
|
575 |
+
const U c_h = static_cast<U>(k_input[l+k]);
|
576 |
+
const U c_loss = static_cast<U>(k_dout[l+k]);
|
577 |
+
sum_loss1 += c_loss;
|
578 |
+
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
|
579 |
+
}
|
580 |
+
}
|
581 |
+
for (; l < n2; ++l) {
|
582 |
+
const U c_h = static_cast<U>(k_input[l]);
|
583 |
+
const U c_loss = static_cast<U>(k_dout[l]);
|
584 |
+
sum_loss1 += c_loss;
|
585 |
+
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
|
586 |
+
}
|
587 |
+
}
|
588 |
+
// intra-warp reductions
|
589 |
+
for (int mask = blockDim.x/2; mask > 0; mask /= 2) {
|
590 |
+
sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask);
|
591 |
+
sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask);
|
592 |
+
}
|
593 |
+
// inter-warp reductions
|
594 |
+
if (blockDim.y > 1) {
|
595 |
+
SharedMemory<U> shared;
|
596 |
+
U* buf = shared.getPointer();
|
597 |
+
for (int offset = blockDim.y/2; offset > 0; offset /= 2) {
|
598 |
+
// upper half of warps write to shared
|
599 |
+
if (threadIdx.y >= offset && threadIdx.y < 2*offset) {
|
600 |
+
const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
|
601 |
+
buf[2*wrt_i] = sum_loss1;
|
602 |
+
buf[2*wrt_i+1] = sum_loss2;
|
603 |
+
}
|
604 |
+
__syncthreads();
|
605 |
+
// lower half merges
|
606 |
+
if (threadIdx.y < offset) {
|
607 |
+
const int read_i = threadIdx.y * blockDim.x + threadIdx.x;
|
608 |
+
sum_loss1 += buf[2*read_i];
|
609 |
+
sum_loss2 += buf[2*read_i+1];
|
610 |
+
}
|
611 |
+
__syncthreads();
|
612 |
+
}
|
613 |
+
if (threadIdx.y == 0) {
|
614 |
+
buf[2*threadIdx.x] = sum_loss1;
|
615 |
+
buf[2*threadIdx.x+1] = sum_loss2;
|
616 |
+
}
|
617 |
+
__syncthreads();
|
618 |
+
if (threadIdx.y !=0) {
|
619 |
+
sum_loss1 = buf[2*threadIdx.x];
|
620 |
+
sum_loss2 = buf[2*threadIdx.x+1];
|
621 |
+
}
|
622 |
+
}
|
623 |
+
// all threads now have the two sums over l
|
624 |
+
U fH = (U)n2;
|
625 |
+
U term1 = (U(1) / fH) * c_invvar;
|
626 |
+
T* k_grad_input = grad_input + i1*n2;
|
627 |
+
if (gamma != NULL) {
|
628 |
+
for (int l = thrx; l < n2; l+=numx) {
|
629 |
+
const U c_h = static_cast<U>(k_input[l]);
|
630 |
+
const U c_loss = static_cast<U>(k_dout[l]);
|
631 |
+
U f_grad_input = fH * c_loss * gamma[l];
|
632 |
+
f_grad_input -= sum_loss1;
|
633 |
+
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
|
634 |
+
f_grad_input *= term1;
|
635 |
+
k_grad_input[l] = static_cast<T>(f_grad_input);
|
636 |
+
}
|
637 |
+
} else {
|
638 |
+
for (int l = thrx; l < n2; l+=numx) {
|
639 |
+
const U c_h = static_cast<U>(k_input[l]);
|
640 |
+
const U c_loss = static_cast<U>(k_dout[l]);
|
641 |
+
U f_grad_input = fH * c_loss;
|
642 |
+
f_grad_input -= sum_loss1;
|
643 |
+
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
|
644 |
+
f_grad_input *= term1;
|
645 |
+
k_grad_input[l] = static_cast<T>(f_grad_input);
|
646 |
+
}
|
647 |
+
}
|
648 |
+
// prevent race where buf is written again before reads are done
|
649 |
+
__syncthreads();
|
650 |
+
}
|
651 |
+
}
|
652 |
+
|
653 |
+
|
654 |
+
|
655 |
+
|
656 |
+
template<typename T, typename U, typename V>
|
657 |
+
void HostApplyLayerNorm(
|
658 |
+
V* output,
|
659 |
+
U* mean,
|
660 |
+
U* invvar,
|
661 |
+
const T* input,
|
662 |
+
int n1,
|
663 |
+
int n2,
|
664 |
+
double epsilon,
|
665 |
+
const V* gamma,
|
666 |
+
const V* beta
|
667 |
+
)
|
668 |
+
{
|
669 |
+
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
670 |
+
const dim3 threads(32,4,1);
|
671 |
+
const uint64_t maxGridY =
|
672 |
+
at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
|
673 |
+
const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1);
|
674 |
+
int nshared =
|
675 |
+
threads.y > 1 ?
|
676 |
+
threads.y*sizeof(U)+(threads.y/2)*sizeof(U) :
|
677 |
+
0;
|
678 |
+
cuApplyLayerNorm<<<blocks, threads, nshared, stream>>>(
|
679 |
+
output,
|
680 |
+
mean,
|
681 |
+
invvar,
|
682 |
+
input,
|
683 |
+
n1,n2,
|
684 |
+
U(epsilon),
|
685 |
+
gamma,beta);
|
686 |
+
}
|
687 |
+
|
688 |
+
|
689 |
+
void cuda_layer_norm(
|
690 |
+
at::Tensor* output,
|
691 |
+
at::Tensor* mean,
|
692 |
+
at::Tensor* invvar,
|
693 |
+
at::Tensor* input,
|
694 |
+
int n1,
|
695 |
+
int n2,
|
696 |
+
#ifdef VERSION_GE_1_1
|
697 |
+
at::IntArrayRef normalized_shape,
|
698 |
+
#else
|
699 |
+
at::IntList normalized_shape,
|
700 |
+
#endif
|
701 |
+
at::Tensor* gamma,
|
702 |
+
at::Tensor* beta,
|
703 |
+
double epsilon)
|
704 |
+
{
|
705 |
+
using namespace at;
|
706 |
+
DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
|
707 |
+
input->scalar_type(), output->scalar_type(), "cuda_layer_norm_kernel",
|
708 |
+
HostApplyLayerNorm(
|
709 |
+
output->DATA_PTR<scalar_t_out>(),
|
710 |
+
mean->DATA_PTR<float>(),
|
711 |
+
invvar->DATA_PTR<float>(),
|
712 |
+
input->DATA_PTR<scalar_t_in>(),
|
713 |
+
n1,n2,
|
714 |
+
epsilon,
|
715 |
+
gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
|
716 |
+
beta != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL);
|
717 |
+
)
|
718 |
+
}
|
719 |
+
|
720 |
+
|
721 |
+
template<typename T, typename U, typename V>
|
722 |
+
void HostLayerNormGradient(
|
723 |
+
const V* dout,
|
724 |
+
const U* mean,
|
725 |
+
const U* invvar,
|
726 |
+
at::Tensor* input,
|
727 |
+
int n1,
|
728 |
+
int n2,
|
729 |
+
const V* gamma,
|
730 |
+
const V* beta,
|
731 |
+
double epsilon,
|
732 |
+
T* grad_input,
|
733 |
+
V* grad_gamma,
|
734 |
+
V* grad_beta
|
735 |
+
)
|
736 |
+
{
|
737 |
+
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
738 |
+
|
739 |
+
if (gamma != NULL && beta != NULL) {
|
740 |
+
// compute grad_gamma(j) and grad_beta(j)
|
741 |
+
const int part_size = 16;
|
742 |
+
const dim3 threads2(32,4,1);
|
743 |
+
const dim3 blocks2((n2+threads2.x-1)/threads2.x,part_size,1);
|
744 |
+
const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y *
|
745 |
+
(threads2.x + 1);
|
746 |
+
const int nshared2_b = threads2.x * threads2.y * sizeof(U);
|
747 |
+
const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;
|
748 |
+
at::Tensor part_grad_gamma = at::empty(
|
749 |
+
{part_size,n2}, input->options().dtype(at::ScalarType::Float));
|
750 |
+
at::Tensor part_grad_beta = at::empty_like(part_grad_gamma);
|
751 |
+
cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>(
|
752 |
+
dout,
|
753 |
+
input->DATA_PTR<T>(),
|
754 |
+
n1,n2,
|
755 |
+
mean,
|
756 |
+
invvar,
|
757 |
+
U(epsilon),
|
758 |
+
part_grad_gamma.DATA_PTR<U>(),
|
759 |
+
part_grad_beta.DATA_PTR<U>());
|
760 |
+
|
761 |
+
const dim3 threads3(32,8,1);
|
762 |
+
const dim3 blocks3((n2+threads2.x-1)/threads2.x,1,1);
|
763 |
+
const int nshared3 = threads3.x * threads3.y * sizeof(U);
|
764 |
+
cuComputeGradGammaBeta<<<blocks3, threads3, nshared3, stream>>>(
|
765 |
+
part_grad_gamma.DATA_PTR<U>(),
|
766 |
+
part_grad_beta.DATA_PTR<U>(),
|
767 |
+
part_size,
|
768 |
+
n1,n2,
|
769 |
+
grad_gamma,
|
770 |
+
grad_beta);
|
771 |
+
}
|
772 |
+
|
773 |
+
// compute grad_input
|
774 |
+
const uint64_t maxGridY =
|
775 |
+
at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
|
776 |
+
const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1);
|
777 |
+
const dim3 threads1(32,4,1);
|
778 |
+
int nshared =
|
779 |
+
threads1.y > 1 ?
|
780 |
+
threads1.y*threads1.x*sizeof(U) :
|
781 |
+
0;
|
782 |
+
cuComputeGradInput<<<blocks1, threads1, nshared, stream>>>(
|
783 |
+
dout,
|
784 |
+
input->DATA_PTR<T>(),
|
785 |
+
n1,n2,
|
786 |
+
mean,
|
787 |
+
invvar,
|
788 |
+
U(epsilon),
|
789 |
+
gamma,
|
790 |
+
grad_input);
|
791 |
+
}
|
792 |
+
|
793 |
+
|
794 |
+
void cuda_layer_norm_gradient(
|
795 |
+
at::Tensor* dout,
|
796 |
+
at::Tensor* mean,
|
797 |
+
at::Tensor* invvar,
|
798 |
+
at::Tensor* input,
|
799 |
+
int n1,
|
800 |
+
int n2,
|
801 |
+
#ifdef VERSION_GE_1_1
|
802 |
+
at::IntArrayRef normalized_shape,
|
803 |
+
#else
|
804 |
+
at::IntList normalized_shape,
|
805 |
+
#endif
|
806 |
+
at::Tensor* gamma,
|
807 |
+
at::Tensor* beta,
|
808 |
+
double epsilon,
|
809 |
+
at::Tensor* grad_input,
|
810 |
+
at::Tensor* grad_gamma,
|
811 |
+
at::Tensor* grad_beta)
|
812 |
+
{
|
813 |
+
using namespace at;
|
814 |
+
DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
|
815 |
+
input->scalar_type(), gamma->scalar_type(),
|
816 |
+
"cuda_layer_norm_gradient_kernel",
|
817 |
+
HostLayerNormGradient(
|
818 |
+
dout->DATA_PTR<scalar_t_out>(),
|
819 |
+
mean->DATA_PTR<float>(),
|
820 |
+
invvar->DATA_PTR<float>(),
|
821 |
+
input,
|
822 |
+
n1,n2,
|
823 |
+
// TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
|
824 |
+
// if gamma Tensor is NULL on input.
|
825 |
+
gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
|
826 |
+
gamma != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL,
|
827 |
+
epsilon,
|
828 |
+
grad_input->DATA_PTR<scalar_t_in>(),
|
829 |
+
gamma != NULL ? grad_gamma->DATA_PTR<scalar_t_out>() : NULL,
|
830 |
+
gamma != NULL ? grad_beta->DATA_PTR<scalar_t_out>() : NULL);
|
831 |
+
)
|
832 |
+
}
|