tpu-optimized-llm / parallelism /tensor_parallel.py
Threatthriver's picture
Upload folder using huggingface_hub
f24563f verified
"""
Tensor parallelism for LLM training on TPU v4-32.
Optimized for training a 600B parameter model efficiently.
"""
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P
from typing import Dict, List, Optional, Tuple, Union, Any, Callable
import flax.linen as nn
from functools import partial
import numpy as np
import time
from parallelism.sharding import ShardingStrategy, ParameterSharding, create_device_mesh
class TensorParallel:
"""
Tensor parallelism for distributed training on TPU v4-32.
Attributes:
num_devices: Number of devices
num_tp: Number of tensor parallel devices
mesh: Device mesh
dp_size: Number of data parallel devices
"""
def __init__(
self,
num_devices: Optional[int] = None,
num_tp: int = 8,
use_2d_sharding: bool = True
):
"""
Initialize tensor parallelism optimized for TPU v4-32.
Args:
num_devices: Number of devices (defaults to all available devices)
num_tp: Number of tensor parallel devices
use_2d_sharding: Whether to use 2D sharding for better efficiency
"""
# Get number of devices
self.num_devices = num_devices or jax.device_count()
self.num_tp = min(num_tp, self.num_devices)
# Calculate optimal data parallelism size
self.dp_size = self.num_devices // self.num_tp
# Log device configuration
print(f"TPU configuration: {self.num_devices} total devices")
print(f"Tensor parallelism: {self.num_tp} devices")
print(f"Data parallelism: {self.dp_size} devices")
# Create device mesh
self.mesh = create_device_mesh(self.num_devices, num_tp=self.num_tp)
# Create parameter sharding
self.param_sharding = ParameterSharding(
self.mesh,
ShardingStrategy.TENSOR_PARALLEL
)
# Store 2D sharding preference
self.use_2d_sharding = use_2d_sharding
def shard_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
"""
Shard parameters across devices optimized for TPU v4-32.
Args:
params: Model parameters
Returns:
Sharded parameters
"""
# Create sharding rules
rules = self.param_sharding.create_sharding_rules(params)
# Measure sharding time for performance monitoring
start_time = time.time()
# Apply 2D sharding for better TPU utilization if enabled
if self.use_2d_sharding:
# Modify rules for 2D sharding of large matrices
rules = self._apply_2d_sharding(rules, params)
# Shard parameters
with self.mesh:
sharded_params = jax.tree_map(
lambda p, r: jax.lax.with_sharding_constraint(p, r),
params,
rules
)
# Log sharding time
sharding_time = time.time() - start_time
print(f"Parameter sharding completed in {sharding_time:.2f} seconds")
return sharded_params
def _apply_2d_sharding(self, rules: Dict[str, Any], params: Dict[str, Any]) -> Dict[str, Any]:
"""
Apply 2D sharding to large matrices for better TPU utilization.
Args:
rules: Sharding rules
params: Model parameters
Returns:
Modified sharding rules
"""
# Flatten parameter tree
flat_params = jax.tree_util.tree_flatten(params)[0]
flat_paths = jax.tree_util.tree_flatten_with_path(params)[0]
paths = ['/'.join(str(p) for p in path) for path, _ in flat_paths]
# Create modified rules
modified_rules = dict(rules)
# Apply 2D sharding to large matrices
for path, param in zip(paths, flat_params):
# Check if parameter is a large matrix (>= 4096 in both dimensions)
if len(param.shape) == 2 and param.shape[0] >= 4096 and param.shape[1] >= 4096:
# Apply 2D sharding
modified_rules[path] = P('dp', 'tp')
return modified_rules
def parallelize(self, fn: Callable, donate_argnums: Optional[Tuple[int, ...]] = None) -> Callable:
"""
Parallelize function across devices with optimizations for TPU v4-32.
Args:
fn: Function to parallelize
donate_argnums: Indices of arguments to donate (for memory optimization)
Returns:
Parallelized function
"""
# Use cached computation for better performance
fn = jax.jit(fn, donate_argnums=donate_argnums)
# Parallelize function with optimized device mapping
return jax.pmap(
fn,
axis_name="batch",
devices=self.mesh.devices.reshape(self.dp_size, -1)[:, 0],
donate_argnums=donate_argnums
)
def gather_outputs(self, outputs: Any) -> Any:
"""
Gather outputs from devices with optimized communication.
Args:
outputs: Outputs from parallelized function
Returns:
Gathered outputs
"""
# Measure gathering time for performance monitoring
start_time = time.time()
# Gather outputs
gathered = jax.tree_map(lambda x: x[0], outputs)
# Log gathering time for large outputs
if isinstance(outputs, dict) and 'logits' in outputs:
gather_time = time.time() - start_time
if gather_time > 0.1: # Only log if significant
print(f"Output gathering completed in {gather_time:.2f} seconds")
return gathered
def all_reduce(self, values: Any, reduce_type: str = "mean") -> Any:
"""
Perform all-reduce operation across devices with optimized communication.
Args:
values: Values to reduce
reduce_type: Type of reduction ("mean" or "sum")
Returns:
Reduced values
"""
if reduce_type == "mean":
return jax.lax.pmean(values, axis_name="batch")
elif reduce_type == "sum":
return jax.lax.psum(values, axis_name="batch")
else:
raise ValueError(f"Unsupported reduce type: {reduce_type}")
def replicate(self, values: Any) -> Any:
"""
Replicate values across devices with optimized memory usage.
Args:
values: Values to replicate
Returns:
Replicated values
"""
# Use broadcast instead of replication for better performance
return jax.tree_map(lambda x: jnp.broadcast_to(x, (self.dp_size,) + x.shape), values)
def split_batch(self, batch: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
"""
Split batch across devices with optimized memory layout for TPU.
Args:
batch: Batch of data
Returns:
Split batch
"""
# Compute batch size per device
batch_size = batch["input_ids"].shape[0]
per_device_batch_size = batch_size // self.dp_size
# Check if batch size is divisible by number of devices
if batch_size % self.dp_size != 0:
print(f"Warning: Batch size {batch_size} is not divisible by number of data parallel devices {self.dp_size}")
# Adjust batch size to be divisible
new_batch_size = per_device_batch_size * self.dp_size
batch = jax.tree_map(lambda x: x[:new_batch_size], batch)
# Split batch with optimized memory layout
return jax.tree_map(
lambda x: x.reshape(self.dp_size, per_device_batch_size, *x.shape[1:]),
batch
)