Spaces:
Build error
Build error
# Copyright 2024 X.AI Corp. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import bisect | |
import functools | |
import logging | |
import math | |
import re | |
from dataclasses import dataclass | |
from typing import Any, Callable, NamedTuple, Optional, Tuple | |
import haiku as hk | |
import jax | |
import jax.experimental.pjit as pjit | |
import jax.numpy as jnp | |
import numpy as np | |
import sentencepiece | |
from jax.experimental import mesh_utils | |
from jax.sharding import PartitionSpec as P | |
from jax.typing import ArrayLike | |
import checkpoint as xai_checkpoint | |
from model import ( | |
LanguageModelConfig, | |
LanguageModelOutput, | |
TrainingState, | |
apply_rules, | |
Memory, | |
KVMemory, | |
) | |
logger = logging.getLogger(__name__) | |
rank_logger = logging.getLogger("rank") | |
TOP_K = 8 | |
class SampleSettings(NamedTuple): | |
temperature: ArrayLike | |
nucleus_p: ArrayLike | |
mask: ArrayLike | |
# Whether a given batch element is actively used. [B] | |
active: ArrayLike | |
class SampleOutput(NamedTuple): | |
token_id: ArrayLike | |
prob: ArrayLike | |
top_k_token_ids: ArrayLike | |
top_k_probs: ArrayLike | |
def insert_slice(memory: Memory, slice, length, i): | |
slice = Memory( | |
layers=[ | |
KVMemory(layer.k, layer.v, step=jnp.array([length])) | |
for layer in slice.layers | |
], | |
) | |
return jax.tree_map(lambda m, u: jax.lax.dynamic_update_index_in_dim(m, u[0], i, axis=0), | |
memory, slice) | |
def pad_to_size(x, size): | |
if x.shape[0] > size: | |
# Left truncate if the context is too long. | |
x = x[-size:] | |
return np.pad(x, [0, size - x.shape[0]], mode="constant", constant_values=0) | |
def top_p_filter(logits: jax.Array, top_p: jax.Array) -> jax.Array: | |
"""Performs nucleus filtering on logits.""" | |
assert logits.ndim == top_p.ndim, f"Expected {logits.ndim} equal {top_p.ndim}" | |
sorted_logits = jax.lax.sort(logits, is_stable=False) | |
sorted_probs = jax.nn.softmax(sorted_logits) | |
threshold_idx = jnp.argmax(jnp.cumsum(sorted_probs, -1) >= 1 - top_p, axis=-1) | |
threshold_largest_logits = jnp.take_along_axis( | |
sorted_logits, threshold_idx[..., jnp.newaxis], axis=-1 | |
) | |
assert threshold_largest_logits.shape == logits.shape[:-1] + (1,) | |
mask = logits >= threshold_largest_logits | |
# Set unused logits to -inf. | |
logits = jnp.where(mask, logits, -1e10) | |
return logits | |
def sample_token( | |
rngs: jax.random.PRNGKey, | |
lm_outputs: LanguageModelOutput, | |
settings: SampleSettings, | |
) -> SampleOutput: | |
# Expand the settings shape to match the logit shape. | |
settings = SampleSettings( | |
temperature=jnp.expand_dims(settings.temperature, (1, 2)), # Input [B], output [B, 1, 1]. | |
nucleus_p=jnp.expand_dims(settings.nucleus_p, (1, 2)), # Input [B], output [B, 1, 1]. | |
mask=jnp.expand_dims(settings.mask, 1), # Input [B, V], output [B, 1, V]. | |
active=settings.active, # [B]. | |
) | |
logits = lm_outputs.logits / settings.temperature.astype(lm_outputs.logits.dtype) | |
# Mask out all disallowed tokens by assigning them a near-zero probability. | |
logits = jnp.where(settings.mask, logits, -1e10) | |
# Mask out all tokens that don't fall into the p-th percentile. | |
logits = top_p_filter(logits, settings.nucleus_p.astype(logits.dtype)) | |
new_token = jax.vmap(jax.random.categorical)(rngs, logits) | |
probabilities = jax.nn.softmax(logits) | |
token_prob = jnp.take_along_axis(probabilities, jnp.expand_dims(new_token, 1), axis=2) | |
token_prob = jnp.squeeze(token_prob, 1) | |
# Gather the top-k tokens and probabilities. | |
top_k_probs, top_k_token_ids = jax.lax.top_k(probabilities, TOP_K) | |
top_k_probs = jnp.squeeze(top_k_probs, 1) | |
top_k_token_ids = jnp.squeeze(top_k_token_ids, 1) | |
return SampleOutput( | |
new_token, | |
token_prob, | |
top_k_token_ids, | |
top_k_probs, | |
) | |
class ModelRunner: | |
model: LanguageModelConfig | |
bs_per_device: float = 2.0 | |
load_rename_rules: Optional[list[tuple[str, str]]] = None | |
load_exclude_rules: Optional[list[str]] = None | |
rng_seed: int = 42 # Initial rng seed. | |
transform_forward: bool = False | |
checkpoint_path: str = "" | |
def make_forward_fn(self, mesh: Any): | |
def forward(tokens): | |
out = self.model.make(mesh=mesh)(tokens) | |
return out, None | |
if self.transform_forward: | |
forward = hk.transform(forward) | |
return forward | |
def initialize( | |
self, | |
init_data, | |
local_mesh_config: tuple[int, int], | |
between_hosts_config: tuple[int, int], | |
): | |
num_replicas = math.prod(between_hosts_config) | |
self.model.initialize() | |
self.model.fprop_dtype = jnp.bfloat16 | |
num_local_gpus = len(jax.local_devices()) | |
# Calculate the global batch size from the local batch size. | |
self.batch_size = int(self.bs_per_device * num_local_gpus * num_replicas) | |
# Calculate the batch size per host from the global batch size. | |
self.local_batch_size = self.batch_size // jax.process_count() | |
self.local_mesh_config = local_mesh_config | |
self.between_hosts_config = between_hosts_config | |
rank_logger.info( | |
f"Initializing mesh for {self.local_mesh_config=} {self.between_hosts_config=}..." | |
) | |
self.mesh = make_mesh(self.local_mesh_config, self.between_hosts_config) | |
self.forward = self.make_forward_fn(mesh=self.mesh) | |
self.logits_fn = hk.transform(lambda tokens: self.forward(tokens)[0]) | |
self.eval_forward = self.make_forward_fn(mesh=self.mesh) | |
self.logits_eval_fn = hk.transform(lambda tokens: self.eval_forward(tokens)[0]) | |
if self.transform_forward: | |
self.state_sharding = self.get_state_sharding(init_data) | |
rank_logger.info(f"State sharding type: {type(self.state_sharding)}") | |
self.init_fn = pjit.pjit(self.init, out_shardings=self.state_sharding) | |
def init(self, rng: jax.Array, data) -> TrainingState: | |
assert self.transform_forward | |
rng, init_rng = jax.random.split(rng) | |
params = self.forward.init(init_rng, data["inputs"]) | |
return TrainingState(params=params) | |
def get_state_sharding(self, init_data): | |
assert self.transform_forward | |
rng = jax.random.PRNGKey(self.rng_seed) | |
rank_logger.info(f"partition rules: {self.model.partition_rules}") | |
with self.mesh: | |
shapes = jax.eval_shape(self.init, rng, init_data) | |
sharding = jax.tree_util.tree_map_with_path( | |
apply_rules(self.model.partition_rules()), | |
shapes, | |
) | |
return sharding | |
def load_or_init( | |
self, | |
init_data: Any, | |
from_checkpoint: bool = True, | |
init_fn: Optional[Callable] = None, | |
): | |
rng = jax.random.PRNGKey(self.rng_seed) | |
if not self.checkpoint_path or not from_checkpoint: | |
rank_logger.info("Initializing model...") | |
with self.mesh: | |
if init_fn is not None: | |
state = init_fn(rng, init_data) | |
else: | |
assert self.transform_forward | |
state = self.init_fn(rng, init_data) | |
rank_logger.info("Model state is newly initialized.") | |
else: | |
with self.mesh: | |
if init_fn: | |
state_shapes = jax.eval_shape(init_fn, rng, init_data) | |
else: | |
assert self.transform_forward | |
state_shapes = jax.eval_shape(self.init_fn, rng, init_data) | |
init_state = None | |
state = xai_checkpoint.restore( | |
checkpoint_path=self.checkpoint_path, | |
state_shapes=state_shapes, | |
mesh=self.mesh, | |
between_hosts_config=self.between_hosts_config, | |
state_sharding=self.state_sharding, | |
init_state=init_state, | |
params_only=True, | |
) | |
del init_state | |
return state | |
class Request: | |
prompt: str | |
temperature: float | |
nucleus_p: float | |
rng_seed: int | |
max_len: int | |
class InferenceRunner: | |
name: str | |
runner: Any | |
load: str | |
tokenizer_path: str = "/tmp/xai_data/tokenizer.model" | |
local_mesh_config: Tuple[int, int] = (1, 1) | |
between_hosts_config: Tuple[int, int] = (1, 1) | |
pad_sizes: tuple[int] = (1024,) | |
def get_pad_bucket(self, size): | |
i = bisect.bisect_left(self.pad_sizes, size) | |
return self.pad_sizes[min(i, len(self.pad_sizes) - 1)] | |
def initialize(self): | |
runner = self.runner | |
self.runner.transform_forward = True | |
dummy_data = dict( | |
inputs=np.zeros((1, 256), dtype=np.int32), | |
targets=np.zeros((1, 256), dtype=np.int32), | |
) | |
runner.initialize( | |
dummy_data, | |
local_mesh_config=self.local_mesh_config, | |
between_hosts_config=self.between_hosts_config, | |
) | |
self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=self.tokenizer_path) | |
max_len = runner.model.sequence_len | |
self.vocab_size = self.runner.model.vocab_size | |
params = runner.load_or_init(dummy_data) | |
self.params = params | |
def pad_to_max_len(x): | |
if len(x.shape) > 1: | |
pad_width = max_len - x.shape[1] | |
return jnp.pad(x, [(0, 0), (0, pad_width), (0, 0), (0, 0)]) | |
else: | |
return x | |
def lm(): | |
return runner.model.make(mesh=runner.mesh) | |
def hk_forward( | |
tokens, | |
memory=None, | |
length=None, | |
active=None, | |
) -> LanguageModelOutput: | |
if memory is not None: | |
assert active is not None | |
layers = [] | |
for l in memory.layers: | |
# Reset steps to 0 for inactive requests to avoid unnecessary computations. | |
step = jnp.where(active, l.step, jnp.zeros_like(l.step)) | |
layers.append(l._replace(step=step)) | |
memory = memory._replace(layers=layers) | |
return lm()(tokens, memory, length=length) | |
def hk_sample_step(rngs, last_output: SampleOutput, memory, settings): | |
rngs, rngs_ = jax.vmap(jax.random.split, out_axes=1)(rngs) | |
lm_outputs = hk_forward(last_output.token_id, memory=memory, active=settings.active) | |
sample_result = sample_token(rngs_, lm_outputs, settings) | |
return rngs, sample_result, lm_outputs.model_state | |
def hk_new_memory(batch_size, sequence_len): | |
return lm().init_memory(batch_size, sequence_len) | |
def hk_prefill_memory( | |
rngs, | |
memory, | |
settings, | |
last_output, | |
prompt, | |
length, | |
rng_seed, | |
new_settings, | |
i, | |
): | |
rng = jax.random.PRNGKey(seed=rng_seed) | |
rng, rng_ = jax.random.split(rng) | |
# Allocate new memory for this sample. The memory length is equal to the length of the | |
# prompt. | |
slice = hk_new_memory(1, prompt.shape[0]) | |
# Move the settings for this individual batch entry into the joint settings tensor. | |
settings = jax.tree_map( | |
lambda o, v: jax.lax.dynamic_update_index_in_dim(o, v, i, axis=0), | |
settings, | |
new_settings, | |
) | |
# Get the settings for the batch entry from the joint settings tensor. | |
settings_slice = jax.tree_map(lambda t: jnp.expand_dims(t[i], axis=0), settings) | |
# Process the first n-1 tokens of the prompt. | |
lm_outputs = hk_forward( | |
jnp.expand_dims(prompt, 0), | |
memory=slice, | |
length=jnp.expand_dims(length, 0), | |
active=settings_slice.active, | |
) | |
# The forward pass doesn't correctly set the `step` counter inside the memory. Manually | |
# override it so `hk_forward` uses the correct context length in the next call. | |
slice = lm_outputs.model_state | |
slice = slice._replace( | |
layers=[l._replace(step=jnp.array([length])) for l in slice.layers] | |
) | |
# Sample the actual output token. | |
rng_ = jnp.expand_dims(rng_, 0) | |
new_output = sample_token(rng_, lm_outputs, settings_slice) | |
# Update the KV cache/memory. | |
slice = jax.tree_map(pad_to_max_len, slice) | |
memory = insert_slice(memory, slice, length, i) | |
rng = jnp.expand_dims(rng, 0) | |
rngs = jax.lax.dynamic_update_index_in_dim(rngs, rng, i, axis=0) | |
# Move the network outputs for this batch entry into the joint output tensor. | |
last_output = jax.tree_util.tree_map( | |
lambda last, new: jax.lax.dynamic_update_index_in_dim(last, new, i, axis=0), | |
last_output, | |
new_output, | |
) | |
return rngs, last_output, memory, settings | |
sample_step_ = hk.without_apply_rng(hk.transform(hk_sample_step)) | |
prefill_memory_ = hk.without_apply_rng(hk.transform(hk_prefill_memory)) | |
new_memory_ = hk.without_apply_rng(hk.transform(hk_new_memory)) | |
forward_ = hk.without_apply_rng(hk.transform(hk_forward)) | |
rng = jax.random.PRNGKey(42) | |
dummy_tokens = jnp.zeros((1, max_len), jnp.int32) | |
with runner.mesh: | |
shapes = jax.eval_shape(forward_.init, rng, dummy_tokens) | |
self.params_sharding = jax.tree_util.tree_map_with_path( | |
apply_rules(runner.model.partition_rules()), | |
shapes, | |
) | |
ds = P("data") | |
ms = runner.model.model.get_memory_sharding() | |
self.sample_step = pjit.pjit( | |
sample_step_.apply, | |
in_shardings=(self.params_sharding, None, ds, ms, None), | |
out_shardings=(None, ds, ms), | |
donate_argnums=3, | |
) | |
self.prefill_memory = pjit.pjit( | |
functools.partial(prefill_memory_.apply), | |
in_shardings=( | |
self.params_sharding, | |
None, | |
ms, | |
None, | |
ds, | |
None, | |
None, | |
None, | |
None, | |
None, | |
), | |
out_shardings=(None, ds, ms, None), | |
donate_argnums=(2,), | |
) | |
self.new_memory = pjit.pjit( | |
new_memory_.apply, | |
static_argnums=(1, 2), | |
out_shardings=ms, | |
) | |
def run(self): | |
"""Generator that accepts prompts.""" | |
runner = self.runner | |
mesh = runner.mesh | |
max_len = runner.model.sequence_len | |
batch_size = runner.batch_size | |
params = self.params | |
rngs = jax.random.split(jax.random.PRNGKey(1), batch_size) | |
with mesh: | |
memory = self.new_memory(params, batch_size, max_len) | |
settings = SampleSettings( | |
temperature=np.zeros((batch_size,), dtype=np.float32), | |
nucleus_p=np.zeros((batch_size,), dtype=np.float32), | |
mask=np.ones((batch_size, self.vocab_size), dtype=np.int32), | |
active=np.zeros((batch_size), dtype=np.int32), | |
) | |
last_output = SampleOutput( | |
token_id=np.zeros((batch_size, 1), dtype=np.int32), | |
prob=np.zeros((batch_size, 1), dtype=jnp.bfloat16), | |
top_k_token_ids=np.zeros((batch_size, TOP_K), dtype=np.int32), | |
top_k_probs=np.zeros((batch_size, TOP_K), dtype=jnp.bfloat16), | |
) | |
prompt = np.array([300, 400, 500, 600, 600, 700, 800]) | |
new_settings = SampleSettings( | |
temperature=np.float32(1), | |
nucleus_p=np.float32(1), | |
mask=np.ones((self.vocab_size,), dtype=np.int32), | |
active=np.zeros((), dtype=np.int32), | |
) | |
rng_seed = np.uint64(1) | |
for size in self.pad_sizes: | |
if size > runner.model.sequence_len: | |
break | |
logger.info("Precompile {}".format(size)) | |
prompt_len = len(prompt) | |
prompt = pad_to_size(prompt, size) | |
rngs, last_output, memory, settings = self.prefill_memory( | |
params, | |
rngs, | |
memory, | |
settings, | |
last_output, | |
prompt, | |
prompt_len, | |
rng_seed, | |
new_settings, | |
0, | |
) | |
with runner.mesh: | |
logger.info("Compiling...") | |
rngs, last_output, memory = self.sample_step( | |
params, rngs, last_output, memory, settings | |
) | |
logger.info("Done compiling.") | |
all_tokens = [] | |
free_slots = list(range(batch_size)) | |
requests = [None] * batch_size | |
first_output = [None] * batch_size | |
jax.tree_map(lambda x: x.copy_to_host_async(), last_output) | |
prev_token = last_output | |
step = 0 | |
total_num_tokens = 0 | |
total_num_sequences = 0 | |
with mesh: | |
while True: | |
while free_slots: | |
request: Optional[Request] = yield | |
tokens = self.tokenizer.encode(request.prompt) | |
temperature = request.temperature | |
nucleus_p = request.nucleus_p | |
rng_seed = request.rng_seed | |
i = free_slots.pop() | |
prompt = np.array(tokens, dtype=np.int32) | |
prompt_len = len(prompt) | |
prompt = pad_to_size(prompt, self.get_pad_bucket(prompt.shape[0])) | |
# All tokens are allowed. | |
mask = np.ones((self.vocab_size,), dtype=np.int32) | |
new_settings = SampleSettings( | |
temperature=np.float32(temperature), | |
nucleus_p=np.float32(nucleus_p), | |
mask=mask, | |
active=np.ones((), dtype=np.int32), | |
) | |
rng_seed = np.uint64(rng_seed) | |
rngs, last_output, memory, settings = self.prefill_memory( | |
params, | |
rngs, | |
memory, | |
settings, | |
last_output, | |
prompt, | |
prompt_len, | |
rng_seed, | |
new_settings, | |
i, | |
) | |
jax.tree_map(lambda x: x.copy_to_host_async(), last_output) | |
first_output[i] = last_output | |
requests[i] = request | |
total_num_sequences += 1 | |
rngs, last_output, memory = self.sample_step( | |
params, rngs, last_output, memory, settings | |
) | |
total_num_tokens += batch_size - len(free_slots) | |
# prev_token should already be on the host. | |
prev_token = jax.tree_map(np.array, prev_token) | |
for i in range(batch_size): | |
if requests[i] is not None: | |
if first_output[i] is not None: | |
first_output_i = jax.tree_map(np.array, first_output[i]) | |
all_tokens.append(int(first_output_i.token_id[i][0])) | |
first_output[i] = None | |
continue | |
all_tokens.append(int(prev_token.token_id[i][0])) | |
cont = len(all_tokens) < requests[i].max_len | |
if not cont: | |
output_str = self.tokenizer.decode(all_tokens) | |
requests[i] = None | |
free_slots.append(i) | |
all_tokens = [] | |
settings = settings._replace(active=settings.active.at[i].set(0)) | |
yield output_str | |
jax.tree_map(lambda x: x.copy_to_host_async(), last_output) | |
prev_token = last_output | |
step += 1 | |
def make_mesh( | |
local_mesh_config: tuple[int, ...], between_hosts_config: tuple[int, ...] | |
) -> jax.sharding.Mesh: | |
assert len(local_mesh_config) == 2 | |
assert len(between_hosts_config) == 2 | |
rank_logger.info("Detected %s devices in mesh", jax.device_count()) | |
device_mesh = mesh_utils.create_hybrid_device_mesh( | |
local_mesh_config, | |
between_hosts_config, | |
devices=jax.devices(), | |
process_is_granule=True, | |
) | |
rank_logger.debug(re.sub("\n+", "\n", f"Job device mesh is:\n{device_mesh}")) | |
return jax.sharding.Mesh(device_mesh, ("data", "model")) | |
def sample_from_model(server, prompt, max_len, temperature): | |
next(server) | |
inp = Request( | |
prompt=prompt, | |
temperature=temperature, | |
nucleus_p=1.0, | |
rng_seed=42, | |
max_len=max_len, | |
) | |
return server.send(inp) | |