| from typing import List |
|
|
| from torchtune.modules.tokenizers import TikTokenTokenizer |
| from torchtune.modules.tokenizers._utils import _split_long_repetitions |
| from torchtune.modules.tokenizers._tiktoken import ( |
| MAX_ENCODE_CHARS, |
| MAX_NO_WHITESPACE_CHARS, |
| ALL_SPECIAL_TOKENS, |
| ) |
| |
|
|
| |
| START_IMAGE = "<|start_image|>" |
| END_IMAGE = "<|end_image|>" |
| START_VIDEO = "<|start_video|>" |
| END_VIDEO = "<|end_video|>" |
| START_AUDIO = "<|start_audio|>" |
| END_AUDIO = "<|end_audio|>" |
|
|
| A2A_SPECIAL_TOKENS = ALL_SPECIAL_TOKENS[:-2] + [ |
| START_IMAGE, |
| END_IMAGE, |
| START_VIDEO, |
| END_VIDEO, |
| START_AUDIO, |
| END_AUDIO, |
| ] + ALL_SPECIAL_TOKENS[-2:] |
|
|
| |
| class A2ATokenizer(TikTokenTokenizer): |
| def encode( |
| self, |
| text: str, |
| add_bos: bool, |
| add_eos: bool, |
| ) -> List[int]: |
| """ |
| Encode a string into a list of token ids. Assumes that the string |
| contains no special tokens. |
| |
| Args: |
| text (str): The string to encode. |
| add_bos (bool): Whether to add the beginning of sequence token. |
| add_eos (bool): Whether to add the end of sequence token. |
| |
| Returns: |
| List[int]: The list of token ids. |
| """ |
| substrs: List[str] = [] |
| tokens = [] |
| for i in range(0, len(text), MAX_ENCODE_CHARS): |
| substr = text[i : i + MAX_ENCODE_CHARS] |
| |
| sliced_substr = _split_long_repetitions(substr, MAX_NO_WHITESPACE_CHARS) |
| substrs.extend(sliced_substr) |
| for substr in substrs: |
| |
| |
| |
| |
| tokens.extend( |
| self.tt_model.encode( |
| substr, |
| allowed_special=set([ |
| START_IMAGE, |
| END_IMAGE, |
| START_VIDEO, |
| END_VIDEO, |
| START_AUDIO, |
| END_AUDIO, |
| ]), |
| disallowed_special=(), |
| ) |
| ) |
| if add_bos: |
| tokens.insert(0, self.bos_id) |
| if add_eos: |
| tokens.append(self.eos_id) |
| return tokens |
|
|
|
|
| def a2a_tokenizer(path: str) -> TikTokenTokenizer: |
| tiktoken = A2ATokenizer(path, all_special_tokens=A2A_SPECIAL_TOKENS) |
| tiktoken.pad_id = 0 |
| return tiktoken |