|
|
|
import os |
|
from typing import Optional, Tuple, Union |
|
|
|
import hivemind |
|
from hivemind import DHT, get_logger, use_hivemind_log_handler |
|
|
|
from src.bloom.from_pretrained import CLIENT_BRANCH, _load_state_dict |
|
from src.bloom.model import BloomConfig, BloomForCausalLM, BloomModel, BloomPreTrainedModel |
|
from src.client.remote_sequential import RemoteSequential |
|
from src.data_structures import UID_DELIMITER |
|
|
|
use_hivemind_log_handler("in_root_logger") |
|
logger = get_logger(__file__) |
|
|
|
|
|
class DistributedBloomConfig(BloomConfig): |
|
""" |
|
A bloom config that contains information about DHT peers. |
|
To create a distributed model, one must provide dht_prefix and either initial_peers or dht. |
|
""" |
|
|
|
initial_peers: Tuple[str, ...] = () |
|
dht_prefix: str |
|
dht: Optional[hivemind.DHT] = None |
|
|
|
|
|
class DistributedBloomModel(BloomModel): |
|
"""BloomModel, but all transformer layers are hosted by the swarm""" |
|
config_class = DistributedBloomConfig |
|
|
|
def __init__(self, config: DistributedBloomConfig): |
|
assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..." |
|
assert config.initial_peers or config.dht, "Please specify initial_peers=list(...) or dht=hivemind.DHT(...)" |
|
|
|
n_layer, config.n_layer = config.n_layer, 0 |
|
super().__init__(config) |
|
assert len(self.h) == 0 |
|
config.n_layer = n_layer |
|
|
|
dht = ( |
|
config.dht |
|
if config.dht is not None |
|
else hivemind.DHT(initial_peers=config.initial_peers, client_mode=True, start=True) |
|
) |
|
assert isinstance(dht, hivemind.DHT) and dht.is_alive(), "dht must be a running hivemind.DHT instance" |
|
self.h = RemoteSequential(config, dht, config.dht_prefix) |
|
|
|
|
|
class DistributedBloomForCausalLM(BloomForCausalLM): |
|
"""DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm""" |
|
config_class = DistributedBloomConfig |
|
|
|
def __init__(self, config: DistributedBloomConfig): |
|
BloomPreTrainedModel.__init__(self, config) |
|
self.transformer = DistributedBloomModel(config) |
|
|
|
self.post_init() |
|
|