File size: 2,900 Bytes
ce24f5e 8d959a7 ce24f5e 8d959a7 ce24f5e a6028d3 ce24f5e 8d959a7 ce24f5e 8d959a7 ce24f5e 8d959a7 f2a2029 8d959a7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import abc
from transformers import PreTrainedTokenizer
IGNORE_INDEX = -100
LLAMA_DEFAULT_PAD_TOKEN = "[PAD]"
LLAMA_DEFAULT_EOS_TOKEN = "</s>"
LLAMA_DEFAULT_BOS_TOKEN = "<s>"
LLAMA_DEFAULT_UNK_TOKEN = "<unk>"
class InvalidDataException(Exception):
pass
class PromptTokenizingStrategy(abc.ABC):
def __init__(
self,
prompter,
tokenizer,
train_on_inputs: bool = False,
sequence_len: int = 2048,
):
self.prompter = prompter
self.tokenizer: PreTrainedTokenizer = tokenizer
self.train_on_inputs = train_on_inputs
self.sequence_len = sequence_len
@abc.abstractmethod
def tokenize_prompt(self, prompt):
pass
class AlpacaPromptTokenizingStrategy(PromptTokenizingStrategy):
def tokenize_prompt(self, prompt):
full_prompt = self._tokenize_full_prompt(prompt)
tokenized_full_prompt = self._tokenize(full_prompt)
if not self.train_on_inputs:
user_prompt = self.prompter.build_prompt(
prompt["instruction"], prompt["input"]
)
tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
user_prompt_len = len(tokenized_user_prompt["input_ids"])
# TODO this could be sped up using numpy array slicing
tokenized_full_prompt["labels"] = [
-100
] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:]
return tokenized_full_prompt
def _tokenize_full_prompt(self, prompt):
return self.prompter.build_prompt(
prompt["instruction"],
prompt["input"],
prompt["output"],
)
def _tokenize(self, prompt, add_eos_token=True):
result = self.tokenizer(
prompt,
truncation=True,
max_length=self.sequence_len,
padding=False,
return_tensors=None,
)
if (
result["input_ids"][-1] != self.tokenizer.eos_token_id
and len(result["input_ids"]) < self.sequence_len
and add_eos_token
):
result["input_ids"].append(self.tokenizer.eos_token_id)
result["attention_mask"].append(1)
result["labels"] = result["input_ids"].copy()
return result
class GPTeacherPromptTokenizingStrategy(AlpacaPromptTokenizingStrategy):
def _tokenize_full_prompt(self, prompt):
return self.prompter.build_prompt(
prompt["instruction"],
prompt["input"],
prompt["response"],
)
class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
def tokenize_prompt(self, prompt):
try:
return self.prompter.build_prompt(prompt["conversations"], self.tokenizer)
except (KeyError, AssertionError, IndexError) as e:
raise InvalidDataException(str(e))
|