Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2019-present, the HuggingFace Inc. team and Facebook, Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" Dataset to distilled models | |
adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM) | |
""" | |
import numpy as np | |
import torch | |
from torch.utils.data import Dataset | |
from utils import logger | |
class LmSeqsDataset(Dataset): | |
"""Custom Dataset wrapping language modeling sequences. | |
Each sample will be retrieved by indexing the list of token_ids and their corresponding lengths. | |
Input: | |
------ | |
params: `NameSpace` parameters | |
data: `List[np.array[int]] | |
""" | |
def __init__(self, params, data): | |
self.params = params | |
self.token_ids = np.array(data) | |
self.lengths = np.array([len(t) for t in data]) | |
self.check() | |
self.remove_long_sequences() | |
self.remove_empty_sequences() | |
self.remove_unknown_sequences() | |
self.check() | |
self.print_statistics() | |
def __getitem__(self, index): | |
return (self.token_ids[index], self.lengths[index]) | |
def __len__(self): | |
return len(self.lengths) | |
def check(self): | |
""" | |
Some sanity checks | |
""" | |
assert len(self.token_ids) == len(self.lengths) | |
assert all(self.lengths[i] == len(self.token_ids[i]) for i in range(len(self.lengths))) | |
def remove_long_sequences(self): | |
""" | |
Sequences that are too long are split by chunk of max_model_input_size. | |
""" | |
max_len = self.params.max_model_input_size | |
indices = self.lengths > max_len | |
logger.info(f"Splitting {sum(indices)} too long sequences.") | |
def divide_chunks(l, n): | |
return [l[i : i + n] for i in range(0, len(l), n)] | |
new_tok_ids = [] | |
new_lengths = [] | |
if self.params.mlm: | |
cls_id, sep_id = self.params.special_tok_ids["cls_token"], self.params.special_tok_ids["sep_token"] | |
else: | |
cls_id, sep_id = self.params.special_tok_ids["bos_token"], self.params.special_tok_ids["eos_token"] | |
for seq_, len_ in zip(self.token_ids, self.lengths): | |
assert (seq_[0] == cls_id) and (seq_[-1] == sep_id), seq_ | |
if len_ <= max_len: | |
new_tok_ids.append(seq_) | |
new_lengths.append(len_) | |
else: | |
sub_seqs = [] | |
for sub_s in divide_chunks(seq_, max_len - 2): | |
if sub_s[0] != cls_id: | |
sub_s = np.insert(sub_s, 0, cls_id) | |
if sub_s[-1] != sep_id: | |
sub_s = np.insert(sub_s, len(sub_s), sep_id) | |
assert len(sub_s) <= max_len | |
assert (sub_s[0] == cls_id) and (sub_s[-1] == sep_id), sub_s | |
sub_seqs.append(sub_s) | |
new_tok_ids.extend(sub_seqs) | |
new_lengths.extend([len(l) for l in sub_seqs]) | |
self.token_ids = np.array(new_tok_ids) | |
self.lengths = np.array(new_lengths) | |
def remove_empty_sequences(self): | |
""" | |
Too short sequences are simply removed. This could be tuned. | |
""" | |
init_size = len(self) | |
indices = self.lengths > 11 | |
self.token_ids = self.token_ids[indices] | |
self.lengths = self.lengths[indices] | |
new_size = len(self) | |
logger.info(f"Remove {init_size - new_size} too short (<=11 tokens) sequences.") | |
def remove_unknown_sequences(self): | |
""" | |
Remove sequences with a (too) high level of unknown tokens. | |
""" | |
if "unk_token" not in self.params.special_tok_ids: | |
return | |
else: | |
unk_token_id = self.params.special_tok_ids["unk_token"] | |
init_size = len(self) | |
unk_occs = np.array([np.count_nonzero(a == unk_token_id) for a in self.token_ids]) | |
indices = (unk_occs / self.lengths) < 0.5 | |
self.token_ids = self.token_ids[indices] | |
self.lengths = self.lengths[indices] | |
new_size = len(self) | |
logger.info(f"Remove {init_size - new_size} sequences with a high level of unknown tokens (50%).") | |
def print_statistics(self): | |
""" | |
Print some statistics on the corpus. Only the master process. | |
""" | |
if not self.params.is_master: | |
return | |
logger.info(f"{len(self)} sequences") | |
# data_len = sum(self.lengths) | |
# nb_unique_tokens = len(Counter(list(chain(*self.token_ids)))) | |
# logger.info(f'{data_len} tokens ({nb_unique_tokens} unique)') | |
# unk_idx = self.params.special_tok_ids['unk_token'] | |
# nb_unknown = sum([(t==unk_idx).sum() for t in self.token_ids]) | |
# logger.info(f'{nb_unknown} unknown tokens (covering {100*nb_unknown/data_len:.2f}% of the data)') | |
def batch_sequences(self, batch): | |
""" | |
Do the padding and transform into torch.tensor. | |
""" | |
token_ids = [t[0] for t in batch] | |
lengths = [t[1] for t in batch] | |
assert len(token_ids) == len(lengths) | |
# Max for paddings | |
max_seq_len_ = max(lengths) | |
# Pad token ids | |
if self.params.mlm: | |
pad_idx = self.params.special_tok_ids["pad_token"] | |
else: | |
pad_idx = self.params.special_tok_ids["unk_token"] | |
tk_ = [list(t.astype(int)) + [pad_idx] * (max_seq_len_ - len(t)) for t in token_ids] | |
assert len(tk_) == len(token_ids) | |
assert all(len(t) == max_seq_len_ for t in tk_) | |
tk_t = torch.tensor(tk_) # (bs, max_seq_len_) | |
lg_t = torch.tensor(lengths) # (bs) | |
return tk_t, lg_t | |