File size: 3,513 Bytes
e71a2ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

import numpy as np
from hivemind import PeerID, get_logger

from src.data_structures import RemoteModuleInfo, ServerState

__all__ = ["choose_best_blocks", "should_choose_other_blocks"]

logger = get_logger(__file__)


@dataclass
class Span:
    start: int
    end: int
    throughput: float

    @property
    def length(self):
        return self.end - self.start

    def move_to(self, new_start: int) -> None:
        self.start, self.end = new_start, new_start + self.length


def _compute_spans(module_infos: List[Optional[RemoteModuleInfo]]) -> Tuple[Dict[PeerID, Span], np.ndarray]:
    spans = {}
    throughputs = np.zeros(len(module_infos))
    for block, module in enumerate(module_infos):
        if module is None:
            continue

        for peer_id, server in module.servers.items():
            if server.state == ServerState.OFFLINE:
                continue

            if peer_id in spans:
                spans[peer_id].start = min(spans[peer_id].start, block)
                spans[peer_id].end = max(spans[peer_id].start, block + 1)
            else:
                spans[peer_id] = Span(start=block, end=block + 1, throughput=server.throughput)

            throughputs[block] += server.throughput

    return spans, throughputs


def _choose_best_start(throughputs: np.ndarray, num_blocks: int, cur_start: Optional[int]) -> int:
    options = (
        (sorted(throughputs[i : i + num_blocks]), i != cur_start, i)
        for i in range(0, len(throughputs) - num_blocks + 1)
    )
    return min(options)[-1]


def choose_best_blocks(num_blocks: int, module_infos: List[Optional[RemoteModuleInfo]]) -> List[int]:
    _, throughputs = _compute_spans(module_infos)
    start = _choose_best_start(throughputs, num_blocks, None)
    return list(range(start, start + num_blocks))


def should_choose_other_blocks(
    local_peer_id: PeerID, module_infos: List[Optional[RemoteModuleInfo]], balance_quality: float
) -> bool:
    if balance_quality > 1.0:
        return True  # Forces rebalancing on each check (may be used for debugging purposes)

    spans, throughputs = _compute_spans(module_infos)
    initial_throughput = throughputs.min()

    assert local_peer_id in spans, "Span served by this server is not present in the DHT"
    local_span = spans[local_peer_id]
    throughputs[local_span.start : local_span.end] -= local_span.throughput

    new_start = _choose_best_start(throughputs, local_span.length, local_span.start)
    if local_span.start == new_start:
        return False  # This server is on its best place already
    local_span.move_to(new_start)

    throughputs[local_span.start : local_span.end] += local_span.throughput

    moved = True
    while moved:
        servers = list(spans.keys())
        np.random.shuffle(servers)

        moved = False
        for peer_id in servers:
            span = spans[peer_id]
            throughputs[span.start : span.end] -= span.throughput

            new_start = _choose_best_start(throughputs, span.length, span.start)
            if span.start != new_start:
                span.move_to(new_start)
                moved = True

            throughputs[span.start : span.end] += span.throughput

    new_throughput = throughputs.min()
    actual_quality = initial_throughput / new_throughput
    logger.info(f"Swarm balance quality: {actual_quality * 100:.1f}%")

    eps = 1e-6
    return actual_quality < balance_quality - eps