File size: 11,793 Bytes
5bdad4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
# Note: this code is being actively modified by justheuristic. If you want to change anything about it, please warn me.
import contextlib
from typing import AsyncIterator, Dict, Sequence

import torch
from hivemind import DHT, P2PContext, TensorDescriptor, deserialize_torch_tensor, nested_flatten, serialize_torch_tensor
from hivemind.moe.server.connection_handler import ConnectionHandler
from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
from hivemind.proto import runtime_pb2
from hivemind.utils import as_aiter
from hivemind.utils.asyncio import anext
from hivemind.utils.streaming import split_for_streaming

from src.data_structures import CHAIN_DELIMITER, ModuleUID
from src.server.backend import MAX_LENGTH, TransformerBackend


class TransformerConnectionHandler(ConnectionHandler):
    """Handles three request types: forward, backward and forward-incremental (inference)"""

    module_backends: Dict[ModuleUID, TransformerBackend]

    def __init__(self, dht: DHT, module_backends: Dict[str, TransformerBackend]):
        super().__init__(dht, module_backends)
        for module_backend in self.module_backends.values():
            assert isinstance(module_backend, TransformerBackend)

    async def rpc_inference(
        self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
    ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
        """Compute a single step of inference using attention cache; update attention cache accordingly."""
        try:
            print("OPENED RPC_INFERENCE")
            request = await anext(requests)
            requested_uids = self._check_header(request)
            requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)

            cache_metadata = torch.tensor([[-1, -1]], dtype=torch.int64)  # [cache_handle, prefix_length]
            prefix_length = 0

            async with self._allocate_caches(requested_backends) as cache_handles:
                assert len(cache_handles) == len(requested_backends)
                while request.tensors:  # iterate while user is willing to supply tensors
                    hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]

                    # run request tensors through all requested modules, update caches
                    for backend, cache_handle in zip(requested_backends, cache_handles):
                        cache_metadata[0, 0], cache_metadata[0, 1] = cache_handle, prefix_length
                        assert (
                            len(hidden_states) == 1 and hidden_states[0].ndim == 3
                        ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"

                        hidden_states = await backend.inference_pool.submit_task(cache_metadata, *hidden_states)
                        assert isinstance(hidden_states, (list, tuple))
                        assert len(hidden_states) == 1 and hidden_states[0].ndim == 3

                    # serialize and send last layer outputs
                    yield runtime_pb2.ExpertResponse(
                        tensors=[
                            serialize_torch_tensor(result, proto.compression, allow_inplace=True)
                            for result, proto in zip(
                                hidden_states, nested_flatten(requested_backends[-1].outputs_schema)
                            )
                        ]
                    )

                    # prepare for next step
                    prefix_length += hidden_states[0].shape[1]
                    request = await (anext(requests))
        finally:
            print("CLOSED RPC_INFERENCE")

    async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
        # Parse request and prepare backends
        hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
        requested_uids = self._check_header(request)
        requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)

        # Run a chain of requested backends
        for backend in requested_backends:
            assert isinstance(hidden_states, (list, tuple))
            assert (
                len(hidden_states) == 1 and hidden_states[0].ndim == 3
            ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
            hidden_states = await backend.forward_pool.submit_task(*hidden_states)

        # Serialize the overall output and respond
        assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
        return runtime_pb2.ExpertResponse(
            tensors=[
                serialize_torch_tensor(result, proto.compression, allow_inplace=True)
                for result, proto in zip(hidden_states, nested_flatten(requested_backends[-1].outputs_schema))
            ]
        )

    async def rpc_forward_stream(
        self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
    ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
        # Parse requests and prepare backends
        uids_header, hidden_states = await self._gather_inputs(requests, context)
        requested_uids = self._check_header_str(uids_header)
        requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)

        # Run a chain of requested backends
        for backend in requested_backends:
            assert isinstance(hidden_states, (list, tuple))
            assert (
                len(hidden_states) == 1 and hidden_states[0].ndim == 3
            ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
            hidden_states = await backend.forward_pool.submit_task(*hidden_states)

        # Serialize the overall output
        assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
        serialized_output = [
            serialize_torch_tensor(result, proto.compression, allow_inplace=True)
            for result, proto in zip(hidden_states, nested_flatten(requested_backends[-1].outputs_schema))
        ]

        # Split the serialized_output for streaming and respond
        output_split = [
            part for tensor in serialized_output for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
        ]
        async for part in as_aiter(*output_split):
            yield runtime_pb2.ExpertResponse(tensors=[part])

    async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
        # Parse requests and prepare backends
        inputs, grads = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
        requested_uids = self._check_header(request)
        requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)

        # Run a forward chain to collect intermediate inputs
        # Note that we do not forward for the last module since we do not need its output
        inter_inputs = [inputs]
        for backend in requested_backends[:-1]:
            assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
            inputs = await backend.forward_pool.submit_task(inputs)
            assert isinstance(inputs, (list, tuple)) and len(inputs) == 1
            inputs = inputs[0]
            inter_inputs.append(inputs)

        # Run a chain of requested backends
        for inp, backend in zip(inter_inputs[::-1], requested_backends[::-1]):
            inputs_and_grads = [inp, grads]
            grads = await backend.backward_pool.submit_task(*inputs_and_grads)
            assert isinstance(grads, (list, tuple)) and len(grads) == 1
            grads = grads[0]

        # Serialize the overall grad_input and respond
        return runtime_pb2.ExpertResponse(
            tensors=[
                serialize_torch_tensor(result, proto.compression, allow_inplace=True)
                for result, proto in zip([grads], nested_flatten(requested_backends[0].grad_inputs_schema))
            ]
        )

    async def rpc_backward_stream(
        self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
    ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
        uids_header, inputs_and_grads = await self._gather_inputs(requests, context)
        inputs, grads = inputs_and_grads
        requested_uids = self._check_header_str(uids_header)
        requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)

        # Run a forward chain to collect intermediate inputs
        # Note that we do not forward for the last module since we do not need its outputs
        inter_inputs = [inputs]
        for backend in requested_backends[:-1]:
            assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
            inputs = await backend.forward_pool.submit_task(inputs)
            assert isinstance(inputs, (list, tuple)) and len(inputs) == 1
            inputs = inputs[0]
            inter_inputs.append(inputs)

        # Run a backward chain for requested backends
        for inp, backend in zip(inter_inputs[::-1], requested_backends[::-1]):
            inputs_and_grads = [inp, grads]
            grads = await backend.backward_pool.submit_task(*inputs_and_grads)
            assert isinstance(grads, (list, tuple)) and len(grads) == 1
            grads = grads[0]

        # Serialize the overall grad_inputs
        serialized_grad_inputs = [
            serialize_torch_tensor(result, proto.compression, allow_inplace=True)
            for result, proto in zip([grads], nested_flatten(requested_backends[0].grad_inputs_schema))
        ]
        # Split the serialized_grad_inputs for streaming and respond
        output_split = [
            part for tensor in serialized_grad_inputs for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
        ]

        async for part in as_aiter(*output_split):
            yield runtime_pb2.ExpertResponse(tensors=[part])

    def _check_header(self, request: runtime_pb2.ExpertRequest) -> Sequence[ModuleUID]:
        """Check that the first request to rpc_inference is valid"""
        uids = (request.uid or "").split(CHAIN_DELIMITER)
        if not uids:
            raise RuntimeError("User did not provide any uids")
        for uid in uids:
            if uid not in self.module_backends:
                raise RuntimeError(f"Remote peer does not serve {uid}")
        return tuple(uids)

    def _check_header_str(self, header) -> Sequence[ModuleUID]:
        """Check that the first request to rpc_inference is valid"""
        uids = (header or "").split(CHAIN_DELIMITER)
        if not uids:
            raise RuntimeError("User did not provide any uids")
        for uid in uids:
            if uid not in self.module_backends:
                raise RuntimeError(f"Remote peer does not serve {uid}")
        return tuple(uids)

    @contextlib.asynccontextmanager
    async def _allocate_caches(self, backends: Sequence[TransformerBackend]) -> Sequence[int]:
        """Allocate memory caches for each transformer block, return cache handles"""
        async with contextlib.AsyncExitStack() as stack:
            handles = []
            for backend in backends:
                num_heads = backend.module.self_attention.num_heads
                head_dim = backend.module.self_attention.head_dim

                cache_descriptor = TensorDescriptor(size=(2, 1, MAX_LENGTH, num_heads, head_dim), dtype=torch.float32)
                # [key_or_value, batch_size, max_length, num_heads, head_dim]

                handles.append(await stack.enter_async_context(backend.memory_cache.allocate_cache(cache_descriptor)))

            yield handles