File size: 7,716 Bytes
f24563f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
"""
Tensor parallelism for LLM training on TPU v4-32.
Optimized for training a 600B parameter model efficiently.
"""

import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P
from typing import Dict, List, Optional, Tuple, Union, Any, Callable
import flax.linen as nn
from functools import partial
import numpy as np
import time

from parallelism.sharding import ShardingStrategy, ParameterSharding, create_device_mesh


class TensorParallel:
    """
    Tensor parallelism for distributed training on TPU v4-32.

    Attributes:
        num_devices: Number of devices
        num_tp: Number of tensor parallel devices
        mesh: Device mesh
        dp_size: Number of data parallel devices
    """

    def __init__(
        self,
        num_devices: Optional[int] = None,
        num_tp: int = 8,
        use_2d_sharding: bool = True
    ):
        """
        Initialize tensor parallelism optimized for TPU v4-32.

        Args:
            num_devices: Number of devices (defaults to all available devices)
            num_tp: Number of tensor parallel devices
            use_2d_sharding: Whether to use 2D sharding for better efficiency
        """
        # Get number of devices
        self.num_devices = num_devices or jax.device_count()
        self.num_tp = min(num_tp, self.num_devices)

        # Calculate optimal data parallelism size
        self.dp_size = self.num_devices // self.num_tp

        # Log device configuration
        print(f"TPU configuration: {self.num_devices} total devices")
        print(f"Tensor parallelism: {self.num_tp} devices")
        print(f"Data parallelism: {self.dp_size} devices")

        # Create device mesh
        self.mesh = create_device_mesh(self.num_devices, num_tp=self.num_tp)

        # Create parameter sharding
        self.param_sharding = ParameterSharding(
            self.mesh,
            ShardingStrategy.TENSOR_PARALLEL
        )

        # Store 2D sharding preference
        self.use_2d_sharding = use_2d_sharding

    def shard_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
        """
        Shard parameters across devices optimized for TPU v4-32.

        Args:
            params: Model parameters

        Returns:
            Sharded parameters
        """
        # Create sharding rules
        rules = self.param_sharding.create_sharding_rules(params)

        # Measure sharding time for performance monitoring
        start_time = time.time()

        # Apply 2D sharding for better TPU utilization if enabled
        if self.use_2d_sharding:
            # Modify rules for 2D sharding of large matrices
            rules = self._apply_2d_sharding(rules, params)

        # Shard parameters
        with self.mesh:
            sharded_params = jax.tree_map(
                lambda p, r: jax.lax.with_sharding_constraint(p, r),
                params,
                rules
            )

        # Log sharding time
        sharding_time = time.time() - start_time
        print(f"Parameter sharding completed in {sharding_time:.2f} seconds")

        return sharded_params

    def _apply_2d_sharding(self, rules: Dict[str, Any], params: Dict[str, Any]) -> Dict[str, Any]:
        """
        Apply 2D sharding to large matrices for better TPU utilization.

        Args:
            rules: Sharding rules
            params: Model parameters

        Returns:
            Modified sharding rules
        """
        # Flatten parameter tree
        flat_params = jax.tree_util.tree_flatten(params)[0]
        flat_paths = jax.tree_util.tree_flatten_with_path(params)[0]
        paths = ['/'.join(str(p) for p in path) for path, _ in flat_paths]

        # Create modified rules
        modified_rules = dict(rules)

        # Apply 2D sharding to large matrices
        for path, param in zip(paths, flat_params):
            # Check if parameter is a large matrix (>= 4096 in both dimensions)
            if len(param.shape) == 2 and param.shape[0] >= 4096 and param.shape[1] >= 4096:
                # Apply 2D sharding
                modified_rules[path] = P('dp', 'tp')

        return modified_rules

    def parallelize(self, fn: Callable, donate_argnums: Optional[Tuple[int, ...]] = None) -> Callable:
        """
        Parallelize function across devices with optimizations for TPU v4-32.

        Args:
            fn: Function to parallelize
            donate_argnums: Indices of arguments to donate (for memory optimization)

        Returns:
            Parallelized function
        """
        # Use cached computation for better performance
        fn = jax.jit(fn, donate_argnums=donate_argnums)

        # Parallelize function with optimized device mapping
        return jax.pmap(
            fn,
            axis_name="batch",
            devices=self.mesh.devices.reshape(self.dp_size, -1)[:, 0],
            donate_argnums=donate_argnums
        )

    def gather_outputs(self, outputs: Any) -> Any:
        """
        Gather outputs from devices with optimized communication.

        Args:
            outputs: Outputs from parallelized function

        Returns:
            Gathered outputs
        """
        # Measure gathering time for performance monitoring
        start_time = time.time()

        # Gather outputs
        gathered = jax.tree_map(lambda x: x[0], outputs)

        # Log gathering time for large outputs
        if isinstance(outputs, dict) and 'logits' in outputs:
            gather_time = time.time() - start_time
            if gather_time > 0.1:  # Only log if significant
                print(f"Output gathering completed in {gather_time:.2f} seconds")

        return gathered

    def all_reduce(self, values: Any, reduce_type: str = "mean") -> Any:
        """
        Perform all-reduce operation across devices with optimized communication.

        Args:
            values: Values to reduce
            reduce_type: Type of reduction ("mean" or "sum")

        Returns:
            Reduced values
        """
        if reduce_type == "mean":
            return jax.lax.pmean(values, axis_name="batch")
        elif reduce_type == "sum":
            return jax.lax.psum(values, axis_name="batch")
        else:
            raise ValueError(f"Unsupported reduce type: {reduce_type}")

    def replicate(self, values: Any) -> Any:
        """
        Replicate values across devices with optimized memory usage.

        Args:
            values: Values to replicate

        Returns:
            Replicated values
        """
        # Use broadcast instead of replication for better performance
        return jax.tree_map(lambda x: jnp.broadcast_to(x, (self.dp_size,) + x.shape), values)

    def split_batch(self, batch: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
        """
        Split batch across devices with optimized memory layout for TPU.

        Args:
            batch: Batch of data

        Returns:
            Split batch
        """
        # Compute batch size per device
        batch_size = batch["input_ids"].shape[0]
        per_device_batch_size = batch_size // self.dp_size

        # Check if batch size is divisible by number of devices
        if batch_size % self.dp_size != 0:
            print(f"Warning: Batch size {batch_size} is not divisible by number of data parallel devices {self.dp_size}")
            # Adjust batch size to be divisible
            new_batch_size = per_device_batch_size * self.dp_size
            batch = jax.tree_map(lambda x: x[:new_batch_size], batch)

        # Split batch with optimized memory layout
        return jax.tree_map(
            lambda x: x.reshape(self.dp_size, per_device_batch_size, *x.shape[1:]),
            batch
        )