File size: 3,306 Bytes
25ac67d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
"""
Utils for fetching pretrained model parts. Currently, this relies on huggingface transformers' from_pretrained code.
If necessary, one can rewrite this to implement a different behavior, such as:
 - loading files from a local data source (e.g. S3)
 - load files via BitTorrent ( https://pypi.org/project/libtorrent/ ) or IPFS( https://docs.ipfs.io/how-to )
 - fetch the weights over IPoAC, using a fleet of trained pigeons ( http://www.faqs.org/rfcs/rfc1149.html )

"""
from __future__ import annotations

from typing import Optional, OrderedDict, Union

import torch
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from transformers.modeling_utils import WEIGHTS_NAME
from transformers.utils.hub import cached_path, hf_bucket_url

from src.bloom import BloomBlock, BloomConfig

use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)

CLIENT_BRANCH = "main"
BLOCK_BRANCH_PREFIX = "block_"
USER_AGENT = {"file_type": "model", "framework": "pytorch", "from_auto_class": False}
FORCE_DOWNLOAD = False
RESUME_DOWNLOAD = False
LOCAL_FILES_ONLY = False


def load_pretrained_block(
    converted_model_name_or_path: str,
    block_index: int,
    config: Optional[BloomConfig] = None,
    torch_dtype: Union[torch.dtype, str] = "auto",
    use_auth_token: Optional[str] = None,
) -> BloomBlock:
    """Load one BloomBlock from a converted model. See convert_model.py (or README.md) on how to convert it."""
    if config is None:
        config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token)
    block = BloomBlock(config, layer_number=block_index)
    state_dict = _load_state_dict(converted_model_name_or_path, block_index, use_auth_token=use_auth_token)
    block.load_state_dict(state_dict)

    if torch_dtype == "auto":
        with torch.no_grad():
            for name, param in block.named_parameters():
                assert name in state_dict, f"{name} not in state dict"
                param.data = param.data.to(state_dict[name].dtype)
    else:
        assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
        block = block.to(dtype=torch_dtype)

    report = block.load_state_dict(state_dict, strict=True)
    logger.info(f"Loaded {converted_model_name_or_path} block {block_index}, {report}")
    return block


def _load_state_dict(
    pretrained_model_name_or_path: str, block_index: Optional[int] = None, use_auth_token: Optional[str] = None
) -> OrderedDict[str, torch.Tensor]:
    revision = BLOCK_BRANCH_PREFIX + str(block_index) if block_index is not None else CLIENT_BRANCH
    archive_file = hf_bucket_url(pretrained_model_name_or_path, filename=WEIGHTS_NAME, revision=revision, mirror=None)

    # Load from URL or cache if already cached
    resolved_archive_file = cached_path(
        archive_file,
        cache_dir=None,
        force_download=FORCE_DOWNLOAD,
        proxies=None,
        resume_download=RESUME_DOWNLOAD,
        local_files_only=LOCAL_FILES_ONLY,
        use_auth_token=use_auth_token,
        user_agent=USER_AGENT,
    )
    state_dict = torch.load(resolved_archive_file, map_location="cpu")
    return state_dict


DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")