File size: 2,387 Bytes
			
			210c84c  | 
								1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71  | 
								import os
import pickle
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import HfHubHTTPError
class NanoGPTTokenizer:
    """Lightweight wrapper over a tiktoken Encoding stored in tokenizer.pkl.
    Provides minimal encode/decode needed for inference and a from_pretrained
    constructor so it can be loaded via AutoTokenizer with trust_remote_code.
    """
    def __init__(self, enc):
        self.enc = enc
        self.bos_token_id = enc.encode_single_token("<|bos|>")
    @classmethod
    def register_for_auto_class(cls, auto_class="AutoTokenizer"):
        """Required for AutoTokenizer registration."""
        pass
    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
        """
        Load tokenizer from either:
        - Local directory path
        - Hugging Face Hub repo ID
        - Cached directory (handled automatically)
        """
        # First, try to load from local path
        local_tok_path = os.path.join(pretrained_model_name_or_path, "tokenizer.pkl")
        if os.path.isfile(local_tok_path):
            # Local file exists, load it directly
            with open(local_tok_path, "rb") as f:
                enc = pickle.load(f)
        else:
            # Try to download from Hugging Face Hub
            try:
                # This handles cache automatically and returns the cached file path
                tok_path = hf_hub_download(
                    repo_id=pretrained_model_name_or_path,
                    filename="tokenizer.pkl"
                )
                with open(tok_path, "rb") as f:
                    enc = pickle.load(f)
            except (HfHubHTTPError, OSError) as e:
                raise ValueError(
                    f"Could not load tokenizer.pkl from {pretrained_model_name_or_path}. "
                    f"Make sure the path exists or the repo is accessible on the Hub."
                ) from e
        return cls(enc)
    def encode(self, text, prepend=None):
        ids = self.enc.encode_ordinary(text)
        if prepend is not None:
            prepend_id = prepend if isinstance(prepend, int) else self.enc.encode_single_token(prepend)
            ids.insert(0, prepend_id)
        return ids
    def decode(self, ids):
        return self.enc.decode(ids)
    def get_bos_token_id(self):
        return self.bos_token_id
 |