Spaces:
Sleeping
Sleeping
import json | |
import os | |
from collections import defaultdict | |
from typing import Any, Dict, Iterable, List, Optional, Union | |
import numpy as np | |
import transformers as tr | |
from tqdm import tqdm | |
class HardNegativesManager: | |
def __init__( | |
self, | |
tokenizer: tr.PreTrainedTokenizer, | |
data: Union[List[Dict], os.PathLike, Dict[int, List]] = None, | |
max_length: int = 64, | |
batch_size: int = 1000, | |
lazy: bool = False, | |
) -> None: | |
self._db: dict = None | |
self.tokenizer = tokenizer | |
if data is None: | |
self._db = {} | |
else: | |
if isinstance(data, Dict): | |
self._db = data | |
elif isinstance(data, os.PathLike): | |
with open(data) as f: | |
self._db = json.load(f) | |
else: | |
raise ValueError( | |
f"Data type {type(data)} not supported, only Dict and os.PathLike are supported." | |
) | |
# add the tokenizer to the class for future use | |
self.tokenizer = tokenizer | |
# invert the db to have a passage -> sample_idx mapping | |
self._passage_db = defaultdict(set) | |
for sample_idx, passages in self._db.items(): | |
for passage in passages: | |
self._passage_db[passage].add(sample_idx) | |
self._passage_hard_negatives = {} | |
if not lazy: | |
# create a dictionary of passage -> hard_negative mapping | |
batch_size = min(batch_size, len(self._passage_db)) | |
unique_passages = list(self._passage_db.keys()) | |
for i in tqdm( | |
range(0, len(unique_passages), batch_size), | |
desc="Tokenizing Hard Negatives", | |
): | |
batch = unique_passages[i : i + batch_size] | |
tokenized_passages = self.tokenizer( | |
batch, | |
max_length=max_length, | |
truncation=True, | |
) | |
for i, passage in enumerate(batch): | |
self._passage_hard_negatives[passage] = { | |
k: tokenized_passages[k][i] for k in tokenized_passages.keys() | |
} | |
def __len__(self) -> int: | |
return len(self._db) | |
def __getitem__(self, idx: int) -> Dict: | |
return self._db[idx] | |
def __iter__(self): | |
for sample in self._db: | |
yield sample | |
def __contains__(self, idx: int) -> bool: | |
return idx in self._db | |
def get(self, idx: int) -> List[str]: | |
"""Get the hard negatives for a given sample index.""" | |
if idx not in self._db: | |
raise ValueError(f"Sample index {idx} not in the database.") | |
passages = self._db[idx] | |
output = [] | |
for passage in passages: | |
if passage not in self._passage_hard_negatives: | |
self._passage_hard_negatives[passage] = self._tokenize(passage) | |
output.append(self._passage_hard_negatives[passage]) | |
return output | |
def _tokenize(self, passage: str) -> Dict: | |
return self.tokenizer(passage, max_length=self.max_length, truncation=True) | |
class NegativeSampler: | |
def __init__( | |
self, num_elements: int, probabilities: Optional[Union[List, np.ndarray]] = None | |
): | |
if not isinstance(probabilities, np.ndarray): | |
probabilities = np.array(probabilities) | |
if probabilities is None: | |
# probabilities should sum to 1 | |
probabilities = np.random.random(num_elements) | |
probabilities /= np.sum(probabilities) | |
self.probabilities = probabilities | |
def __call__( | |
self, | |
sample_size: int, | |
num_samples: int = 1, | |
probabilities: np.array = None, | |
exclude: List[int] = None, | |
) -> np.array: | |
""" | |
Fast sampling of `sample_size` elements from `num_elements` elements. | |
The sampling is done by randomly shifting the probabilities and then | |
finding the smallest of the negative numbers. This is much faster than | |
sampling from a multinomial distribution. | |
Args: | |
sample_size (`int`): | |
number of elements to sample | |
num_samples (`int`, optional): | |
number of samples to draw. Defaults to 1. | |
probabilities (`np.array`, optional): | |
probabilities of each element. Defaults to None. | |
exclude (`List[int]`, optional): | |
indices of elements to exclude. Defaults to None. | |
Returns: | |
`np.array`: array of sampled indices | |
""" | |
if probabilities is None: | |
probabilities = self.probabilities | |
if exclude is not None: | |
probabilities[exclude] = 0 | |
# re-normalize? | |
# probabilities /= np.sum(probabilities) | |
# replicate probabilities as many times as `num_samples` | |
replicated_probabilities = np.tile(probabilities, (num_samples, 1)) | |
# get random shifting numbers & scale them correctly | |
random_shifts = np.random.random(replicated_probabilities.shape) | |
random_shifts /= random_shifts.sum(axis=1)[:, np.newaxis] | |
# shift by numbers & find largest (by finding the smallest of the negative) | |
shifted_probabilities = random_shifts - replicated_probabilities | |
sampled_indices = np.argpartition(shifted_probabilities, sample_size, axis=1)[ | |
:, :sample_size | |
] | |
return sampled_indices | |
def batch_generator(samples: Iterable[Any], batch_size: int) -> Iterable[Any]: | |
""" | |
Generate batches from samples. | |
Args: | |
samples (`Iterable[Any]`): Iterable of samples. | |
batch_size (`int`): Batch size. | |
Returns: | |
`Iterable[Any]`: Iterable of batches. | |
""" | |
batch = [] | |
for sample in samples: | |
batch.append(sample) | |
if len(batch) == batch_size: | |
yield batch | |
batch = [] | |
# leftover batch | |
if len(batch) > 0: | |
yield batch | |