Jingjing Zhai
Remove dependency on external yairschiff/caduceus_base dependency; switch to self-contained config and local model files
8c41e4d
| """Character tokenizer for Hugging Face. | |
| """ | |
| from typing import List, Optional, Dict, Sequence, Tuple | |
| from transformers import PreTrainedTokenizer | |
| class CaduceusTokenizer(PreTrainedTokenizer): | |
| model_input_names = ["input_ids"] | |
| def __init__(self, | |
| model_max_length: int, | |
| characters: Sequence[str] = ("A", "C", "G", "T", "N"), | |
| complement_map=None, | |
| bos_token="[BOS]", | |
| eos_token="[SEP]", | |
| sep_token="[SEP]", | |
| cls_token="[CLS]", | |
| pad_token="[PAD]", | |
| mask_token="[MASK]", | |
| unk_token="[UNK]", | |
| **kwargs): | |
| """Character tokenizer for Hugging Face transformers. | |
| Adapted from https://huggingface.co/LongSafari/hyenadna-tiny-1k-seqlen-hf/blob/main/tokenization_hyena.py | |
| Args: | |
| model_max_length (int): Model maximum sequence length. | |
| characters (Sequence[str]): List of desired characters. Any character which | |
| is not included in this list will be replaced by a special token called | |
| [UNK] with id=6. Following is a list of the special tokens with | |
| their corresponding ids: | |
| "[CLS]": 0 | |
| "[SEP]": 1 | |
| "[BOS]": 2 | |
| "[MASK]": 3 | |
| "[PAD]": 4 | |
| "[RESERVED]": 5 | |
| "[UNK]": 6 | |
| an id (starting at 7) will be assigned to each character. | |
| complement_map (Optional[Dict[str, str]]): Dictionary with string complements for each character. | |
| """ | |
| if complement_map is None: | |
| complement_map = {"A": "T", "C": "G", "G": "C", "T": "A"} | |
| self.characters = characters | |
| self.model_max_length = model_max_length | |
| self._vocab_str_to_int = { | |
| "[CLS]": 0, | |
| "[SEP]": 1, | |
| "[BOS]": 2, | |
| "[MASK]": 3, | |
| "[PAD]": 4, | |
| "[RESERVED]": 5, | |
| "[UNK]": 6, | |
| **{ch: i + 7 for i, ch in enumerate(self.characters)}, | |
| } | |
| self._vocab_int_to_str = {v: k for k, v in self._vocab_str_to_int.items()} | |
| add_prefix_space = kwargs.pop("add_prefix_space", False) | |
| padding_side = kwargs.pop("padding_side", "left") | |
| self._complement_map = {} | |
| for k, v in self._vocab_str_to_int.items(): | |
| complement_id = self._vocab_str_to_int[complement_map[k]] if k in complement_map.keys() else v | |
| self._complement_map[self._vocab_str_to_int[k]] = complement_id | |
| super().__init__( | |
| bos_token=bos_token, | |
| eos_token=eos_token, | |
| sep_token=sep_token, | |
| cls_token=cls_token, | |
| pad_token=pad_token, | |
| mask_token=mask_token, | |
| unk_token=unk_token, | |
| add_prefix_space=add_prefix_space, | |
| model_max_length=model_max_length, | |
| padding_side=padding_side, | |
| **kwargs, | |
| ) | |
| def vocab_size(self) -> int: | |
| return len(self._vocab_str_to_int) | |
| def complement_map(self) -> Dict[int, int]: | |
| return self._complement_map | |
| def _tokenize(self, text: str, **kwargs) -> List[str]: | |
| return list(text.upper()) # Convert all base pairs to uppercase | |
| def _convert_token_to_id(self, token: str) -> int: | |
| return self._vocab_str_to_int.get(token, self._vocab_str_to_int["[UNK]"]) | |
| def _convert_id_to_token(self, index: int) -> str: | |
| return self._vocab_int_to_str[index] | |
| def convert_tokens_to_string(self, tokens): | |
| return "".join(tokens) # Note: this operation has lost info about which base pairs were originally lowercase | |
| def get_special_tokens_mask( | |
| self, | |
| token_ids_0: List[int], | |
| token_ids_1: Optional[List[int]] = None, | |
| already_has_special_tokens: bool = False, | |
| ) -> List[int]: | |
| if already_has_special_tokens: | |
| return super().get_special_tokens_mask( | |
| token_ids_0=token_ids_0, | |
| token_ids_1=token_ids_1, | |
| already_has_special_tokens=True, | |
| ) | |
| result = ([0] * len(token_ids_0)) + [1] | |
| if token_ids_1 is not None: | |
| result += ([0] * len(token_ids_1)) + [1] | |
| return result | |
| def build_inputs_with_special_tokens( | |
| self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None | |
| ) -> List[int]: | |
| sep = [self.sep_token_id] | |
| # cls = [self.cls_token_id] | |
| result = token_ids_0 + sep | |
| if token_ids_1 is not None: | |
| result += token_ids_1 + sep | |
| return result | |
| def get_vocab(self) -> Dict[str, int]: | |
| return self._vocab_str_to_int | |
| # Fixed vocabulary with no vocab file | |
| def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple: | |
| return () | |