|
import json |
|
import os |
|
from dataclasses import dataclass |
|
from functools import partial |
|
from typing import Callable |
|
|
|
import flax.linen as nn |
|
import jax |
|
import jax.numpy as jnp |
|
import joblib |
|
import optax |
|
import wandb |
|
from flax import jax_utils, struct, traverse_util |
|
from flax.serialization import from_bytes, to_bytes |
|
from flax.training import train_state |
|
from flax.training.common_utils import shard |
|
from tqdm.auto import tqdm |
|
|
|
from transformers import BigBirdConfig, FlaxBigBirdForQuestionAnswering |
|
from transformers.models.big_bird.modeling_flax_big_bird import FlaxBigBirdForQuestionAnsweringModule |
|
|
|
|
|
class FlaxBigBirdForNaturalQuestionsModule(FlaxBigBirdForQuestionAnsweringModule): |
|
""" |
|
BigBirdForQuestionAnswering with CLS Head over the top for predicting category |
|
|
|
This way we can load its weights with FlaxBigBirdForQuestionAnswering |
|
""" |
|
|
|
config: BigBirdConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
add_pooling_layer: bool = True |
|
|
|
def setup(self): |
|
super().setup() |
|
self.cls = nn.Dense(5, dtype=self.dtype) |
|
|
|
def __call__(self, *args, **kwargs): |
|
outputs = super().__call__(*args, **kwargs) |
|
cls_out = self.cls(outputs[2]) |
|
return outputs[:2] + (cls_out,) |
|
|
|
|
|
class FlaxBigBirdForNaturalQuestions(FlaxBigBirdForQuestionAnswering): |
|
module_class = FlaxBigBirdForNaturalQuestionsModule |
|
|
|
|
|
def calculate_loss_for_nq(start_logits, start_labels, end_logits, end_labels, pooled_logits, pooler_labels): |
|
def cross_entropy(logits, labels, reduction=None): |
|
""" |
|
Args: |
|
logits: bsz, seqlen, vocab_size |
|
labels: bsz, seqlen |
|
""" |
|
vocab_size = logits.shape[-1] |
|
labels = (labels[..., None] == jnp.arange(vocab_size)[None]).astype("f4") |
|
logits = jax.nn.log_softmax(logits, axis=-1) |
|
loss = -jnp.sum(labels * logits, axis=-1) |
|
if reduction is not None: |
|
loss = reduction(loss) |
|
return loss |
|
|
|
cross_entropy = partial(cross_entropy, reduction=jnp.mean) |
|
start_loss = cross_entropy(start_logits, start_labels) |
|
end_loss = cross_entropy(end_logits, end_labels) |
|
pooled_loss = cross_entropy(pooled_logits, pooler_labels) |
|
return (start_loss + end_loss + pooled_loss) / 3 |
|
|
|
|
|
@dataclass |
|
class Args: |
|
model_id: str = "google/bigbird-roberta-base" |
|
logging_steps: int = 3000 |
|
save_steps: int = 10500 |
|
|
|
block_size: int = 128 |
|
num_random_blocks: int = 3 |
|
|
|
batch_size_per_device: int = 1 |
|
max_epochs: int = 5 |
|
|
|
|
|
lr: float = 3e-5 |
|
init_lr: float = 0.0 |
|
warmup_steps: int = 20000 |
|
weight_decay: float = 0.0095 |
|
|
|
save_dir: str = "bigbird-roberta-natural-questions" |
|
base_dir: str = "training-expt" |
|
tr_data_path: str = "data/nq-training.jsonl" |
|
val_data_path: str = "data/nq-validation.jsonl" |
|
|
|
def __post_init__(self): |
|
os.makedirs(self.base_dir, exist_ok=True) |
|
self.save_dir = os.path.join(self.base_dir, self.save_dir) |
|
self.batch_size = self.batch_size_per_device * jax.device_count() |
|
|
|
|
|
@dataclass |
|
class DataCollator: |
|
pad_id: int |
|
max_length: int = 4096 |
|
|
|
def __call__(self, batch): |
|
batch = self.collate_fn(batch) |
|
batch = jax.tree_util.tree_map(shard, batch) |
|
return batch |
|
|
|
def collate_fn(self, features): |
|
input_ids, attention_mask = self.fetch_inputs(features["input_ids"]) |
|
batch = { |
|
"input_ids": jnp.array(input_ids, dtype=jnp.int32), |
|
"attention_mask": jnp.array(attention_mask, dtype=jnp.int32), |
|
"start_labels": jnp.array(features["start_token"], dtype=jnp.int32), |
|
"end_labels": jnp.array(features["end_token"], dtype=jnp.int32), |
|
"pooled_labels": jnp.array(features["category"], dtype=jnp.int32), |
|
} |
|
return batch |
|
|
|
def fetch_inputs(self, input_ids: list): |
|
inputs = [self._fetch_inputs(ids) for ids in input_ids] |
|
return zip(*inputs) |
|
|
|
def _fetch_inputs(self, input_ids: list): |
|
attention_mask = [1 for _ in range(len(input_ids))] |
|
while len(input_ids) < self.max_length: |
|
input_ids.append(self.pad_id) |
|
attention_mask.append(0) |
|
return input_ids, attention_mask |
|
|
|
|
|
def get_batched_dataset(dataset, batch_size, seed=None): |
|
if seed is not None: |
|
dataset = dataset.shuffle(seed=seed) |
|
for i in range(len(dataset) // batch_size): |
|
batch = dataset[i * batch_size : (i + 1) * batch_size] |
|
yield dict(batch) |
|
|
|
|
|
@partial(jax.pmap, axis_name="batch") |
|
def train_step(state, drp_rng, **model_inputs): |
|
def loss_fn(params): |
|
start_labels = model_inputs.pop("start_labels") |
|
end_labels = model_inputs.pop("end_labels") |
|
pooled_labels = model_inputs.pop("pooled_labels") |
|
|
|
outputs = state.apply_fn(**model_inputs, params=params, dropout_rng=drp_rng, train=True) |
|
start_logits, end_logits, pooled_logits = outputs |
|
|
|
return state.loss_fn( |
|
start_logits, |
|
start_labels, |
|
end_logits, |
|
end_labels, |
|
pooled_logits, |
|
pooled_labels, |
|
) |
|
|
|
drp_rng, new_drp_rng = jax.random.split(drp_rng) |
|
grad_fn = jax.value_and_grad(loss_fn) |
|
loss, grads = grad_fn(state.params) |
|
metrics = jax.lax.pmean({"loss": loss}, axis_name="batch") |
|
grads = jax.lax.pmean(grads, "batch") |
|
|
|
state = state.apply_gradients(grads=grads) |
|
return state, metrics, new_drp_rng |
|
|
|
|
|
@partial(jax.pmap, axis_name="batch") |
|
def val_step(state, **model_inputs): |
|
start_labels = model_inputs.pop("start_labels") |
|
end_labels = model_inputs.pop("end_labels") |
|
pooled_labels = model_inputs.pop("pooled_labels") |
|
|
|
outputs = state.apply_fn(**model_inputs, params=state.params, train=False) |
|
start_logits, end_logits, pooled_logits = outputs |
|
|
|
loss = state.loss_fn(start_logits, start_labels, end_logits, end_labels, pooled_logits, pooled_labels) |
|
metrics = jax.lax.pmean({"loss": loss}, axis_name="batch") |
|
return metrics |
|
|
|
|
|
class TrainState(train_state.TrainState): |
|
loss_fn: Callable = struct.field(pytree_node=False) |
|
|
|
|
|
@dataclass |
|
class Trainer: |
|
args: Args |
|
data_collator: Callable |
|
train_step_fn: Callable |
|
val_step_fn: Callable |
|
model_save_fn: Callable |
|
logger: wandb |
|
scheduler_fn: Callable = None |
|
|
|
def create_state(self, model, tx, num_train_steps, ckpt_dir=None): |
|
params = model.params |
|
state = TrainState.create( |
|
apply_fn=model.__call__, |
|
params=params, |
|
tx=tx, |
|
loss_fn=calculate_loss_for_nq, |
|
) |
|
if ckpt_dir is not None: |
|
params, opt_state, step, args, data_collator = restore_checkpoint(ckpt_dir, state) |
|
tx_args = { |
|
"lr": args.lr, |
|
"init_lr": args.init_lr, |
|
"warmup_steps": args.warmup_steps, |
|
"num_train_steps": num_train_steps, |
|
"weight_decay": args.weight_decay, |
|
} |
|
tx, lr = build_tx(**tx_args) |
|
state = train_state.TrainState( |
|
step=step, |
|
apply_fn=model.__call__, |
|
params=params, |
|
tx=tx, |
|
opt_state=opt_state, |
|
) |
|
self.args = args |
|
self.data_collator = data_collator |
|
self.scheduler_fn = lr |
|
model.params = params |
|
state = jax_utils.replicate(state) |
|
return state |
|
|
|
def train(self, state, tr_dataset, val_dataset): |
|
args = self.args |
|
total = len(tr_dataset) // args.batch_size |
|
|
|
rng = jax.random.PRNGKey(0) |
|
drp_rng = jax.random.split(rng, jax.device_count()) |
|
for epoch in range(args.max_epochs): |
|
running_loss = jnp.array(0, dtype=jnp.float32) |
|
tr_dataloader = get_batched_dataset(tr_dataset, args.batch_size, seed=epoch) |
|
i = 0 |
|
for batch in tqdm(tr_dataloader, total=total, desc=f"Running EPOCH-{epoch}"): |
|
batch = self.data_collator(batch) |
|
state, metrics, drp_rng = self.train_step_fn(state, drp_rng, **batch) |
|
running_loss += jax_utils.unreplicate(metrics["loss"]) |
|
i += 1 |
|
if i % args.logging_steps == 0: |
|
state_step = jax_utils.unreplicate(state.step) |
|
tr_loss = running_loss.item() / i |
|
lr = self.scheduler_fn(state_step - 1) |
|
|
|
eval_loss = self.evaluate(state, val_dataset) |
|
logging_dict = { |
|
"step": state_step.item(), |
|
"eval_loss": eval_loss.item(), |
|
"tr_loss": tr_loss, |
|
"lr": lr.item(), |
|
} |
|
tqdm.write(str(logging_dict)) |
|
self.logger.log(logging_dict, commit=True) |
|
|
|
if i % args.save_steps == 0: |
|
self.save_checkpoint(args.save_dir + f"-e{epoch}-s{i}", state=state) |
|
|
|
def evaluate(self, state, dataset): |
|
dataloader = get_batched_dataset(dataset, self.args.batch_size) |
|
total = len(dataset) // self.args.batch_size |
|
running_loss = jnp.array(0, dtype=jnp.float32) |
|
i = 0 |
|
for batch in tqdm(dataloader, total=total, desc="Evaluating ... "): |
|
batch = self.data_collator(batch) |
|
metrics = self.val_step_fn(state, **batch) |
|
running_loss += jax_utils.unreplicate(metrics["loss"]) |
|
i += 1 |
|
return running_loss / i |
|
|
|
def save_checkpoint(self, save_dir, state): |
|
state = jax_utils.unreplicate(state) |
|
print(f"SAVING CHECKPOINT IN {save_dir}", end=" ... ") |
|
self.model_save_fn(save_dir, params=state.params) |
|
with open(os.path.join(save_dir, "opt_state.msgpack"), "wb") as f: |
|
f.write(to_bytes(state.opt_state)) |
|
joblib.dump(self.args, os.path.join(save_dir, "args.joblib")) |
|
joblib.dump(self.data_collator, os.path.join(save_dir, "data_collator.joblib")) |
|
with open(os.path.join(save_dir, "training_state.json"), "w") as f: |
|
json.dump({"step": state.step.item()}, f) |
|
print("DONE") |
|
|
|
|
|
def restore_checkpoint(save_dir, state): |
|
print(f"RESTORING CHECKPOINT FROM {save_dir}", end=" ... ") |
|
with open(os.path.join(save_dir, "flax_model.msgpack"), "rb") as f: |
|
params = from_bytes(state.params, f.read()) |
|
|
|
with open(os.path.join(save_dir, "opt_state.msgpack"), "rb") as f: |
|
opt_state = from_bytes(state.opt_state, f.read()) |
|
|
|
args = joblib.load(os.path.join(save_dir, "args.joblib")) |
|
data_collator = joblib.load(os.path.join(save_dir, "data_collator.joblib")) |
|
|
|
with open(os.path.join(save_dir, "training_state.json"), "r") as f: |
|
training_state = json.load(f) |
|
step = training_state["step"] |
|
|
|
print("DONE") |
|
return params, opt_state, step, args, data_collator |
|
|
|
|
|
def scheduler_fn(lr, init_lr, warmup_steps, num_train_steps): |
|
decay_steps = num_train_steps - warmup_steps |
|
warmup_fn = optax.linear_schedule(init_value=init_lr, end_value=lr, transition_steps=warmup_steps) |
|
decay_fn = optax.linear_schedule(init_value=lr, end_value=1e-7, transition_steps=decay_steps) |
|
lr = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[warmup_steps]) |
|
return lr |
|
|
|
|
|
def build_tx(lr, init_lr, warmup_steps, num_train_steps, weight_decay): |
|
def weight_decay_mask(params): |
|
params = traverse_util.flatten_dict(params) |
|
mask = {k: (v[-1] != "bias" and v[-2:] != ("LayerNorm", "scale")) for k, v in params.items()} |
|
return traverse_util.unflatten_dict(mask) |
|
|
|
lr = scheduler_fn(lr, init_lr, warmup_steps, num_train_steps) |
|
|
|
tx = optax.adamw(learning_rate=lr, weight_decay=weight_decay, mask=weight_decay_mask) |
|
return tx, lr |
|
|