riccorl's picture
first commit
626eca0
raw
history blame
27 kB
import os
from copy import deepcopy
from enum import Enum
from functools import partial
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
import datasets
import psutil
import torch
import transformers as tr
from datasets import load_dataset
from torch.utils.data import Dataset
from tqdm import tqdm
from relik.common.log import get_console_logger, get_logger
from relik.retriever.common.model_inputs import ModelInputs
from relik.retriever.data.base.datasets import BaseDataset, IterableBaseDataset
from relik.retriever.data.utils import HardNegativesManager
console_logger = get_console_logger()
logger = get_logger(__name__)
class SubsampleStrategyEnum(Enum):
NONE = "none"
RANDOM = "random"
IN_ORDER = "in_order"
class GoldenRetrieverDataset:
def __init__(
self,
name: str,
path: Union[str, os.PathLike, List[str], List[os.PathLike]] = None,
data: Any = None,
tokenizer: Optional[Union[str, tr.PreTrainedTokenizer]] = None,
# passages: Union[str, os.PathLike, List[str]] = None,
passage_batch_size: int = 32,
question_batch_size: int = 32,
max_positives: int = -1,
max_negatives: int = 0,
max_hard_negatives: int = 0,
max_question_length: int = 256,
max_passage_length: int = 64,
shuffle: bool = False,
subsample_strategy: Optional[str] = SubsampleStrategyEnum.NONE,
subsample_portion: float = 0.1,
num_proc: Optional[int] = None,
load_from_cache_file: bool = True,
keep_in_memory: bool = False,
prefetch: bool = True,
load_fn_kwargs: Optional[Dict[str, Any]] = None,
batch_fn_kwargs: Optional[Dict[str, Any]] = None,
collate_fn_kwargs: Optional[Dict[str, Any]] = None,
):
if path is None and data is None:
raise ValueError("Either `path` or `data` must be provided")
if tokenizer is None:
raise ValueError("A tokenizer must be provided")
# dataset parameters
self.name = name
self.path = Path(path) or path
if path is not None and not isinstance(self.path, Sequence):
self.path = [self.path]
# self.project_folder = Path(__file__).parent.parent.parent
self.data = data
# hyper-parameters
self.passage_batch_size = passage_batch_size
self.question_batch_size = question_batch_size
self.max_positives = max_positives
self.max_negatives = max_negatives
self.max_hard_negatives = max_hard_negatives
self.max_question_length = max_question_length
self.max_passage_length = max_passage_length
self.shuffle = shuffle
self.num_proc = num_proc
self.load_from_cache_file = load_from_cache_file
self.keep_in_memory = keep_in_memory
self.prefetch = prefetch
self.tokenizer = tokenizer
if isinstance(self.tokenizer, str):
self.tokenizer = tr.AutoTokenizer.from_pretrained(self.tokenizer)
self.padding_ops = {
"input_ids": partial(
self.pad_sequence,
value=self.tokenizer.pad_token_id,
),
"attention_mask": partial(self.pad_sequence, value=0),
"token_type_ids": partial(
self.pad_sequence,
value=self.tokenizer.pad_token_type_id,
),
}
# check if subsample strategy is valid
if subsample_strategy is not None:
# subsample_strategy can be a string or a SubsampleStrategy
if isinstance(subsample_strategy, str):
try:
subsample_strategy = SubsampleStrategyEnum(subsample_strategy)
except ValueError:
raise ValueError(
f"Subsample strategy {subsample_strategy} is not valid. "
f"Valid strategies are: {SubsampleStrategyEnum.__members__}"
)
if not isinstance(subsample_strategy, SubsampleStrategyEnum):
raise ValueError(
f"Subsample strategy {subsample_strategy} is not valid. "
f"Valid strategies are: {SubsampleStrategyEnum.__members__}"
)
self.subsample_strategy = subsample_strategy
self.subsample_portion = subsample_portion
# load the dataset
if data is None:
self.data: Dataset = self.load(
self.path,
tokenizer=self.tokenizer,
load_from_cache_file=load_from_cache_file,
load_fn_kwargs=load_fn_kwargs,
num_proc=num_proc,
shuffle=shuffle,
keep_in_memory=keep_in_memory,
max_positives=max_positives,
max_negatives=max_negatives,
max_hard_negatives=max_hard_negatives,
max_question_length=max_question_length,
max_passage_length=max_passage_length,
)
else:
self.data: Dataset = data
self.hn_manager: Optional[HardNegativesManager] = None
# keep track of how many times the dataset has been iterated over
self.number_of_complete_iterations = 0
def __repr__(self) -> str:
return f"GoldenRetrieverDataset({self.name=}, {self.path=})"
def __len__(self) -> int:
raise NotImplementedError
def __getitem__(
self, index
) -> Union[Dict[str, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
raise NotImplementedError
def to_torch_dataset(self, *args, **kwargs) -> torch.utils.data.Dataset:
raise NotImplementedError
def load(
self,
paths: Union[str, os.PathLike, List[str], List[os.PathLike]],
tokenizer: tr.PreTrainedTokenizer = None,
load_fn_kwargs: Dict = None,
load_from_cache_file: bool = True,
num_proc: Optional[int] = None,
shuffle: bool = False,
keep_in_memory: bool = True,
max_positives: int = -1,
max_negatives: int = -1,
max_hard_negatives: int = -1,
max_passages: int = -1,
max_question_length: int = 256,
max_passage_length: int = 64,
*args,
**kwargs,
) -> Any:
# if isinstance(paths, Sequence):
# paths = [self.project_folder / path for path in paths]
# else:
# paths = [self.project_folder / paths]
# read the data and put it in a placeholder list
for path in paths:
if not path.exists():
raise ValueError(f"{path} does not exist")
fn_kwargs = dict(
tokenizer=tokenizer,
max_positives=max_positives,
max_negatives=max_negatives,
max_hard_negatives=max_hard_negatives,
max_passages=max_passages,
max_question_length=max_question_length,
max_passage_length=max_passage_length,
)
if load_fn_kwargs is not None:
fn_kwargs.update(load_fn_kwargs)
if num_proc is None:
num_proc = psutil.cpu_count(logical=False)
# The data is a list of dictionaries, each dictionary is a sample
# Each sample has the following keys:
# - "question": the question
# - "answers": a list of answers
# - "positive_ctxs": a list of positive passages
# - "negative_ctxs": a list of negative passages
# - "hard_negative_ctxs": a list of hard negative passages
# use the huggingface dataset library to load the data, by default it will load the
# data in a dict with the key being "train".
logger.info(f"Loading data for dataset {self.name}")
data = load_dataset(
"json",
data_files=[str(p) for p in paths], # datasets needs str paths and not Path
split="train",
streaming=False, # TODO maybe we can make streaming work
keep_in_memory=keep_in_memory,
)
# add id if not present
if isinstance(data, datasets.Dataset):
data = data.add_column("sample_idx", range(len(data)))
else:
data = data.map(
lambda x, idx: x.update({"sample_idx": idx}), with_indices=True
)
map_kwargs = dict(
function=self.load_fn,
fn_kwargs=fn_kwargs,
)
if isinstance(data, datasets.Dataset):
map_kwargs.update(
dict(
load_from_cache_file=load_from_cache_file,
keep_in_memory=keep_in_memory,
num_proc=num_proc,
desc="Loading data",
)
)
# preprocess the data
data = data.map(**map_kwargs)
# shuffle the data
if shuffle:
data.shuffle(seed=42)
return data
@staticmethod
def create_batches(
data: Dataset,
batch_fn: Callable,
batch_fn_kwargs: Optional[Dict[str, Any]] = None,
prefetch: bool = True,
*args,
**kwargs,
) -> Union[Iterable, List]:
if not prefetch:
# if we are streaming, we don't need to create batches right now
# we will create them on the fly when we need them
batched_data = (
batch
for batch in batch_fn(
data, **(batch_fn_kwargs if batch_fn_kwargs is not None else {})
)
)
else:
batched_data = [
batch
for batch in tqdm(
batch_fn(
data, **(batch_fn_kwargs if batch_fn_kwargs is not None else {})
),
desc="Creating batches",
)
]
return batched_data
@staticmethod
def collate_batches(
batched_data: Union[Iterable, List],
collate_fn: Callable,
collate_fn_kwargs: Optional[Dict[str, Any]] = None,
prefetch: bool = True,
*args,
**kwargs,
) -> Union[Iterable, List]:
if not prefetch:
collated_data = (
collate_fn(batch, **(collate_fn_kwargs if collate_fn_kwargs else {}))
for batch in batched_data
)
else:
collated_data = [
collate_fn(batch, **(collate_fn_kwargs if collate_fn_kwargs else {}))
for batch in tqdm(batched_data, desc="Collating batches")
]
return collated_data
@staticmethod
def load_fn(sample: Dict, *args, **kwargs) -> Dict:
raise NotImplementedError
@staticmethod
def batch_fn(data: Dataset, *args, **kwargs) -> Any:
raise NotImplementedError
@staticmethod
def collate_fn(batch: Any, *args, **kwargs) -> Any:
raise NotImplementedError
@staticmethod
def pad_sequence(
sequence: Union[List, torch.Tensor],
length: int,
value: Any = None,
pad_to_left: bool = False,
) -> Union[List, torch.Tensor]:
"""
Pad the input to the specified length with the given value.
Args:
sequence (:obj:`List`, :obj:`torch.Tensor`):
Element to pad, it can be either a :obj:`List` or a :obj:`torch.Tensor`.
length (:obj:`int`, :obj:`str`, optional, defaults to :obj:`subtoken`):
Length after pad.
value (:obj:`Any`, optional):
Value to use as padding.
pad_to_left (:obj:`bool`, optional, defaults to :obj:`False`):
If :obj:`True`, pads to the left, right otherwise.
Returns:
:obj:`List`, :obj:`torch.Tensor`: The padded sequence.
"""
padding = [value] * abs(length - len(sequence))
if isinstance(sequence, torch.Tensor):
if len(sequence.shape) > 1:
raise ValueError(
f"Sequence tensor must be 1D. Current shape is `{len(sequence.shape)}`"
)
padding = torch.as_tensor(padding)
if pad_to_left:
if isinstance(sequence, torch.Tensor):
return torch.cat((padding, sequence), -1)
return padding + sequence
if isinstance(sequence, torch.Tensor):
return torch.cat((sequence, padding), -1)
return sequence + padding
def convert_to_batch(
self, samples: Any, *args, **kwargs
) -> Dict[str, torch.Tensor]:
"""
Convert the list of samples to a batch.
Args:
samples (:obj:`List`):
List of samples to convert to a batch.
Returns:
:obj:`Dict[str, torch.Tensor]`: The batch.
"""
# invert questions from list of dict to dict of list
samples = {k: [d[k] for d in samples] for k in samples[0]}
# get max length of questions
max_len = max(len(x) for x in samples["input_ids"])
# pad the questions
for key in samples:
if key in self.padding_ops:
samples[key] = torch.as_tensor(
[self.padding_ops[key](b, max_len) for b in samples[key]]
)
return samples
def shuffle_data(self, seed: int = 42):
self.data = self.data.shuffle(seed=seed)
class InBatchNegativesDataset(GoldenRetrieverDataset):
def __len__(self) -> int:
if isinstance(self.data, datasets.Dataset):
return len(self.data)
def __getitem__(
self, index
) -> Union[Dict[str, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
return self.data[index]
def to_torch_dataset(self) -> torch.utils.data.Dataset:
shuffle_this_time = self.shuffle
if (
self.subsample_strategy
and self.subsample_strategy != SubsampleStrategyEnum.NONE
):
number_of_samples = int(len(self.data) * self.subsample_portion)
if self.subsample_strategy == SubsampleStrategyEnum.RANDOM:
logger.info(
f"Random subsampling {number_of_samples} samples from {len(self.data)}"
)
data = (
deepcopy(self.data)
.shuffle(seed=42 + self.number_of_complete_iterations)
.select(range(0, number_of_samples))
)
elif self.subsample_strategy == SubsampleStrategyEnum.IN_ORDER:
# number_of_samples = int(len(self.data) * self.subsample_portion)
already_selected = (
number_of_samples * self.number_of_complete_iterations
)
logger.info(
f"Subsampling {number_of_samples} samples out of {len(self.data)}"
)
to_select = min(already_selected + number_of_samples, len(self.data))
logger.info(
f"Portion of data selected: {already_selected} " f"to {to_select}"
)
data = deepcopy(self.data).select(range(already_selected, to_select))
# don't shuffle the data if we are subsampling, and we have still not completed
# one full iteration over the dataset
if self.number_of_complete_iterations > 0:
shuffle_this_time = False
# reset the number of complete iterations
if to_select >= len(self.data):
# reset the number of complete iterations,
# we have completed one full iteration over the dataset
# the value is -1 because we want to start from 0 at the next iteration
self.number_of_complete_iterations = -1
else:
raise ValueError(
f"Subsample strategy `{self.subsample_strategy}` is not valid. "
f"Valid strategies are: {SubsampleStrategyEnum.__members__}"
)
else:
data = data = self.data
# do we need to shuffle the data?
if self.shuffle and shuffle_this_time:
logger.info("Shuffling the data")
data = data.shuffle(seed=42 + self.number_of_complete_iterations)
batch_fn_kwargs = {
"passage_batch_size": self.passage_batch_size,
"question_batch_size": self.question_batch_size,
"hard_negatives_manager": self.hn_manager,
}
batched_data = self.create_batches(
data,
batch_fn=self.batch_fn,
batch_fn_kwargs=batch_fn_kwargs,
prefetch=self.prefetch,
)
batched_data = self.collate_batches(
batched_data, self.collate_fn, prefetch=self.prefetch
)
# increment the number of complete iterations
self.number_of_complete_iterations += 1
if self.prefetch:
return BaseDataset(name=self.name, data=batched_data)
else:
return IterableBaseDataset(name=self.name, data=batched_data)
@staticmethod
def load_fn(
sample: Dict,
tokenizer: tr.PreTrainedTokenizer,
max_positives: int,
max_negatives: int,
max_hard_negatives: int,
max_passages: int = -1,
max_question_length: int = 256,
max_passage_length: int = 128,
*args,
**kwargs,
) -> Dict:
# remove duplicates and limit the number of passages
positives = list(set([p["text"].strip() for p in sample["positive_ctxs"]]))
if max_positives != -1:
positives = positives[:max_positives]
negatives = list(set([n["text"].strip() for n in sample["negative_ctxs"]]))
if max_negatives != -1:
negatives = negatives[:max_negatives]
hard_negatives = list(
set([h["text"].strip() for h in sample["hard_negative_ctxs"]])
)
if max_hard_negatives != -1:
hard_negatives = hard_negatives[:max_hard_negatives]
question = tokenizer(
sample["question"], max_length=max_question_length, truncation=True
)
passage = positives + negatives + hard_negatives
if max_passages != -1:
passage = passage[:max_passages]
passage = tokenizer(passage, max_length=max_passage_length, truncation=True)
# invert the passage data structure from a dict of lists to a list of dicts
passage = [dict(zip(passage, t)) for t in zip(*passage.values())]
output = dict(
question=question,
passage=passage,
positives=positives,
positive_pssgs=passage[: len(positives)],
)
return output
@staticmethod
def batch_fn(
data: Dataset,
passage_batch_size: int,
question_batch_size: int,
hard_negatives_manager: Optional[HardNegativesManager] = None,
*args,
**kwargs,
) -> Dict[str, List[Dict[str, Any]]]:
def split_batch(
batch: Union[Dict[str, Any], ModelInputs], question_batch_size: int
) -> List[ModelInputs]:
"""
Split a batch into multiple batches of size `question_batch_size` while keeping
the same number of passages.
"""
def split_fn(x):
return [
x[i : i + question_batch_size]
for i in range(0, len(x), question_batch_size)
]
# split the sample_idx
sample_idx = split_fn(batch["sample_idx"])
# split the questions
questions = split_fn(batch["questions"])
# split the positives
positives = split_fn(batch["positives"])
# split the positives_pssgs
positives_pssgs = split_fn(batch["positives_pssgs"])
# collect the new batches
batches = []
for i in range(len(questions)):
batches.append(
ModelInputs(
dict(
sample_idx=sample_idx[i],
questions=questions[i],
passages=batch["passages"],
positives=positives[i],
positives_pssgs=positives_pssgs[i],
)
)
)
return batches
batch = []
passages_in_batch = {}
for sample in data:
if len(passages_in_batch) >= passage_batch_size:
# create the batch dict
batch_dict = ModelInputs(
dict(
sample_idx=[s["sample_idx"] for s in batch],
questions=[s["question"] for s in batch],
passages=list(passages_in_batch.values()),
positives_pssgs=[s["positive_pssgs"] for s in batch],
positives=[s["positives"] for s in batch],
)
)
# split the batch if needed
if len(batch) > question_batch_size:
for splited_batch in split_batch(batch_dict, question_batch_size):
yield splited_batch
else:
yield batch_dict
# reset batch
batch = []
passages_in_batch = {}
batch.append(sample)
# yes it's a bit ugly but it works :)
# count the number of passages in the batch and stop if we reach the limit
# we use a set to avoid counting the same passage twice
# we use a tuple because set doesn't support lists
# we use input_ids as discriminator
passages_in_batch.update(
{tuple(passage["input_ids"]): passage for passage in sample["passage"]}
)
# check for hard negatives and add with a probability of 0.1
if hard_negatives_manager is not None:
if sample["sample_idx"] in hard_negatives_manager:
passages_in_batch.update(
{
tuple(passage["input_ids"]): passage
for passage in hard_negatives_manager.get(
sample["sample_idx"]
)
}
)
# left over
if len(batch) > 0:
# create the batch dict
batch_dict = ModelInputs(
dict(
sample_idx=[s["sample_idx"] for s in batch],
questions=[s["question"] for s in batch],
passages=list(passages_in_batch.values()),
positives_pssgs=[s["positive_pssgs"] for s in batch],
positives=[s["positives"] for s in batch],
)
)
# split the batch if needed
if len(batch) > question_batch_size:
for splited_batch in split_batch(batch_dict, question_batch_size):
yield splited_batch
else:
yield batch_dict
def collate_fn(self, batch: Any, *args, **kwargs) -> Any:
# convert questions and passages to a batch
questions = self.convert_to_batch(batch.questions)
passages = self.convert_to_batch(batch.passages)
# build an index to map the position of the passage in the batch
passage_index = {tuple(c["input_ids"]): i for i, c in enumerate(batch.passages)}
# now we can create the labels
labels = torch.zeros(
questions["input_ids"].shape[0], passages["input_ids"].shape[0]
)
# iterate over the questions and set the labels to 1 if the passage is positive
for sample_idx in range(len(questions["input_ids"])):
for pssg in batch["positives_pssgs"][sample_idx]:
# get the index of the positive passage
index = passage_index[tuple(pssg["input_ids"])]
# set the label to 1
labels[sample_idx, index] = 1
model_inputs = ModelInputs(
{
"questions": questions,
"passages": passages,
"labels": labels,
"positives": batch["positives"],
"sample_idx": batch["sample_idx"],
}
)
return model_inputs
class AidaInBatchNegativesDataset(InBatchNegativesDataset):
def __init__(self, use_topics: bool = False, *args, **kwargs):
if "load_fn_kwargs" not in kwargs:
kwargs["load_fn_kwargs"] = {}
kwargs["load_fn_kwargs"]["use_topics"] = use_topics
super().__init__(*args, **kwargs)
@staticmethod
def load_fn(
sample: Dict,
tokenizer: tr.PreTrainedTokenizer,
max_positives: int,
max_negatives: int,
max_hard_negatives: int,
max_passages: int = -1,
max_question_length: int = 256,
max_passage_length: int = 128,
use_topics: bool = False,
*args,
**kwargs,
) -> Dict:
# remove duplicates and limit the number of passages
positives = list(set([p["text"].strip() for p in sample["positive_ctxs"]]))
if max_positives != -1:
positives = positives[:max_positives]
negatives = list(set([n["text"].strip() for n in sample["negative_ctxs"]]))
if max_negatives != -1:
negatives = negatives[:max_negatives]
hard_negatives = list(
set([h["text"].strip() for h in sample["hard_negative_ctxs"]])
)
if max_hard_negatives != -1:
hard_negatives = hard_negatives[:max_hard_negatives]
question = sample["question"]
if "doc_topic" in sample and use_topics:
question = tokenizer(
question,
sample["doc_topic"],
max_length=max_question_length,
truncation=True,
)
else:
question = tokenizer(
question, max_length=max_question_length, truncation=True
)
passage = positives + negatives + hard_negatives
if max_passages != -1:
passage = passage[:max_passages]
passage = tokenizer(passage, max_length=max_passage_length, truncation=True)
# invert the passage data structure from a dict of lists to a list of dicts
passage = [dict(zip(passage, t)) for t in zip(*passage.values())]
output = dict(
question=question,
passage=passage,
positives=positives,
positive_pssgs=passage[: len(positives)],
)
return output