|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import os |
|
import contextlib |
|
|
|
import numpy as np |
|
import torch |
|
|
|
from fairseq.data import FairseqDataset, data_utils |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class ExtractedFeaturesDataset(FairseqDataset): |
|
def __init__( |
|
self, |
|
path, |
|
split, |
|
min_length=3, |
|
max_length=None, |
|
labels=None, |
|
label_dict=None, |
|
shuffle=True, |
|
sort_by_length=True, |
|
): |
|
super().__init__() |
|
|
|
self.min_length = min_length |
|
self.max_length = max_length |
|
self.shuffle = shuffle |
|
self.sort_by_length = sort_by_length |
|
self.label_dict = label_dict |
|
|
|
if labels is not None: |
|
assert label_dict is not None |
|
|
|
self.sizes = [] |
|
self.offsets = [] |
|
self.labels = [] |
|
|
|
path = os.path.join(path, split) |
|
data_path = path |
|
self.data = np.load(data_path + ".npy", mmap_mode="r") |
|
|
|
offset = 0 |
|
skipped = 0 |
|
|
|
if not os.path.exists(path + f".{labels}"): |
|
labels = None |
|
|
|
with open(data_path + ".lengths", "r") as len_f, open( |
|
path + f".{labels}", "r" |
|
) if labels is not None else contextlib.ExitStack() as lbl_f: |
|
for line in len_f: |
|
length = int(line.rstrip()) |
|
lbl = None if labels is None else next(lbl_f).rstrip().split() |
|
if length >= min_length and ( |
|
max_length is None or length <= max_length |
|
): |
|
self.sizes.append(length) |
|
self.offsets.append(offset) |
|
if lbl is not None: |
|
self.labels.append(lbl) |
|
offset += length |
|
|
|
self.sizes = np.asarray(self.sizes) |
|
self.offsets = np.asarray(self.offsets) |
|
|
|
logger.info(f"loaded {len(self.offsets)}, skipped {skipped} samples") |
|
|
|
def __getitem__(self, index): |
|
offset = self.offsets[index] |
|
end = self.sizes[index] + offset |
|
feats = torch.from_numpy(self.data[offset:end].copy()).float() |
|
|
|
res = {"id": index, "features": feats} |
|
if len(self.labels) > 0: |
|
res["target"] = self.label_dict.encode_line( |
|
self.labels[index], |
|
line_tokenizer=lambda x: x, |
|
append_eos=False, |
|
) |
|
|
|
return res |
|
|
|
def __len__(self): |
|
return len(self.sizes) |
|
|
|
def collater(self, samples): |
|
if len(samples) == 0: |
|
return {} |
|
|
|
features = [s["features"] for s in samples] |
|
sizes = [len(s) for s in features] |
|
|
|
target_size = max(sizes) |
|
|
|
collated_features = features[0].new_zeros( |
|
len(features), target_size, features[0].size(-1) |
|
) |
|
padding_mask = torch.BoolTensor(collated_features.shape[:-1]).fill_(False) |
|
for i, (f, size) in enumerate(zip(features, sizes)): |
|
collated_features[i, :size] = f |
|
padding_mask[i, size:] = True |
|
|
|
res = { |
|
"id": torch.LongTensor([s["id"] for s in samples]), |
|
"net_input": {"features": collated_features, "padding_mask": padding_mask}, |
|
} |
|
|
|
if len(self.labels) > 0: |
|
target = data_utils.collate_tokens( |
|
[s["target"] for s in samples], |
|
pad_idx=self.label_dict.pad(), |
|
left_pad=False, |
|
) |
|
res["target"] = target |
|
return res |
|
|
|
def num_tokens(self, index): |
|
return self.size(index) |
|
|
|
def size(self, index): |
|
return self.sizes[index] |
|
|
|
def ordered_indices(self): |
|
"""Return an ordered list of indices. Batches will be constructed based |
|
on this order.""" |
|
if self.shuffle: |
|
order = [np.random.permutation(len(self))] |
|
else: |
|
order = [np.arange(len(self))] |
|
|
|
if self.sort_by_length: |
|
order.append(self.sizes) |
|
return np.lexsort(order)[::-1] |
|
else: |
|
return order[0] |
|
|