File size: 1,798 Bytes
023f73f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
"""Tokenization classes for Arctic."""
from typing import Any, Dict, Optional
from transformers.models.llama import LlamaTokenizer
class ArcticTokenizer(LlamaTokenizer):
def __init__(
self,
vocab_file,
unk_token="<unk>",
bos_token="<s>",
eos_token="</s>",
pad_token=None,
sp_model_kwargs: Optional[Dict[str, Any]] = None,
add_bos_token=True,
add_eos_token=False,
clean_up_tokenization_spaces=False,
use_default_system_prompt=False,
spaces_between_special_tokens=False,
legacy=False,
add_prefix_space=True,
**kwargs,
):
# Same as LlamaTokenizer except default legacy=False.
super().__init__(
vocab_file,
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
pad_token=pad_token,
sp_model_kwargs=sp_model_kwargs,
add_bos_token=add_bos_token,
add_eos_token=add_eos_token,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
use_default_system_prompt=use_default_system_prompt,
spaces_between_special_tokens=spaces_between_special_tokens,
legacy=legacy,
add_prefix_space=add_prefix_space,
**kwargs,
)
@property
def default_chat_template(self):
"""
This template formats inputs in the standard Arctic format.
"""
return (
"{% for message in messages %}"
"{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
"{% endfor %}"
"{% if add_generation_prompt %}"
"{{ '<|im_start|>assistant\n' }}"
"{% endif %}"
)
|