|
""" |
|
Tensor parallelism for LLM training on TPU v4-32. |
|
Optimized for training a 600B parameter model efficiently. |
|
""" |
|
|
|
import jax |
|
import jax.numpy as jnp |
|
from jax.sharding import Mesh, PartitionSpec as P |
|
from typing import Dict, List, Optional, Tuple, Union, Any, Callable |
|
import flax.linen as nn |
|
from functools import partial |
|
import numpy as np |
|
import time |
|
|
|
from parallelism.sharding import ShardingStrategy, ParameterSharding, create_device_mesh |
|
|
|
|
|
class TensorParallel: |
|
""" |
|
Tensor parallelism for distributed training on TPU v4-32. |
|
|
|
Attributes: |
|
num_devices: Number of devices |
|
num_tp: Number of tensor parallel devices |
|
mesh: Device mesh |
|
dp_size: Number of data parallel devices |
|
""" |
|
|
|
def __init__( |
|
self, |
|
num_devices: Optional[int] = None, |
|
num_tp: int = 8, |
|
use_2d_sharding: bool = True |
|
): |
|
""" |
|
Initialize tensor parallelism optimized for TPU v4-32. |
|
|
|
Args: |
|
num_devices: Number of devices (defaults to all available devices) |
|
num_tp: Number of tensor parallel devices |
|
use_2d_sharding: Whether to use 2D sharding for better efficiency |
|
""" |
|
|
|
self.num_devices = num_devices or jax.device_count() |
|
self.num_tp = min(num_tp, self.num_devices) |
|
|
|
|
|
self.dp_size = self.num_devices // self.num_tp |
|
|
|
|
|
print(f"TPU configuration: {self.num_devices} total devices") |
|
print(f"Tensor parallelism: {self.num_tp} devices") |
|
print(f"Data parallelism: {self.dp_size} devices") |
|
|
|
|
|
self.mesh = create_device_mesh(self.num_devices, num_tp=self.num_tp) |
|
|
|
|
|
self.param_sharding = ParameterSharding( |
|
self.mesh, |
|
ShardingStrategy.TENSOR_PARALLEL |
|
) |
|
|
|
|
|
self.use_2d_sharding = use_2d_sharding |
|
|
|
def shard_params(self, params: Dict[str, Any]) -> Dict[str, Any]: |
|
""" |
|
Shard parameters across devices optimized for TPU v4-32. |
|
|
|
Args: |
|
params: Model parameters |
|
|
|
Returns: |
|
Sharded parameters |
|
""" |
|
|
|
rules = self.param_sharding.create_sharding_rules(params) |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
if self.use_2d_sharding: |
|
|
|
rules = self._apply_2d_sharding(rules, params) |
|
|
|
|
|
with self.mesh: |
|
sharded_params = jax.tree_map( |
|
lambda p, r: jax.lax.with_sharding_constraint(p, r), |
|
params, |
|
rules |
|
) |
|
|
|
|
|
sharding_time = time.time() - start_time |
|
print(f"Parameter sharding completed in {sharding_time:.2f} seconds") |
|
|
|
return sharded_params |
|
|
|
def _apply_2d_sharding(self, rules: Dict[str, Any], params: Dict[str, Any]) -> Dict[str, Any]: |
|
""" |
|
Apply 2D sharding to large matrices for better TPU utilization. |
|
|
|
Args: |
|
rules: Sharding rules |
|
params: Model parameters |
|
|
|
Returns: |
|
Modified sharding rules |
|
""" |
|
|
|
flat_params = jax.tree_util.tree_flatten(params)[0] |
|
flat_paths = jax.tree_util.tree_flatten_with_path(params)[0] |
|
paths = ['/'.join(str(p) for p in path) for path, _ in flat_paths] |
|
|
|
|
|
modified_rules = dict(rules) |
|
|
|
|
|
for path, param in zip(paths, flat_params): |
|
|
|
if len(param.shape) == 2 and param.shape[0] >= 4096 and param.shape[1] >= 4096: |
|
|
|
modified_rules[path] = P('dp', 'tp') |
|
|
|
return modified_rules |
|
|
|
def parallelize(self, fn: Callable, donate_argnums: Optional[Tuple[int, ...]] = None) -> Callable: |
|
""" |
|
Parallelize function across devices with optimizations for TPU v4-32. |
|
|
|
Args: |
|
fn: Function to parallelize |
|
donate_argnums: Indices of arguments to donate (for memory optimization) |
|
|
|
Returns: |
|
Parallelized function |
|
""" |
|
|
|
fn = jax.jit(fn, donate_argnums=donate_argnums) |
|
|
|
|
|
return jax.pmap( |
|
fn, |
|
axis_name="batch", |
|
devices=self.mesh.devices.reshape(self.dp_size, -1)[:, 0], |
|
donate_argnums=donate_argnums |
|
) |
|
|
|
def gather_outputs(self, outputs: Any) -> Any: |
|
""" |
|
Gather outputs from devices with optimized communication. |
|
|
|
Args: |
|
outputs: Outputs from parallelized function |
|
|
|
Returns: |
|
Gathered outputs |
|
""" |
|
|
|
start_time = time.time() |
|
|
|
|
|
gathered = jax.tree_map(lambda x: x[0], outputs) |
|
|
|
|
|
if isinstance(outputs, dict) and 'logits' in outputs: |
|
gather_time = time.time() - start_time |
|
if gather_time > 0.1: |
|
print(f"Output gathering completed in {gather_time:.2f} seconds") |
|
|
|
return gathered |
|
|
|
def all_reduce(self, values: Any, reduce_type: str = "mean") -> Any: |
|
""" |
|
Perform all-reduce operation across devices with optimized communication. |
|
|
|
Args: |
|
values: Values to reduce |
|
reduce_type: Type of reduction ("mean" or "sum") |
|
|
|
Returns: |
|
Reduced values |
|
""" |
|
if reduce_type == "mean": |
|
return jax.lax.pmean(values, axis_name="batch") |
|
elif reduce_type == "sum": |
|
return jax.lax.psum(values, axis_name="batch") |
|
else: |
|
raise ValueError(f"Unsupported reduce type: {reduce_type}") |
|
|
|
def replicate(self, values: Any) -> Any: |
|
""" |
|
Replicate values across devices with optimized memory usage. |
|
|
|
Args: |
|
values: Values to replicate |
|
|
|
Returns: |
|
Replicated values |
|
""" |
|
|
|
return jax.tree_map(lambda x: jnp.broadcast_to(x, (self.dp_size,) + x.shape), values) |
|
|
|
def split_batch(self, batch: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]: |
|
""" |
|
Split batch across devices with optimized memory layout for TPU. |
|
|
|
Args: |
|
batch: Batch of data |
|
|
|
Returns: |
|
Split batch |
|
""" |
|
|
|
batch_size = batch["input_ids"].shape[0] |
|
per_device_batch_size = batch_size // self.dp_size |
|
|
|
|
|
if batch_size % self.dp_size != 0: |
|
print(f"Warning: Batch size {batch_size} is not divisible by number of data parallel devices {self.dp_size}") |
|
|
|
new_batch_size = per_device_batch_size * self.dp_size |
|
batch = jax.tree_map(lambda x: x[:new_batch_size], batch) |
|
|
|
|
|
return jax.tree_map( |
|
lambda x: x.reshape(self.dp_size, per_device_batch_size, *x.shape[1:]), |
|
batch |
|
) |
|
|