|
import codecs |
|
import logging |
|
import os |
|
from typing import Any |
|
|
|
from transformers import AutoTokenizer |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
TEXT_SEPARATOR = "<TEXT_SEPARATOR>" |
|
|
|
|
|
def get_texts(df, cfg, separator=None): |
|
if isinstance(cfg.dataset.prompt_column, str): |
|
|
|
texts = df[cfg.dataset.prompt_column].astype(str) |
|
texts = texts.values |
|
else: |
|
|
|
columns = list(cfg.dataset.prompt_column) |
|
|
|
for column in columns: |
|
df[column] = df[column].astype(str) |
|
|
|
if separator is None: |
|
separator = getattr(cfg, "_tokenizer_sep_token", TEXT_SEPARATOR) |
|
|
|
join_str = f" {separator} " |
|
texts = df[columns].astype(str) |
|
texts = texts.apply(lambda x: join_str.join(x), axis=1).values |
|
|
|
return texts |
|
|
|
|
|
def get_tokenizer(cfg: Any): |
|
kwargs = dict( |
|
revision=cfg.environment.huggingface_branch, |
|
use_fast=cfg.tokenizer.use_fast, |
|
trust_remote_code=cfg.environment.trust_remote_code, |
|
token=os.getenv("HUGGINGFACE_TOKEN"), |
|
) |
|
|
|
try: |
|
tokenizer = AutoTokenizer.from_pretrained(cfg.llm_backbone, **kwargs) |
|
except TypeError as e: |
|
error_message = str(e) |
|
if "token" in error_message: |
|
|
|
|
|
kwargs.pop("token") |
|
tokenizer = AutoTokenizer.from_pretrained(cfg.llm_backbone, **kwargs) |
|
elif "not a string" in error_message: |
|
|
|
kwargs.pop("add_prefix_space") |
|
tokenizer = AutoTokenizer.from_pretrained(cfg.llm_backbone, **kwargs) |
|
else: |
|
raise e |
|
|
|
tokenizer.padding_side = getattr( |
|
cfg.tokenizer, "_padding_side", tokenizer.padding_side |
|
) |
|
|
|
|
|
if tokenizer.eos_token == "": |
|
tokenizer.add_special_tokens({"eos_token": "</s>"}) |
|
tokenizer.eos_token = "</s>" |
|
|
|
if tokenizer.pad_token is None: |
|
if tokenizer.unk_token is not None: |
|
tokenizer.pad_token = tokenizer.unk_token |
|
else: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
if tokenizer.bos_token is None: |
|
tokenizer.bos_token = tokenizer.eos_token |
|
if tokenizer.cls_token is None: |
|
tokenizer.cls_token = tokenizer.eos_token |
|
if tokenizer.sep_token is None: |
|
tokenizer.sep_token = tokenizer.eos_token |
|
|
|
cfg._tokenizer_sep_token = tokenizer.sep_token |
|
|
|
if tokenizer.unk_token_id is not None: |
|
cfg._tokenizer_mask_token_id = tokenizer.unk_token_id |
|
elif tokenizer.mask_token_id is not None: |
|
cfg._tokenizer_mask_token_id = tokenizer.mask_token_id |
|
elif tokenizer.pad_token_id is not None: |
|
cfg._tokenizer_mask_token_id = tokenizer.pad_token_id |
|
else: |
|
|
|
|
|
cfg._tokenizer_mask_token_id = len(tokenizer) - 1 |
|
|
|
cfg._tokenizer_eos_token = tokenizer.eos_token |
|
|
|
if hasattr(cfg.prediction, "stop_tokens"): |
|
set_stop_token_ids(cfg, tokenizer) |
|
cfg.tokenizer._vocab_length = len(tokenizer) |
|
|
|
return tokenizer |
|
|
|
|
|
def set_stop_token_ids(cfg, tokenizer): |
|
cfg.tokenizer._stop_words = list( |
|
filter(None, cfg.prediction.stop_tokens.split(",")) |
|
) |
|
for stop_word in [ |
|
cfg.dataset.text_system_start, |
|
cfg.dataset.text_prompt_start, |
|
cfg.dataset.text_answer_separator, |
|
]: |
|
stop_word = codecs.decode(stop_word, "unicode_escape").strip() |
|
if ( |
|
stop_word != "" |
|
and cfg.tokenizer.add_prompt_answer_tokens |
|
and (stop_word not in tokenizer.get_vocab()) |
|
): |
|
tokenizer.add_tokens([stop_word]) |
|
cfg.tokenizer._stop_words.append(stop_word) |
|
cfg.tokenizer._stop_words = [ |
|
stop_word for stop_word in cfg.tokenizer._stop_words if stop_word != "" |
|
] |
|
cfg.tokenizer._stop_words_ids = [] |
|
for stop_word in set(cfg.tokenizer._stop_words): |
|
cfg.tokenizer._stop_words_ids.append( |
|
tokenizer(stop_word, return_tensors="pt", add_special_tokens=False)[ |
|
"input_ids" |
|
][0] |
|
) |
|
if cfg.environment._local_rank == 0: |
|
logger.info(f"Stop token ids: {cfg.tokenizer._stop_words_ids}") |
|
|