# Copyright 2022 Google. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Sequence to sequence model.""" from typing import Any, Callable, Dict, Tuple from absl import logging from flax import linen as nn from flax.training import common_utils import gin import jax import jax.numpy as jnp import metrics_summary from transformer import decoder_stack from transformer import metric_utils from transformer import text_dataset import numpy as np import seqio Array = jnp.ndarray MetricsSummary = metrics_summary.MetricsSummary # TODO(mrabe): Remove this function and find a better way to turn text metrics # into text on tensorboard. def process_summaries(vocab: seqio.Vocabulary, met_summary: MetricsSummary, mode: str) -> MetricsSummary: """Compute some additional summaries, and convert tokens to text. Args: vocab: The vocabulary to detokenize generated text. met_summary: The summary object to process. mode: The mode of the summary (e.g. "test", "train") Returns: The modified summary dictionary. """ mdict = met_summary.current_metric_dict() # Calculate perplexity from the average nats_per_token over all replicas. # This has to be done here, because the perplexities themselves can't be # averaged in the usual way. if "nats_per_token" in mdict: nats_per_token = mdict["nats_per_token"].to_value() met_summary.add({"perplexity": np.exp(nats_per_token)}) if mode == "generate" and "gen_tokens" in mdict: # Convert output tokens to example output text. # Write text to both the summary, and pretty-print to the log file. gen_toks = mdict["gen_tokens"].to_value() if np.ndim(gen_toks) != 2: raise ValueError("Unsupported shape for gen_tokens: %s" % gen_toks.shape) ntoks = gen_toks.shape[-1] gen_text = text_dataset.decode_tokens(gen_toks, vocab, max_length=ntoks) logging.info("Generated text = %s", gen_text) met_summary.add_text({"gen_text": gen_text}) del mdict["gen_tokens"] # Otherwise it will turn into a histogram. return met_summary @gin.configurable def process_summaries_function(vocab: seqio.Vocabulary) -> Callable[ [MetricsSummary, str], MetricsSummary]: """Return a function that processes summaries with the given vocabulary.""" # For use with training_loop.process_summaries_function def process_fn(met_summary: MetricsSummary, mode: str): return process_summaries(vocab, met_summary, mode) return process_fn @gin.configurable class DecoderOnlyLanguageModel(nn.Module): """Decoder only language modeling.""" mode: str task_config: decoder_stack.TransformerTaskConfig = gin.REQUIRED decoder_factory: Callable[[], Any] = gin.REQUIRED sample_method: str = "sample" # Can be {"sample", "greedy"} output_token_losses: bool = False def get_fake_input(self): """Returns a fake input for initialization of the appropriate shape.""" b = self.task_config.batch_size fake_input_dict = { "targets": jnp.ones([b, self.task_config.sequence_length], dtype=jnp.int32), "start_of_sequence": jnp.ones([b], dtype=jnp.bool_), "epoch": jnp.ones([b], dtype=jnp.int32), } if text_dataset.get_loss_mask_tokens(split=self.mode) != (None, None): # We are not adding the loss mask to the dummy input by default as it can # cause a slowdown during evaluation and perhaps inference. fake_input_dict["loss_mask"] = jnp.ones( [b, self.task_config.sequence_length], dtype=jnp.bool_) return fake_input_dict def metrics_summary_operations(self, aggregate_over: str) -> Dict[str, str]: """Summary operation to use for recorded metrics.""" metric_ops = { "loss": "mean", "nats_per_token": "mean", "bits_per_token": "mean", "bits_per_char": "mean", "accuracy": "mean", "num_tokens": "mean", "num_chars_per_device": "mean", "num_chars_per_batch": "mean", "nonzero_tokens": "mean", "num_tokens_per_device": "mean", "num_tokens_per_batch": "mean", "epoch": "mean", } if aggregate_over == "steps": return metric_ops elif aggregate_over == "devices": # Ensure that statistics that refer to the total batch size stay constant # as TPU topologies change. For those we have to sum over devices, but # compute the mean over steps. metric_ops.update({ "num_tokens_per_batch": "sum", "num_chars_per_batch": "sum", "loss": "sum"}) return metric_ops else: raise ValueError("Don't know how to aggregate over: %s" % aggregate_over) def setup(self): self.decoder = self.decoder_factory(mode=self.mode, task_config=self.task_config) # pytype: disable=wrong-keyword-args # trace-all-classes def __call__(self, inputs: ...): task_config = self.task_config input_tokens = inputs["targets"] # [b, seq_len] start_of_sequence = inputs["start_of_sequence"] # [b] epochs = inputs["epoch"] # [b] if "loss_mask" in inputs: loss_mask = inputs["loss_mask"] # [b, seq_len] else: loss_mask = jnp.ones((1, 1), dtype=jnp.bool_) input_tokens = jnp.asarray(input_tokens) assert input_tokens.ndim == 2 assert input_tokens.shape[0] == task_config.batch_size assert input_tokens.shape[1] == task_config.sequence_length assert start_of_sequence.shape[0] == task_config.batch_size # Sanity check to avoid out-of-bounds on token lookup. input_tokens = input_tokens % task_config.vocab_size logging.info("langmodel: Compiling model for mode %s", self.mode) logging.info("langmodel: input_tokens = %r", input_tokens) logging.info("langmodel: start_of_sequece = %r", start_of_sequence) logging.info("langmodel: epochs = %r", epochs) # The target outputs are the next character in each sequence. # Shift tokens left and pad with a zero at the end. # TODO(delesley): We don't predict the first token of each sequence. target_tokens = jnp.pad(input_tokens[:, 1:], [(0, 0), (0, 1)]) logging.info("langmodel: target_tokens = %r", target_tokens) # Invoke the decoder stack. # The decoder will return pre-softmax logits for the predicted targets. (logits, _, d_metrics) = self.decoder(input_tokens=input_tokens, target_tokens=target_tokens, start_of_sequence=start_of_sequence) # Softmax cross-entropy loss on target tokens. logits = nn.log_softmax(logits, axis=-1) # (b, seq_len, vocab_size) logging.info("langmodel: logits = %r", logits) soft_targets = common_utils.onehot(target_tokens, task_config.vocab_size) logging.info("langmodel: soft_targets = %r", soft_targets) losses = -jnp.sum(soft_targets * logits, axis=-1) # (b, seq_len) logging.info("langmodel: losses = %r", losses) # Don't predict null tokens which are past the end-of-sequence. # Also don't predict the 0 at the end of the sequence. # TODO(delesley): Predict the final end-of-sequence marker. loss_mask = jnp.logical_and( loss_mask, input_tokens > 0) loss_mask = jnp.logical_and( loss_mask, target_tokens > 0) logging.info("langmodel: loss_mask = %r", loss_mask) losses = jnp.where(loss_mask, losses, 0.0) # (batch_size, seq_len) loss = jnp.sum(losses) # total loss on device token_count = jnp.sum(loss_mask) # tokens on device token_count_nz = token_count + 1.0e-6 loss_per_token = loss / token_count_nz bits_per_token = loss_per_token * 1.442695 # log(e)/log(2) accuracy = metric_utils.compute_accuracy_sum(logits, target_tokens, loss_mask) accuracy = accuracy / token_count_nz # Percent correct. epoch = jnp.mean(epochs) if self.mode == "generate" and self.decoder.supports_generate(): # Generate example text. logging.info("lang_model: text inference.") gen_tokens = self.generate(inputs, task_config.sequence_length) # Return generated text, along with vizualizations and histograms. metrics = {"gen_tokens": gen_tokens, **d_metrics} return (loss, metrics) # Just return metrics related to the loss. metrics = { "loss": loss, # will be summed over devices "nats_per_token": (loss_per_token, token_count), "bits_per_token": (bits_per_token, token_count), "accuracy": (accuracy, token_count), "num_tokens_per_device": token_count, "num_tokens_per_batch": token_count, # will be summed over devices "epoch": epoch, } # Compute bits per character if we have the number of characters. if "num_chars" in inputs: num_chars = jnp.sum(inputs["num_chars"]) bits_per_char = loss / (num_chars + 1e-6) * 1.442695 metrics["num_chars_per_device"] = num_chars metrics["num_chars_per_batch"] = num_chars # will be summed over devices metrics["bits_per_char"] = (bits_per_char, num_chars) # Provided to make sure that the data pipeline and the the model agree # on the number of tokens with a loss. if "nonzero_tokens" in inputs: nonzero_tokens = jnp.sum(inputs["nonzero_tokens"]) metrics["nonzero_tokens"] = nonzero_tokens if self.output_token_losses: metrics["token_losses"] = losses return (loss, metrics) def generate(self, inputs: ..., sequence_length: int) -> Array: """Generate an output sequence. Args: inputs: the same as argument to _call_. sequence_length: the length of sequence to generate. Returns: An array of generated tokens of shape (batch_size, sequence_length). """ # TODO(delesley): Add support for passing the prefix as an argument. # TODO(delesley): Add support for temperature, gumbel softmax, beam search. batch_size = self.task_config.batch_size input_tokens = inputs["targets"] # [b,seq_len] start_of_sequence = inputs["start_of_sequence"] # [b] # Initialize decoder. dstate = self.decoder.init_decoder_state(sequence_length, start_of_sequence) # TODO(delesley): Handle start-of-sequence in a better way. # There is no special token for start of sequence, so we grab the first # one from the ground-truth input data. first_token = input_tokens[:, 0:1] no_start_of_seq = jnp.array([False] * batch_size, dtype=jnp.bool_) sample_method = self.sample_method sample_prng = self.make_rng("sample") # Greedy autoregressive decoder function. def loop_fn(scan_state: Any, i: Array) -> Tuple[Any, Array]: prng = jax.random.fold_in(sample_prng, i) (dstate, input_token) = scan_state del i (logits, dstate, _) = self.decoder(input_tokens=input_token, target_tokens=None, start_of_sequence=no_start_of_seq, decoder_state=dstate) if sample_method == "sample": logging.info("Using categorical sampling.") output_token = jax.random.categorical(prng, logits, axis=-1) elif sample_method == "greedy": logging.info("Using greedy sampling.") output_token = jnp.argmax(logits, axis=-1) else: raise ValueError(f"Invalid sampling method: {sample_method}") logging.info("generate_loop_fn: output_token = %r", output_token) return ((dstate, output_token), output_token) # Scan over the sequence length. iterations = jnp.arange(sequence_length) initial_scan_state = (dstate, first_token) (_, output_tokens) = jax.lax.scan(loop_fn, initial_scan_state, iterations) logging.info("generate: output_tokens = %r", output_tokens) # Output_tokens has shape (sequence_length, batch_size, 1) assert output_tokens.shape == (sequence_length, batch_size, 1) output_tokens = jnp.reshape( output_tokens, (sequence_length, self.task_config.batch_size)) output_tokens = output_tokens.transpose([1, 0]) return output_tokens