File size: 15,452 Bytes
2cc518e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 |
import logging
import warnings
from typing import Any, Dict, List, Optional, Union
import torch
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
log = logging.getLogger(__name__)
_HF_IGNORE_INDEX = -100
TokenizedExample = Dict[str, List[Dict[str, List[int]]]]
def ensure_list(x: Union[List, torch.Tensor]) -> List:
if isinstance(x, torch.Tensor):
x = list(x.flatten())
assert isinstance(x, list)
return x
def validate_target_settings(target_prompts: str, target_responses: str, decoder_only_format: bool):
"""Raises an error if target settings are invalid."""
if not decoder_only_format and (target_prompts != 'none' or target_responses != 'last'):
raise ValueError(f'When using encoder_decoder format, you must use target_prompts="none" and target_responses="last".')
if target_responses not in {'all', 'last'}:
raise ValueError(f'target_responses must be either "last" or "all" but target_responses={target_responses!r}')
if target_prompts.startswith('length>='):
cutoff = target_prompts[8:]
if not cutoff.isdigit():
raise ValueError(f'target_prompts starts with "length>=" but the rest of the string is not digits (target_prompts={target_prompts!r}). ' + 'To use this configuration option, set target_prompts "length>=XX" where "XX" is a positive integer indicating ' + 'the length cutoff. Prompts of at least XX tokens in length will be treated as targets.')
cutoff = int(cutoff)
if cutoff <= 0:
raise ValueError(f'You are trying to set the target_prompts length cutoff to a negative number cutoff={cutoff!r}. This is not allowed.')
elif target_prompts not in {'all', 'none'}:
raise ValueError(f'target_prompts must either be "all", "none" or "length>=XX" where "XX" is a positive integer, but target_prompts={target_prompts!r}')
def _sequence_to_labels_all(sequence: list[int], is_last_turn: bool, cutoff: Optional[int]=None) -> list[int]:
del is_last_turn, cutoff
return sequence
def _sequence_to_labels_none(sequence: list[int], is_last_turn: bool, cutoff: Optional[int]=None) -> list[int]:
del is_last_turn, cutoff
return [_HF_IGNORE_INDEX] * len(sequence)
def _sequence_to_labels_last(sequence: list[int], is_last_turn: bool, cutoff: Optional[int]=None) -> list[int]:
del cutoff
if is_last_turn:
return sequence
else:
return [_HF_IGNORE_INDEX] * len(sequence)
def _sequence_to_labels_cutoff(sequence: list[int], is_last_turn: bool, cutoff: Optional[int]=None) -> list[int]:
del is_last_turn
if cutoff is None:
raise ValueError('input ``cutoff`` must be provided')
if len(sequence) >= cutoff:
return sequence
else:
return [_HF_IGNORE_INDEX] * len(sequence)
_TARGET_POLICY_LOOKUP = {'all': _sequence_to_labels_all, 'none': _sequence_to_labels_none, 'last': _sequence_to_labels_last, 'length': _sequence_to_labels_cutoff}
def stitch_turns_decoder_only(example_turns: list[dict[str, list[int]]], target_prompts: str, target_responses: str, eos_token_id: Optional[int]=None, validate: bool=False) -> tuple[list[int], list[int]]:
target_prompts = target_prompts.lower()
target_responses = target_responses.lower()
if validate:
validate_target_settings(target_prompts, target_responses, decoder_only_format=True)
if target_prompts.startswith('length'):
prompt_cutoff = int(target_prompts.split('>=')[-1])
prompt_to_target = _TARGET_POLICY_LOOKUP['length']
else:
prompt_cutoff = None
prompt_to_target = _TARGET_POLICY_LOOKUP[target_prompts]
response_to_target = _TARGET_POLICY_LOOKUP[target_responses]
input_ids = []
labels = []
for idx, turn in enumerate(example_turns):
is_last_turn = idx + 1 == len(example_turns)
context = ensure_list(turn['input_ids'])
target = ensure_list(turn['labels'])
if is_last_turn and eos_token_id is not None:
if target[-1] != eos_token_id:
target = target + [eos_token_id]
input_ids += context
input_ids += target
labels += prompt_to_target(context, is_last_turn, prompt_cutoff)
labels += response_to_target(target, is_last_turn)
if len(input_ids) != len(labels):
raise ValueError(f'input_ids and labels should be the same length, len(input_ids)={len(input_ids)!r}, len(labels)={len(labels)!r}')
return (input_ids, labels)
def stitch_turns_encoder_decoder(example_turns: list[dict[str, list[int]]], eos_token_id: Optional[int]=None) -> tuple[list[int], list[int]]:
context = []
target = None
for idx, turn in enumerate(example_turns):
is_last_turn = idx + 1 == len(example_turns)
turn_context = ensure_list(turn['input_ids'])
turn_target = ensure_list(turn['labels'])
context += turn_context
if is_last_turn:
if eos_token_id is not None and turn_target[-1] != eos_token_id:
turn_target = turn_target + [eos_token_id]
target = turn_target
else:
context += turn_target
if target is None:
raise ValueError('target is still None but should be list[int]')
return (context, target)
class Seq2SeqFinetuningCollator:
"""A general-purpose collator for sequence-to-sequence training/evaluation.
Args:
tokenizer: A HuggingFace tokenizer. Must have a pad_token set.
max_seq_len (int): The maximum sequence length of the combined
context/target sequence (decoder-only format) or of each the
context sequence and target sequence (encoder-decoder format).
decoder_only_format (bool): Whether to format the batches for a
decoder-only model (if True) or an encoder-decoder model (if False).
target_responses (str): For multi-turn examples, this controls which
responses are treated as training targets (i.e. generate loss).
Options are:
"last": (Default) Only the final response is used as the training
target; non-terminal responses are only part of the context.
"all": All of the responses are used as training targets.
target_prompts (str): This controls which prompts are treated as
training targets (i.e. generate loss).
Options are:
"none": (Default) Prompts are never used as training targets.
"all": Prompts are always used as training targets.
"length>=XX": Prompt sequences are used as training targets when
they have length of at least XX tokens. For instance,
setting "length>=512" instructs the collator to use a prompt
sequence as a training target when it is at least 512 tokens long.
allow_pad_trimming (bool, optional): Whether to allow the collator
to trim padding, which may result in smaller but inconsistent batch
sizes. Default: ``False`` ensures that all sequences are max_seq_len.
batch_metadata (dict, optional): A dictionary of metadata which will be added
to the batch.
"""
def __init__(self, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], max_seq_len: int, decoder_only_format: bool, target_responses: str='last', target_prompts: str='none', allow_pad_trimming: bool=False, batch_metadata: Optional[Dict[str, Any]]=None):
self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
self.decoder_only_format = decoder_only_format
self.target_responses = target_responses.lower()
self.target_prompts = target_prompts.lower()
self.batch_metadata = batch_metadata or {}
self._allow_pad_trimming = allow_pad_trimming
self._seen_first_batch = False
illegal_keys = ['input_ids', 'labels', 'attention_mask', 'decoder_input_ids', 'decoder_attention_mask']
found_keys = []
for illegal_key in illegal_keys:
if illegal_key in self.batch_metadata:
found_keys.append(illegal_key)
if found_keys:
raise ValueError(f"The following keys are in batch_metadata but are not allowed: {', '.join(found_keys)}.\n" + f'You cannot use keys that are used directly by the models. The prohibited keys are:\n' + f"{', '.join(illegal_keys)}")
if max_seq_len % 8 != 0:
log.warning('For performance, a max_seq_len as a multiple of 8 is recommended.')
if self.tokenizer.pad_token_id is None:
raise ValueError(f'{self.__class__.__name__} requires that the tokenizer has the pad token set, but it is None')
validate_target_settings(self.target_prompts, self.target_responses, self.decoder_only_format)
if self.target_prompts.startswith('length'):
self.prompt_cutoff = int(self.target_prompts.split('>=')[-1])
self.prompt_to_target = _TARGET_POLICY_LOOKUP['length']
else:
self.prompt_cutoff = None
self.prompt_to_target = _TARGET_POLICY_LOOKUP[self.target_prompts]
self.response_to_target = _TARGET_POLICY_LOOKUP[self.target_responses]
self._warned_truncated = False
self._warned_context = False
self._warned_target = False
def __call__(self, examples: List[TokenizedExample]) -> Dict[str, torch.Tensor]:
for check_key in ['input_ids', 'labels']:
if check_key not in examples[0]['turns'][0]:
raise KeyError(f'Examples returned by dataset do not include required key: {check_key}')
if self.decoder_only_format:
batch = self._process_and_batch_decoder_only(examples)
else:
batch = self._process_and_batch_encoder_decoder(examples)
batch_size = batch['input_ids'].shape[0]
batch.update({k: torch.tensor([v] * batch_size) for k, v in self.batch_metadata.items()})
return batch
def _process_and_batch_decoder_only(self, examples: List[TokenizedExample]) -> Dict[str, torch.Tensor]:
processed_examples = []
for example in examples:
input_ids, labels = stitch_turns_decoder_only(example_turns=example['turns'], target_prompts=self.target_prompts, target_responses=self.target_responses, eos_token_id=self.tokenizer.eos_token_id)
orig_size = len(input_ids)
if orig_size > self.max_seq_len:
input_ids = input_ids[:self.max_seq_len]
labels = labels[:self.max_seq_len]
if len([l for l in labels if l != _HF_IGNORE_INDEX]) == 0:
raise ValueError(f'Truncating to max_seq_len={self.max_seq_len} has removed all loss-generating tokens. ' + f'Pre-truncation sequence length was {orig_size}. ' + 'This sample should have been filtered out before reaching the collator. If using ' + 'pre-tokenized streaming data, this may have resulted from using different ' + '``target_prompts``, ``target_responses``, or ``max_seq_len`` ' + 'settings when preparing the streaming dataset than what are currently being used.')
if not self._warned_truncated:
warnings.warn(f'Truncating sequence of length={orig_size} to fit max_seq_len={self.max_seq_len}. ' + f'If truncation is a problem, consider increasing max_seq_len.')
self._warned_truncated = True
attention_mask = [1] * len(input_ids)
n_total = len(input_ids)
i_pad = [_HF_IGNORE_INDEX] * (self.max_seq_len - n_total)
if self.tokenizer.padding_side == 'left':
labels = i_pad + labels
else:
labels = labels + i_pad
processed_example = {'input_ids': input_ids, 'labels': labels, 'attention_mask': attention_mask}
processed_examples.append(processed_example)
batch = self.tokenizer.pad(processed_examples, padding='max_length', max_length=self.max_seq_len, return_tensors='pt')
batch['sequence_id'] = batch['attention_mask'] - 1
if not (self._allow_pad_trimming and self._seen_first_batch):
self._seen_first_batch = True
return batch
self._seen_first_batch = True
multiple_of = 8
n_non_padding = batch['attention_mask'].sum(dim=1).max()
keep_tokens = int(multiple_of * torch.ceil(n_non_padding / multiple_of))
for k, v in batch.items():
if len(v.shape) < 2:
continue
if self.tokenizer.padding_side == 'left':
batch[k] = v[:, -keep_tokens:].contiguous()
else:
batch[k] = v[:, :keep_tokens].contiguous()
return batch
def _process_and_batch_encoder_decoder(self, examples: List[TokenizedExample]) -> Dict[str, torch.Tensor]:
processed_examples = []
for example in examples:
context, target = stitch_turns_encoder_decoder(example_turns=example['turns'], eos_token_id=self.tokenizer.eos_token_id)
if len(target) < self.max_seq_len:
i_pad = [_HF_IGNORE_INDEX] * (self.max_seq_len - len(target))
target = target + i_pad
else:
if not self._warned_target:
warnings.warn(f'Truncating TARGET sequence of length={len(target)} ' + f'to max_seq_len={self.max_seq_len}. If truncation is ' + f'a problem, consider increasing max_seq_len.')
self._warned_target = True
target = target[:self.max_seq_len - 1] + [self.tokenizer.eos_token_id]
if len(context) > self.max_seq_len:
if not self._warned_context:
warnings.warn(f'Truncating CONTEXT sequence of length={len(context)} ' + f'to max_seq_len={self.max_seq_len}. If truncation is ' + f'a problem, consider increasing max_seq_len.')
self._warned_context = True
context = context[:self.max_seq_len - 1] + [self.tokenizer.eos_token_id]
processed_example = {'input_ids': context, 'labels': target, 'attention_mask': [1] * len(context)}
processed_examples.append(processed_example)
batch = self.tokenizer.pad(processed_examples, padding='max_length', max_length=self.max_seq_len, return_tensors='pt')
batch['decoder_input_ids'] = torch.cat([torch.full((len(processed_examples), 1), self.tokenizer.pad_token_id), batch['labels'][:, :-1]], dim=1)
batch['decoder_input_ids'].masked_fill_(batch['decoder_input_ids'] == _HF_IGNORE_INDEX, self.tokenizer.pad_token_id)
batch['decoder_attention_mask'] = torch.not_equal(batch['labels'], _HF_IGNORE_INDEX)
if not (self._allow_pad_trimming and self._seen_first_batch):
self._seen_first_batch = True
return batch
self._seen_first_batch = True
multiple_of = 8
n_non_padding = batch['attention_mask'].sum(dim=1).max()
keep_tokens = int(multiple_of * torch.ceil(n_non_padding / multiple_of))
for k in ['input_ids', 'attention_mask']:
batch[k] = batch[k][:, :keep_tokens].contiguous()
n_non_padding = batch['decoder_attention_mask'].sum(dim=1).max()
keep_tokens = int(multiple_of * torch.ceil(n_non_padding / multiple_of))
for k in ['decoder_input_ids', 'decoder_attention_mask', 'labels']:
batch[k] = batch[k][:, :keep_tokens].contiguous()
return batch |