File size: 4,100 Bytes
5bdad4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from __future__ import annotations

import dataclasses
import threading
from functools import partial
from typing import List, NamedTuple, Optional, Sequence, Tuple

from hivemind import DHT, PeerID
from hivemind.utils.logging import get_logger, use_hivemind_log_handler

from src.data_structures import ModuleUID, RemoteModuleInfo
from src.dht_utils import _get_remote_module_infos

use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)


Span = NamedTuple("Span", [("start", int), ("end", Optional[int]), ("peer_id", PeerID)])


@dataclasses.dataclass(frozen=False, init=False)  # TODO[borzunov@] eto ne dataclass
class RemoteSequenceInfo:
    """Keeps and updates the meta-information about which peers host which blocks"""

    dht: DHT
    block_uids: List[ModuleUID, ...]
    block_infos: List[Optional[RemoteModuleInfo], ...]
    spans_by_priority: List[Span]  # sorted from best to worst
    spans_containing_block: Tuple[List[Span], ...]
    lock_changes: threading.Lock

    def __init__(self, dht: DHT, block_uids: Sequence[ModuleUID]):
        self.dht = dht
        self.block_uids = list(block_uids)
        self.block_infos: List[Optional[RemoteModuleInfo], ...] = [None] * len(self.block_uids)
        self.spans_by_priority = []
        self.spans_containing_block = tuple(list() for _ in range(len(self.block_uids)))
        self.lock_changes = threading.Lock()
        self.update_()

        for uid, info in zip(self.block_uids, self.block_infos):
            assert info is not None, f"Found no remote peers for block {uid}"
        assert self.spans_by_priority and self.spans_containing_block

    def update_(self):
        with self.lock_changes:
            self.update_block_infos_()
            self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos)

    def update_block_infos_(self):
        new_block_infos: Sequence[RemoteModuleInfo] = self.dht.run_coroutine(
            partial(_get_remote_module_infos, uids=self.block_uids, expiration_time=float("inf")), return_future=False
        )
        assert len(new_block_infos) == len(self.block_uids)
        for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):
            if info is None:
                logger.warning(f"Found no block info for block {uid}")
            if not isinstance(info, RemoteModuleInfo):
                logger.warning(f"Unexpected dht entry type for {uid}: {info}")
            if not info.peer_ids:
                logger.warning(f"Found no active peers for block {uid}")
            if info.uid != uid:
                logger.warning(f"The DHT entry for {uid} actually points to {info.uid}")
            if not isinstance(info.peer_ids, set):
                logger.warning(f"Expected peer_ids for {uid} to be a set, got {type(info.peer_ids)}")
            self.block_infos[block_index] = info

    @staticmethod
    def compute_spans(block_infos: Sequence[RemoteModuleInfo]):
        closed_spans = []
        active_spans = {}
        for block_index, info in enumerate(block_infos):
            for peer_id in info.peer_ids:
                if peer_id not in active_spans:
                    active_spans[peer_id] = Span(start=block_index, end=block_index + 1, peer_id=peer_id)
                else:  # peer_id in active_spans
                    active_spans[peer_id] = active_spans[peer_id]._replace(end=block_index + 1)

            for peer_id in list(active_spans.keys()):
                if peer_id not in info.peer_ids or block_index == len(block_infos) - 1:
                    closed_spans.append(active_spans.pop(peer_id))
        assert not active_spans

        closed_spans.sort(key=lambda span: span.end - span.start, reverse=True)

        spans_containing_block = tuple(list() for _ in range(len(block_infos)))
        for span in closed_spans:
            for block_index in range(span.start, span.end):
                spans_containing_block[block_index].append(span)

        return closed_spans, spans_containing_block

    def __len__(self):
        return len(self.block_uids)