Spaces:
Sleeping
Sleeping
import os | |
from copy import deepcopy | |
from enum import Enum | |
from functools import partial | |
from pathlib import Path | |
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union | |
import datasets | |
import psutil | |
import torch | |
import transformers as tr | |
from datasets import load_dataset | |
from torch.utils.data import Dataset | |
from tqdm import tqdm | |
from relik.common.log import get_console_logger, get_logger | |
from relik.retriever.common.model_inputs import ModelInputs | |
from relik.retriever.data.base.datasets import BaseDataset, IterableBaseDataset | |
from relik.retriever.data.utils import HardNegativesManager | |
console_logger = get_console_logger() | |
logger = get_logger(__name__) | |
class SubsampleStrategyEnum(Enum): | |
NONE = "none" | |
RANDOM = "random" | |
IN_ORDER = "in_order" | |
class GoldenRetrieverDataset: | |
def __init__( | |
self, | |
name: str, | |
path: Union[str, os.PathLike, List[str], List[os.PathLike]] = None, | |
data: Any = None, | |
tokenizer: Optional[Union[str, tr.PreTrainedTokenizer]] = None, | |
# passages: Union[str, os.PathLike, List[str]] = None, | |
passage_batch_size: int = 32, | |
question_batch_size: int = 32, | |
max_positives: int = -1, | |
max_negatives: int = 0, | |
max_hard_negatives: int = 0, | |
max_question_length: int = 256, | |
max_passage_length: int = 64, | |
shuffle: bool = False, | |
subsample_strategy: Optional[str] = SubsampleStrategyEnum.NONE, | |
subsample_portion: float = 0.1, | |
num_proc: Optional[int] = None, | |
load_from_cache_file: bool = True, | |
keep_in_memory: bool = False, | |
prefetch: bool = True, | |
load_fn_kwargs: Optional[Dict[str, Any]] = None, | |
batch_fn_kwargs: Optional[Dict[str, Any]] = None, | |
collate_fn_kwargs: Optional[Dict[str, Any]] = None, | |
): | |
if path is None and data is None: | |
raise ValueError("Either `path` or `data` must be provided") | |
if tokenizer is None: | |
raise ValueError("A tokenizer must be provided") | |
# dataset parameters | |
self.name = name | |
self.path = Path(path) or path | |
if path is not None and not isinstance(self.path, Sequence): | |
self.path = [self.path] | |
# self.project_folder = Path(__file__).parent.parent.parent | |
self.data = data | |
# hyper-parameters | |
self.passage_batch_size = passage_batch_size | |
self.question_batch_size = question_batch_size | |
self.max_positives = max_positives | |
self.max_negatives = max_negatives | |
self.max_hard_negatives = max_hard_negatives | |
self.max_question_length = max_question_length | |
self.max_passage_length = max_passage_length | |
self.shuffle = shuffle | |
self.num_proc = num_proc | |
self.load_from_cache_file = load_from_cache_file | |
self.keep_in_memory = keep_in_memory | |
self.prefetch = prefetch | |
self.tokenizer = tokenizer | |
if isinstance(self.tokenizer, str): | |
self.tokenizer = tr.AutoTokenizer.from_pretrained(self.tokenizer) | |
self.padding_ops = { | |
"input_ids": partial( | |
self.pad_sequence, | |
value=self.tokenizer.pad_token_id, | |
), | |
"attention_mask": partial(self.pad_sequence, value=0), | |
"token_type_ids": partial( | |
self.pad_sequence, | |
value=self.tokenizer.pad_token_type_id, | |
), | |
} | |
# check if subsample strategy is valid | |
if subsample_strategy is not None: | |
# subsample_strategy can be a string or a SubsampleStrategy | |
if isinstance(subsample_strategy, str): | |
try: | |
subsample_strategy = SubsampleStrategyEnum(subsample_strategy) | |
except ValueError: | |
raise ValueError( | |
f"Subsample strategy {subsample_strategy} is not valid. " | |
f"Valid strategies are: {SubsampleStrategyEnum.__members__}" | |
) | |
if not isinstance(subsample_strategy, SubsampleStrategyEnum): | |
raise ValueError( | |
f"Subsample strategy {subsample_strategy} is not valid. " | |
f"Valid strategies are: {SubsampleStrategyEnum.__members__}" | |
) | |
self.subsample_strategy = subsample_strategy | |
self.subsample_portion = subsample_portion | |
# load the dataset | |
if data is None: | |
self.data: Dataset = self.load( | |
self.path, | |
tokenizer=self.tokenizer, | |
load_from_cache_file=load_from_cache_file, | |
load_fn_kwargs=load_fn_kwargs, | |
num_proc=num_proc, | |
shuffle=shuffle, | |
keep_in_memory=keep_in_memory, | |
max_positives=max_positives, | |
max_negatives=max_negatives, | |
max_hard_negatives=max_hard_negatives, | |
max_question_length=max_question_length, | |
max_passage_length=max_passage_length, | |
) | |
else: | |
self.data: Dataset = data | |
self.hn_manager: Optional[HardNegativesManager] = None | |
# keep track of how many times the dataset has been iterated over | |
self.number_of_complete_iterations = 0 | |
def __repr__(self) -> str: | |
return f"GoldenRetrieverDataset({self.name=}, {self.path=})" | |
def __len__(self) -> int: | |
raise NotImplementedError | |
def __getitem__( | |
self, index | |
) -> Union[Dict[str, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: | |
raise NotImplementedError | |
def to_torch_dataset(self, *args, **kwargs) -> torch.utils.data.Dataset: | |
raise NotImplementedError | |
def load( | |
self, | |
paths: Union[str, os.PathLike, List[str], List[os.PathLike]], | |
tokenizer: tr.PreTrainedTokenizer = None, | |
load_fn_kwargs: Dict = None, | |
load_from_cache_file: bool = True, | |
num_proc: Optional[int] = None, | |
shuffle: bool = False, | |
keep_in_memory: bool = True, | |
max_positives: int = -1, | |
max_negatives: int = -1, | |
max_hard_negatives: int = -1, | |
max_passages: int = -1, | |
max_question_length: int = 256, | |
max_passage_length: int = 64, | |
*args, | |
**kwargs, | |
) -> Any: | |
# if isinstance(paths, Sequence): | |
# paths = [self.project_folder / path for path in paths] | |
# else: | |
# paths = [self.project_folder / paths] | |
# read the data and put it in a placeholder list | |
for path in paths: | |
if not path.exists(): | |
raise ValueError(f"{path} does not exist") | |
fn_kwargs = dict( | |
tokenizer=tokenizer, | |
max_positives=max_positives, | |
max_negatives=max_negatives, | |
max_hard_negatives=max_hard_negatives, | |
max_passages=max_passages, | |
max_question_length=max_question_length, | |
max_passage_length=max_passage_length, | |
) | |
if load_fn_kwargs is not None: | |
fn_kwargs.update(load_fn_kwargs) | |
if num_proc is None: | |
num_proc = psutil.cpu_count(logical=False) | |
# The data is a list of dictionaries, each dictionary is a sample | |
# Each sample has the following keys: | |
# - "question": the question | |
# - "answers": a list of answers | |
# - "positive_ctxs": a list of positive passages | |
# - "negative_ctxs": a list of negative passages | |
# - "hard_negative_ctxs": a list of hard negative passages | |
# use the huggingface dataset library to load the data, by default it will load the | |
# data in a dict with the key being "train". | |
logger.info(f"Loading data for dataset {self.name}") | |
data = load_dataset( | |
"json", | |
data_files=[str(p) for p in paths], # datasets needs str paths and not Path | |
split="train", | |
streaming=False, # TODO maybe we can make streaming work | |
keep_in_memory=keep_in_memory, | |
) | |
# add id if not present | |
if isinstance(data, datasets.Dataset): | |
data = data.add_column("sample_idx", range(len(data))) | |
else: | |
data = data.map( | |
lambda x, idx: x.update({"sample_idx": idx}), with_indices=True | |
) | |
map_kwargs = dict( | |
function=self.load_fn, | |
fn_kwargs=fn_kwargs, | |
) | |
if isinstance(data, datasets.Dataset): | |
map_kwargs.update( | |
dict( | |
load_from_cache_file=load_from_cache_file, | |
keep_in_memory=keep_in_memory, | |
num_proc=num_proc, | |
desc="Loading data", | |
) | |
) | |
# preprocess the data | |
data = data.map(**map_kwargs) | |
# shuffle the data | |
if shuffle: | |
data.shuffle(seed=42) | |
return data | |
def create_batches( | |
data: Dataset, | |
batch_fn: Callable, | |
batch_fn_kwargs: Optional[Dict[str, Any]] = None, | |
prefetch: bool = True, | |
*args, | |
**kwargs, | |
) -> Union[Iterable, List]: | |
if not prefetch: | |
# if we are streaming, we don't need to create batches right now | |
# we will create them on the fly when we need them | |
batched_data = ( | |
batch | |
for batch in batch_fn( | |
data, **(batch_fn_kwargs if batch_fn_kwargs is not None else {}) | |
) | |
) | |
else: | |
batched_data = [ | |
batch | |
for batch in tqdm( | |
batch_fn( | |
data, **(batch_fn_kwargs if batch_fn_kwargs is not None else {}) | |
), | |
desc="Creating batches", | |
) | |
] | |
return batched_data | |
def collate_batches( | |
batched_data: Union[Iterable, List], | |
collate_fn: Callable, | |
collate_fn_kwargs: Optional[Dict[str, Any]] = None, | |
prefetch: bool = True, | |
*args, | |
**kwargs, | |
) -> Union[Iterable, List]: | |
if not prefetch: | |
collated_data = ( | |
collate_fn(batch, **(collate_fn_kwargs if collate_fn_kwargs else {})) | |
for batch in batched_data | |
) | |
else: | |
collated_data = [ | |
collate_fn(batch, **(collate_fn_kwargs if collate_fn_kwargs else {})) | |
for batch in tqdm(batched_data, desc="Collating batches") | |
] | |
return collated_data | |
def load_fn(sample: Dict, *args, **kwargs) -> Dict: | |
raise NotImplementedError | |
def batch_fn(data: Dataset, *args, **kwargs) -> Any: | |
raise NotImplementedError | |
def collate_fn(batch: Any, *args, **kwargs) -> Any: | |
raise NotImplementedError | |
def pad_sequence( | |
sequence: Union[List, torch.Tensor], | |
length: int, | |
value: Any = None, | |
pad_to_left: bool = False, | |
) -> Union[List, torch.Tensor]: | |
""" | |
Pad the input to the specified length with the given value. | |
Args: | |
sequence (:obj:`List`, :obj:`torch.Tensor`): | |
Element to pad, it can be either a :obj:`List` or a :obj:`torch.Tensor`. | |
length (:obj:`int`, :obj:`str`, optional, defaults to :obj:`subtoken`): | |
Length after pad. | |
value (:obj:`Any`, optional): | |
Value to use as padding. | |
pad_to_left (:obj:`bool`, optional, defaults to :obj:`False`): | |
If :obj:`True`, pads to the left, right otherwise. | |
Returns: | |
:obj:`List`, :obj:`torch.Tensor`: The padded sequence. | |
""" | |
padding = [value] * abs(length - len(sequence)) | |
if isinstance(sequence, torch.Tensor): | |
if len(sequence.shape) > 1: | |
raise ValueError( | |
f"Sequence tensor must be 1D. Current shape is `{len(sequence.shape)}`" | |
) | |
padding = torch.as_tensor(padding) | |
if pad_to_left: | |
if isinstance(sequence, torch.Tensor): | |
return torch.cat((padding, sequence), -1) | |
return padding + sequence | |
if isinstance(sequence, torch.Tensor): | |
return torch.cat((sequence, padding), -1) | |
return sequence + padding | |
def convert_to_batch( | |
self, samples: Any, *args, **kwargs | |
) -> Dict[str, torch.Tensor]: | |
""" | |
Convert the list of samples to a batch. | |
Args: | |
samples (:obj:`List`): | |
List of samples to convert to a batch. | |
Returns: | |
:obj:`Dict[str, torch.Tensor]`: The batch. | |
""" | |
# invert questions from list of dict to dict of list | |
samples = {k: [d[k] for d in samples] for k in samples[0]} | |
# get max length of questions | |
max_len = max(len(x) for x in samples["input_ids"]) | |
# pad the questions | |
for key in samples: | |
if key in self.padding_ops: | |
samples[key] = torch.as_tensor( | |
[self.padding_ops[key](b, max_len) for b in samples[key]] | |
) | |
return samples | |
def shuffle_data(self, seed: int = 42): | |
self.data = self.data.shuffle(seed=seed) | |
class InBatchNegativesDataset(GoldenRetrieverDataset): | |
def __len__(self) -> int: | |
if isinstance(self.data, datasets.Dataset): | |
return len(self.data) | |
def __getitem__( | |
self, index | |
) -> Union[Dict[str, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: | |
return self.data[index] | |
def to_torch_dataset(self) -> torch.utils.data.Dataset: | |
shuffle_this_time = self.shuffle | |
if ( | |
self.subsample_strategy | |
and self.subsample_strategy != SubsampleStrategyEnum.NONE | |
): | |
number_of_samples = int(len(self.data) * self.subsample_portion) | |
if self.subsample_strategy == SubsampleStrategyEnum.RANDOM: | |
logger.info( | |
f"Random subsampling {number_of_samples} samples from {len(self.data)}" | |
) | |
data = ( | |
deepcopy(self.data) | |
.shuffle(seed=42 + self.number_of_complete_iterations) | |
.select(range(0, number_of_samples)) | |
) | |
elif self.subsample_strategy == SubsampleStrategyEnum.IN_ORDER: | |
# number_of_samples = int(len(self.data) * self.subsample_portion) | |
already_selected = ( | |
number_of_samples * self.number_of_complete_iterations | |
) | |
logger.info( | |
f"Subsampling {number_of_samples} samples out of {len(self.data)}" | |
) | |
to_select = min(already_selected + number_of_samples, len(self.data)) | |
logger.info( | |
f"Portion of data selected: {already_selected} " f"to {to_select}" | |
) | |
data = deepcopy(self.data).select(range(already_selected, to_select)) | |
# don't shuffle the data if we are subsampling, and we have still not completed | |
# one full iteration over the dataset | |
if self.number_of_complete_iterations > 0: | |
shuffle_this_time = False | |
# reset the number of complete iterations | |
if to_select >= len(self.data): | |
# reset the number of complete iterations, | |
# we have completed one full iteration over the dataset | |
# the value is -1 because we want to start from 0 at the next iteration | |
self.number_of_complete_iterations = -1 | |
else: | |
raise ValueError( | |
f"Subsample strategy `{self.subsample_strategy}` is not valid. " | |
f"Valid strategies are: {SubsampleStrategyEnum.__members__}" | |
) | |
else: | |
data = data = self.data | |
# do we need to shuffle the data? | |
if self.shuffle and shuffle_this_time: | |
logger.info("Shuffling the data") | |
data = data.shuffle(seed=42 + self.number_of_complete_iterations) | |
batch_fn_kwargs = { | |
"passage_batch_size": self.passage_batch_size, | |
"question_batch_size": self.question_batch_size, | |
"hard_negatives_manager": self.hn_manager, | |
} | |
batched_data = self.create_batches( | |
data, | |
batch_fn=self.batch_fn, | |
batch_fn_kwargs=batch_fn_kwargs, | |
prefetch=self.prefetch, | |
) | |
batched_data = self.collate_batches( | |
batched_data, self.collate_fn, prefetch=self.prefetch | |
) | |
# increment the number of complete iterations | |
self.number_of_complete_iterations += 1 | |
if self.prefetch: | |
return BaseDataset(name=self.name, data=batched_data) | |
else: | |
return IterableBaseDataset(name=self.name, data=batched_data) | |
def load_fn( | |
sample: Dict, | |
tokenizer: tr.PreTrainedTokenizer, | |
max_positives: int, | |
max_negatives: int, | |
max_hard_negatives: int, | |
max_passages: int = -1, | |
max_question_length: int = 256, | |
max_passage_length: int = 128, | |
*args, | |
**kwargs, | |
) -> Dict: | |
# remove duplicates and limit the number of passages | |
positives = list(set([p["text"].strip() for p in sample["positive_ctxs"]])) | |
if max_positives != -1: | |
positives = positives[:max_positives] | |
negatives = list(set([n["text"].strip() for n in sample["negative_ctxs"]])) | |
if max_negatives != -1: | |
negatives = negatives[:max_negatives] | |
hard_negatives = list( | |
set([h["text"].strip() for h in sample["hard_negative_ctxs"]]) | |
) | |
if max_hard_negatives != -1: | |
hard_negatives = hard_negatives[:max_hard_negatives] | |
question = tokenizer( | |
sample["question"], max_length=max_question_length, truncation=True | |
) | |
passage = positives + negatives + hard_negatives | |
if max_passages != -1: | |
passage = passage[:max_passages] | |
passage = tokenizer(passage, max_length=max_passage_length, truncation=True) | |
# invert the passage data structure from a dict of lists to a list of dicts | |
passage = [dict(zip(passage, t)) for t in zip(*passage.values())] | |
output = dict( | |
question=question, | |
passage=passage, | |
positives=positives, | |
positive_pssgs=passage[: len(positives)], | |
) | |
return output | |
def batch_fn( | |
data: Dataset, | |
passage_batch_size: int, | |
question_batch_size: int, | |
hard_negatives_manager: Optional[HardNegativesManager] = None, | |
*args, | |
**kwargs, | |
) -> Dict[str, List[Dict[str, Any]]]: | |
def split_batch( | |
batch: Union[Dict[str, Any], ModelInputs], question_batch_size: int | |
) -> List[ModelInputs]: | |
""" | |
Split a batch into multiple batches of size `question_batch_size` while keeping | |
the same number of passages. | |
""" | |
def split_fn(x): | |
return [ | |
x[i : i + question_batch_size] | |
for i in range(0, len(x), question_batch_size) | |
] | |
# split the sample_idx | |
sample_idx = split_fn(batch["sample_idx"]) | |
# split the questions | |
questions = split_fn(batch["questions"]) | |
# split the positives | |
positives = split_fn(batch["positives"]) | |
# split the positives_pssgs | |
positives_pssgs = split_fn(batch["positives_pssgs"]) | |
# collect the new batches | |
batches = [] | |
for i in range(len(questions)): | |
batches.append( | |
ModelInputs( | |
dict( | |
sample_idx=sample_idx[i], | |
questions=questions[i], | |
passages=batch["passages"], | |
positives=positives[i], | |
positives_pssgs=positives_pssgs[i], | |
) | |
) | |
) | |
return batches | |
batch = [] | |
passages_in_batch = {} | |
for sample in data: | |
if len(passages_in_batch) >= passage_batch_size: | |
# create the batch dict | |
batch_dict = ModelInputs( | |
dict( | |
sample_idx=[s["sample_idx"] for s in batch], | |
questions=[s["question"] for s in batch], | |
passages=list(passages_in_batch.values()), | |
positives_pssgs=[s["positive_pssgs"] for s in batch], | |
positives=[s["positives"] for s in batch], | |
) | |
) | |
# split the batch if needed | |
if len(batch) > question_batch_size: | |
for splited_batch in split_batch(batch_dict, question_batch_size): | |
yield splited_batch | |
else: | |
yield batch_dict | |
# reset batch | |
batch = [] | |
passages_in_batch = {} | |
batch.append(sample) | |
# yes it's a bit ugly but it works :) | |
# count the number of passages in the batch and stop if we reach the limit | |
# we use a set to avoid counting the same passage twice | |
# we use a tuple because set doesn't support lists | |
# we use input_ids as discriminator | |
passages_in_batch.update( | |
{tuple(passage["input_ids"]): passage for passage in sample["passage"]} | |
) | |
# check for hard negatives and add with a probability of 0.1 | |
if hard_negatives_manager is not None: | |
if sample["sample_idx"] in hard_negatives_manager: | |
passages_in_batch.update( | |
{ | |
tuple(passage["input_ids"]): passage | |
for passage in hard_negatives_manager.get( | |
sample["sample_idx"] | |
) | |
} | |
) | |
# left over | |
if len(batch) > 0: | |
# create the batch dict | |
batch_dict = ModelInputs( | |
dict( | |
sample_idx=[s["sample_idx"] for s in batch], | |
questions=[s["question"] for s in batch], | |
passages=list(passages_in_batch.values()), | |
positives_pssgs=[s["positive_pssgs"] for s in batch], | |
positives=[s["positives"] for s in batch], | |
) | |
) | |
# split the batch if needed | |
if len(batch) > question_batch_size: | |
for splited_batch in split_batch(batch_dict, question_batch_size): | |
yield splited_batch | |
else: | |
yield batch_dict | |
def collate_fn(self, batch: Any, *args, **kwargs) -> Any: | |
# convert questions and passages to a batch | |
questions = self.convert_to_batch(batch.questions) | |
passages = self.convert_to_batch(batch.passages) | |
# build an index to map the position of the passage in the batch | |
passage_index = {tuple(c["input_ids"]): i for i, c in enumerate(batch.passages)} | |
# now we can create the labels | |
labels = torch.zeros( | |
questions["input_ids"].shape[0], passages["input_ids"].shape[0] | |
) | |
# iterate over the questions and set the labels to 1 if the passage is positive | |
for sample_idx in range(len(questions["input_ids"])): | |
for pssg in batch["positives_pssgs"][sample_idx]: | |
# get the index of the positive passage | |
index = passage_index[tuple(pssg["input_ids"])] | |
# set the label to 1 | |
labels[sample_idx, index] = 1 | |
model_inputs = ModelInputs( | |
{ | |
"questions": questions, | |
"passages": passages, | |
"labels": labels, | |
"positives": batch["positives"], | |
"sample_idx": batch["sample_idx"], | |
} | |
) | |
return model_inputs | |
class AidaInBatchNegativesDataset(InBatchNegativesDataset): | |
def __init__(self, use_topics: bool = False, *args, **kwargs): | |
if "load_fn_kwargs" not in kwargs: | |
kwargs["load_fn_kwargs"] = {} | |
kwargs["load_fn_kwargs"]["use_topics"] = use_topics | |
super().__init__(*args, **kwargs) | |
def load_fn( | |
sample: Dict, | |
tokenizer: tr.PreTrainedTokenizer, | |
max_positives: int, | |
max_negatives: int, | |
max_hard_negatives: int, | |
max_passages: int = -1, | |
max_question_length: int = 256, | |
max_passage_length: int = 128, | |
use_topics: bool = False, | |
*args, | |
**kwargs, | |
) -> Dict: | |
# remove duplicates and limit the number of passages | |
positives = list(set([p["text"].strip() for p in sample["positive_ctxs"]])) | |
if max_positives != -1: | |
positives = positives[:max_positives] | |
negatives = list(set([n["text"].strip() for n in sample["negative_ctxs"]])) | |
if max_negatives != -1: | |
negatives = negatives[:max_negatives] | |
hard_negatives = list( | |
set([h["text"].strip() for h in sample["hard_negative_ctxs"]]) | |
) | |
if max_hard_negatives != -1: | |
hard_negatives = hard_negatives[:max_hard_negatives] | |
question = sample["question"] | |
if "doc_topic" in sample and use_topics: | |
question = tokenizer( | |
question, | |
sample["doc_topic"], | |
max_length=max_question_length, | |
truncation=True, | |
) | |
else: | |
question = tokenizer( | |
question, max_length=max_question_length, truncation=True | |
) | |
passage = positives + negatives + hard_negatives | |
if max_passages != -1: | |
passage = passage[:max_passages] | |
passage = tokenizer(passage, max_length=max_passage_length, truncation=True) | |
# invert the passage data structure from a dict of lists to a list of dicts | |
passage = [dict(zip(passage, t)) for t in zip(*passage.values())] | |
output = dict( | |
question=question, | |
passage=passage, | |
positives=positives, | |
positive_pssgs=passage[: len(positives)], | |
) | |
return output | |