Spaces:
Runtime error
Runtime error
| # coding=utf-8 | |
| # Copyright 2018 The Google AI Language Team Authors, and NVIDIA. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # Most of the code here has been copied from: | |
| # https://github.com/google-research/albert/blob/master/create_pretraining_data.py | |
| # with some modifications. | |
| import math | |
| import time | |
| import collections | |
| import numpy as np | |
| import re | |
| from fengshen.data.megatron_dataloader.utils import ( | |
| print_rank_0 | |
| ) | |
| from fengshen.data.megatron_dataloader.blendable_dataset import BlendableDataset | |
| from fengshen.data.megatron_dataloader.indexed_dataset import make_dataset as make_indexed_dataset | |
| DSET_TYPE_BERT = 'standard_bert' | |
| DSET_TYPE_ICT = 'ict' | |
| DSET_TYPE_T5 = 't5' | |
| DSET_TYPE_BERT_CN_WWM = 'bert_cn_wwm' | |
| DSET_TYPE_BART = 'bart' | |
| DSET_TYPE_COCOLM = 'coco_lm' | |
| DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT, | |
| DSET_TYPE_T5, DSET_TYPE_BERT_CN_WWM, | |
| DSET_TYPE_BART, DSET_TYPE_COCOLM] | |
| def get_datasets_weights_and_num_samples(data_prefix, | |
| train_valid_test_num_samples): | |
| # The data prefix should be in the format of: | |
| # weight-1, data-prefix-1, weight-2, data-prefix-2, .. | |
| assert len(data_prefix) % 2 == 0 | |
| num_datasets = len(data_prefix) // 2 | |
| weights = [0] * num_datasets | |
| prefixes = [0] * num_datasets | |
| for i in range(num_datasets): | |
| weights[i] = float(data_prefix[2 * i]) | |
| prefixes[i] = (data_prefix[2 * i + 1]).strip() | |
| # Normalize weights | |
| weight_sum = 0.0 | |
| for weight in weights: | |
| weight_sum += weight | |
| assert weight_sum > 0.0 | |
| weights = [weight / weight_sum for weight in weights] | |
| # Add 0.5% (the 1.005 factor) so in case the bleding dataset does | |
| # not uniformly distribute the number of samples, we still have | |
| # samples left to feed to the network. | |
| datasets_train_valid_test_num_samples = [] | |
| for weight in weights: | |
| datasets_train_valid_test_num_samples.append( | |
| [int(math.ceil(val * weight * 1.005)) | |
| for val in train_valid_test_num_samples]) | |
| return prefixes, weights, datasets_train_valid_test_num_samples | |
| def compile_helper(): | |
| """Compile helper function ar runtime. Make sure this | |
| is invoked on a single process.""" | |
| import os | |
| import subprocess | |
| path = os.path.abspath(os.path.dirname(__file__)) | |
| ret = subprocess.run(['make', '-C', path]) | |
| if ret.returncode != 0: | |
| print("Making C++ dataset helpers module failed, exiting.") | |
| import sys | |
| sys.exit(1) | |
| def get_a_and_b_segments(sample, np_rng): | |
| """Divide sample into a and b segments.""" | |
| # Number of sentences in the sample. | |
| n_sentences = len(sample) | |
| # Make sure we always have two sentences. | |
| assert n_sentences > 1, 'make sure each sample has at least two sentences.' | |
| # First part: | |
| # `a_end` is how many sentences go into the `A`. | |
| a_end = 1 | |
| if n_sentences >= 3: | |
| # Note that randin in numpy is exclusive. | |
| a_end = np_rng.randint(1, n_sentences) | |
| tokens_a = [] | |
| for j in range(a_end): | |
| tokens_a.extend(sample[j]) | |
| # Second part: | |
| tokens_b = [] | |
| for j in range(a_end, n_sentences): | |
| tokens_b.extend(sample[j]) | |
| # Random next: | |
| is_next_random = False | |
| if np_rng.random() < 0.5: | |
| is_next_random = True | |
| tokens_a, tokens_b = tokens_b, tokens_a | |
| return tokens_a, tokens_b, is_next_random | |
| def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng): | |
| """Truncates a pair of sequences to a maximum sequence length.""" | |
| # print(len_a, len_b, max_num_tokens) | |
| assert len_a > 0 | |
| if len_a + len_b <= max_num_tokens: | |
| return False | |
| while len_a + len_b > max_num_tokens: | |
| if len_a > len_b: | |
| len_a -= 1 | |
| tokens = tokens_a | |
| else: | |
| len_b -= 1 | |
| tokens = tokens_b | |
| if np_rng.random() < 0.5: | |
| del tokens[0] | |
| else: | |
| tokens.pop() | |
| return True | |
| def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id): | |
| """Merge segments A and B, add [CLS] and [SEP] and build tokentypes.""" | |
| tokens = [] | |
| tokentypes = [] | |
| # [CLS]. | |
| tokens.append(cls_id) | |
| tokentypes.append(0) | |
| # Segment A. | |
| for token in tokens_a: | |
| tokens.append(token) | |
| tokentypes.append(0) | |
| # [SEP]. | |
| tokens.append(sep_id) | |
| tokentypes.append(0) | |
| # Segment B. | |
| for token in tokens_b: | |
| tokens.append(token) | |
| tokentypes.append(1) | |
| if tokens_b: | |
| # [SEP]. | |
| tokens.append(sep_id) | |
| tokentypes.append(1) | |
| return tokens, tokentypes | |
| MaskedLmInstance = collections.namedtuple("MaskedLmInstance", | |
| ["index", "label"]) | |
| def is_start_piece(piece): | |
| """Check if the current word piece is the starting piece (BERT).""" | |
| # When a word has been split into | |
| # WordPieces, the first token does not have any marker and any subsequence | |
| # tokens are prefixed with ##. So whenever we see the ## token, we | |
| # append it to the previous set of word indexes. | |
| return not piece.startswith("##") | |
| def create_masked_lm_predictions(tokens, | |
| vocab_id_list, vocab_id_to_token_dict, | |
| masked_lm_prob, | |
| cls_id, sep_id, mask_id, | |
| max_predictions_per_seq, | |
| np_rng, | |
| tokenizer, | |
| max_ngrams=3, | |
| do_whole_word_mask=True, | |
| favor_longer_ngram=False, | |
| do_permutation=False, | |
| geometric_dist=False, | |
| masking_style="bert", | |
| zh_tokenizer=None): | |
| """Creates the predictions for the masked LM objective. | |
| Note: Tokens here are vocab ids and not text tokens.""" | |
| cand_indexes = [] | |
| # Note(mingdachen): We create a list for recording if the piece is | |
| # the starting piece of current token, where 1 means true, so that | |
| # on-the-fly whole word masking is possible. | |
| token_boundary = [0] * len(tokens) | |
| # 如果没有指定中文分词器,那就直接按##算 | |
| if zh_tokenizer is None: | |
| for (i, token) in enumerate(tokens): | |
| if token == cls_id or token == sep_id: | |
| token_boundary[i] = 1 | |
| continue | |
| # Whole Word Masking means that if we mask all of the wordpieces | |
| # corresponding to an original word. | |
| # | |
| # Note that Whole Word Masking does *not* change the training code | |
| # at all -- we still predict each WordPiece independently, softmaxed | |
| # over the entire vocabulary. | |
| if (do_whole_word_mask and len(cand_indexes) >= 1 and | |
| not is_start_piece(vocab_id_to_token_dict[token])): | |
| cand_indexes[-1].append(i) | |
| else: | |
| cand_indexes.append([i]) | |
| if is_start_piece(vocab_id_to_token_dict[token]): | |
| token_boundary[i] = 1 | |
| else: | |
| # 如果指定了中文分词器,那就先用分词器分词,然后再进行判断 | |
| # 获取去掉CLS SEP的原始文本 | |
| raw_tokens = [] | |
| for t in tokens: | |
| if t != cls_id and t != sep_id: | |
| raw_tokens.append(t) | |
| raw_tokens = [vocab_id_to_token_dict[i] for i in raw_tokens] | |
| # 分词然后获取每次字开头的最长词的长度 | |
| word_list = set(zh_tokenizer(''.join(raw_tokens), HMM=True)) | |
| word_length_dict = {} | |
| for w in word_list: | |
| if len(w) < 1: | |
| continue | |
| if w[0] not in word_length_dict: | |
| word_length_dict[w[0]] = len(w) | |
| elif word_length_dict[w[0]] < len(w): | |
| word_length_dict[w[0]] = len(w) | |
| i = 0 | |
| # 从词表里面检索 | |
| while i < len(tokens): | |
| token_id = tokens[i] | |
| token = vocab_id_to_token_dict[token_id] | |
| if len(token) == 0 or token_id == cls_id or token_id == sep_id: | |
| token_boundary[i] = 1 | |
| i += 1 | |
| continue | |
| word_max_length = 1 | |
| if token[0] in word_length_dict: | |
| word_max_length = word_length_dict[token[0]] | |
| j = 0 | |
| word = '' | |
| word_end = i+1 | |
| # 兼容以前##的形式,如果后面的词是##开头的,那么直接把后面的拼到前面当作一个词 | |
| old_style = False | |
| while word_end < len(tokens) and vocab_id_to_token_dict[tokens[word_end]].startswith('##'): | |
| old_style = True | |
| word_end += 1 | |
| if not old_style: | |
| while j < word_max_length and i+j < len(tokens): | |
| cur_token = tokens[i+j] | |
| word += vocab_id_to_token_dict[cur_token] | |
| j += 1 | |
| if word in word_list: | |
| word_end = i+j | |
| cand_indexes.append([p for p in range(i, word_end)]) | |
| token_boundary[i] = 1 | |
| i = word_end | |
| output_tokens = list(tokens) | |
| # add by ganruyi | |
| if masking_style == 'bert-cn-wwm': | |
| # if non chinese is False, that means it is chinese | |
| # then try to remove "##" which is added previously | |
| new_token_ids = [] | |
| for token_id in output_tokens: | |
| token = tokenizer.convert_ids_to_tokens([token_id])[0] | |
| if len(re.findall('##[\u4E00-\u9FA5]', token)) > 0: | |
| token = token[2:] | |
| new_token_id = tokenizer.convert_tokens_to_ids([token])[ | |
| 0] | |
| new_token_ids.append(new_token_id) | |
| output_tokens = new_token_ids | |
| masked_lm_positions = [] | |
| masked_lm_labels = [] | |
| if masked_lm_prob == 0: | |
| return (output_tokens, masked_lm_positions, | |
| masked_lm_labels, token_boundary) | |
| num_to_predict = min(max_predictions_per_seq, | |
| max(1, int(round(len(tokens) * masked_lm_prob)))) | |
| ngrams = np.arange(1, max_ngrams + 1, dtype=np.int64) | |
| if not geometric_dist: | |
| # Note(mingdachen): | |
| # By default, we set the probilities to favor shorter ngram sequences. | |
| pvals = 1. / np.arange(1, max_ngrams + 1) | |
| pvals /= pvals.sum(keepdims=True) | |
| if favor_longer_ngram: | |
| pvals = pvals[::-1] | |
| # 获取一个ngram的idx,对于每个word,记录他的ngram的word | |
| ngram_indexes = [] | |
| for idx in range(len(cand_indexes)): | |
| ngram_index = [] | |
| for n in ngrams: | |
| ngram_index.append(cand_indexes[idx:idx + n]) | |
| ngram_indexes.append(ngram_index) | |
| np_rng.shuffle(ngram_indexes) | |
| (masked_lms, masked_spans) = ([], []) | |
| covered_indexes = set() | |
| for cand_index_set in ngram_indexes: | |
| if len(masked_lms) >= num_to_predict: | |
| break | |
| if not cand_index_set: | |
| continue | |
| # Note(mingdachen): | |
| # Skip current piece if they are covered in lm masking or previous ngrams. | |
| for index_set in cand_index_set[0]: | |
| for index in index_set: | |
| if index in covered_indexes: | |
| continue | |
| if not geometric_dist: | |
| n = np_rng.choice(ngrams[:len(cand_index_set)], | |
| p=pvals[:len(cand_index_set)] / | |
| pvals[:len(cand_index_set)].sum(keepdims=True)) | |
| else: | |
| # Sampling "n" from the geometric distribution and clipping it to | |
| # the max_ngrams. Using p=0.2 default from the SpanBERT paper | |
| # https://arxiv.org/pdf/1907.10529.pdf (Sec 3.1) | |
| n = min(np_rng.geometric(0.2), max_ngrams) | |
| index_set = sum(cand_index_set[n - 1], []) | |
| n -= 1 | |
| # Note(mingdachen): | |
| # Repeatedly looking for a candidate that does not exceed the | |
| # maximum number of predictions by trying shorter ngrams. | |
| while len(masked_lms) + len(index_set) > num_to_predict: | |
| if n == 0: | |
| break | |
| index_set = sum(cand_index_set[n - 1], []) | |
| n -= 1 | |
| # If adding a whole-word mask would exceed the maximum number of | |
| # predictions, then just skip this candidate. | |
| if len(masked_lms) + len(index_set) > num_to_predict: | |
| continue | |
| is_any_index_covered = False | |
| for index in index_set: | |
| if index in covered_indexes: | |
| is_any_index_covered = True | |
| break | |
| if is_any_index_covered: | |
| continue | |
| for index in index_set: | |
| covered_indexes.add(index) | |
| masked_token = None | |
| if masking_style == "bert": | |
| # 80% of the time, replace with [MASK] | |
| if np_rng.random() < 0.8: | |
| masked_token = mask_id | |
| else: | |
| # 10% of the time, keep original | |
| if np_rng.random() < 0.5: | |
| masked_token = tokens[index] | |
| # 10% of the time, replace with random word | |
| else: | |
| masked_token = vocab_id_list[np_rng.randint(0, len(vocab_id_list))] | |
| elif masking_style == 'bert-cn-wwm': | |
| # 80% of the time, replace with [MASK] | |
| if np_rng.random() < 0.8: | |
| masked_token = mask_id | |
| else: | |
| # 10% of the time, keep original | |
| if np_rng.random() < 0.5: | |
| # 如果是中文全词mask,去掉tokens里的## | |
| token_id = tokens[index] | |
| token = tokenizer.convert_ids_to_tokens([token_id])[ | |
| 0] | |
| if len(re.findall('##[\u4E00-\u9FA5]', token)) > 0: | |
| token = token[2:] | |
| new_token_id = tokenizer.convert_tokens_to_ids([token])[ | |
| 0] | |
| masked_token = new_token_id | |
| # 10% of the time, replace with random word | |
| else: | |
| masked_token = vocab_id_list[np_rng.randint( | |
| 0, len(vocab_id_list))] | |
| elif masking_style == "t5": | |
| masked_token = mask_id | |
| else: | |
| raise ValueError("invalid value of masking style") | |
| output_tokens[index] = masked_token | |
| masked_lms.append(MaskedLmInstance( | |
| index=index, label=tokens[index])) | |
| masked_spans.append(MaskedLmInstance( | |
| index=index_set, | |
| label=[tokens[index] for index in index_set])) | |
| assert len(masked_lms) <= num_to_predict | |
| np_rng.shuffle(ngram_indexes) | |
| select_indexes = set() | |
| if do_permutation: | |
| for cand_index_set in ngram_indexes: | |
| if len(select_indexes) >= num_to_predict: | |
| break | |
| if not cand_index_set: | |
| continue | |
| # Note(mingdachen): | |
| # Skip current piece if they are covered in lm masking or previous ngrams. | |
| for index_set in cand_index_set[0]: | |
| for index in index_set: | |
| if index in covered_indexes or index in select_indexes: | |
| continue | |
| n = np.random.choice(ngrams[:len(cand_index_set)], | |
| p=pvals[:len(cand_index_set)] / | |
| pvals[:len(cand_index_set)].sum(keepdims=True)) | |
| index_set = sum(cand_index_set[n - 1], []) | |
| n -= 1 | |
| while len(select_indexes) + len(index_set) > num_to_predict: | |
| if n == 0: | |
| break | |
| index_set = sum(cand_index_set[n - 1], []) | |
| n -= 1 | |
| # If adding a whole-word mask would exceed the maximum number of | |
| # predictions, then just skip this candidate. | |
| if len(select_indexes) + len(index_set) > num_to_predict: | |
| continue | |
| is_any_index_covered = False | |
| for index in index_set: | |
| if index in covered_indexes or index in select_indexes: | |
| is_any_index_covered = True | |
| break | |
| if is_any_index_covered: | |
| continue | |
| for index in index_set: | |
| select_indexes.add(index) | |
| assert len(select_indexes) <= num_to_predict | |
| select_indexes = sorted(select_indexes) | |
| permute_indexes = list(select_indexes) | |
| np_rng.shuffle(permute_indexes) | |
| orig_token = list(output_tokens) | |
| for src_i, tgt_i in zip(select_indexes, permute_indexes): | |
| output_tokens[src_i] = orig_token[tgt_i] | |
| masked_lms.append(MaskedLmInstance( | |
| index=src_i, label=orig_token[src_i])) | |
| masked_lms = sorted(masked_lms, key=lambda x: x.index) | |
| # Sort the spans by the index of the first span | |
| masked_spans = sorted(masked_spans, key=lambda x: x.index[0]) | |
| for p in masked_lms: | |
| masked_lm_positions.append(p.index) | |
| masked_lm_labels.append(p.label) | |
| return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary, masked_spans) | |
| def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, | |
| masked_labels, pad_id, max_seq_length): | |
| """Pad sequences and convert them to numpy.""" | |
| # Some checks. | |
| num_tokens = len(tokens) | |
| padding_length = max_seq_length - num_tokens | |
| assert padding_length >= 0 | |
| assert len(tokentypes) == num_tokens | |
| assert len(masked_positions) == len(masked_labels) | |
| # Tokens and token types. | |
| filler = [pad_id] * padding_length | |
| tokens_np = np.array(tokens + filler, dtype=np.int64) | |
| tokentypes_np = np.array(tokentypes + filler, dtype=np.int64) | |
| # Padding mask. | |
| padding_mask_np = np.array([1] * num_tokens + [0] * padding_length, | |
| dtype=np.int64) | |
| # Lables and loss mask. | |
| labels = [-1] * max_seq_length | |
| loss_mask = [0] * max_seq_length | |
| for i in range(len(masked_positions)): | |
| assert masked_positions[i] < num_tokens | |
| labels[masked_positions[i]] = masked_labels[i] | |
| loss_mask[masked_positions[i]] = 1 | |
| labels_np = np.array(labels, dtype=np.int64) | |
| loss_mask_np = np.array(loss_mask, dtype=np.int64) | |
| return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np | |
| def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, | |
| train_valid_test_num_samples, | |
| max_seq_length, | |
| masked_lm_prob, short_seq_prob, seed, | |
| tokenizer, | |
| skip_warmup, binary_head=False, | |
| max_seq_length_dec=None, | |
| dataset_type='standard_bert', | |
| zh_tokenizer=None, | |
| span=None): | |
| if len(data_prefix) == 1: | |
| return _build_train_valid_test_datasets(data_prefix[0], | |
| data_impl, splits_string, | |
| train_valid_test_num_samples, | |
| max_seq_length, masked_lm_prob, | |
| short_seq_prob, seed, | |
| skip_warmup, | |
| binary_head, | |
| max_seq_length_dec, | |
| tokenizer, | |
| dataset_type=dataset_type, | |
| zh_tokenizer=zh_tokenizer, | |
| span=span) | |
| # Blending dataset. | |
| # Parse the values. | |
| output = get_datasets_weights_and_num_samples(data_prefix, | |
| train_valid_test_num_samples) | |
| prefixes, weights, datasets_train_valid_test_num_samples = output | |
| # Build individual datasets. | |
| train_datasets = [] | |
| valid_datasets = [] | |
| test_datasets = [] | |
| for i in range(len(prefixes)): | |
| train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( | |
| prefixes[i], data_impl, splits_string, | |
| datasets_train_valid_test_num_samples[i], | |
| max_seq_length, masked_lm_prob, short_seq_prob, | |
| seed, skip_warmup, binary_head, max_seq_length_dec, | |
| tokenizer, dataset_type=dataset_type, zh_tokenizer=zh_tokenizer) | |
| if train_ds: | |
| train_datasets.append(train_ds) | |
| if valid_ds: | |
| valid_datasets.append(valid_ds) | |
| if test_ds: | |
| test_datasets.append(test_ds) | |
| # Blend. | |
| blending_train_dataset = None | |
| if train_datasets: | |
| blending_train_dataset = BlendableDataset(train_datasets, weights) | |
| blending_valid_dataset = None | |
| if valid_datasets: | |
| blending_valid_dataset = BlendableDataset(valid_datasets, weights) | |
| blending_test_dataset = None | |
| if test_datasets: | |
| blending_test_dataset = BlendableDataset(test_datasets, weights) | |
| return (blending_train_dataset, blending_valid_dataset, | |
| blending_test_dataset) | |
| def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, | |
| train_valid_test_num_samples, | |
| max_seq_length, | |
| masked_lm_prob, short_seq_prob, seed, | |
| skip_warmup, binary_head, | |
| max_seq_length_dec, | |
| tokenizer, | |
| dataset_type='standard_bert', | |
| zh_tokenizer=None, | |
| span=None): | |
| if dataset_type not in DSET_TYPES: | |
| raise ValueError("Invalid dataset_type: ", dataset_type) | |
| # Indexed dataset. | |
| indexed_dataset = get_indexed_dataset_(data_prefix, | |
| data_impl, | |
| skip_warmup) | |
| # Get start and end indices of train/valid/train into doc-idx | |
| # Note that doc-idx is desinged to be num-docs + 1 so we can | |
| # easily iterate over it. | |
| total_num_of_documents = indexed_dataset.doc_idx.shape[0] - 1 | |
| splits = get_train_valid_test_split_(splits_string, total_num_of_documents) | |
| # Print stats about the splits. | |
| print_rank_0(' > dataset split:') | |
| def print_split_stats(name, index): | |
| print_rank_0(' {}:'.format(name)) | |
| print_rank_0(' document indices in [{}, {}) total of {} ' | |
| 'documents'.format(splits[index], splits[index + 1], | |
| splits[index + 1] - splits[index])) | |
| start_index = indexed_dataset.doc_idx[splits[index]] | |
| end_index = indexed_dataset.doc_idx[splits[index + 1]] | |
| print_rank_0(' sentence indices in [{}, {}) total of {} ' | |
| 'sentences'.format(start_index, end_index, | |
| end_index - start_index)) | |
| print_split_stats('train', 0) | |
| print_split_stats('validation', 1) | |
| print_split_stats('test', 2) | |
| def build_dataset(index, name): | |
| from fengshen.data.megatron_dataloader.bert_dataset import BertDataset | |
| from fengshen.data.megatron_dataloader.bart_dataset import BartDataset | |
| from fengshen.data.megatron_dataloader.cocolm_dataset import COCOLMDataset | |
| dataset = None | |
| if splits[index + 1] > splits[index]: | |
| # Get the pointer to the original doc-idx so we can set it later. | |
| doc_idx_ptr = indexed_dataset.get_doc_idx() | |
| # Slice the doc-idx | |
| start_index = splits[index] | |
| # Add +1 so we can index into the dataset to get the upper bound. | |
| end_index = splits[index + 1] + 1 | |
| # New doc_idx view. | |
| indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index]) | |
| # Build the dataset accordingly. | |
| kwargs = dict( | |
| name=name, | |
| data_prefix=data_prefix, | |
| num_epochs=None, | |
| max_num_samples=train_valid_test_num_samples[index], | |
| max_seq_length=max_seq_length, | |
| seed=seed, | |
| ) | |
| if dataset_type == DSET_TYPE_BERT or dataset_type == DSET_TYPE_BERT_CN_WWM: | |
| dataset = BertDataset( | |
| indexed_dataset=indexed_dataset, | |
| masked_lm_prob=masked_lm_prob, | |
| short_seq_prob=short_seq_prob, | |
| binary_head=binary_head, | |
| # 增加参数区分bert和bert-cn-wwm | |
| tokenizer=tokenizer, | |
| masking_style='bert' if dataset_type == DSET_TYPE_BERT else 'bert-cn-wwm', | |
| **kwargs | |
| ) | |
| elif dataset_type == DSET_TYPE_BART: | |
| dataset = BartDataset( | |
| indexed_dataset=indexed_dataset, | |
| masked_lm_prob=masked_lm_prob, | |
| short_seq_prob=short_seq_prob, | |
| tokenizer=tokenizer, | |
| zh_tokenizer=zh_tokenizer, | |
| **kwargs | |
| ) | |
| elif dataset_type == DSET_TYPE_COCOLM: | |
| dataset = COCOLMDataset( | |
| indexed_dataset=indexed_dataset, | |
| masked_lm_prob=masked_lm_prob, | |
| short_seq_prob=short_seq_prob, | |
| tokenizer=tokenizer, | |
| masking_style='bert', | |
| span=span, | |
| **kwargs | |
| ) | |
| else: | |
| raise NotImplementedError( | |
| "Dataset type not fully implemented.") | |
| # Set the original pointer so dataset remains the main dataset. | |
| indexed_dataset.set_doc_idx(doc_idx_ptr) | |
| # Checks. | |
| assert indexed_dataset.doc_idx[0] == 0 | |
| assert indexed_dataset.doc_idx.shape[0] == \ | |
| (total_num_of_documents + 1) | |
| return dataset | |
| train_dataset = build_dataset(0, 'train') | |
| valid_dataset = build_dataset(1, 'valid') | |
| test_dataset = build_dataset(2, 'test') | |
| return (train_dataset, valid_dataset, test_dataset) | |
| def get_indexed_dataset_(data_prefix, data_impl, skip_warmup): | |
| print_rank_0(' > building dataset index ...') | |
| start_time = time.time() | |
| indexed_dataset = make_indexed_dataset(data_prefix, | |
| data_impl, | |
| skip_warmup) | |
| assert indexed_dataset.sizes.shape[0] == indexed_dataset.doc_idx[-1] | |
| print_rank_0(' > finished creating indexed dataset in {:4f} ' | |
| 'seconds'.format(time.time() - start_time)) | |
| print_rank_0(' > indexed dataset stats:') | |
| print_rank_0(' number of documents: {}'.format( | |
| indexed_dataset.doc_idx.shape[0] - 1)) | |
| print_rank_0(' number of sentences: {}'.format( | |
| indexed_dataset.sizes.shape[0])) | |
| return indexed_dataset | |
| def get_train_valid_test_split_(splits_string, size): | |
| """ Get dataset splits from comma or '/' separated string list.""" | |
| splits = [] | |
| if splits_string.find(',') != -1: | |
| splits = [float(s) for s in splits_string.split(',')] | |
| elif splits_string.find('/') != -1: | |
| splits = [float(s) for s in splits_string.split('/')] | |
| else: | |
| splits = [float(splits_string)] | |
| while len(splits) < 3: | |
| splits.append(0.) | |
| splits = splits[:3] | |
| splits_sum = sum(splits) | |
| assert splits_sum > 0.0 | |
| splits = [split / splits_sum for split in splits] | |
| splits_index = [0] | |
| for index, split in enumerate(splits): | |
| splits_index.append(splits_index[index] + | |
| int(round(split * float(size)))) | |
| diff = splits_index[-1] - size | |
| for index in range(1, len(splits_index)): | |
| splits_index[index] -= diff | |
| assert len(splits_index) == 4 | |
| assert splits_index[-1] == size | |
| return splits_index | |
| def get_samples_mapping(indexed_dataset, | |
| data_prefix, | |
| num_epochs, | |
| max_num_samples, | |
| max_seq_length, | |
| short_seq_prob, | |
| seed, | |
| name, | |
| binary_head): | |
| """Get a list that maps a sample index to a starting | |
| sentence index, end sentence index, and length""" | |
| if not num_epochs: | |
| if not max_num_samples: | |
| raise ValueError("Need to specify either max_num_samples " | |
| "or num_epochs") | |
| num_epochs = np.iinfo(np.int32).max - 1 | |
| if not max_num_samples: | |
| max_num_samples = np.iinfo(np.int64).max - 1 | |
| # Filename of the index mapping | |
| indexmap_filename = data_prefix | |
| indexmap_filename += '_{}_indexmap'.format(name) | |
| if num_epochs != (np.iinfo(np.int32).max - 1): | |
| indexmap_filename += '_{}ep'.format(num_epochs) | |
| if max_num_samples != (np.iinfo(np.int64).max - 1): | |
| indexmap_filename += '_{}mns'.format(max_num_samples) | |
| indexmap_filename += '_{}msl'.format(max_seq_length) | |
| indexmap_filename += '_{:0.2f}ssp'.format(short_seq_prob) | |
| indexmap_filename += '_{}s'.format(seed) | |
| indexmap_filename += '.npy' | |
| # This should be a barrier but nccl barrier assumes | |
| # device_index=rank which is not the case for model | |
| # parallel case | |
| # ganruyi comment | |
| # counts = torch.cuda.LongTensor([1]) | |
| # torch.distributed.all_reduce( | |
| # counts, group=mpu.get_data_parallel_group()) | |
| # torch.distributed.all_reduce( | |
| # counts, group=mpu.get_pipeline_model_parallel_group()) | |
| # assert counts[0].item() == ( | |
| # torch.distributed.get_world_size() // | |
| # torch.distributed.get_world_size( | |
| # group=mpu.get_tensor_model_parallel_group())) | |
| # Load indexed dataset. | |
| print_rank_0(' > loading indexed mapping from {}'.format( | |
| indexmap_filename)) | |
| start_time = time.time() | |
| samples_mapping = np.load( | |
| indexmap_filename, allow_pickle=True, mmap_mode='r') | |
| print_rank_0(' loaded indexed file in {:3.3f} seconds'.format( | |
| time.time() - start_time)) | |
| print_rank_0(' total number of samples: {}'.format( | |
| samples_mapping.shape[0])) | |
| return samples_mapping | |