Spaces:
Paused
Paused
first files(
Browse files- README.md +104 -12
- app.py +7 -0
- cli/__init__.py +0 -0
- cli/config.json +20 -0
- cli/convert_model.py +86 -0
- cli/deploy_server.sh +87 -0
- cli/inference_one_block.py +53 -0
- cli/local_server_config_example.cfg +5 -0
- cli/remote_server_config_example.cfg +6 -0
- cli/run_local_servers.sh +111 -0
- cli/run_remote_servers.sh +112 -0
- cli/run_server.py +85 -0
- requirements.txt +4 -0
- src/__init__.py +5 -0
- src/bloom/__init__.py +2 -0
- src/bloom/block.py +248 -0
- src/bloom/from_pretrained.py +80 -0
- src/bloom/model.py +408 -0
- src/bloom/ops.py +246 -0
- src/client/__init__.py +4 -0
- src/client/remote_block.py +135 -0
- src/client/remote_model.py +58 -0
- src/client/remote_sequence_info.py +94 -0
- src/client/remote_sequential.py +135 -0
- src/data_structures.py +8 -0
- src/dht_utils.py +132 -0
- src/server/__init__.py +0 -0
- src/server/backend.py +58 -0
- src/server/cache.py +127 -0
- src/server/handler.py +229 -0
- src/server/server.py +254 -0
README.md
CHANGED
@@ -1,12 +1,104 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# bloom-demo
|
2 |
+
Early dev prototype for decentralized bloom. Not for public eyes **yet**.
|
3 |
+
|
4 |
+
Roadmap: [issue #12](https://github.com/learning-at-home/bloom-demo/issues/12)
|
5 |
+
|
6 |
+
Latest news @ main branch (max 5):
|
7 |
+
- [Jul 4] @dbaranchuk implemented chained rpc_forward and rpc_backward (for prompt tuning)
|
8 |
+
- [Jul 3] @dbaranchuk optimized DistributedBloom to reduce embeddings/logits RAM usage
|
9 |
+
- [Jul 1] @yozh added RemoteSequential and test for full model exact match
|
10 |
+
- [June 28] @dbaranchunk added quick deployment scripts for testnet
|
11 |
+
|
12 |
+
### install
|
13 |
+
|
14 |
+
|
15 |
+
```bash
|
16 |
+
conda create -y --name bloom-demo python=3.8.12 pip
|
17 |
+
conda activate bloom-demo
|
18 |
+
|
19 |
+
conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
|
20 |
+
pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
|
21 |
+
pip install accelerate==0.10.0 huggingface-hub==0.7.0 hivemind==1.1.0
|
22 |
+
pip install bitsandbytes-cuda113==0.26.0
|
23 |
+
pip install https://github.com/huggingface/transformers/archive/6589e510fa4e6c442059de2fab84752535de9b23.zip
|
24 |
+
```
|
25 |
+
|
26 |
+
|
27 |
+
### run local inference:
|
28 |
+
No networking whatsoever, used to verify architecture optimizations
|
29 |
+
|
30 |
+
```bash
|
31 |
+
# run one bloom block for a few steps -- on a local machine
|
32 |
+
python -m cli.inference_one_block --config cli/config.json # see other args
|
33 |
+
```
|
34 |
+
|
35 |
+
### run distributed inference / training
|
36 |
+
|
37 |
+
First, run one or more servers like this:
|
38 |
+
```bash
|
39 |
+
# minimalistic server with non-trained bloom blocks
|
40 |
+
python -m cli.run_server --converted_model_name_or_path bigscience/test-bloomd-6b3 \
|
41 |
+
--block_indices 3:5 --torch_dtype float32 --identity_path ./server1.id --host_maddrs /ip4/127.0.0.1/tcp/31337
|
42 |
+
# when running multiple servers:
|
43 |
+
# - give each server a unique --identity_path (or remote --identity_path arg when debugging)
|
44 |
+
# - if running multiple servers on the same machine, give each a unique port (last integer in --host_maddrs, 0 means random port)
|
45 |
+
# - when running over the internet, change --host_maddrs according to https://learning-at-home.readthedocs.io/en/latest/user/dht.html#running-across-the-internet
|
46 |
+
# - each server except first should have --initial_peers pointing to one of pre-existing servers
|
47 |
+
```
|
48 |
+
|
49 |
+
Then open a python notebook or console and run:
|
50 |
+
```python
|
51 |
+
import torch
|
52 |
+
import hivemind
|
53 |
+
from src import get_remote_module
|
54 |
+
|
55 |
+
|
56 |
+
dht = hivemind.DHT(
|
57 |
+
initial_peers=[TODO_COPY_FULL_ADDRESS_FROM_ANY_OF_THE_SERVERS], # e.g. /ip4/127.0.0.1/...
|
58 |
+
client_mode=True, start=True,
|
59 |
+
)
|
60 |
+
|
61 |
+
layer3, layer4 = get_remote_module(dht, ['bigscience/test-bloomd-6b3.3', 'bigscience/test-bloomd-6b3.4'])
|
62 |
+
assert layer3 is not None and layer4 is not None, "one or both layers were not found in DHT"
|
63 |
+
# test forward/backward, two blocks
|
64 |
+
outputs, = layer4(*layer3(torch.randn(1, 64, 4096)))
|
65 |
+
loss = (outputs * torch.randn_like(outputs)).norm()
|
66 |
+
loss.backward()
|
67 |
+
|
68 |
+
# test inference, one block
|
69 |
+
with layer3.begin_inference_session() as sess:
|
70 |
+
for i in range(10):
|
71 |
+
res = sess.step(torch.ones(1, 1, 4096))
|
72 |
+
```
|
73 |
+
|
74 |
+
|
75 |
+
### convert regular bloom to distributed
|
76 |
+
```bash
|
77 |
+
|
78 |
+
# convert model from HF hub to a distributed format (can take hours depending on your connection!)
|
79 |
+
MY_WRITE_TOKEN=TODO_WRITE_TOKEN_FROM_https://huggingface.co/settings/token
|
80 |
+
python -m cli.convert_model --model bigscience/bloom-6b3 \
|
81 |
+
--output_path ./converted_model --output_repo bigscience/test-bloomd-6b3 \
|
82 |
+
--use_auth_token $MY_WRITE_TOKEN # ^-- todo replace output repo with something you have access to
|
83 |
+
```
|
84 |
+
|
85 |
+
|
86 |
+
### test local vs remote block (allclose)
|
87 |
+
|
88 |
+
To test distributed inference, run one or more servers, then open a new shell and run pytest with environment variables:
|
89 |
+
```bash
|
90 |
+
# shell A: serve blocks 3 and 4
|
91 |
+
python -m cli.run_server --converted_model_name_or_path bigscience/test-bloomd-6b3 \
|
92 |
+
--block_indices 3:5 --torch_dtype float32 --identity_path ./server1.id --host_maddrs /ip4/127.0.0.1/tcp/31337
|
93 |
+
|
94 |
+
# shell B: connect to the swarm and test individual blocks for exact match
|
95 |
+
export PYTHONPATH=. INITIAL_PEERS="/ip4/TODO_COPY_INITIAL_PEERS_FROM_SERVER_OUTPUT"
|
96 |
+
BLOCK_UID=bigscience/test-bloomd-6b3.3 pytest tests/test_block_exact_match.py
|
97 |
+
BLOCK_UID=bigscience/test-bloomd-6b3.4 pytest tests/test_block_exact_match.py
|
98 |
+
|
99 |
+
# the test below will fail because there is no server that serves layer 7
|
100 |
+
# BLOCK_UID=bigscience/test-bloomd-6b3.7 pytest tests/test_block_exact_match.py
|
101 |
+
|
102 |
+
|
103 |
+
BLOCK_UID=bigscience/test-bloomd-6b3.4 pytest tests/test_block_exact_match.py
|
104 |
+
```
|
app.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
def greet(name):
|
4 |
+
return "Hello " + name + "!!"
|
5 |
+
|
6 |
+
iface = gr.Interface(fn=greet, inputs="text", outputs="text")
|
7 |
+
iface.launch()
|
cli/__init__.py
ADDED
File without changes
|
cli/config.json
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"apply_residual_connection_post_layernorm": false,
|
3 |
+
"attention_dropout": 0.0,
|
4 |
+
"attention_softmax_in_fp32": true,
|
5 |
+
"bos_token_id": 1,
|
6 |
+
"eos_token_id": 2,
|
7 |
+
"hidden_dropout": 0.0,
|
8 |
+
"initializer_range": 0.02,
|
9 |
+
"layer_norm_epsilon": 1e-05,
|
10 |
+
"masked_softmax_fusion": true,
|
11 |
+
"model_type": "bloom",
|
12 |
+
"n_embed": 14336,
|
13 |
+
"n_layer": 70,
|
14 |
+
"num_attention_heads": 112,
|
15 |
+
"pretraining_tp": 4,
|
16 |
+
"slow_but_exact": false,
|
17 |
+
"transformers_version": "4.20.0.dev0",
|
18 |
+
"use_cache": true,
|
19 |
+
"vocab_size": 250880
|
20 |
+
}
|
cli/convert_model.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
|
4 |
+
import psutil
|
5 |
+
import torch.backends.quantized
|
6 |
+
import torch.nn as nn
|
7 |
+
import transformers
|
8 |
+
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
9 |
+
from huggingface_hub import Repository
|
10 |
+
from tqdm.auto import tqdm
|
11 |
+
|
12 |
+
from src import BloomModel
|
13 |
+
from src.client import DistributedBloomConfig
|
14 |
+
from src.bloom.from_pretrained import CLIENT_BRANCH, BLOCK_BRANCH_PREFIX
|
15 |
+
use_hivemind_log_handler("in_root_logger")
|
16 |
+
logger = get_logger(__file__)
|
17 |
+
|
18 |
+
DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")
|
19 |
+
|
20 |
+
|
21 |
+
if __name__ == "__main__":
|
22 |
+
parser = argparse.ArgumentParser(description="Load bloom layers and convert to 8-bit using torch quantization.")
|
23 |
+
|
24 |
+
parser.add_argument("--model", type=str, default="bigscience/bloom-6b3", help="Model name for from_pretrained")
|
25 |
+
parser.add_argument("--revision", type=str, default=None, help="Optional commit id from HF hub")
|
26 |
+
parser.add_argument("--torch_dtype", type=str, default="auto", help="Load initial model in this dtype")
|
27 |
+
parser.add_argument("--output_path", type=str, default="./converted_model", help="Track output repo to this folder")
|
28 |
+
parser.add_argument("--output_repo", type=str, default="bigscience/test-bloomd", help="Push to this HF hub repo")
|
29 |
+
parser.add_argument("--client_branch", type=str, default=CLIENT_BRANCH, help="Save client version to this branch")
|
30 |
+
parser.add_argument(
|
31 |
+
"--block_branch_prefix", type=str, default=BLOCK_BRANCH_PREFIX, help="Save blocks to branches with this prefix"
|
32 |
+
)
|
33 |
+
parser.add_argument(
|
34 |
+
"--commit_message", type=str, default="push-o-matic", help="Use this commit message for all parts"
|
35 |
+
)
|
36 |
+
parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
|
37 |
+
args = parser.parse_args()
|
38 |
+
|
39 |
+
free_ram_gb = psutil.virtual_memory().available / 2**30
|
40 |
+
if args.model == "bigscience/bloom" and free_ram_gb < 400:
|
41 |
+
logger.warning(f"ACHTUNG! converting bloom-176b will use up 350-400GB RAM, you have {free_ram_gb:.3f} free")
|
42 |
+
|
43 |
+
assert args.torch_dtype in DTYPE_MAP, f"torch_dtype must be one of {list(DTYPE_MAP.keys())}"
|
44 |
+
if os.path.exists(args.output_path) and (
|
45 |
+
len(os.listdir(args.output_path)) != 0 or not os.path.isdir(args.output_path)
|
46 |
+
):
|
47 |
+
raise FileExistsError(f"Output path {args.output_path} already exists and is not an empty directory")
|
48 |
+
|
49 |
+
logger.info(f"Loading source model {args.model} (this may take a few minutes)")
|
50 |
+
config = DistributedBloomConfig.from_pretrained(
|
51 |
+
args.model, use_auth_token=args.use_auth_token, revision=args.revision
|
52 |
+
)
|
53 |
+
config.dht_prefix = args.output_repo
|
54 |
+
|
55 |
+
model = BloomModel.from_pretrained(
|
56 |
+
args.model, use_auth_token=args.use_auth_token, revision=args.revision, torch_dtype=DTYPE_MAP[args.torch_dtype]
|
57 |
+
)
|
58 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
59 |
+
args.model, use_auth_token=args.use_auth_token, revision=args.revision
|
60 |
+
)
|
61 |
+
os.makedirs(args.output_path, exist_ok=True)
|
62 |
+
|
63 |
+
repo = Repository(args.output_path, clone_from=args.output_repo, use_auth_token=args.use_auth_token)
|
64 |
+
repo.git_pull()
|
65 |
+
|
66 |
+
transformer_blocks = model.h
|
67 |
+
logger.info(
|
68 |
+
f"Saving transformer blocks to {args.output_repo}@{args.block_branch_prefix}0"
|
69 |
+
f" - {args.output_repo}@{args.block_branch_prefix}{len(transformer_blocks)}"
|
70 |
+
)
|
71 |
+
for i, block in enumerate(tqdm(transformer_blocks)):
|
72 |
+
repo.git_checkout(args.client_branch, create_branch_ok=True)
|
73 |
+
with repo.commit(
|
74 |
+
commit_message=args.commit_message, branch=args.block_branch_prefix + str(i), track_large_files=True
|
75 |
+
):
|
76 |
+
torch.save(block.state_dict(), "./pytorch_model.bin")
|
77 |
+
|
78 |
+
logger.info(f"Saving client-side modules to {args.output_repo}@{args.client_branch}")
|
79 |
+
repo.git_checkout(args.client_branch, create_branch_ok=True)
|
80 |
+
with repo.commit(commit_message=args.commit_message, branch=args.client_branch, track_large_files=True):
|
81 |
+
model.h = nn.ModuleList()
|
82 |
+
model.save_pretrained(".")
|
83 |
+
tokenizer.save_pretrained(".")
|
84 |
+
config.save_pretrained(".")
|
85 |
+
|
86 |
+
logger.info(f"Converted {args.model} and pushed to {args.output_repo}")
|
cli/deploy_server.sh
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
#################
|
4 |
+
# Parse options #
|
5 |
+
#################
|
6 |
+
|
7 |
+
instructions() {
|
8 |
+
echo "Usage: $0 [-i] [ -d ] [ -p ] [ -b ] [-a] [-t]" >&2
|
9 |
+
echo " -i: initial peer"
|
10 |
+
echo " -d: device" >&2
|
11 |
+
echo " -p: server identity path" >&2
|
12 |
+
echo " -b: block_ids" >&2
|
13 |
+
echo " -a: host maddrs" >&2
|
14 |
+
echo " -t: whether to run local tests" >&2
|
15 |
+
exit 1
|
16 |
+
}
|
17 |
+
|
18 |
+
if [ ! $# -ge 8 ]; then
|
19 |
+
instructions
|
20 |
+
fi
|
21 |
+
|
22 |
+
while getopts ":i:d:p:b:a:t:" option; do
|
23 |
+
case $option in
|
24 |
+
i) INITIAL_PEER=${OPTARG}
|
25 |
+
;;
|
26 |
+
d) DEVICE=${OPTARG}
|
27 |
+
;;
|
28 |
+
p) SERVER_ID_PATH=${OPTARG}
|
29 |
+
;;
|
30 |
+
b) BLOCK_IDS=${OPTARG}
|
31 |
+
;;
|
32 |
+
a) HOST_MADDR=${OPTARG} # TODO: allow several maddrs
|
33 |
+
;;
|
34 |
+
t) RUN_LOCAL_TESTS=true
|
35 |
+
;;
|
36 |
+
\?) instructions
|
37 |
+
;;
|
38 |
+
esac
|
39 |
+
done
|
40 |
+
|
41 |
+
|
42 |
+
echo "=========="
|
43 |
+
echo "= Config ="
|
44 |
+
echo "=========="
|
45 |
+
echo "Initial peer: ${INITIAL_PEER}"
|
46 |
+
echo "Device: ${DEVICE}"
|
47 |
+
echo "Server name: ${SERVER_ID_PATH}"
|
48 |
+
echo "Server address: ${HOST_MADDR}"
|
49 |
+
echo "Bloom blocks: ${BLOCK_IDS}"
|
50 |
+
|
51 |
+
|
52 |
+
###########################
|
53 |
+
# Install or activate env #
|
54 |
+
###########################
|
55 |
+
|
56 |
+
# TODO fix bug with self calling
|
57 |
+
source ~/miniconda3/etc/profile.d/conda.sh
|
58 |
+
if conda env list | grep ".*bloom-demo.*" >/dev/null 2>/dev/null; then
|
59 |
+
conda activate bloom-demo
|
60 |
+
else
|
61 |
+
conda create -y --name bloom-demo python=3.8.12 pip
|
62 |
+
conda activate bloom-demo
|
63 |
+
|
64 |
+
conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
|
65 |
+
pip install -i https://pypi.org/simple torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
|
66 |
+
pip install -i https://pypi.org/simple accelerate==0.10.0 huggingface-hub==0.7.0 hivemind==1.1.0
|
67 |
+
pip install -i https://pypi.org/simple bitsandbytes-cuda113==0.26.0
|
68 |
+
pip install -i https://pypi.org/simple https://github.com/huggingface/transformers/archive/6589e510fa4e6c442059de2fab84752535de9b23.zip
|
69 |
+
fi
|
70 |
+
|
71 |
+
|
72 |
+
##############
|
73 |
+
# Local test #
|
74 |
+
##############
|
75 |
+
|
76 |
+
if [ "$RUN_LOCAL_TESTS" = true ] ; then
|
77 |
+
echo "Run test on your local machine"
|
78 |
+
python -m cli.inference_one_block --config cli/config.json --device ${DEVICE} # see other args
|
79 |
+
fi
|
80 |
+
|
81 |
+
|
82 |
+
##############
|
83 |
+
# Run server #
|
84 |
+
##############
|
85 |
+
|
86 |
+
python -m cli.run_server --prefix bloom6b3 --converted_model_name_or_path bigscience/test-bloomd-6b3 --device ${DEVICE} --initial_peer ${INITIAL_PEER} \
|
87 |
+
--block_indices ${BLOCK_IDS} --torch_dtype float32 --identity_path ${SERVER_ID_PATH} --host_maddrs ${HOST_MADDR} &> ${SERVER_ID_PATH}.log
|
cli/inference_one_block.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
5 |
+
from tqdm.auto import trange
|
6 |
+
|
7 |
+
from src.bloom.block import BloomBlock
|
8 |
+
from src.bloom.model import BloomConfig
|
9 |
+
from src.bloom.ops import build_alibi_tensor
|
10 |
+
|
11 |
+
use_hivemind_log_handler("in_root_logger")
|
12 |
+
logger = get_logger(__file__)
|
13 |
+
|
14 |
+
logger.warning("inference_one_block will soon be deprecated in favour of tests!")
|
15 |
+
|
16 |
+
|
17 |
+
def print_device_info(device=None):
|
18 |
+
"""Prints device stats. Code from https://stackoverflow.com/a/53374933/12891528"""
|
19 |
+
device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
|
20 |
+
logger.info(f"Using device: {device}")
|
21 |
+
|
22 |
+
# Additional Info when using cuda
|
23 |
+
if device.type == "cuda":
|
24 |
+
logger.info(torch.cuda.get_device_name(0))
|
25 |
+
logger.info(f"Memory Usage:")
|
26 |
+
logger.info(f"Allocated: {round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1)} GB")
|
27 |
+
logger.info(f"Cached: {round(torch.cuda.memory_cached(0) / 1024 ** 3, 1)} GB")
|
28 |
+
|
29 |
+
|
30 |
+
if __name__ == "__main__":
|
31 |
+
parser = argparse.ArgumentParser(description="Run a single bloom block locally on dummy data")
|
32 |
+
parser.add_argument("--config", required=True, type=str, help="Path to a config json file")
|
33 |
+
parser.add_argument("--state_dict", default=None, type=str, help="Optional path to saved block state dict")
|
34 |
+
parser.add_argument("--layer_index", default=0, type=int, help="Optional path to saved block state dict")
|
35 |
+
parser.add_argument("--num_steps", default=500, type=int, help="How many inference steps to run")
|
36 |
+
parser.add_argument("--device", default=None, type=str, help="Run inference on this device")
|
37 |
+
args = parser.parse_args()
|
38 |
+
|
39 |
+
if args.device is None:
|
40 |
+
args.device = "cuda" if torch.cuda.is_available() else "cpu"
|
41 |
+
|
42 |
+
config = BloomConfig.from_json_file(args.config)
|
43 |
+
block = BloomBlock(config, args.layer_index).to(args.device)
|
44 |
+
|
45 |
+
cache = None
|
46 |
+
|
47 |
+
for i in trange(args.num_steps):
|
48 |
+
dummy_input = torch.randn(1, 1, config.hidden_size, device=args.device)
|
49 |
+
alibi = build_alibi_tensor(i + 1, config.num_attention_heads).to(args.device)
|
50 |
+
with torch.no_grad():
|
51 |
+
outputs, cache = block.forward(dummy_input, alibi=alibi, use_cache=True, layer_past=cache)
|
52 |
+
|
53 |
+
print_device_info(args.device)
|
cli/local_server_config_example.cfg
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
device=cpu
|
2 |
+
block_ids=2:3
|
3 |
+
id_path=./server.id
|
4 |
+
maddr=/ip4/127.0.0.1/tcp/30000
|
5 |
+
#
|
cli/remote_server_config_example.cfg
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name=bloom-peer-0.bloom.net
|
2 |
+
device=cpu
|
3 |
+
block_ids=1:3
|
4 |
+
id_path=./server.id
|
5 |
+
maddr=/ip4/0.0.0.0/tcp/30000
|
6 |
+
#
|
cli/run_local_servers.sh
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# !/usr/bin/env bash
|
2 |
+
|
3 |
+
#################
|
4 |
+
# Parse options #
|
5 |
+
#################
|
6 |
+
|
7 |
+
instructions() {
|
8 |
+
echo "Usage: $0 [-n] [-c]" >&2
|
9 |
+
echo " -n: number of servers to run" >&2
|
10 |
+
echo " -c: path to the server configs" >&2
|
11 |
+
exit 1
|
12 |
+
}
|
13 |
+
|
14 |
+
if [ $# != 4 ]; then
|
15 |
+
instructions
|
16 |
+
fi
|
17 |
+
|
18 |
+
while getopts ":n:c:t:" option; do
|
19 |
+
case $option in
|
20 |
+
n) NUM_SERVERS=${OPTARG}
|
21 |
+
;;
|
22 |
+
c) CONFIG_PATH=${OPTARG}
|
23 |
+
;;
|
24 |
+
\?) instructions
|
25 |
+
;;
|
26 |
+
esac
|
27 |
+
done
|
28 |
+
|
29 |
+
|
30 |
+
###########################
|
31 |
+
# Install or activate env #
|
32 |
+
###########################
|
33 |
+
|
34 |
+
source ~/miniconda3/etc/profile.d/conda.sh
|
35 |
+
if conda env list | grep ".*bloom-demo.*" &>/dev/null; then
|
36 |
+
conda activate bloom-demo
|
37 |
+
else
|
38 |
+
conda create -y --name bloom-demo python=3.8.12 pip
|
39 |
+
conda activate bloom-demo
|
40 |
+
|
41 |
+
conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
|
42 |
+
pip install -i https://pypi.org/simple torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
|
43 |
+
pip install -i https://pypi.org/simple accelerate==0.10.0 huggingface-hub==0.7.0 hivemind==1.1.0
|
44 |
+
pip install -i https://pypi.org/simple bitsandbytes-cuda113==0.26.0
|
45 |
+
pip install -i https://pypi.org/simple https://github.com/huggingface/transformers/archive/6589e510fa4e6c442059de2fab84752535de9b23.zip
|
46 |
+
fi
|
47 |
+
|
48 |
+
|
49 |
+
#######################
|
50 |
+
# Create Initial peer #
|
51 |
+
#######################
|
52 |
+
|
53 |
+
hivemind-dht &> tmp.out &
|
54 |
+
sleep 3
|
55 |
+
INITIAL_PEER=$(python -c "with open('tmp.out') as f: print(f.readlines()[1].split()[-1])" )
|
56 |
+
echo "Initial peer: ${INITIAL_PEER}"
|
57 |
+
|
58 |
+
|
59 |
+
##############################
|
60 |
+
# Initialize the config file #
|
61 |
+
##############################
|
62 |
+
|
63 |
+
typeset -A cfg
|
64 |
+
cfg=( # set default values in config array
|
65 |
+
[device]="cpu"
|
66 |
+
[block_ids]="1:2"
|
67 |
+
[id_path]="server.id"
|
68 |
+
[maddr]="/ip4/127.0.0.1/tcp/30000"
|
69 |
+
)
|
70 |
+
|
71 |
+
###############
|
72 |
+
# Run servers #
|
73 |
+
###############
|
74 |
+
|
75 |
+
for SERVER_ID in $(seq 0 $(( $NUM_SERVERS - 1 )) )
|
76 |
+
do
|
77 |
+
###############
|
78 |
+
# Read config #
|
79 |
+
###############
|
80 |
+
|
81 |
+
while read line
|
82 |
+
do
|
83 |
+
if echo $line | grep -F = &>/dev/null
|
84 |
+
then
|
85 |
+
varname=$(echo "$line" | cut -d '=' -f 1)
|
86 |
+
cfg[$varname]=$(echo "$line" | cut -d '=' -f 2-)
|
87 |
+
fi
|
88 |
+
done < ${CONFIG_PATH}/server_${SERVER_ID}.cfg
|
89 |
+
|
90 |
+
echo "=== Server #${SERVER_ID} ==="
|
91 |
+
echo "Server ID: ${id_path}"
|
92 |
+
echo "Device: ${cfg[device]}"
|
93 |
+
echo "Bloom block ids: ${cfg[block_ids]}"
|
94 |
+
echo "Host maddr: ${cfg[maddr]}"
|
95 |
+
echo ""
|
96 |
+
|
97 |
+
##############
|
98 |
+
# Run server #
|
99 |
+
##############
|
100 |
+
|
101 |
+
tmux new-session -d -s "Server_${SERVER_ID}" bash cli/deploy_server.sh -i ${INITIAL_PEER} -d ${cfg[device]} -p ${cfg[id_path]} -b ${cfg[block_ids]} -a ${cfg[maddr]}
|
102 |
+
done
|
103 |
+
|
104 |
+
|
105 |
+
#####################
|
106 |
+
# Kill initial peer #
|
107 |
+
#####################
|
108 |
+
|
109 |
+
sleep 10
|
110 |
+
pkill -f hivemind-dht # TODO: kill only particular pids of hivemind-dht
|
111 |
+
rm tmp.out
|
cli/run_remote_servers.sh
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# !/usr/bin/env bash
|
2 |
+
|
3 |
+
SSH_KEY_PATH="~/.ssh/<YOUR_KEY>"
|
4 |
+
|
5 |
+
#################
|
6 |
+
# Parse options #
|
7 |
+
#################
|
8 |
+
|
9 |
+
instructions() {
|
10 |
+
echo "Usage: $0 [-u] [-n] [-c]" >&2
|
11 |
+
echo " -u: username" >&2
|
12 |
+
echo " -n: number of servers to run" >&2
|
13 |
+
echo " -c: path to the server configs" >&2
|
14 |
+
exit 1
|
15 |
+
}
|
16 |
+
|
17 |
+
if [ $# != 6 ]; then
|
18 |
+
instructions
|
19 |
+
fi
|
20 |
+
|
21 |
+
while getopts ":u:n:c:" option; do
|
22 |
+
case $option in
|
23 |
+
u) USERNAME=${OPTARG}
|
24 |
+
;;
|
25 |
+
n) NUM_SERVERS=${OPTARG}
|
26 |
+
;;
|
27 |
+
c) CONFIG_PATH=${OPTARG}
|
28 |
+
;;
|
29 |
+
\?) instructions
|
30 |
+
;;
|
31 |
+
esac
|
32 |
+
done
|
33 |
+
|
34 |
+
|
35 |
+
###########################
|
36 |
+
# Install or activate env #
|
37 |
+
###########################
|
38 |
+
|
39 |
+
source ~/miniconda3/etc/profile.d/conda.sh
|
40 |
+
if conda env list | grep ".*bloom-demo.*" &>/dev/null; then
|
41 |
+
conda activate bloom-demo
|
42 |
+
else
|
43 |
+
conda create -y --name bloom-demo python=3.8.12 pip
|
44 |
+
conda activate bloom-demo
|
45 |
+
|
46 |
+
conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
|
47 |
+
pip install -i https://pypi.org/simple torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
|
48 |
+
pip install -i https://pypi.org/simple accelerate==0.10.0 huggingface-hub==0.7.0 hivemind==1.1.0
|
49 |
+
pip install -i https://pypi.org/simple bitsandbytes-cuda113==0.26.0
|
50 |
+
pip install -i https://pypi.org/simple https://github.com/huggingface/transformers/archive/6589e510fa4e6c442059de2fab84752535de9b23.zip
|
51 |
+
fi
|
52 |
+
|
53 |
+
|
54 |
+
#######################
|
55 |
+
# Create Initial peer #
|
56 |
+
#######################
|
57 |
+
|
58 |
+
hivemind-dht &> tmp.out &
|
59 |
+
|
60 |
+
sleep 3
|
61 |
+
INITIAL_PEER=$(python -c "with open('tmp.out') as f: print(f.readlines()[1].split()[-2])" )
|
62 |
+
rm tmp.out
|
63 |
+
echo "Initial peer: ${INITIAL_PEER}"
|
64 |
+
|
65 |
+
|
66 |
+
##############################
|
67 |
+
# Initialize the config file #
|
68 |
+
##############################
|
69 |
+
|
70 |
+
typeset -A cfg
|
71 |
+
cfg=( # set default values in config array
|
72 |
+
[name]=""
|
73 |
+
[device]="cpu"
|
74 |
+
[block_ids]="1:2"
|
75 |
+
[id_path]="server.id"
|
76 |
+
[maddr]="/ip4/0.0.0.0/tcp/30000"
|
77 |
+
)
|
78 |
+
|
79 |
+
###############
|
80 |
+
# Run servers #
|
81 |
+
###############
|
82 |
+
|
83 |
+
for SERVER_ID in $(seq 0 $(( $NUM_SERVERS - 1 )) )
|
84 |
+
do
|
85 |
+
###############
|
86 |
+
# Read config #
|
87 |
+
###############
|
88 |
+
|
89 |
+
while read line
|
90 |
+
do
|
91 |
+
if echo $line | grep -F = &>/dev/null
|
92 |
+
then
|
93 |
+
varname=$(echo "$line" | cut -d '=' -f 1)
|
94 |
+
cfg[$varname]=$(echo "$line" | cut -d '=' -f 2-)
|
95 |
+
fi
|
96 |
+
done < ${CONFIG_PATH}/server_${SERVER_ID}.cfg
|
97 |
+
|
98 |
+
SERVER_NAME="${USERNAME}@${cfg[name]}"
|
99 |
+
echo "=== Server #${SERVER_ID} ==="
|
100 |
+
echo "Server name ${SERVER_NAME}"
|
101 |
+
echo "Server ID: ${cfg[id_path]}"
|
102 |
+
echo "Device: ${cfg[device]}"
|
103 |
+
echo "Bloom block ids: ${cfg[block_ids]}"
|
104 |
+
echo "Host maddr: ${cfg[maddr]}"
|
105 |
+
echo "================="
|
106 |
+
|
107 |
+
##############
|
108 |
+
# Run server #
|
109 |
+
##############
|
110 |
+
|
111 |
+
ssh -i ${SSH_KEY_PATH} ${SERVER_NAME} "tmux new-session -d -s 'Server_${SERVER_ID}' 'cd bloom-demo && bash cli/deploy_server.sh -i ${INITIAL_PEER} -d ${cfg[device]} -p ${cfg[id_path]} -b ${cfg[block_ids]} -a ${cfg[maddr]}'"
|
112 |
+
done
|
cli/run_server.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import configargparse
|
2 |
+
from hivemind.proto.runtime_pb2 import CompressionType
|
3 |
+
from hivemind.utils.limits import increase_file_limit
|
4 |
+
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
5 |
+
|
6 |
+
from src.server.server import Server
|
7 |
+
|
8 |
+
use_hivemind_log_handler("in_root_logger")
|
9 |
+
logger = get_logger(__file__)
|
10 |
+
|
11 |
+
|
12 |
+
def main():
|
13 |
+
# fmt:off
|
14 |
+
parser = configargparse.ArgParser(default_config_files=["config.yml"])
|
15 |
+
parser.add('-c', '--config', required=False, is_config_file=True, help='config file path')
|
16 |
+
|
17 |
+
parser.add_argument('--converted_model_name_or_path', type=str, default='bigscience/test-bloomd-6b3',
|
18 |
+
help="path or name of a pretrained model, converted with cli/convert_model.py (see README.md)")
|
19 |
+
parser.add_argument('--num_blocks', type=int, default=None, help="The number of blocks to serve")
|
20 |
+
parser.add_argument('--block_indices', type=str, default=None, help="Specific block indices to serve")
|
21 |
+
parser.add_argument('--prefix', type=str, default=None, help="Announce all blocks with this prefix. By default,"
|
22 |
+
"use the same name as in the converted model.")
|
23 |
+
parser.add_argument('--host_maddrs', nargs='+', default=['/ip4/0.0.0.0/tcp/0'], required=False,
|
24 |
+
help='Multiaddrs to listen for external connections from other p2p instances; default: all IPv4 and TCP: /ip4/0.0.0.0/tcp/0')
|
25 |
+
parser.add_argument('--announce_maddrs', nargs='+', default=None, required=False,
|
26 |
+
help='Visible multiaddrs the host announces for external connections from other p2p instances')
|
27 |
+
|
28 |
+
parser.add_argument('--compression', type=str, default='NONE', required=False, help='Tensor compression communication')
|
29 |
+
|
30 |
+
parser.add_argument('--num_handlers', type=int, default=None, required=False,
|
31 |
+
help='server will use this many processes to handle incoming requests')
|
32 |
+
parser.add_argument('--min_batch_size', type=int, default=1,
|
33 |
+
help='Minimum required batch size for all expert operations')
|
34 |
+
parser.add_argument('--max_batch_size', type=int, default=16384,
|
35 |
+
help='The total number of examples in the same batch will not exceed this value')
|
36 |
+
parser.add_argument('--cache_size_bytes', type=int, default=None,
|
37 |
+
help='The size of memory cache for storing past attention keys/values between inference steps')
|
38 |
+
parser.add_argument('--device', type=str, default=None, required=False,
|
39 |
+
help='all experts will use this device in torch notation; default: cuda if available else cpu')
|
40 |
+
parser.add_argument("--torch_dtype", type=str, default="auto",
|
41 |
+
help="Use this dtype to store block weights and do computations. "
|
42 |
+
"By default, respect the dtypes in the pre-trained state dict.")
|
43 |
+
|
44 |
+
parser.add_argument('--update_period', type=float, required=False, default=30,
|
45 |
+
help='Server will report experts to DHT once in this many seconds')
|
46 |
+
parser.add_argument('--expiration', type=float, required=False, default=None,
|
47 |
+
help='DHT entries will expire after this many seconds')
|
48 |
+
parser.add_argument('--initial_peers', type=str, nargs='*', required=False, default=[],
|
49 |
+
help='multiaddrs of one or more active DHT peers (if you want to join an existing DHT)')
|
50 |
+
parser.add_argument('--increase_file_limit', action='store_true',
|
51 |
+
help='On *nix, this will increase the max number of processes '
|
52 |
+
'a server can spawn before hitting "Too many open files"; Use at your own risk.')
|
53 |
+
parser.add_argument('--stats_report_interval', type=int, required=False,
|
54 |
+
help='Interval between two reports of batch processing performance statistics')
|
55 |
+
|
56 |
+
parser.add_argument('--custom_module_path', type=str, required=False,
|
57 |
+
help='Path of a file with custom nn.modules, wrapped into special decorator')
|
58 |
+
parser.add_argument('--identity_path', type=str, required=False, help='Path to identity file to be used in P2P')
|
59 |
+
parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
|
60 |
+
|
61 |
+
# fmt:on
|
62 |
+
args = vars(parser.parse_args())
|
63 |
+
args.pop("config", None)
|
64 |
+
|
65 |
+
if args.pop("increase_file_limit"):
|
66 |
+
increase_file_limit()
|
67 |
+
|
68 |
+
compression_type = args.pop("compression")
|
69 |
+
compression = getattr(CompressionType, compression_type)
|
70 |
+
|
71 |
+
use_auth_token = args.pop("use_auth_token")
|
72 |
+
args["use_auth_token"] = True if use_auth_token in ("True", "true", "") else use_auth_token
|
73 |
+
|
74 |
+
server = Server.create(**args, start=True, compression=compression)
|
75 |
+
|
76 |
+
try:
|
77 |
+
server.join()
|
78 |
+
except KeyboardInterrupt:
|
79 |
+
logger.info("Caught KeyboardInterrupt, shutting down")
|
80 |
+
finally:
|
81 |
+
server.shutdown()
|
82 |
+
|
83 |
+
|
84 |
+
if __name__ == "__main__":
|
85 |
+
main()
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
https://github.com/huggingface/transformers/archive/6589e510fa4e6c442059de2fab84752535de9b23.zip
|
2 |
+
https://github.com/learning-at-home/hivemind/archive/61e5e8c1f33dd2390e6d0d0221e2de6e75741a9c.zip
|
3 |
+
huggingface-hub==0.7.0
|
4 |
+
accelerate==0.10.0
|
src/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.bloom import *
|
2 |
+
from src.client import *
|
3 |
+
from src.dht_utils import declare_active_modules, get_remote_module
|
4 |
+
|
5 |
+
__version__ = "0.2"
|
src/bloom/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from src.bloom.block import BloomBlock
|
2 |
+
from src.bloom.model import BloomConfig, BloomForCausalLM, BloomModel, BloomPreTrainedModel
|
src/bloom/block.py
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Bloom intermediate layer
|
3 |
+
Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
|
4 |
+
See commit history for authorship.
|
5 |
+
"""
|
6 |
+
import math
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.quantized.dynamic.modules.linear
|
11 |
+
|
12 |
+
from src.bloom.ops import (BloomGelu, BloomScaledSoftmax, attention_mask_func, build_alibi_tensor, dropout_add,
|
13 |
+
pre_process_alibi_for_pad, split_tensor_along_last_dim)
|
14 |
+
|
15 |
+
|
16 |
+
class BloomAttention(nn.Module):
|
17 |
+
def __init__(self, config, layer_number=None):
|
18 |
+
super().__init__()
|
19 |
+
|
20 |
+
self.hidden_size = config.hidden_size
|
21 |
+
self.num_heads = config.n_head
|
22 |
+
self.head_dim = self.hidden_size // self.num_heads
|
23 |
+
self.split_size = self.hidden_size
|
24 |
+
self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
|
25 |
+
self.masked_softmax_fusion = config.masked_softmax_fusion
|
26 |
+
self.hidden_dropout = config.hidden_dropout
|
27 |
+
|
28 |
+
if self.head_dim * self.num_heads != self.hidden_size:
|
29 |
+
raise ValueError(
|
30 |
+
f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
|
31 |
+
f" {self.num_heads})."
|
32 |
+
)
|
33 |
+
|
34 |
+
# Layer-wise attention scaling
|
35 |
+
self.layer_number = max(1, layer_number)
|
36 |
+
self.norm_factor = math.sqrt(self.head_dim) * self.layer_number
|
37 |
+
|
38 |
+
# Scaled Softmax
|
39 |
+
self.scale_mask_softmax = BloomScaledSoftmax(
|
40 |
+
self.masked_softmax_fusion,
|
41 |
+
attention_mask_func,
|
42 |
+
self.attention_softmax_in_fp32,
|
43 |
+
self.layer_number,
|
44 |
+
)
|
45 |
+
|
46 |
+
self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
|
47 |
+
self.dense = nn.Linear(self.hidden_size, self.hidden_size)
|
48 |
+
|
49 |
+
self.attention_dropout = nn.Dropout(config.attention_dropout)
|
50 |
+
|
51 |
+
def forward(
|
52 |
+
self,
|
53 |
+
hidden_states,
|
54 |
+
residual,
|
55 |
+
layer_past=None,
|
56 |
+
attention_mask=None,
|
57 |
+
alibi=None,
|
58 |
+
head_mask=None,
|
59 |
+
use_cache=False,
|
60 |
+
output_attentions=False,
|
61 |
+
):
|
62 |
+
if alibi is None:
|
63 |
+
current_sequence_length = hidden_states.shape[1] + (0 if layer_past is None else layer_past[0].shape[1])
|
64 |
+
alibi = build_alibi_tensor(
|
65 |
+
current_sequence_length, n_head=self.num_heads, dtype=hidden_states.dtype, device=hidden_states.device
|
66 |
+
)
|
67 |
+
|
68 |
+
# hidden_states: [batch_size, seq_length, hidden_size]
|
69 |
+
# apply preprocessing if the input is padded
|
70 |
+
if attention_mask is not None:
|
71 |
+
alibi = pre_process_alibi_for_pad(alibi, attention_mask)
|
72 |
+
# otherwise repeat alibi tensor with the batch size
|
73 |
+
else:
|
74 |
+
alibi = alibi.repeat(hidden_states.shape[0], 1, 1)
|
75 |
+
|
76 |
+
mixed_x_layer = self.query_key_value(hidden_states)
|
77 |
+
|
78 |
+
# [batch_size, seq_length, 3 x hidden_size] --> [batch_size, seq_length, num_heads, 3 x head_dim]
|
79 |
+
new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_heads, 3 * self.head_dim)
|
80 |
+
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
|
81 |
+
|
82 |
+
# [batch_size, seq_length, num_heads, 3 x head_dim] --> 3 [batch_size, seq_length, num_heads, head_dim]
|
83 |
+
(query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
|
84 |
+
|
85 |
+
if layer_past is not None:
|
86 |
+
past_key, past_value = layer_past
|
87 |
+
key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=1)
|
88 |
+
value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=1)
|
89 |
+
|
90 |
+
if use_cache is True:
|
91 |
+
present = (key_layer, value_layer)
|
92 |
+
else:
|
93 |
+
present = None
|
94 |
+
|
95 |
+
# [batch_size, head_dim, q_length, k_length]
|
96 |
+
output_size = (query_layer.size(0), query_layer.size(2), query_layer.size(1), key_layer.size(1))
|
97 |
+
|
98 |
+
# [batch_size, q_length, num_heads, head_dim] -> [q_length, batch_size * num_heads, head_dim]
|
99 |
+
query_layer = query_layer.transpose(1, 0).reshape(output_size[2], output_size[0] * output_size[1], -1)
|
100 |
+
|
101 |
+
# [batch_size, k_length, num_heads, head_dim] -> [k_length, batch_size * num_heads, head_dim]
|
102 |
+
key_layer = key_layer.transpose(1, 0).reshape(output_size[3], output_size[0] * output_size[1], -1)
|
103 |
+
|
104 |
+
# Raw attention scores. [batch_size * num_heads, q_length, k_length]
|
105 |
+
beta = 1.0 / self.layer_number
|
106 |
+
|
107 |
+
matmul_result = torch.baddbmm(
|
108 |
+
alibi,
|
109 |
+
query_layer.transpose(1, 0),
|
110 |
+
key_layer.transpose(1, 0).transpose(1, 2),
|
111 |
+
beta=beta,
|
112 |
+
alpha=(1.0 / self.norm_factor),
|
113 |
+
)
|
114 |
+
|
115 |
+
# change view to [batch_size, num_heads, q_length, k_length]
|
116 |
+
attention_scores = matmul_result.view(*output_size)
|
117 |
+
|
118 |
+
# attention scores and attention mask [b, np, sq, sk]
|
119 |
+
max_positions = max(attention_scores.shape[-1], attention_scores.shape[-2])
|
120 |
+
attention_probs = self.scale_mask_softmax(attention_scores, attention_mask, max_positions).to(value_layer.dtype)
|
121 |
+
attention_probs = self.attention_dropout(attention_probs)
|
122 |
+
|
123 |
+
if head_mask is not None:
|
124 |
+
attention_probs = attention_probs * head_mask
|
125 |
+
|
126 |
+
# context layer shape: [batch_size, num_heads, q_length, head_dim]
|
127 |
+
output_size = (value_layer.size(0), value_layer.size(2), query_layer.size(0), value_layer.size(3))
|
128 |
+
|
129 |
+
# change view [k_length, batch_size x num_heads, head_dim]
|
130 |
+
value_layer = value_layer.transpose(1, 0).reshape(value_layer.size(1), output_size[0] * output_size[1], -1)
|
131 |
+
|
132 |
+
# change view [batch_size x num_heads, q_length, k_length]
|
133 |
+
attention_probs_reshaped = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
|
134 |
+
|
135 |
+
# matmul: [batch_size * num_heads, q_length, head_dim]
|
136 |
+
context_layer = torch.bmm(attention_probs_reshaped, value_layer.transpose(0, 1))
|
137 |
+
|
138 |
+
# change view [batch_size, num_heads, q_length, head_dim]
|
139 |
+
context_layer = context_layer.view(*output_size)
|
140 |
+
|
141 |
+
# [batchs_size, num_heads, q_length, head_dim] --> [q_length, batch_size, num_heads, head_dim]
|
142 |
+
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
|
143 |
+
|
144 |
+
# [q_length, batch_size, num_heads, head_dim] --> [q_length, batch_size, hidden_size]
|
145 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,)
|
146 |
+
|
147 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
148 |
+
|
149 |
+
# Output. [q_length, batch_size, hidden_size]
|
150 |
+
|
151 |
+
# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
|
152 |
+
output_tensor = self.dense(context_layer)
|
153 |
+
output = output_tensor.transpose(1, 0)
|
154 |
+
|
155 |
+
output = dropout_add(output, residual, self.hidden_dropout, self.training)
|
156 |
+
|
157 |
+
outputs = (output, present)
|
158 |
+
if output_attentions:
|
159 |
+
outputs += (attention_probs,)
|
160 |
+
|
161 |
+
return outputs
|
162 |
+
|
163 |
+
|
164 |
+
class BloomMLP(nn.Module):
|
165 |
+
def __init__(self, config):
|
166 |
+
super().__init__()
|
167 |
+
self.hidden_size = config.hidden_size
|
168 |
+
self.dense_h_to_4h = nn.Linear(self.hidden_size, 4 * self.hidden_size)
|
169 |
+
self.dense_4h_to_h = nn.Linear(4 * self.hidden_size, self.hidden_size)
|
170 |
+
self.hidden_dropout = config.hidden_dropout
|
171 |
+
self.gelu_impl = BloomGelu()
|
172 |
+
|
173 |
+
def forward(self, hidden_states, residual):
|
174 |
+
hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states))
|
175 |
+
intermediate_output = self.dense_4h_to_h(hidden_states)
|
176 |
+
output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training)
|
177 |
+
return output
|
178 |
+
|
179 |
+
|
180 |
+
class BloomBlock(nn.Module):
|
181 |
+
def __init__(self, config, layer_number=None):
|
182 |
+
super().__init__()
|
183 |
+
self.hidden_size = config.hidden_size
|
184 |
+
|
185 |
+
self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_epsilon)
|
186 |
+
self.n_head = config.n_head
|
187 |
+
self.self_attention = BloomAttention(config, layer_number=layer_number)
|
188 |
+
self.post_attention_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_epsilon)
|
189 |
+
|
190 |
+
self.mlp = BloomMLP(config)
|
191 |
+
|
192 |
+
self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
|
193 |
+
self.hidden_dropout = config.hidden_dropout
|
194 |
+
|
195 |
+
def forward(
|
196 |
+
self,
|
197 |
+
hidden_states,
|
198 |
+
layer_past=None,
|
199 |
+
attention_mask=None,
|
200 |
+
head_mask=None,
|
201 |
+
use_cache=False,
|
202 |
+
output_attentions=False,
|
203 |
+
alibi=None,
|
204 |
+
):
|
205 |
+
# hidden_states: [batch_size, seq_length, hidden_size]
|
206 |
+
|
207 |
+
# Layer norm at the beginning of the transformer layer.
|
208 |
+
layernorm_output = self.input_layernorm(hidden_states)
|
209 |
+
|
210 |
+
# Layer norm post the self attention.
|
211 |
+
if self.apply_residual_connection_post_layernorm:
|
212 |
+
residual = layernorm_output
|
213 |
+
else:
|
214 |
+
residual = hidden_states
|
215 |
+
|
216 |
+
# Self attention.
|
217 |
+
attn_outputs = self.self_attention(
|
218 |
+
layernorm_output,
|
219 |
+
residual,
|
220 |
+
layer_past=layer_past,
|
221 |
+
attention_mask=attention_mask,
|
222 |
+
alibi=alibi,
|
223 |
+
head_mask=head_mask,
|
224 |
+
use_cache=use_cache,
|
225 |
+
output_attentions=output_attentions,
|
226 |
+
)
|
227 |
+
|
228 |
+
attention_output = attn_outputs[0]
|
229 |
+
|
230 |
+
outputs = attn_outputs[1:]
|
231 |
+
|
232 |
+
layernorm_output = self.post_attention_layernorm(attention_output)
|
233 |
+
|
234 |
+
# Get residual
|
235 |
+
if self.apply_residual_connection_post_layernorm:
|
236 |
+
residual = layernorm_output
|
237 |
+
else:
|
238 |
+
residual = attention_output
|
239 |
+
|
240 |
+
# MLP.
|
241 |
+
output = self.mlp(layernorm_output, residual)
|
242 |
+
|
243 |
+
if use_cache:
|
244 |
+
outputs = (output,) + outputs
|
245 |
+
else:
|
246 |
+
outputs = (output,) + outputs[1:]
|
247 |
+
|
248 |
+
return outputs # hidden_states, present, attentions
|
src/bloom/from_pretrained.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Utils for fetching pretrained model parts. Currently, this relies on huggingface transformers' from_pretrained code.
|
3 |
+
If necessary, one can rewrite this to implement a different behavior, such as:
|
4 |
+
- loading files from a local data source (e.g. S3)
|
5 |
+
- load files via BitTorrent ( https://pypi.org/project/libtorrent/ ) or IPFS( https://docs.ipfs.io/how-to )
|
6 |
+
- fetch the weights over IPoAC, using a fleet of trained pigeons ( http://www.faqs.org/rfcs/rfc1149.html )
|
7 |
+
|
8 |
+
"""
|
9 |
+
from __future__ import annotations
|
10 |
+
|
11 |
+
from typing import Optional, OrderedDict, Union
|
12 |
+
|
13 |
+
import torch
|
14 |
+
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
15 |
+
from transformers.modeling_utils import WEIGHTS_NAME
|
16 |
+
from transformers.utils.hub import cached_path, hf_bucket_url
|
17 |
+
|
18 |
+
from src.bloom import BloomBlock, BloomConfig
|
19 |
+
|
20 |
+
use_hivemind_log_handler("in_root_logger")
|
21 |
+
logger = get_logger(__file__)
|
22 |
+
|
23 |
+
CLIENT_BRANCH = "main"
|
24 |
+
BLOCK_BRANCH_PREFIX = "block_"
|
25 |
+
USER_AGENT = {"file_type": "model", "framework": "pytorch", "from_auto_class": False}
|
26 |
+
FORCE_DOWNLOAD = False
|
27 |
+
RESUME_DOWNLOAD = False
|
28 |
+
LOCAL_FILES_ONLY = False
|
29 |
+
|
30 |
+
|
31 |
+
def load_pretrained_block(
|
32 |
+
converted_model_name_or_path: str,
|
33 |
+
block_index: int,
|
34 |
+
config: Optional[BloomConfig] = None,
|
35 |
+
torch_dtype: Union[torch.dtype, str] = "auto",
|
36 |
+
use_auth_token: Optional[str] = None,
|
37 |
+
) -> BloomBlock:
|
38 |
+
"""Load one BloomBlock from a converted model. See convert_model.py (or README.md) on how to convert it."""
|
39 |
+
if config is None:
|
40 |
+
config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token)
|
41 |
+
block = BloomBlock(config, layer_number=block_index)
|
42 |
+
state_dict = _load_state_dict(converted_model_name_or_path, block_index, use_auth_token=use_auth_token)
|
43 |
+
block.load_state_dict(state_dict)
|
44 |
+
|
45 |
+
if torch_dtype == "auto":
|
46 |
+
with torch.no_grad():
|
47 |
+
for name, param in block.named_parameters():
|
48 |
+
assert name in state_dict, f"{name} not in state dict"
|
49 |
+
param.data = param.data.to(state_dict[name].dtype)
|
50 |
+
else:
|
51 |
+
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
|
52 |
+
block = block.to(dtype=torch_dtype)
|
53 |
+
|
54 |
+
report = block.load_state_dict(state_dict, strict=True)
|
55 |
+
logger.info(f"Loaded {converted_model_name_or_path} block {block_index}, {report}")
|
56 |
+
return block
|
57 |
+
|
58 |
+
|
59 |
+
def _load_state_dict(
|
60 |
+
pretrained_model_name_or_path: str, block_index: Optional[int] = None, use_auth_token: Optional[str] = None
|
61 |
+
) -> OrderedDict[str, torch.Tensor]:
|
62 |
+
revision = BLOCK_BRANCH_PREFIX + str(block_index) if block_index is not None else CLIENT_BRANCH
|
63 |
+
archive_file = hf_bucket_url(pretrained_model_name_or_path, filename=WEIGHTS_NAME, revision=revision, mirror=None)
|
64 |
+
|
65 |
+
# Load from URL or cache if already cached
|
66 |
+
resolved_archive_file = cached_path(
|
67 |
+
archive_file,
|
68 |
+
cache_dir=None,
|
69 |
+
force_download=FORCE_DOWNLOAD,
|
70 |
+
proxies=None,
|
71 |
+
resume_download=RESUME_DOWNLOAD,
|
72 |
+
local_files_only=LOCAL_FILES_ONLY,
|
73 |
+
use_auth_token=use_auth_token,
|
74 |
+
user_agent=USER_AGENT,
|
75 |
+
)
|
76 |
+
state_dict = torch.load(resolved_archive_file, map_location="cpu")
|
77 |
+
return state_dict
|
78 |
+
|
79 |
+
|
80 |
+
DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")
|
src/bloom/model.py
ADDED
@@ -0,0 +1,408 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
PyTorch BLOOM model that implements several memory-efficient modes.
|
3 |
+
Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
|
4 |
+
See commit history for authorship.
|
5 |
+
"""
|
6 |
+
from typing import Tuple
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import torch.utils.checkpoint
|
11 |
+
from hivemind import use_hivemind_log_handler
|
12 |
+
from torch import nn
|
13 |
+
from torch.nn import CrossEntropyLoss, LayerNorm
|
14 |
+
from transformers.file_utils import (add_code_sample_docstrings, add_start_docstrings,
|
15 |
+
add_start_docstrings_to_model_forward)
|
16 |
+
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
|
17 |
+
from transformers.modeling_utils import PreTrainedModel
|
18 |
+
from transformers.models.bloom.configuration_bloom import BloomConfig
|
19 |
+
from transformers.utils import logging
|
20 |
+
|
21 |
+
from src.bloom.block import BloomBlock
|
22 |
+
|
23 |
+
use_hivemind_log_handler("in_root_logger")
|
24 |
+
logger = logging.get_logger(__file__)
|
25 |
+
|
26 |
+
_CHECKPOINT_FOR_DOC = "bigscience/Bloom"
|
27 |
+
_CONFIG_FOR_DOC = "BloomConfig"
|
28 |
+
_TOKENIZER_FOR_DOC = "BloomTokenizer"
|
29 |
+
|
30 |
+
|
31 |
+
class BloomPreTrainedModel(PreTrainedModel):
|
32 |
+
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
|
33 |
+
"""
|
34 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
35 |
+
models.
|
36 |
+
"""
|
37 |
+
|
38 |
+
config_class = BloomConfig
|
39 |
+
base_model_prefix = "transformer"
|
40 |
+
supports_gradient_checkpointing = True
|
41 |
+
_no_split_modules = ["BloomBlock"]
|
42 |
+
|
43 |
+
def __init__(self, *inputs, **kwargs):
|
44 |
+
super().__init__(*inputs, **kwargs)
|
45 |
+
|
46 |
+
def _init_weights(self, module):
|
47 |
+
"""Initialize the weights."""
|
48 |
+
if isinstance(module, (nn.Linear)):
|
49 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
50 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
51 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
52 |
+
if module.bias is not None:
|
53 |
+
module.bias.data.zero_()
|
54 |
+
elif isinstance(module, nn.Embedding):
|
55 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
56 |
+
if module.padding_idx is not None:
|
57 |
+
module.weight.data[module.padding_idx].zero_()
|
58 |
+
elif isinstance(module, LayerNorm):
|
59 |
+
module.bias.data.zero_()
|
60 |
+
module.weight.data.fill_(1.0)
|
61 |
+
|
62 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
63 |
+
if isinstance(module, BloomModel):
|
64 |
+
module.gradient_checkpointing = value
|
65 |
+
|
66 |
+
|
67 |
+
BLOOM_START_DOCSTRING = r"""
|
68 |
+
|
69 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
70 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings etc.)
|
71 |
+
|
72 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
73 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
74 |
+
and behavior.
|
75 |
+
|
76 |
+
Parameters:
|
77 |
+
config ([`MemoryEfficientBloomConfig`]): Model configuration class with all the parameters of the model.
|
78 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
79 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
80 |
+
"""
|
81 |
+
|
82 |
+
BLOOM_INPUTS_DOCSTRING = r"""
|
83 |
+
Args:
|
84 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
|
85 |
+
`input_ids_length` = `sequence_length` if `past_key_values` is `None` else
|
86 |
+
`past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
|
87 |
+
sequence tokens in the vocabulary.
|
88 |
+
|
89 |
+
If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
|
90 |
+
`input_ids`.
|
91 |
+
|
92 |
+
Indices can be obtained using [`BloomTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
93 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
94 |
+
|
95 |
+
[What are input IDs?](../glossary#input-ids)
|
96 |
+
past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`):
|
97 |
+
Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
|
98 |
+
`past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
|
99 |
+
their past given to this model should not be passed as `input_ids` as they have already been computed.
|
100 |
+
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
101 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
102 |
+
|
103 |
+
- 1 for tokens that are **not masked**,
|
104 |
+
- 0 for tokens that are **masked**.
|
105 |
+
|
106 |
+
[What are attention masks?](../glossary#attention-mask)
|
107 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
108 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
109 |
+
config.max_position_embeddings - 1]`.
|
110 |
+
|
111 |
+
[What are position IDs?](../glossary#position-ids)
|
112 |
+
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
113 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
114 |
+
|
115 |
+
- 1 indicates the head is **not masked**,
|
116 |
+
- 0 indicates the head is **masked**.
|
117 |
+
|
118 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
119 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
120 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
121 |
+
model's internal embedding lookup matrix.
|
122 |
+
|
123 |
+
If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
|
124 |
+
`past_key_values`).
|
125 |
+
use_cache (`bool`, *optional*):
|
126 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
127 |
+
`past_key_values`).
|
128 |
+
output_attentions (`bool`, *optional*):
|
129 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
130 |
+
tensors for more detail.
|
131 |
+
output_hidden_states (`bool`, *optional*):
|
132 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
133 |
+
more detail.
|
134 |
+
return_dict (`bool`, *optional*):
|
135 |
+
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
|
136 |
+
"""
|
137 |
+
|
138 |
+
|
139 |
+
@add_start_docstrings(
|
140 |
+
"The bare Bloom Model transformer outputting raw hidden-states without any specific head on top.",
|
141 |
+
BLOOM_START_DOCSTRING,
|
142 |
+
)
|
143 |
+
class BloomModel(BloomPreTrainedModel):
|
144 |
+
def __init__(self, config):
|
145 |
+
super().__init__(config)
|
146 |
+
assert not config.slow_but_exact, "slow_but_exact mode was removed for code simplicity"
|
147 |
+
|
148 |
+
self.embed_dim = config.hidden_size
|
149 |
+
self.n_head = config.n_head
|
150 |
+
|
151 |
+
# Embedding + LN Embedding
|
152 |
+
|
153 |
+
# TODO: @dbaranchuk make efficient fp16 on cpu (convert only word_embeddings!)
|
154 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim) # dtype=config.torch_dtype
|
155 |
+
self.word_embeddings_layernorm = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
156 |
+
|
157 |
+
# Transformer blocks
|
158 |
+
self.h = nn.ModuleList([BloomBlock(config, layer_number=i) for i in range(config.num_hidden_layers)])
|
159 |
+
|
160 |
+
# Final Layer Norm
|
161 |
+
self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
162 |
+
|
163 |
+
self.gradient_checkpointing = False
|
164 |
+
|
165 |
+
# Initialize weights and apply final processing
|
166 |
+
self.post_init()
|
167 |
+
|
168 |
+
# Forbid accumulate grads for embeddings and layernorm
|
169 |
+
self.set_requires_grad(False)
|
170 |
+
|
171 |
+
def get_input_embeddings(self):
|
172 |
+
return self.word_embeddings
|
173 |
+
|
174 |
+
def set_input_embeddings(self, new_embeddings):
|
175 |
+
self.word_embeddings = new_embeddings
|
176 |
+
|
177 |
+
def set_requires_grad(self, value):
|
178 |
+
for p in self.parameters():
|
179 |
+
p.requires_grad = value
|
180 |
+
|
181 |
+
@add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
|
182 |
+
@add_code_sample_docstrings(
|
183 |
+
processor_class=_TOKENIZER_FOR_DOC,
|
184 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
185 |
+
output_type=BaseModelOutputWithPastAndCrossAttentions,
|
186 |
+
config_class=_CONFIG_FOR_DOC,
|
187 |
+
)
|
188 |
+
def forward(
|
189 |
+
self,
|
190 |
+
input_ids=None,
|
191 |
+
past_key_values=None,
|
192 |
+
attention_mask=None,
|
193 |
+
position_ids=None,
|
194 |
+
head_mask=None,
|
195 |
+
inputs_embeds=None,
|
196 |
+
use_cache=None,
|
197 |
+
output_attentions=None,
|
198 |
+
output_hidden_states=None,
|
199 |
+
return_dict=None,
|
200 |
+
):
|
201 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
202 |
+
output_hidden_states = (
|
203 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
204 |
+
)
|
205 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
206 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
207 |
+
|
208 |
+
if input_ids is not None and inputs_embeds is not None:
|
209 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
210 |
+
if position_ids is not None:
|
211 |
+
logger.warning("position_ids are ignored in this bloom implementation")
|
212 |
+
elif input_ids is not None:
|
213 |
+
input_shape = input_ids.size()
|
214 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
215 |
+
elif inputs_embeds is not None:
|
216 |
+
input_shape = inputs_embeds.size()[:-1]
|
217 |
+
else:
|
218 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
219 |
+
|
220 |
+
if past_key_values is None:
|
221 |
+
past_key_values = tuple([None] * len(self.h))
|
222 |
+
|
223 |
+
# Prepare head mask if needed
|
224 |
+
# 1.0 in head_mask indicate we keep the head
|
225 |
+
# attention_probs has shape bsz x n_head x N x N
|
226 |
+
# head_mask has shape n_layer x batch x n_head x N x N
|
227 |
+
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
228 |
+
|
229 |
+
if inputs_embeds is None:
|
230 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
231 |
+
|
232 |
+
hidden_states = self.word_embeddings_layernorm(inputs_embeds.float())
|
233 |
+
|
234 |
+
output_shape = input_shape + (hidden_states.size(-1),)
|
235 |
+
|
236 |
+
presents = () if use_cache else None
|
237 |
+
all_self_attentions = () if output_attentions else None
|
238 |
+
all_hidden_states = () if output_hidden_states else None
|
239 |
+
|
240 |
+
# Compute alibi tensor: check build_alibi_tensor documentation
|
241 |
+
current_sequence_length = hidden_states.shape[1]
|
242 |
+
if past_key_values and past_key_values[0]:
|
243 |
+
current_sequence_length += past_key_values[0][0].shape[1]
|
244 |
+
|
245 |
+
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
246 |
+
|
247 |
+
|