| | """ |
| | Custom Gemma Tokenizer for explicit Format |
| | |
| | This tokenizer implements the explicit format for message processing: |
| | Format: Uses the standard chat template with proper role labels (user/assistant) |
| | |
| | The explicit format uses the model's built-in chat template and includes proper |
| | loss computation flags for training. |
| | |
| | To save: |
| | uv run tokenizers/gemma_explicit_tokenizer.py |
| | which will save the tokenizer to the repos/explicit-gemma-tokenizer directory. |
| | mkdir repos/explicit12b |
| | # copy model over |
| | cp models_v8/base_modified-google-gemma-3-12b-pt-/models/_explicit/checkpoint-8/* repos/explicit12b/ |
| | # copy tokenizer over |
| | cp repos/explicit-gemma-tokenizer/* repos/explicit12b/ |
| | # upload to hf |
| | |
| | uv run upload_to_hf.py \ |
| | --folder repos/explicit12b \ |
| | --repo-id tsor13/explicit12b |
| | """ |
| |
|
| | from typing import List, Dict, Any, Optional, Union |
| | from transformers import AutoTokenizer |
| | from transformers.models.gemma.tokenization_gemma_fast import GemmaTokenizerFast |
| | from transformers.models.gemma.tokenization_gemma import GemmaTokenizer |
| | import warnings |
| | import difflib |
| | import json |
| | import os |
| | import sys |
| |
|
| | |
| | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| | from chat_utils import chat_messages_to_text_loss, chat_messages_to_raw_text |
| |
|
| |
|
| | class GemmaExplicitTokenizer(GemmaTokenizerFast): |
| | """ |
| | Custom tokenizer for Gemma models that implements explicit format message processing. |
| | |
| | This tokenizer formats messages using the explicit format where: |
| | - Messages use the standard chat template with proper role labels |
| | - Uses the model's built-in chat formatting |
| | - Loss is computed on the assistant/output sections |
| | |
| | Attributes: |
| | start_string (str): The starting string used for output generation (depends on tokenizer) |
| | end_string (str): The ending string used for output generation (depends on tokenizer) |
| | """ |
| | |
| | def __init__(self, *args, **kwargs): |
| | """ |
| | Initialize the custom tokenizer. |
| | |
| | Accepts the same arguments as GemmaTokenizerFast. |
| | """ |
| | super().__init__(*args, **kwargs) |
| | |
| | |
| | |
| | self.start_string = None |
| | self.end_string = None |
| | |
| | |
| | if not hasattr(self, 'init_kwargs'): |
| | self.init_kwargs = {} |
| | self.init_kwargs['start_string'] = self.start_string |
| | self.init_kwargs['end_string'] = self.end_string |
| | |
| | @classmethod |
| | def from_gemma_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): |
| | """ |
| | Load a tokenizer from a pretrained model or path. |
| | |
| | This method ensures our custom class is used instead of the base GemmaTokenizerFast. |
| | """ |
| | |
| | base_tokenizer = GemmaTokenizerFast.from_pretrained( |
| | pretrained_model_name_or_path, *args, **kwargs |
| | ) |
| | |
| | |
| | custom_tokenizer = cls.__new__(cls) |
| | |
| | |
| | for attr, value in base_tokenizer.__dict__.items(): |
| | setattr(custom_tokenizer, attr, value) |
| | |
| | |
| | custom_tokenizer.start_string = None |
| | custom_tokenizer.end_string = None |
| | |
| | |
| | if not hasattr(custom_tokenizer, 'init_kwargs'): |
| | custom_tokenizer.init_kwargs = {} |
| | custom_tokenizer.init_kwargs['start_string'] = custom_tokenizer.start_string |
| | custom_tokenizer.init_kwargs['end_string'] = custom_tokenizer.end_string |
| | |
| | return custom_tokenizer |
| | |
| | def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs): |
| | """ |
| | Save the tokenizer to a directory, including custom configuration. |
| | """ |
| | |
| | super().save_pretrained(save_directory, **kwargs) |
| | |
| | |
| | config_file = os.path.join(save_directory, "tokenizer_config.json") |
| | if os.path.exists(config_file): |
| | with open(config_file, 'r') as f: |
| | config = json.load(f) |
| | else: |
| | config = {} |
| | |
| | |
| | config["tokenizer_class"] = "GemmaExplicitTokenizer" |
| | config["start_string"] = self.start_string |
| | config["end_string"] = self.end_string |
| | |
| | config["auto_map"] = { |
| | "AutoTokenizer": ["gemma_explicit_tokenizer.GemmaExplicitTokenizer", "gemma_explicit_tokenizer.GemmaExplicitTokenizer"] |
| | } |
| | |
| | with open(config_file, 'w') as f: |
| | json.dump(config, f, indent=2) |
| | |
| | def messages_to_loss_texts( |
| | self, |
| | messages: List[Dict[str, Any]], |
| | loss_on_start_token: bool = False, |
| | ) -> List[Dict[str, Any]]: |
| | """ |
| | From messages (description / input / output) to texts (text / compute_loss) with whether or not loss should be calculated on the text for training. |
| | Uses the explicit format from chat_utils. |
| | """ |
| | return chat_messages_to_text_loss(messages, self, loss_on_start_token, start_gen_as="output") |
| | |
| | def messages_to_text( |
| | self, |
| | messages: List[Dict[str, Any]], |
| | start_generation: bool = False, |
| | ) -> str: |
| | """ |
| | Messages (description / input / output) to raw text (text). |
| | Uses the explicit format from chat_utils. |
| | """ |
| | return chat_messages_to_raw_text(messages, self, start_generation=start_generation, start_gen_as="output") |
| | |
| |
|
| | def tokenize_messages( |
| | self, |
| | messages: List[Dict[str, Any]] | List[List[Dict[str, Any]]], |
| | start_generation: bool = False, |
| | **kwargs, |
| | ): |
| | """ |
| | For tokenizing from messages to texts. Supports batching. Good for generation |
| | """ |
| | if isinstance(messages, list) and isinstance(messages[0], list): |
| | |
| | all_texts = [] |
| | for message_list in messages: |
| | texts = self.messages_to_text(message_list, start_generation) |
| | all_texts.append(texts) |
| | else: |
| | |
| | texts = self.messages_to_text(messages, start_generation) |
| | all_texts = [texts] |
| | |
| | |
| | processed = self(text=all_texts, **kwargs) |
| | return processed |
| | |
| | |
| | def tokenize_loss_texts( |
| | self, |
| | texts: List[Dict[str, Any]], |
| | loss_on_start_token: bool = False, |
| | loss_on_eos: bool = False, |
| | include_eos: bool = True, |
| | ): |
| | """ |
| | Tokenize texts (text / compute_loss) to tokenized texts (input_ids / attention_mask / labels). |
| | |
| | Needs more complex logic to handle the back and forth labeling. |
| | """ |
| | if loss_on_eos: |
| | raise ValueError("Loss on EOS is not currently supported.") |
| | |
| | |
| | if isinstance(texts, str): |
| | processed = self(text=texts) |
| | |
| | if (self.eos_token_id is not None and |
| | processed["input_ids"][-1] != self.eos_token_id): |
| | processed["input_ids"] = processed["input_ids"] + [self.eos_token_id] |
| | processed["attention_mask"] = processed["attention_mask"] + [1] |
| | return processed |
| | |
| | |
| | all_processed = [] |
| | all_texts = '' |
| | example_inds = [] |
| | dataset_inds = [] |
| | |
| | for i, item in enumerate(texts): |
| | processed = self(text=item["text"]) |
| | |
| | |
| | if i != 0 and self.bos_token_id == processed["input_ids"][0]: |
| | processed["input_ids"] = processed["input_ids"][1:] |
| | processed["attention_mask"] = processed["attention_mask"][1:] |
| | |
| | |
| | if processed["input_ids"][-1] == self.eos_token_id: |
| | processed["input_ids"] = processed["input_ids"][:-1] |
| | processed["attention_mask"] = processed["attention_mask"][:-1] |
| | |
| | |
| | if self.eos_token_id in processed["input_ids"]: |
| | if not self.decode([self.eos_token_id]) == "<|im_end|>": |
| | raise ValueError(f"EOS token is present in input_ids: {processed['input_ids']}. Not currently supported.") |
| | |
| | |
| | if item["compute_loss"]: |
| | processed["labels"] = processed["input_ids"].copy() |
| | else: |
| | processed["labels"] = [-100] * len(processed["input_ids"]) |
| | |
| | |
| | if all_processed: |
| | if processed["input_ids"][0] == self.bos_token_id: |
| | processed["input_ids"] = processed["input_ids"][1:] |
| | processed["attention_mask"] = processed["attention_mask"][1:] |
| | processed["labels"] = processed["labels"][1:] |
| | |
| | all_processed.append(processed) |
| | all_texts += item["text"] |
| | |
| | |
| | this_num = -1 |
| | if 'example_ind' in item.keys(): |
| | if item["example_ind"] is not None: |
| | this_num = item["example_ind"] |
| | example_inds.extend([this_num] * len(processed["input_ids"])) |
| | |
| | |
| | dataset_ind = -1 |
| | if "data_id" in item.keys(): |
| | if item["data_id"] is not None: |
| | dataset_ind = item["data_id"] |
| | dataset_inds.extend([dataset_ind] * len(processed["input_ids"])) |
| | |
| | |
| | processed = all_processed[0].copy() |
| | processed["input_ids"] = [item for sublist in [p["input_ids"] for p in all_processed] for item in sublist] |
| | processed["attention_mask"] = [item for sublist in [p["attention_mask"] for p in all_processed] for item in sublist] |
| | processed["labels"] = [item for sublist in [p["labels"] for p in all_processed] for item in sublist] |
| | processed["example_inds"] = example_inds |
| | processed["data_ids"] = dataset_inds |
| | |
| | |
| | processed_all = self(text=all_texts) |
| | if len(processed_all["input_ids"]) != len(processed["input_ids"]): |
| | warnings.warn(f"All texts are not the same length as the first text. Please check your dataset. {len(processed_all['input_ids'])} != {len(processed['input_ids'])}") |
| | |
| | |
| | all_text = self.decode(processed_all["input_ids"], skip_special_tokens=False) |
| | processed_text = self.decode(processed["input_ids"], skip_special_tokens=False) |
| | |
| | diff = difflib.unified_diff(all_text.splitlines(), processed_text.splitlines()) |
| | diff_str = "\n".join(diff) |
| | print("Diff between texts:") |
| | print(diff_str) |
| | |
| | |
| | all_tokens_str = '\n'.join([str(s) for s in processed_all["input_ids"]]) |
| | processed_tokens_str = '\n'.join([str(s) for s in processed["input_ids"]]) |
| | token_diff = difflib.unified_diff(all_tokens_str.splitlines(), processed_tokens_str.splitlines()) |
| | token_diff_str = "\n".join(token_diff) |
| | print("Diff between tokenized texts:") |
| | print(token_diff_str) |
| | |
| | |
| | if (self.eos_token_id is not None and |
| | processed["input_ids"][-1] != self.eos_token_id): |
| | processed["input_ids"] = processed["input_ids"] + [self.eos_token_id] |
| | processed["example_inds"] = processed["example_inds"] + [-1] |
| | processed["attention_mask"] = processed["attention_mask"] + [1] |
| | if processed["labels"] is not None: |
| | if loss_on_eos: |
| | processed["labels"] = processed["labels"] + [self.eos_token_id] |
| | else: |
| | processed["labels"] = processed["labels"] + [-100] |
| | if "data_ids" in processed: |
| | processed["data_ids"] = processed["data_ids"] + [-1] |
| | |
| | if not include_eos: |
| | |
| | if processed["input_ids"][-1] == self.eos_token_id: |
| | |
| | processed["input_ids"] = processed["input_ids"][:-1] |
| | processed["attention_mask"] = processed["attention_mask"][:-1] |
| | processed["labels"] = processed["labels"][:-1] |
| | processed["example_inds"] = processed["example_inds"][:-1] |
| | processed["data_ids"] = processed["data_ids"][:-1] |
| | |
| | return processed |
| | |
| | def tokenize_messages( |
| | self, |
| | messages: List[Dict[str, Any]], |
| | loss_on_start_token: bool = False, |
| | loss_on_eos: bool = False, |
| | include_eos: bool = True, |
| | ) -> Dict[str, Any]: |
| | """ |
| | Intended for tokenize from messages to tokenized texts with the loss applied. |
| | """ |
| | |
| | texts = self.messages_to_loss_texts(messages, loss_on_start_token) |
| | |
| | |
| | return self.tokenize_loss_texts(texts, loss_on_eos, include_eos = include_eos) |
| | |
| |
|
| |
|
| |
|
| | |
| | AutoTokenizer.register("GemmaExplicitTokenizer", slow_tokenizer_class=None, fast_tokenizer_class=GemmaExplicitTokenizer) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| | |
| | custom_tokenizer = GemmaExplicitTokenizer.from_gemma_pretrained("google/gemma-3-1b-it") |
| |
|
| | |
| | |
| | |
| | |
| | |
| | messages = [ |
| | {"role": "description", "content": "This is a test task"}, |
| | {"role": "input", "content": "What is 2+2?"}, |
| | {"role": "output", "content": "4"}, |
| | {"role": "input", "content": "What is 3+3?"}, |
| | |
| | ] |
| |
|
| | |
| | texts = custom_tokenizer.messages_to_loss_texts(messages) |
| | print("Texts with loss flags:") |
| | for i, text in enumerate(texts): |
| | print(f" {i}: {text}") |
| |
|
| | text = custom_tokenizer.messages_to_text(messages, start_generation=True) |
| | print(f"\nFull text with generation prompt:") |
| | print(text) |
| | |
| | print("\nTesting save/load cycle:") |
| | |
| | tokenizer_path = "repos/explicit-gemma-tokenizer" |
| | custom_tokenizer.save_pretrained(tokenizer_path) |
| | print("Tokenizer saved successfully!") |
| |
|
| | |
| | import shutil |
| | shutil.copy(__file__, os.path.join(tokenizer_path, "gemma_explicit_tokenizer.py")) |
| | print("GemmaExplicitTokenizer.py saved successfully!") |