| |
|
| | from transformers import PreTrainedTokenizer |
| | import json |
| | import os |
| |
|
| | class ChessTokenizer(PreTrainedTokenizer): |
| | model_input_names = ["input_ids", "attention_mask"] |
| | |
| | def __init__(self, vocab_file="vocab.json", **kwargs): |
| | if os.path.exists(vocab_file): |
| | with open(vocab_file, 'r') as f: data = json.load(f) |
| | self.token_to_id = data["token_to_id"] |
| | self.id_to_token = {int(k): v for k, v in data["id_to_token"].items()} |
| | else: |
| | raise ValueError(f"Vocabulary file {vocab_file} not found.") |
| | |
| | self.unk_token = "[UNK]" |
| | self.pad_token = "[PAD]" |
| | self.bos_token = "[BOS]" |
| | self.eos_token = "[EOS]" |
| | |
| | self.bos_token_id = self.token_to_id.get("[BOS]") |
| | self.eos_token_id = self.token_to_id.get("[EOS]") |
| | self.unk_token_id = self.token_to_id.get("[UNK]") |
| | |
| | super().__init__(pad_token="[PAD]", bos_token="[BOS]", eos_token="[EOS]", unk_token="[UNK]", **kwargs) |
| | |
| | @property |
| | def vocab_size(self): return len(self.token_to_id) |
| | |
| | def get_vocab(self): return self.token_to_id |
| | |
| | def _convert_token_to_id(self, token): |
| | return self.token_to_id.get(token, self.unk_token_id) |
| | |
| | def _convert_id_to_token(self, index): |
| | return self.id_to_token.get(index, "[UNK]") |
| | |
| | def __call__(self, text, **kwargs): |
| | |
| | if isinstance(text, list): |
| | return {"input_ids": [self.__call__(t, **kwargs)["input_ids"] for t in text]} |
| | |
| | moves = text.split() |
| | ids = [self.token_to_id.get(m, self.unk_token_id) for m in moves] |
| | |
| | |
| | if self.bos_token_id is not None: |
| | ids = [self.bos_token_id] + ids |
| | if self.eos_token_id is not None: |
| | ids = ids + [self.eos_token_id] |
| |
|
| | max_len = kwargs.get('max_length', 256) |
| | if len(ids) > max_len: ids = ids[:max_len] |
| | |
| | return {"input_ids": ids} |
| | |
| | def save_pretrained(self, save_directory, **kwargs): |
| | with open(os.path.join(save_directory, "vocab.json"), "w") as f: |
| | json.dump({"token_to_id": self.token_to_id, "id_to_token": self.id_to_token}, f) |
| | with open(os.path.join(save_directory, "tokenizer_config.json"), "w") as f: |
| | json.dump({"model_type": "chess_transformer"}, f) |
| | |
| | @classmethod |
| | def from_pretrained(cls, path, **kwargs): |
| | vocab_path = os.path.join(path, "vocab.json") |
| | if os.path.exists(vocab_path): return cls(vocab_file=vocab_path, **kwargs) |
| | return cls(**kwargs) |
| |
|