Ahma-7B / EasyLM /models /gptj /gptj_serve.py
aapot
Add training codes
a85f909
import pprint
from functools import partial
import numpy as np
import mlxu
import jax
import jax.numpy as jnp
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as PS
import flax
from flax import linen as nn
from flax.jax_utils import prefetch_to_device
from flax.training.train_state import TrainState
import optax
from transformers import GenerationConfig, FlaxLogitsProcessorList
from EasyLM.checkpoint import StreamingCheckpointer
from EasyLM.serving import LMServer
from EasyLM.jax_utils import (
JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules, tree_apply,
set_random_seed, get_float_dtype_by_name, make_shard_and_gather_fns,
with_sharding_constraint, FlaxTemperatureLogitsWarper
)
from EasyLM.models.gptj.gptj_model import (
GPTJConfig, FlaxGPTJForCausalLMModule, FlaxGPTJForCausalLM
)
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
seed=42,
initialize_jax_distributed=False,
mesh_dim='1,-1,1',
dtype='bf16',
input_length=1024,
seq_length=2048,
top_k=50,
top_p=1.0,
do_sample=True,
num_beams=1,
add_bos_token=False,
load_gptj_config='',
load_checkpoint='',
tokenizer=GPTJConfig.get_tokenizer_config(),
lm_server=LMServer.get_default_config(),
jax_distributed=JaxDistributedConfig.get_default_config(),
)
def main(argv):
JaxDistributedConfig.initialize(FLAGS.jax_distributed)
set_random_seed(FLAGS.seed)
prefix_tokenizer = GPTJConfig.get_tokenizer(
FLAGS.tokenizer, truncation_side='left', padding_side='left'
)
tokenizer = GPTJConfig.get_tokenizer(
FLAGS.tokenizer, truncation_side='right', padding_side='right'
)
with jax.default_device(jax.devices("cpu")[0]):
gptj_config = GPTJConfig.load_config(FLAGS.load_gptj_config)
load_type, load_path = FLAGS.load_checkpoint.split('::', 1)
if load_type == 'huggingface':
params = gptj_config.load_pretrained(load_path)
else:
_, params = StreamingCheckpointer.load_trainstate_checkpoint(
FLAGS.load_checkpoint, disallow_trainstate=True
)
hf_model = FlaxGPTJForCausalLM(
gptj_config,
input_shape=(1, FLAGS.seq_length),
seed=FLAGS.seed,
_do_init=False
)
model_ps = match_partition_rules(
GPTJConfig.get_partition_rules(), params
)
shard_fns, _ = make_shard_and_gather_fns(
model_ps, get_float_dtype_by_name(FLAGS.dtype)
)
@partial(
pjit,
in_shardings=(model_ps, PS(), PS()),
out_shardings=(PS(), PS(), PS())
)
def forward_loglikelihood(params, rng, batch):
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
rng_generator = JaxRNG(rng)
input_tokens = batch['input_tokens']
output_tokens = batch['output_tokens']
input_mask = batch['input_mask']
output_mask = batch['output_mask']
logits = hf_model.module.apply(
params, input_tokens, attention_mask=input_mask,
deterministic=True, rngs=rng_generator(gptj_config.rng_keys()),
).logits
if gptj_config.n_real_tokens is not None:
logits = logits.at[:, :, gptj_config.n_real_tokens:].set(-1e8)
loglikelihood = -optax.softmax_cross_entropy_with_integer_labels(
logits, output_tokens
)
loglikelihood = jnp.sum(loglikelihood * output_mask, axis=-1)
match_count = jnp.sum(
(jnp.argmax(logits, axis=-1) == output_tokens) * output_mask,
axis=-1
)
total = jnp.sum(output_mask, axis=-1)
is_greedy = match_count == total
return loglikelihood, is_greedy, rng_generator()
@partial(
pjit,
in_shardings=(model_ps, PS(), PS(), PS()),
out_shardings=(PS(), PS())
)
def forward_generate(params, rng, batch, temperature):
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
rng_generator = JaxRNG(rng)
output = hf_model.generate(
batch['input_tokens'],
attention_mask=batch['attention_mask'],
params=params['params'],
prng_key=rng_generator(),
logits_processor=FlaxLogitsProcessorList(
[FlaxTemperatureLogitsWarper(temperature)]
),
generation_config=GenerationConfig(
max_new_tokens=FLAGS.seq_length - FLAGS.input_length,
pad_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
do_sample=FLAGS.do_sample,
num_beams=FLAGS.num_beams,
top_k=FLAGS.top_k,
top_p=FLAGS.top_p,
)
).sequences[:, batch['input_tokens'].shape[1]:]
return output, rng_generator()
@partial(
pjit,
in_shardings=(model_ps, PS(), PS()),
out_shardings=(PS(), PS())
)
def forward_greedy_generate(params, rng, batch):
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
rng_generator = JaxRNG(rng)
output = hf_model.generate(
batch['input_tokens'],
attention_mask=batch['attention_mask'],
params=params['params'],
prng_key=rng_generator(),
generation_config=GenerationConfig(
max_new_tokens=FLAGS.seq_length - FLAGS.input_length,
pad_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
do_sample=False,
num_beams=1,
)
).sequences[:, batch['input_tokens'].shape[1]:]
return output, rng_generator()
mesh = GPTJConfig.get_jax_mesh(FLAGS.mesh_dim)
with mesh:
params = tree_apply(shard_fns, params)
sharded_rng = next_rng()
class ModelServer(LMServer):
@staticmethod
def loglikelihood(prefix_text, text):
nonlocal sharded_rng
prefix = prefix_tokenizer(
prefix_text,
padding='max_length',
truncation=True,
max_length=FLAGS.input_length,
return_tensors='np',
)
inputs = tokenizer(
text,
padding='max_length',
truncation=True,
max_length=FLAGS.seq_length - FLAGS.input_length,
return_tensors='np',
)
output_tokens = np.concatenate([prefix.input_ids, inputs.input_ids], axis=1)
bos_tokens = np.full(
(output_tokens.shape[0], 1), tokenizer.bos_token_id, dtype=np.int32
)
input_tokens = np.concatenate([bos_tokens, output_tokens[:, :-1]], axis=-1)
input_mask = np.concatenate(
[prefix.attention_mask, inputs.attention_mask], axis=1
)
if FLAGS.add_bos_token:
bos_mask = np.ones_like(input_mask[:, :1])
else:
bos_mask = np.zeros_like(input_mask[:, :1])
input_mask = np.concatenate([bos_mask, input_mask[:, :-1]], axis=1)
output_mask = np.concatenate(
[np.zeros_like(prefix.attention_mask), inputs.attention_mask], axis=1
)
batch = dict(
input_tokens=input_tokens,
output_tokens=output_tokens,
input_mask=input_mask,
output_mask=output_mask,
)
with mesh:
loglikelihood, is_greedy, sharded_rng = forward_loglikelihood(
params, sharded_rng, batch
)
loglikelihood, is_greedy = jax.device_get((loglikelihood, is_greedy))
return loglikelihood, is_greedy
@staticmethod
def loglikelihood_rolling(text):
nonlocal sharded_rng
inputs = tokenizer(
text,
padding='longest',
truncation=False,
max_length=np.iinfo(np.int32).max,
return_tensors='np',
)
batch_size = inputs.input_ids.shape[0]
output_tokens = inputs.input_ids
attention_mask = inputs.attention_mask
if output_tokens.shape[1] < FLAGS.seq_length:
padding_length = FLAGS.seq_length - output_tokens.shape[1]
pad_tokens = np.full(
(batch_size, padding_length), tokenizer.pad_token_id, dtype=np.int32
)
output_tokens = np.concatenate([output_tokens, pad_tokens], axis=-1)
pad_mask = np.zeros(
(batch_size, padding_length), dtype=inputs.attention_mask.dtype
)
attention_mask = np.concatenate([attention_mask, pad_mask], axis=-1)
bos_tokens = np.full(
(batch_size, 1), tokenizer.bos_token_id, dtype=np.int32
)
input_tokens = np.concatenate([bos_tokens, output_tokens[:, :-1]], axis=-1)
bos_mask = np.ones((batch_size, 1), dtype=inputs.attention_mask.dtype)
total_seq_length = output_tokens.shape[1]
total_loglikelihood = 0.0
total_is_greedy = True
# Sliding window
for i in range(0, total_seq_length, FLAGS.seq_length):
# Last window
if i + FLAGS.seq_length > total_seq_length:
last_output_mask = np.copy(attention_mask[:, -FLAGS.seq_length:])
last_output_mask[:, :i - total_seq_length] = 0.0
batch = dict(
input_tokens=input_tokens[:, -FLAGS.seq_length:],
output_tokens=output_tokens[:, -FLAGS.seq_length:],
input_mask=attention_mask[:, -FLAGS.seq_length:],
output_mask=last_output_mask,
)
# Normal window
else:
batch = dict(
input_tokens=input_tokens[:, i:i + FLAGS.seq_length],
output_tokens=output_tokens[:, i:i + FLAGS.seq_length],
input_mask=attention_mask[:, i:i + FLAGS.seq_length],
output_mask=attention_mask[:, i:i + FLAGS.seq_length],
)
with mesh:
loglikelihood, is_greedy, sharded_rng = forward_loglikelihood(
params, sharded_rng, batch
)
loglikelihood, is_greedy = jax.device_get((loglikelihood, is_greedy))
total_loglikelihood += loglikelihood
total_is_greedy = np.logical_and(is_greedy, total_is_greedy)
return total_loglikelihood, total_is_greedy
@staticmethod
def generate(text, temperature):
nonlocal sharded_rng
inputs = prefix_tokenizer(
text,
padding='max_length',
truncation=True,
max_length=FLAGS.input_length,
return_tensors='np',
)
input_tokens = inputs.input_ids
input_mask = inputs.attention_mask
if FLAGS.add_bos_token:
input_tokens[:, 0] = tokenizer.bos_token_id
input_mask[:, 0] = 1
batch = dict(
input_tokens=input_tokens,
attention_mask=input_mask,
)
with mesh:
output, sharded_rng = forward_generate(
params, sharded_rng, batch, temperature
)
output = jax.device_get(output)
output_text = []
for text in list(tokenizer.batch_decode(output)):
if tokenizer.eos_token in text:
text = text.split(tokenizer.eos_token, maxsplit=1)[0]
output_text.append(text)
return output_text
@staticmethod
def greedy_until(prefix_text, until, max_length):
nonlocal sharded_rng
all_outputs = []
for pf, ut in zip(prefix_text, until):
if isinstance(ut, str):
ut = [ut]
total_length = 0
total_generated = ''
while total_length < max_length:
pf_tokens = tokenizer(
pf,
padding=False,
truncation=False,
max_length=np.iinfo(np.int32).max,
return_tensors='np',
)
input_tokens = pf_tokens.input_ids
attention_mask = pf_tokens.attention_mask
if input_tokens.shape[1] < FLAGS.input_length:
extra = FLAGS.input_length - input_tokens.shape[1]
pad_tokens = np.full(
(1, extra), tokenizer.pad_token_id, dtype=np.int32
)
input_tokens = np.concatenate(
[pad_tokens, input_tokens], axis=1
)
pad_attention = np.zeros((1, extra), dtype=attention_mask.dtype)
attention_mask = np.concatenate(
[pad_attention, attention_mask], axis=1
)
elif input_tokens.shape[1] > FLAGS.input_length:
input_tokens = input_tokens[:, -FLAGS.input_length:]
attention_mask = attention_mask[:, -FLAGS.input_length:]
if FLAGS.add_bos_token:
input_tokens[:, 0] = tokenizer.bos_token_id
attention_mask[:, 0] = 1
batch = dict(input_tokens=input_tokens, attention_mask=attention_mask)
with mesh:
output, sharded_rng = forward_greedy_generate(
params, sharded_rng, batch
)
output = jax.device_get(output)
total_length += output.shape[1]
output_text = tokenizer.batch_decode(output)[0]
total_generated = total_generated + output_text
pf = pf + output_text
done = False
for s in ut:
if s in total_generated:
total_generated = total_generated.split(s, maxsplit=1)[0]
done = True
if done:
break
all_outputs.append(total_generated)
return all_outputs
server = ModelServer(FLAGS.lm_server)
server.run()
if __name__ == "__main__":
mlxu.run(main)