File size: 7,716 Bytes
f24563f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 |
"""
Tensor parallelism for LLM training on TPU v4-32.
Optimized for training a 600B parameter model efficiently.
"""
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P
from typing import Dict, List, Optional, Tuple, Union, Any, Callable
import flax.linen as nn
from functools import partial
import numpy as np
import time
from parallelism.sharding import ShardingStrategy, ParameterSharding, create_device_mesh
class TensorParallel:
"""
Tensor parallelism for distributed training on TPU v4-32.
Attributes:
num_devices: Number of devices
num_tp: Number of tensor parallel devices
mesh: Device mesh
dp_size: Number of data parallel devices
"""
def __init__(
self,
num_devices: Optional[int] = None,
num_tp: int = 8,
use_2d_sharding: bool = True
):
"""
Initialize tensor parallelism optimized for TPU v4-32.
Args:
num_devices: Number of devices (defaults to all available devices)
num_tp: Number of tensor parallel devices
use_2d_sharding: Whether to use 2D sharding for better efficiency
"""
# Get number of devices
self.num_devices = num_devices or jax.device_count()
self.num_tp = min(num_tp, self.num_devices)
# Calculate optimal data parallelism size
self.dp_size = self.num_devices // self.num_tp
# Log device configuration
print(f"TPU configuration: {self.num_devices} total devices")
print(f"Tensor parallelism: {self.num_tp} devices")
print(f"Data parallelism: {self.dp_size} devices")
# Create device mesh
self.mesh = create_device_mesh(self.num_devices, num_tp=self.num_tp)
# Create parameter sharding
self.param_sharding = ParameterSharding(
self.mesh,
ShardingStrategy.TENSOR_PARALLEL
)
# Store 2D sharding preference
self.use_2d_sharding = use_2d_sharding
def shard_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
"""
Shard parameters across devices optimized for TPU v4-32.
Args:
params: Model parameters
Returns:
Sharded parameters
"""
# Create sharding rules
rules = self.param_sharding.create_sharding_rules(params)
# Measure sharding time for performance monitoring
start_time = time.time()
# Apply 2D sharding for better TPU utilization if enabled
if self.use_2d_sharding:
# Modify rules for 2D sharding of large matrices
rules = self._apply_2d_sharding(rules, params)
# Shard parameters
with self.mesh:
sharded_params = jax.tree_map(
lambda p, r: jax.lax.with_sharding_constraint(p, r),
params,
rules
)
# Log sharding time
sharding_time = time.time() - start_time
print(f"Parameter sharding completed in {sharding_time:.2f} seconds")
return sharded_params
def _apply_2d_sharding(self, rules: Dict[str, Any], params: Dict[str, Any]) -> Dict[str, Any]:
"""
Apply 2D sharding to large matrices for better TPU utilization.
Args:
rules: Sharding rules
params: Model parameters
Returns:
Modified sharding rules
"""
# Flatten parameter tree
flat_params = jax.tree_util.tree_flatten(params)[0]
flat_paths = jax.tree_util.tree_flatten_with_path(params)[0]
paths = ['/'.join(str(p) for p in path) for path, _ in flat_paths]
# Create modified rules
modified_rules = dict(rules)
# Apply 2D sharding to large matrices
for path, param in zip(paths, flat_params):
# Check if parameter is a large matrix (>= 4096 in both dimensions)
if len(param.shape) == 2 and param.shape[0] >= 4096 and param.shape[1] >= 4096:
# Apply 2D sharding
modified_rules[path] = P('dp', 'tp')
return modified_rules
def parallelize(self, fn: Callable, donate_argnums: Optional[Tuple[int, ...]] = None) -> Callable:
"""
Parallelize function across devices with optimizations for TPU v4-32.
Args:
fn: Function to parallelize
donate_argnums: Indices of arguments to donate (for memory optimization)
Returns:
Parallelized function
"""
# Use cached computation for better performance
fn = jax.jit(fn, donate_argnums=donate_argnums)
# Parallelize function with optimized device mapping
return jax.pmap(
fn,
axis_name="batch",
devices=self.mesh.devices.reshape(self.dp_size, -1)[:, 0],
donate_argnums=donate_argnums
)
def gather_outputs(self, outputs: Any) -> Any:
"""
Gather outputs from devices with optimized communication.
Args:
outputs: Outputs from parallelized function
Returns:
Gathered outputs
"""
# Measure gathering time for performance monitoring
start_time = time.time()
# Gather outputs
gathered = jax.tree_map(lambda x: x[0], outputs)
# Log gathering time for large outputs
if isinstance(outputs, dict) and 'logits' in outputs:
gather_time = time.time() - start_time
if gather_time > 0.1: # Only log if significant
print(f"Output gathering completed in {gather_time:.2f} seconds")
return gathered
def all_reduce(self, values: Any, reduce_type: str = "mean") -> Any:
"""
Perform all-reduce operation across devices with optimized communication.
Args:
values: Values to reduce
reduce_type: Type of reduction ("mean" or "sum")
Returns:
Reduced values
"""
if reduce_type == "mean":
return jax.lax.pmean(values, axis_name="batch")
elif reduce_type == "sum":
return jax.lax.psum(values, axis_name="batch")
else:
raise ValueError(f"Unsupported reduce type: {reduce_type}")
def replicate(self, values: Any) -> Any:
"""
Replicate values across devices with optimized memory usage.
Args:
values: Values to replicate
Returns:
Replicated values
"""
# Use broadcast instead of replication for better performance
return jax.tree_map(lambda x: jnp.broadcast_to(x, (self.dp_size,) + x.shape), values)
def split_batch(self, batch: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
"""
Split batch across devices with optimized memory layout for TPU.
Args:
batch: Batch of data
Returns:
Split batch
"""
# Compute batch size per device
batch_size = batch["input_ids"].shape[0]
per_device_batch_size = batch_size // self.dp_size
# Check if batch size is divisible by number of devices
if batch_size % self.dp_size != 0:
print(f"Warning: Batch size {batch_size} is not divisible by number of data parallel devices {self.dp_size}")
# Adjust batch size to be divisible
new_batch_size = per_device_batch_size * self.dp_size
batch = jax.tree_map(lambda x: x[:new_batch_size], batch)
# Split batch with optimized memory layout
return jax.tree_map(
lambda x: x.reshape(self.dp_size, per_device_batch_size, *x.shape[1:]),
batch
)
|