|
"""Module containing the PygmalionPromptTokenizingStrategy and PygmalionPrompter class""" |
|
|
|
import copy |
|
import logging |
|
from collections import defaultdict |
|
from typing import Generator, List, Tuple |
|
|
|
from axolotl.prompt_tokenizers import ( |
|
PromptTokenizingStrategy, |
|
parse_tokenized_to_result, |
|
tokenize_prompt_default, |
|
) |
|
|
|
IGNORE_TOKEN_ID = -100 |
|
|
|
|
|
class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy): |
|
""" |
|
Tokenizing strategy for Pygmalion. |
|
""" |
|
|
|
bot_prefix_token_ids: List[int] = [] |
|
|
|
def __init__(self, prompter, tokenizer, *args, **kwargs): |
|
super().__init__(prompter, tokenizer, *args, **kwargs) |
|
res = self._tokenize("<|model|>", add_eos_token=False, strip_bos_token=True) |
|
self.bot_prefix_token_ids = res["input_ids"] |
|
|
|
def tokenize_prompt(self, prompt): |
|
result, current_len = tokenize_prompt_default() |
|
for _, part in enumerate(self.prompter.build_prompt(prompt["conversations"])): |
|
role, message = part |
|
if role == "system": |
|
prefix = "<|system|>" |
|
|
|
if message.endswith("\n<START>"): |
|
message = message[:-8] |
|
res = self._tokenize( |
|
prefix + "Persona: " + message.strip(), |
|
add_eos_token=False, |
|
strip_bos_token=False, |
|
) |
|
|
|
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) |
|
elif role == "human": |
|
prefix = "<|user|>" |
|
res = self._tokenize( |
|
prefix + " " + message.strip(), |
|
add_eos_token=False, |
|
strip_bos_token=True, |
|
) |
|
|
|
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) |
|
elif role == "bot": |
|
prefix = "<|model|>" |
|
res = self._tokenize( |
|
prefix + " " + message.strip(), |
|
add_eos_token=True, |
|
strip_bos_token=True, |
|
) |
|
|
|
|
|
labels = [IGNORE_TOKEN_ID] * len(self.bot_prefix_token_ids) + [ |
|
*copy.deepcopy(res["input_ids"]) |
|
][len(self.bot_prefix_token_ids) :] |
|
else: |
|
logging.warning(f"unknown role in conversation: {role}") |
|
res = defaultdict(lambda: []) |
|
|
|
|
|
result, current_len = parse_tokenized_to_result( |
|
result, |
|
current_len, |
|
res, |
|
labels, |
|
pad_token_id=self.tokenizer.pad_token_id, |
|
) |
|
return result |
|
|
|
|
|
class PygmalionPrompter: |
|
""" |
|
Prompter for Pygmalion. |
|
""" |
|
|
|
def __init__(self, *args, **kwargs): |
|
pass |
|
|
|
def build_prompt( |
|
self, source, *args, **kwargs |
|
) -> Generator[Tuple[str, str], None, None]: |
|
for msg in source: |
|
yield msg["role"], msg["value"] |
|
|
|
|
|
def load(tokenizer, cfg): |
|
return PygmalionPromptTokenizingStrategy( |
|
PygmalionPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len |
|
) |
|
|