ahassoun's picture
Upload 3018 files
ee6e328
raw
history blame contribute delete
No virus
11.8 kB
import json
import os
from dataclasses import dataclass
from functools import partial
from typing import Callable
import flax.linen as nn
import jax
import jax.numpy as jnp
import joblib
import optax
import wandb
from flax import jax_utils, struct, traverse_util
from flax.serialization import from_bytes, to_bytes
from flax.training import train_state
from flax.training.common_utils import shard
from tqdm.auto import tqdm
from transformers import BigBirdConfig, FlaxBigBirdForQuestionAnswering
from transformers.models.big_bird.modeling_flax_big_bird import FlaxBigBirdForQuestionAnsweringModule
class FlaxBigBirdForNaturalQuestionsModule(FlaxBigBirdForQuestionAnsweringModule):
"""
BigBirdForQuestionAnswering with CLS Head over the top for predicting category
This way we can load its weights with FlaxBigBirdForQuestionAnswering
"""
config: BigBirdConfig
dtype: jnp.dtype = jnp.float32
add_pooling_layer: bool = True
def setup(self):
super().setup()
self.cls = nn.Dense(5, dtype=self.dtype)
def __call__(self, *args, **kwargs):
outputs = super().__call__(*args, **kwargs)
cls_out = self.cls(outputs[2])
return outputs[:2] + (cls_out,)
class FlaxBigBirdForNaturalQuestions(FlaxBigBirdForQuestionAnswering):
module_class = FlaxBigBirdForNaturalQuestionsModule
def calculate_loss_for_nq(start_logits, start_labels, end_logits, end_labels, pooled_logits, pooler_labels):
def cross_entropy(logits, labels, reduction=None):
"""
Args:
logits: bsz, seqlen, vocab_size
labels: bsz, seqlen
"""
vocab_size = logits.shape[-1]
labels = (labels[..., None] == jnp.arange(vocab_size)[None]).astype("f4")
logits = jax.nn.log_softmax(logits, axis=-1)
loss = -jnp.sum(labels * logits, axis=-1)
if reduction is not None:
loss = reduction(loss)
return loss
cross_entropy = partial(cross_entropy, reduction=jnp.mean)
start_loss = cross_entropy(start_logits, start_labels)
end_loss = cross_entropy(end_logits, end_labels)
pooled_loss = cross_entropy(pooled_logits, pooler_labels)
return (start_loss + end_loss + pooled_loss) / 3
@dataclass
class Args:
model_id: str = "google/bigbird-roberta-base"
logging_steps: int = 3000
save_steps: int = 10500
block_size: int = 128
num_random_blocks: int = 3
batch_size_per_device: int = 1
max_epochs: int = 5
# tx_args
lr: float = 3e-5
init_lr: float = 0.0
warmup_steps: int = 20000
weight_decay: float = 0.0095
save_dir: str = "bigbird-roberta-natural-questions"
base_dir: str = "training-expt"
tr_data_path: str = "data/nq-training.jsonl"
val_data_path: str = "data/nq-validation.jsonl"
def __post_init__(self):
os.makedirs(self.base_dir, exist_ok=True)
self.save_dir = os.path.join(self.base_dir, self.save_dir)
self.batch_size = self.batch_size_per_device * jax.device_count()
@dataclass
class DataCollator:
pad_id: int
max_length: int = 4096 # no dynamic padding on TPUs
def __call__(self, batch):
batch = self.collate_fn(batch)
batch = jax.tree_util.tree_map(shard, batch)
return batch
def collate_fn(self, features):
input_ids, attention_mask = self.fetch_inputs(features["input_ids"])
batch = {
"input_ids": jnp.array(input_ids, dtype=jnp.int32),
"attention_mask": jnp.array(attention_mask, dtype=jnp.int32),
"start_labels": jnp.array(features["start_token"], dtype=jnp.int32),
"end_labels": jnp.array(features["end_token"], dtype=jnp.int32),
"pooled_labels": jnp.array(features["category"], dtype=jnp.int32),
}
return batch
def fetch_inputs(self, input_ids: list):
inputs = [self._fetch_inputs(ids) for ids in input_ids]
return zip(*inputs)
def _fetch_inputs(self, input_ids: list):
attention_mask = [1 for _ in range(len(input_ids))]
while len(input_ids) < self.max_length:
input_ids.append(self.pad_id)
attention_mask.append(0)
return input_ids, attention_mask
def get_batched_dataset(dataset, batch_size, seed=None):
if seed is not None:
dataset = dataset.shuffle(seed=seed)
for i in range(len(dataset) // batch_size):
batch = dataset[i * batch_size : (i + 1) * batch_size]
yield dict(batch)
@partial(jax.pmap, axis_name="batch")
def train_step(state, drp_rng, **model_inputs):
def loss_fn(params):
start_labels = model_inputs.pop("start_labels")
end_labels = model_inputs.pop("end_labels")
pooled_labels = model_inputs.pop("pooled_labels")
outputs = state.apply_fn(**model_inputs, params=params, dropout_rng=drp_rng, train=True)
start_logits, end_logits, pooled_logits = outputs
return state.loss_fn(
start_logits,
start_labels,
end_logits,
end_labels,
pooled_logits,
pooled_labels,
)
drp_rng, new_drp_rng = jax.random.split(drp_rng)
grad_fn = jax.value_and_grad(loss_fn)
loss, grads = grad_fn(state.params)
metrics = jax.lax.pmean({"loss": loss}, axis_name="batch")
grads = jax.lax.pmean(grads, "batch")
state = state.apply_gradients(grads=grads)
return state, metrics, new_drp_rng
@partial(jax.pmap, axis_name="batch")
def val_step(state, **model_inputs):
start_labels = model_inputs.pop("start_labels")
end_labels = model_inputs.pop("end_labels")
pooled_labels = model_inputs.pop("pooled_labels")
outputs = state.apply_fn(**model_inputs, params=state.params, train=False)
start_logits, end_logits, pooled_logits = outputs
loss = state.loss_fn(start_logits, start_labels, end_logits, end_labels, pooled_logits, pooled_labels)
metrics = jax.lax.pmean({"loss": loss}, axis_name="batch")
return metrics
class TrainState(train_state.TrainState):
loss_fn: Callable = struct.field(pytree_node=False)
@dataclass
class Trainer:
args: Args
data_collator: Callable
train_step_fn: Callable
val_step_fn: Callable
model_save_fn: Callable
logger: wandb
scheduler_fn: Callable = None
def create_state(self, model, tx, num_train_steps, ckpt_dir=None):
params = model.params
state = TrainState.create(
apply_fn=model.__call__,
params=params,
tx=tx,
loss_fn=calculate_loss_for_nq,
)
if ckpt_dir is not None:
params, opt_state, step, args, data_collator = restore_checkpoint(ckpt_dir, state)
tx_args = {
"lr": args.lr,
"init_lr": args.init_lr,
"warmup_steps": args.warmup_steps,
"num_train_steps": num_train_steps,
"weight_decay": args.weight_decay,
}
tx, lr = build_tx(**tx_args)
state = train_state.TrainState(
step=step,
apply_fn=model.__call__,
params=params,
tx=tx,
opt_state=opt_state,
)
self.args = args
self.data_collator = data_collator
self.scheduler_fn = lr
model.params = params
state = jax_utils.replicate(state)
return state
def train(self, state, tr_dataset, val_dataset):
args = self.args
total = len(tr_dataset) // args.batch_size
rng = jax.random.PRNGKey(0)
drp_rng = jax.random.split(rng, jax.device_count())
for epoch in range(args.max_epochs):
running_loss = jnp.array(0, dtype=jnp.float32)
tr_dataloader = get_batched_dataset(tr_dataset, args.batch_size, seed=epoch)
i = 0
for batch in tqdm(tr_dataloader, total=total, desc=f"Running EPOCH-{epoch}"):
batch = self.data_collator(batch)
state, metrics, drp_rng = self.train_step_fn(state, drp_rng, **batch)
running_loss += jax_utils.unreplicate(metrics["loss"])
i += 1
if i % args.logging_steps == 0:
state_step = jax_utils.unreplicate(state.step)
tr_loss = running_loss.item() / i
lr = self.scheduler_fn(state_step - 1)
eval_loss = self.evaluate(state, val_dataset)
logging_dict = {
"step": state_step.item(),
"eval_loss": eval_loss.item(),
"tr_loss": tr_loss,
"lr": lr.item(),
}
tqdm.write(str(logging_dict))
self.logger.log(logging_dict, commit=True)
if i % args.save_steps == 0:
self.save_checkpoint(args.save_dir + f"-e{epoch}-s{i}", state=state)
def evaluate(self, state, dataset):
dataloader = get_batched_dataset(dataset, self.args.batch_size)
total = len(dataset) // self.args.batch_size
running_loss = jnp.array(0, dtype=jnp.float32)
i = 0
for batch in tqdm(dataloader, total=total, desc="Evaluating ... "):
batch = self.data_collator(batch)
metrics = self.val_step_fn(state, **batch)
running_loss += jax_utils.unreplicate(metrics["loss"])
i += 1
return running_loss / i
def save_checkpoint(self, save_dir, state):
state = jax_utils.unreplicate(state)
print(f"SAVING CHECKPOINT IN {save_dir}", end=" ... ")
self.model_save_fn(save_dir, params=state.params)
with open(os.path.join(save_dir, "opt_state.msgpack"), "wb") as f:
f.write(to_bytes(state.opt_state))
joblib.dump(self.args, os.path.join(save_dir, "args.joblib"))
joblib.dump(self.data_collator, os.path.join(save_dir, "data_collator.joblib"))
with open(os.path.join(save_dir, "training_state.json"), "w") as f:
json.dump({"step": state.step.item()}, f)
print("DONE")
def restore_checkpoint(save_dir, state):
print(f"RESTORING CHECKPOINT FROM {save_dir}", end=" ... ")
with open(os.path.join(save_dir, "flax_model.msgpack"), "rb") as f:
params = from_bytes(state.params, f.read())
with open(os.path.join(save_dir, "opt_state.msgpack"), "rb") as f:
opt_state = from_bytes(state.opt_state, f.read())
args = joblib.load(os.path.join(save_dir, "args.joblib"))
data_collator = joblib.load(os.path.join(save_dir, "data_collator.joblib"))
with open(os.path.join(save_dir, "training_state.json"), "r") as f:
training_state = json.load(f)
step = training_state["step"]
print("DONE")
return params, opt_state, step, args, data_collator
def scheduler_fn(lr, init_lr, warmup_steps, num_train_steps):
decay_steps = num_train_steps - warmup_steps
warmup_fn = optax.linear_schedule(init_value=init_lr, end_value=lr, transition_steps=warmup_steps)
decay_fn = optax.linear_schedule(init_value=lr, end_value=1e-7, transition_steps=decay_steps)
lr = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[warmup_steps])
return lr
def build_tx(lr, init_lr, warmup_steps, num_train_steps, weight_decay):
def weight_decay_mask(params):
params = traverse_util.flatten_dict(params)
mask = {k: (v[-1] != "bias" and v[-2:] != ("LayerNorm", "scale")) for k, v in params.items()}
return traverse_util.unflatten_dict(mask)
lr = scheduler_fn(lr, init_lr, warmup_steps, num_train_steps)
tx = optax.adamw(learning_rate=lr, weight_decay=weight_decay, mask=weight_decay_mask)
return tx, lr