| | from typing import List |
| |
|
| | from torchtune.modules.tokenizers import TikTokenTokenizer |
| | from torchtune.modules.tokenizers._utils import _split_long_repetitions |
| | from torchtune.modules.tokenizers._tiktoken import ( |
| | MAX_ENCODE_CHARS, |
| | MAX_NO_WHITESPACE_CHARS, |
| | ALL_SPECIAL_TOKENS, |
| | ) |
| | |
| |
|
| | |
| | START_IMAGE = "<|start_image|>" |
| | END_IMAGE = "<|end_image|>" |
| | START_VIDEO = "<|start_video|>" |
| | END_VIDEO = "<|end_video|>" |
| | START_AUDIO = "<|start_audio|>" |
| | END_AUDIO = "<|end_audio|>" |
| |
|
| | A2A_SPECIAL_TOKENS = ALL_SPECIAL_TOKENS[:-2] + [ |
| | START_IMAGE, |
| | END_IMAGE, |
| | START_VIDEO, |
| | END_VIDEO, |
| | START_AUDIO, |
| | END_AUDIO, |
| | ] + ALL_SPECIAL_TOKENS[-2:] |
| |
|
| | |
| | class A2ATokenizer(TikTokenTokenizer): |
| | def encode( |
| | self, |
| | text: str, |
| | add_bos: bool, |
| | add_eos: bool, |
| | ) -> List[int]: |
| | """ |
| | Encode a string into a list of token ids. Assumes that the string |
| | contains no special tokens. |
| | |
| | Args: |
| | text (str): The string to encode. |
| | add_bos (bool): Whether to add the beginning of sequence token. |
| | add_eos (bool): Whether to add the end of sequence token. |
| | |
| | Returns: |
| | List[int]: The list of token ids. |
| | """ |
| | substrs: List[str] = [] |
| | tokens = [] |
| | for i in range(0, len(text), MAX_ENCODE_CHARS): |
| | substr = text[i : i + MAX_ENCODE_CHARS] |
| | |
| | sliced_substr = _split_long_repetitions(substr, MAX_NO_WHITESPACE_CHARS) |
| | substrs.extend(sliced_substr) |
| | for substr in substrs: |
| | |
| | |
| | |
| | |
| | tokens.extend( |
| | self.tt_model.encode( |
| | substr, |
| | allowed_special=set([ |
| | START_IMAGE, |
| | END_IMAGE, |
| | START_VIDEO, |
| | END_VIDEO, |
| | START_AUDIO, |
| | END_AUDIO, |
| | ]), |
| | disallowed_special=(), |
| | ) |
| | ) |
| | if add_bos: |
| | tokens.insert(0, self.bos_id) |
| | if add_eos: |
| | tokens.append(self.eos_id) |
| | return tokens |
| |
|
| |
|
| | def a2a_tokenizer(path: str) -> TikTokenTokenizer: |
| | tiktoken = A2ATokenizer(path, all_special_tokens=A2A_SPECIAL_TOKENS) |
| | tiktoken.pad_id = 0 |
| | return tiktoken |