|
import os |
|
from copy import deepcopy |
|
from enum import Enum |
|
from functools import partial |
|
from pathlib import Path |
|
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union |
|
|
|
import datasets |
|
import psutil |
|
import torch |
|
import transformers as tr |
|
from datasets import load_dataset |
|
from torch.utils.data import Dataset |
|
from tqdm import tqdm |
|
|
|
from relik.common.log import get_logger |
|
from relik.retriever.common.model_inputs import ModelInputs |
|
from relik.retriever.data.base.datasets import BaseDataset, IterableBaseDataset |
|
from relik.retriever.data.utils import HardNegativesManager |
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
class SubsampleStrategyEnum(Enum): |
|
NONE = "none" |
|
RANDOM = "random" |
|
IN_ORDER = "in_order" |
|
|
|
|
|
class GoldenRetrieverDataset: |
|
def __init__( |
|
self, |
|
name: str, |
|
path: Union[str, os.PathLike, List[str], List[os.PathLike]] = None, |
|
data: Any = None, |
|
tokenizer: Optional[Union[str, tr.PreTrainedTokenizer]] = None, |
|
|
|
passage_batch_size: int = 32, |
|
question_batch_size: int = 32, |
|
max_positives: int = -1, |
|
max_negatives: int = 0, |
|
max_hard_negatives: int = 0, |
|
max_question_length: int = 256, |
|
max_passage_length: int = 64, |
|
shuffle: bool = False, |
|
subsample_strategy: Optional[str] = SubsampleStrategyEnum.NONE, |
|
subsample_portion: float = 0.1, |
|
num_proc: Optional[int] = None, |
|
load_from_cache_file: bool = True, |
|
keep_in_memory: bool = False, |
|
prefetch: bool = True, |
|
load_fn_kwargs: Optional[Dict[str, Any]] = None, |
|
batch_fn_kwargs: Optional[Dict[str, Any]] = None, |
|
collate_fn_kwargs: Optional[Dict[str, Any]] = None, |
|
): |
|
if path is None and data is None: |
|
raise ValueError("Either `path` or `data` must be provided") |
|
|
|
if tokenizer is None: |
|
raise ValueError("A tokenizer must be provided") |
|
|
|
|
|
self.name = name |
|
self.path = Path(path) or path |
|
if path is not None and not isinstance(self.path, Sequence): |
|
self.path = [self.path] |
|
|
|
self.data = data |
|
|
|
|
|
self.passage_batch_size = passage_batch_size |
|
self.question_batch_size = question_batch_size |
|
self.max_positives = max_positives |
|
self.max_negatives = max_negatives |
|
self.max_hard_negatives = max_hard_negatives |
|
self.max_question_length = max_question_length |
|
self.max_passage_length = max_passage_length |
|
self.shuffle = shuffle |
|
self.num_proc = num_proc |
|
self.load_from_cache_file = load_from_cache_file |
|
self.keep_in_memory = keep_in_memory |
|
self.prefetch = prefetch |
|
|
|
self.tokenizer = tokenizer |
|
if isinstance(self.tokenizer, str): |
|
self.tokenizer = tr.AutoTokenizer.from_pretrained(self.tokenizer) |
|
|
|
self.padding_ops = { |
|
"input_ids": partial( |
|
self.pad_sequence, |
|
value=self.tokenizer.pad_token_id, |
|
), |
|
"attention_mask": partial(self.pad_sequence, value=0), |
|
"token_type_ids": partial( |
|
self.pad_sequence, |
|
value=self.tokenizer.pad_token_type_id, |
|
), |
|
} |
|
|
|
|
|
if subsample_strategy is not None: |
|
|
|
if isinstance(subsample_strategy, str): |
|
try: |
|
subsample_strategy = SubsampleStrategyEnum(subsample_strategy) |
|
except ValueError: |
|
raise ValueError( |
|
f"Subsample strategy `{subsample_strategy}` is not valid. " |
|
f"Valid strategies are: {SubsampleStrategyEnum.__members__}" |
|
) |
|
if not isinstance(subsample_strategy, SubsampleStrategyEnum): |
|
raise ValueError( |
|
f"Subsample strategy `{subsample_strategy}` is not valid. " |
|
f"Valid strategies are: {SubsampleStrategyEnum.__members__}" |
|
) |
|
self.subsample_strategy = subsample_strategy |
|
self.subsample_portion = subsample_portion |
|
|
|
|
|
if data is None: |
|
self.data: Dataset = self.load( |
|
self.path, |
|
tokenizer=self.tokenizer, |
|
load_from_cache_file=load_from_cache_file, |
|
load_fn_kwargs=load_fn_kwargs, |
|
num_proc=num_proc, |
|
shuffle=shuffle, |
|
keep_in_memory=keep_in_memory, |
|
max_positives=max_positives, |
|
max_negatives=max_negatives, |
|
max_hard_negatives=max_hard_negatives, |
|
max_question_length=max_question_length, |
|
max_passage_length=max_passage_length, |
|
) |
|
else: |
|
self.data: Dataset = data |
|
|
|
self.hn_manager: Optional[HardNegativesManager] = None |
|
|
|
|
|
self.number_of_complete_iterations = 0 |
|
|
|
def __repr__(self) -> str: |
|
return f"GoldenRetrieverDataset({self.name=}, {self.path=})" |
|
|
|
def __len__(self) -> int: |
|
raise NotImplementedError |
|
|
|
def __getitem__( |
|
self, index |
|
) -> Union[Dict[str, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: |
|
raise NotImplementedError |
|
|
|
def to_torch_dataset(self, *args, **kwargs) -> torch.utils.data.Dataset: |
|
raise NotImplementedError |
|
|
|
def load( |
|
self, |
|
paths: Union[str, os.PathLike, List[str], List[os.PathLike]], |
|
tokenizer: tr.PreTrainedTokenizer = None, |
|
load_fn_kwargs: Dict = None, |
|
load_from_cache_file: bool = True, |
|
num_proc: Optional[int] = None, |
|
shuffle: bool = False, |
|
keep_in_memory: bool = True, |
|
max_positives: int = -1, |
|
max_negatives: int = -1, |
|
max_hard_negatives: int = -1, |
|
max_passages: int = -1, |
|
max_question_length: int = 256, |
|
max_passage_length: int = 64, |
|
*args, |
|
**kwargs, |
|
) -> Any: |
|
|
|
|
|
|
|
|
|
|
|
|
|
for path in paths: |
|
if not path.exists(): |
|
raise ValueError(f"{path} does not exist") |
|
|
|
fn_kwargs = dict( |
|
tokenizer=tokenizer, |
|
max_positives=max_positives, |
|
max_negatives=max_negatives, |
|
max_hard_negatives=max_hard_negatives, |
|
max_passages=max_passages, |
|
max_question_length=max_question_length, |
|
max_passage_length=max_passage_length, |
|
) |
|
if load_fn_kwargs is not None: |
|
fn_kwargs.update(load_fn_kwargs) |
|
|
|
if num_proc is None: |
|
num_proc = psutil.cpu_count(logical=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"Loading data for dataset {self.name}") |
|
data = load_dataset( |
|
"json", |
|
data_files=[str(p) for p in paths], |
|
split="train", |
|
streaming=False, |
|
keep_in_memory=keep_in_memory, |
|
) |
|
|
|
if isinstance(data, datasets.Dataset): |
|
data = data.add_column("sample_idx", range(len(data))) |
|
else: |
|
data = data.map( |
|
lambda x, idx: x.update({"sample_idx": idx}), with_indices=True |
|
) |
|
|
|
map_kwargs = dict( |
|
function=self.load_fn, |
|
fn_kwargs=fn_kwargs, |
|
) |
|
if isinstance(data, datasets.Dataset): |
|
map_kwargs.update( |
|
dict( |
|
load_from_cache_file=load_from_cache_file, |
|
keep_in_memory=keep_in_memory, |
|
num_proc=num_proc, |
|
desc="Loading data", |
|
) |
|
) |
|
|
|
data = data.map(**map_kwargs) |
|
|
|
|
|
if shuffle: |
|
data.shuffle(seed=42) |
|
|
|
return data |
|
|
|
@staticmethod |
|
def create_batches( |
|
data: Dataset, |
|
batch_fn: Callable, |
|
batch_fn_kwargs: Optional[Dict[str, Any]] = None, |
|
prefetch: bool = True, |
|
*args, |
|
**kwargs, |
|
) -> Union[Iterable, List]: |
|
if not prefetch: |
|
|
|
|
|
batched_data = ( |
|
batch |
|
for batch in batch_fn( |
|
data, **(batch_fn_kwargs if batch_fn_kwargs is not None else {}) |
|
) |
|
) |
|
else: |
|
batched_data = [ |
|
batch |
|
for batch in tqdm( |
|
batch_fn( |
|
data, **(batch_fn_kwargs if batch_fn_kwargs is not None else {}) |
|
), |
|
desc="Creating batches", |
|
) |
|
] |
|
return batched_data |
|
|
|
@staticmethod |
|
def collate_batches( |
|
batched_data: Union[Iterable, List], |
|
collate_fn: Callable, |
|
collate_fn_kwargs: Optional[Dict[str, Any]] = None, |
|
prefetch: bool = True, |
|
*args, |
|
**kwargs, |
|
) -> Union[Iterable, List]: |
|
if not prefetch: |
|
collated_data = ( |
|
collate_fn(batch, **(collate_fn_kwargs if collate_fn_kwargs else {})) |
|
for batch in batched_data |
|
) |
|
else: |
|
collated_data = [ |
|
collate_fn(batch, **(collate_fn_kwargs if collate_fn_kwargs else {})) |
|
for batch in tqdm(batched_data, desc="Collating batches") |
|
] |
|
return collated_data |
|
|
|
@staticmethod |
|
def load_fn(sample: Dict, *args, **kwargs) -> Dict: |
|
raise NotImplementedError |
|
|
|
@staticmethod |
|
def batch_fn(data: Dataset, *args, **kwargs) -> Any: |
|
raise NotImplementedError |
|
|
|
@staticmethod |
|
def collate_fn(batch: Any, *args, **kwargs) -> Any: |
|
raise NotImplementedError |
|
|
|
@staticmethod |
|
def pad_sequence( |
|
sequence: Union[List, torch.Tensor], |
|
length: int, |
|
value: Any = None, |
|
pad_to_left: bool = False, |
|
) -> Union[List, torch.Tensor]: |
|
""" |
|
Pad the input to the specified length with the given value. |
|
|
|
Args: |
|
sequence (:obj:`List`, :obj:`torch.Tensor`): |
|
Element to pad, it can be either a :obj:`List` or a :obj:`torch.Tensor`. |
|
length (:obj:`int`, :obj:`str`, optional, defaults to :obj:`subtoken`): |
|
Length after pad. |
|
value (:obj:`Any`, optional): |
|
Value to use as padding. |
|
pad_to_left (:obj:`bool`, optional, defaults to :obj:`False`): |
|
If :obj:`True`, pads to the left, right otherwise. |
|
|
|
Returns: |
|
:obj:`List`, :obj:`torch.Tensor`: The padded sequence. |
|
|
|
""" |
|
padding = [value] * abs(length - len(sequence)) |
|
if isinstance(sequence, torch.Tensor): |
|
if len(sequence.shape) > 1: |
|
raise ValueError( |
|
f"Sequence tensor must be 1D. Current shape is `{len(sequence.shape)}`" |
|
) |
|
padding = torch.as_tensor(padding) |
|
if pad_to_left: |
|
if isinstance(sequence, torch.Tensor): |
|
return torch.cat((padding, sequence), -1) |
|
return padding + sequence |
|
if isinstance(sequence, torch.Tensor): |
|
return torch.cat((sequence, padding), -1) |
|
return sequence + padding |
|
|
|
def convert_to_batch( |
|
self, samples: Any, *args, **kwargs |
|
) -> Dict[str, torch.Tensor]: |
|
""" |
|
Convert the list of samples to a batch. |
|
|
|
Args: |
|
samples (:obj:`List`): |
|
List of samples to convert to a batch. |
|
|
|
Returns: |
|
:obj:`Dict[str, torch.Tensor]`: The batch. |
|
""" |
|
|
|
samples = {k: [d[k] for d in samples] for k in samples[0]} |
|
|
|
max_len = max(len(x) for x in samples["input_ids"]) |
|
|
|
for key in samples: |
|
if key in self.padding_ops: |
|
samples[key] = torch.as_tensor( |
|
[self.padding_ops[key](b, max_len) for b in samples[key]] |
|
) |
|
return samples |
|
|
|
def shuffle_data(self, seed: int = 42): |
|
self.data = self.data.shuffle(seed=seed) |
|
|
|
|
|
class InBatchNegativesDataset(GoldenRetrieverDataset): |
|
def __len__(self) -> int: |
|
if isinstance(self.data, datasets.Dataset): |
|
return len(self.data) |
|
|
|
def __getitem__( |
|
self, index |
|
) -> Union[Dict[str, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: |
|
return self.data[index] |
|
|
|
def to_torch_dataset(self) -> torch.utils.data.Dataset: |
|
shuffle_this_time = self.shuffle |
|
|
|
if ( |
|
self.subsample_strategy |
|
and self.subsample_strategy != SubsampleStrategyEnum.NONE |
|
): |
|
number_of_samples = int(len(self.data) * self.subsample_portion) |
|
if self.subsample_strategy == SubsampleStrategyEnum.RANDOM: |
|
logger.info( |
|
f"Random subsampling {number_of_samples} samples from {len(self.data)}" |
|
) |
|
data = ( |
|
deepcopy(self.data) |
|
.shuffle(seed=42 + self.number_of_complete_iterations) |
|
.select(range(0, number_of_samples)) |
|
) |
|
elif self.subsample_strategy == SubsampleStrategyEnum.IN_ORDER: |
|
|
|
already_selected = ( |
|
number_of_samples * self.number_of_complete_iterations |
|
) |
|
logger.info( |
|
f"Subsampling {number_of_samples} samples out of {len(self.data)}" |
|
) |
|
to_select = min(already_selected + number_of_samples, len(self.data)) |
|
logger.info( |
|
f"Portion of data selected: {already_selected} " f"to {to_select}" |
|
) |
|
data = deepcopy(self.data).select(range(already_selected, to_select)) |
|
|
|
|
|
|
|
if self.number_of_complete_iterations > 0: |
|
shuffle_this_time = False |
|
|
|
|
|
if to_select >= len(self.data): |
|
|
|
|
|
|
|
self.number_of_complete_iterations = -1 |
|
else: |
|
raise ValueError( |
|
f"Subsample strategy `{self.subsample_strategy}` is not valid. " |
|
f"Valid strategies are: {SubsampleStrategyEnum.__members__}" |
|
) |
|
|
|
else: |
|
data = data = self.data |
|
|
|
|
|
if self.shuffle and shuffle_this_time: |
|
logger.info("Shuffling the data") |
|
data = data.shuffle(seed=42 + self.number_of_complete_iterations) |
|
|
|
batch_fn_kwargs = { |
|
"passage_batch_size": self.passage_batch_size, |
|
"question_batch_size": self.question_batch_size, |
|
"hard_negatives_manager": self.hn_manager, |
|
} |
|
batched_data = self.create_batches( |
|
data, |
|
batch_fn=self.batch_fn, |
|
batch_fn_kwargs=batch_fn_kwargs, |
|
prefetch=self.prefetch, |
|
) |
|
|
|
batched_data = self.collate_batches( |
|
batched_data, self.collate_fn, prefetch=self.prefetch |
|
) |
|
|
|
|
|
self.number_of_complete_iterations += 1 |
|
|
|
if self.prefetch: |
|
return BaseDataset(name=self.name, data=batched_data) |
|
else: |
|
return IterableBaseDataset(name=self.name, data=batched_data) |
|
|
|
@staticmethod |
|
def load_fn( |
|
sample: Dict, |
|
tokenizer: tr.PreTrainedTokenizer, |
|
max_positives: int, |
|
max_negatives: int, |
|
max_hard_negatives: int, |
|
max_passages: int = -1, |
|
max_question_length: int = 256, |
|
max_passage_length: int = 128, |
|
*args, |
|
**kwargs, |
|
) -> Dict: |
|
|
|
positives = list(set([p["text"] for p in sample["positive_ctxs"]])) |
|
if max_positives != -1: |
|
positives = positives[:max_positives] |
|
negatives = list(set([n["text"] for n in sample["negative_ctxs"]])) |
|
if max_negatives != -1: |
|
negatives = negatives[:max_negatives] |
|
hard_negatives = list(set([h["text"] for h in sample["hard_negative_ctxs"]])) |
|
if max_hard_negatives != -1: |
|
hard_negatives = hard_negatives[:max_hard_negatives] |
|
|
|
question = tokenizer( |
|
sample["question"], max_length=max_question_length, truncation=True |
|
) |
|
|
|
passage = positives + negatives + hard_negatives |
|
if max_passages != -1: |
|
passage = passage[:max_passages] |
|
|
|
passage = tokenizer(passage, max_length=max_passage_length, truncation=True) |
|
|
|
|
|
passage = [dict(zip(passage, t)) for t in zip(*passage.values())] |
|
|
|
output = dict( |
|
question=question, |
|
passage=passage, |
|
positives=positives, |
|
positive_pssgs=passage[: len(positives)], |
|
) |
|
return output |
|
|
|
@staticmethod |
|
def batch_fn( |
|
data: Dataset, |
|
passage_batch_size: int, |
|
question_batch_size: int, |
|
hard_negatives_manager: Optional[HardNegativesManager] = None, |
|
*args, |
|
**kwargs, |
|
) -> Dict[str, List[Dict[str, Any]]]: |
|
def split_batch( |
|
batch: Union[Dict[str, Any], ModelInputs], question_batch_size: int |
|
) -> List[ModelInputs]: |
|
""" |
|
Split a batch into multiple batches of size `question_batch_size` while keeping |
|
the same number of passages. |
|
""" |
|
|
|
def split_fn(x): |
|
return [ |
|
x[i : i + question_batch_size] |
|
for i in range(0, len(x), question_batch_size) |
|
] |
|
|
|
|
|
sample_idx = split_fn(batch["sample_idx"]) |
|
|
|
questions = split_fn(batch["questions"]) |
|
|
|
positives = split_fn(batch["positives"]) |
|
|
|
positives_pssgs = split_fn(batch["positives_pssgs"]) |
|
|
|
|
|
batches = [] |
|
for i in range(len(questions)): |
|
batches.append( |
|
ModelInputs( |
|
dict( |
|
sample_idx=sample_idx[i], |
|
questions=questions[i], |
|
passages=batch["passages"], |
|
positives=positives[i], |
|
positives_pssgs=positives_pssgs[i], |
|
) |
|
) |
|
) |
|
return batches |
|
|
|
batch = [] |
|
passages_in_batch = {} |
|
|
|
for sample in data: |
|
if len(passages_in_batch) >= passage_batch_size: |
|
|
|
batch_dict = ModelInputs( |
|
dict( |
|
sample_idx=[s["sample_idx"] for s in batch], |
|
questions=[s["question"] for s in batch], |
|
passages=list(passages_in_batch.values()), |
|
positives_pssgs=[s["positive_pssgs"] for s in batch], |
|
positives=[s["positives"] for s in batch], |
|
) |
|
) |
|
|
|
if len(batch) > question_batch_size: |
|
for splited_batch in split_batch(batch_dict, question_batch_size): |
|
yield splited_batch |
|
else: |
|
yield batch_dict |
|
|
|
|
|
batch = [] |
|
passages_in_batch = {} |
|
|
|
batch.append(sample) |
|
|
|
|
|
|
|
|
|
|
|
passages_in_batch.update( |
|
{tuple(passage["input_ids"]): passage for passage in sample["passage"]} |
|
) |
|
|
|
if hard_negatives_manager is not None: |
|
if sample["sample_idx"] in hard_negatives_manager: |
|
passages_in_batch.update( |
|
{ |
|
tuple(passage["input_ids"]): passage |
|
for passage in hard_negatives_manager.get( |
|
sample["sample_idx"] |
|
) |
|
} |
|
) |
|
|
|
|
|
if len(batch) > 0: |
|
|
|
batch_dict = ModelInputs( |
|
dict( |
|
sample_idx=[s["sample_idx"] for s in batch], |
|
questions=[s["question"] for s in batch], |
|
passages=list(passages_in_batch.values()), |
|
positives_pssgs=[s["positive_pssgs"] for s in batch], |
|
positives=[s["positives"] for s in batch], |
|
) |
|
) |
|
|
|
if len(batch) > question_batch_size: |
|
for splited_batch in split_batch(batch_dict, question_batch_size): |
|
yield splited_batch |
|
else: |
|
yield batch_dict |
|
|
|
def collate_fn(self, batch: Any, *args, **kwargs) -> Any: |
|
|
|
questions = self.convert_to_batch(batch.questions) |
|
passages = self.convert_to_batch(batch.passages) |
|
|
|
|
|
passage_index = {tuple(c["input_ids"]): i for i, c in enumerate(batch.passages)} |
|
|
|
|
|
labels = torch.zeros( |
|
questions["input_ids"].shape[0], passages["input_ids"].shape[0] |
|
) |
|
|
|
for sample_idx in range(len(questions["input_ids"])): |
|
for pssg in batch["positives_pssgs"][sample_idx]: |
|
|
|
index = passage_index[tuple(pssg["input_ids"])] |
|
|
|
labels[sample_idx, index] = 1 |
|
|
|
model_inputs = ModelInputs( |
|
{ |
|
"questions": questions, |
|
"passages": passages, |
|
"labels": labels, |
|
"positives": batch["positives"], |
|
"sample_idx": batch["sample_idx"], |
|
} |
|
) |
|
return model_inputs |
|
|
|
|
|
class AidaInBatchNegativesDataset(InBatchNegativesDataset): |
|
def __init__(self, use_topics: bool = False, *args, **kwargs): |
|
if "load_fn_kwargs" not in kwargs: |
|
kwargs["load_fn_kwargs"] = {} |
|
kwargs["load_fn_kwargs"]["use_topics"] = use_topics |
|
super().__init__(*args, **kwargs) |
|
|
|
@staticmethod |
|
def load_fn( |
|
sample: Dict, |
|
tokenizer: tr.PreTrainedTokenizer, |
|
max_positives: int, |
|
max_negatives: int, |
|
max_hard_negatives: int, |
|
max_passages: int = -1, |
|
max_question_length: int = 256, |
|
max_passage_length: int = 128, |
|
use_topics: bool = False, |
|
*args, |
|
**kwargs, |
|
) -> Dict: |
|
|
|
positives = list(set([p["text"] for p in sample["positive_ctxs"]])) |
|
if max_positives != -1: |
|
positives = positives[:max_positives] |
|
negatives = list(set([n["text"] for n in sample["negative_ctxs"]])) |
|
if max_negatives != -1: |
|
negatives = negatives[:max_negatives] |
|
hard_negatives = list(set([h["text"] for h in sample["hard_negative_ctxs"]])) |
|
if max_hard_negatives != -1: |
|
hard_negatives = hard_negatives[:max_hard_negatives] |
|
|
|
question = sample["question"] |
|
|
|
if "doc_topic" in sample and use_topics: |
|
question = tokenizer( |
|
question, |
|
sample["doc_topic"], |
|
max_length=max_question_length, |
|
truncation=True, |
|
) |
|
else: |
|
question = tokenizer( |
|
question, max_length=max_question_length, truncation=True |
|
) |
|
|
|
passage = positives + negatives + hard_negatives |
|
if max_passages != -1: |
|
passage = passage[:max_passages] |
|
|
|
passage = tokenizer(passage, max_length=max_passage_length, truncation=True) |
|
|
|
|
|
passage = [dict(zip(passage, t)) for t in zip(*passage.values())] |
|
|
|
output = dict( |
|
question=question, |
|
passage=passage, |
|
positives=positives, |
|
positive_pssgs=passage[: len(positives)], |
|
) |
|
return output |
|
|