Spaces:
Runtime error
Runtime error
| # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """JAX implementation of CLRS baseline models.""" | |
| import functools | |
| import os | |
| import pickle | |
| from typing import Dict, List, Optional, Tuple, Union | |
| import chex | |
| from clrs._src import decoders | |
| from clrs._src import losses | |
| from clrs._src import model | |
| from clrs._src import nets | |
| from clrs._src import probing | |
| from clrs._src import processors | |
| from clrs._src import samplers | |
| from clrs._src import specs | |
| import haiku as hk | |
| import jax | |
| import jax.numpy as jnp | |
| import numpy as np | |
| import optax | |
| _Array = chex.Array | |
| _DataPoint = probing.DataPoint | |
| _Features = samplers.Features | |
| _FeaturesChunked = samplers.FeaturesChunked | |
| _Feedback = samplers.Feedback | |
| _Location = specs.Location | |
| _Seed = jnp.integer | |
| _Spec = specs.Spec | |
| _Stage = specs.Stage | |
| _Trajectory = samplers.Trajectory | |
| _Type = specs.Type | |
| _OutputClass = specs.OutputClass | |
| # pytype: disable=signature-mismatch | |
| def _maybe_pick_first_pmapped(tree): | |
| if jax.local_device_count() == 1: | |
| return tree | |
| return jax.tree_util.tree_map(lambda x: x[0], tree) | |
| def _restack_from_pmap(tree): | |
| """Stack the results of a pmapped computation across the first two axes.""" | |
| restack_array = lambda x: jnp.reshape(x, (-1,) + x.shape[2:]) | |
| return jax.tree_util.tree_map(restack_array, tree) | |
| def _maybe_restack_from_pmap(tree): | |
| if jax.local_device_count() == 1: | |
| return tree | |
| return _restack_from_pmap(tree) | |
| def _pmap_reshape(x, n_devices, split_axis=0): | |
| """Splits a pytree over n_devices on axis split_axis for pmapping.""" | |
| def _reshape(arr): | |
| new_shape = (arr.shape[:split_axis] + | |
| (n_devices, arr.shape[split_axis] // n_devices) + | |
| arr.shape[split_axis + 1:]) | |
| return jnp.moveaxis(jnp.reshape(arr, new_shape), split_axis, 0) | |
| return jax.tree_util.tree_map(_reshape, x) | |
| def _maybe_pmap_reshape(x, split_axis=0): | |
| n_devices = jax.local_device_count() | |
| if n_devices == 1: | |
| return x | |
| return _pmap_reshape(x, n_devices, split_axis) | |
| def _pmap_data(data: Union[_Feedback, _Features], n_devices: int): | |
| """Replicate/split feedback or features for pmapping.""" | |
| if isinstance(data, _Feedback): | |
| features = data.features | |
| else: | |
| features = data | |
| pmap_data = features._replace( | |
| inputs=_pmap_reshape(features.inputs, n_devices), | |
| hints=_pmap_reshape(features.hints, n_devices, split_axis=1), | |
| lengths=_pmap_reshape(features.lengths, n_devices), | |
| ) | |
| if isinstance(data, _Feedback): | |
| pmap_data = data._replace( | |
| features=pmap_data, | |
| outputs=_pmap_reshape(data.outputs, n_devices) | |
| ) | |
| return pmap_data | |
| def _maybe_pmap_data(data: Union[_Feedback, _Features]): | |
| n_devices = jax.local_device_count() | |
| if n_devices == 1: | |
| return data | |
| return _pmap_data(data, n_devices) | |
| def _maybe_put_replicated(tree): | |
| if jax.local_device_count() == 1: | |
| return jax.device_put(tree) | |
| else: | |
| return jax.device_put_replicated(tree, jax.local_devices()) | |
| def _maybe_pmap_rng_key(rng_key: _Array): | |
| n_devices = jax.local_device_count() | |
| if n_devices == 1: | |
| return rng_key | |
| pmap_rng_keys = jax.random.split(rng_key, n_devices) | |
| return jax.device_put_sharded(list(pmap_rng_keys), jax.local_devices()) | |
| class BaselineModel(model.Model): | |
| """Model implementation with selectable message passing algorithm.""" | |
| def __init__( | |
| self, | |
| spec: Union[_Spec, List[_Spec]], | |
| dummy_trajectory: Union[List[_Feedback], _Feedback], | |
| processor_factory: processors.ProcessorFactory, | |
| hidden_dim: int = 32, | |
| encode_hints: bool = False, | |
| decode_hints: bool = True, | |
| encoder_init: str = 'default', | |
| use_lstm: bool = False, | |
| learning_rate: float = 0.005, | |
| grad_clip_max_norm: float = 0.0, | |
| checkpoint_path: str = '/tmp/clrs3', | |
| freeze_processor: bool = False, | |
| dropout_prob: float = 0.0, | |
| hint_teacher_forcing: float = 0.0, | |
| hint_repred_mode: str = 'soft', | |
| name: str = 'base_model', | |
| nb_msg_passing_steps: int = 1, | |
| ): | |
| """Constructor for BaselineModel. | |
| The model consists of encoders, processor and decoders. It can train | |
| and evaluate either a single algorithm or a set of algorithms; in the | |
| latter case, a single processor is shared among all the algorithms, while | |
| the encoders and decoders are separate for each algorithm. | |
| Args: | |
| spec: Either a single spec for one algorithm, or a list of specs for | |
| multiple algorithms to be trained and evaluated. | |
| dummy_trajectory: Either a single feedback batch, in the single-algorithm | |
| case, or a list of feedback batches, in the multi-algorithm case, that | |
| comply with the `spec` (or list of specs), to initialize network size. | |
| processor_factory: A callable that takes an `out_size` parameter | |
| and returns a processor (see `processors.py`). | |
| hidden_dim: Size of the hidden state of the model, i.e., size of the | |
| message-passing vectors. | |
| encode_hints: Whether to provide hints as model inputs. | |
| decode_hints: Whether to provide hints as model outputs. | |
| encoder_init: The initialiser type to use for the encoders. | |
| use_lstm: Whether to insert an LSTM after message passing. | |
| learning_rate: Learning rate for training. | |
| grad_clip_max_norm: if greater than 0, the maximum norm of the gradients. | |
| checkpoint_path: Path for loading/saving checkpoints. | |
| freeze_processor: If True, the processor weights will be frozen and | |
| only encoders and decoders (and, if used, the lstm) will be trained. | |
| dropout_prob: Dropout rate in the message-passing stage. | |
| hint_teacher_forcing: Probability of using ground-truth hints instead | |
| of predicted hints as inputs during training (only relevant if | |
| `encode_hints`=True) | |
| hint_repred_mode: How to process predicted hints when fed back as inputs. | |
| Only meaningful when `encode_hints` and `decode_hints` are True. | |
| Options are: | |
| - 'soft', where we use softmaxes for categoricals, pointers | |
| and mask_one, and sigmoids for masks. This will allow gradients | |
| to flow through hints during training. | |
| - 'hard', where we use argmax instead of softmax, and hard | |
| thresholding of masks. No gradients will go through the hints | |
| during training; even for scalar hints, which don't have any | |
| kind of post-processing, gradients will be stopped. | |
| - 'hard_on_eval', which is soft for training and hard for evaluation. | |
| name: Model name. | |
| nb_msg_passing_steps: Number of message passing steps per hint. | |
| Raises: | |
| ValueError: if `encode_hints=True` and `decode_hints=False`. | |
| """ | |
| super(BaselineModel, self).__init__(spec=spec) | |
| if encode_hints and not decode_hints: | |
| raise ValueError('`encode_hints=True`, `decode_hints=False` is invalid.') | |
| assert hint_repred_mode in ['soft', 'hard', 'hard_on_eval'] | |
| self.decode_hints = decode_hints | |
| self.checkpoint_path = checkpoint_path | |
| self.name = name | |
| self._freeze_processor = freeze_processor | |
| if grad_clip_max_norm != 0.0: | |
| optax_chain = [optax.clip_by_global_norm(grad_clip_max_norm), | |
| optax.scale_by_adam(), | |
| optax.scale(-learning_rate)] | |
| self.opt = optax.chain(*optax_chain) | |
| else: | |
| self.opt = optax.adam(learning_rate) | |
| self.nb_msg_passing_steps = nb_msg_passing_steps | |
| self.nb_dims = [] | |
| if isinstance(dummy_trajectory, _Feedback): | |
| assert len(self._spec) == 1 | |
| dummy_trajectory = [dummy_trajectory] | |
| for traj in dummy_trajectory: | |
| nb_dims = {} | |
| for inp in traj.features.inputs: | |
| nb_dims[inp.name] = inp.data.shape[-1] | |
| for hint in traj.features.hints: | |
| nb_dims[hint.name] = hint.data.shape[-1] | |
| for outp in traj.outputs: | |
| nb_dims[outp.name] = outp.data.shape[-1] | |
| self.nb_dims.append(nb_dims) | |
| self._create_net_fns(hidden_dim, encode_hints, processor_factory, use_lstm, | |
| encoder_init, dropout_prob, hint_teacher_forcing, | |
| hint_repred_mode) | |
| self._device_params = None | |
| self._device_opt_state = None | |
| self.opt_state_skeleton = None | |
| def _create_net_fns(self, hidden_dim, encode_hints, processor_factory, | |
| use_lstm, encoder_init, dropout_prob, | |
| hint_teacher_forcing, hint_repred_mode): | |
| def _use_net(*args, **kwargs): | |
| return nets.Net(self._spec, hidden_dim, encode_hints, self.decode_hints, | |
| processor_factory, use_lstm, encoder_init, | |
| dropout_prob, hint_teacher_forcing, | |
| hint_repred_mode, | |
| self.nb_dims, self.nb_msg_passing_steps)(*args, **kwargs) | |
| self.net_fn = hk.transform(_use_net) | |
| pmap_args = dict(axis_name='batch', devices=jax.local_devices()) | |
| n_devices = jax.local_device_count() | |
| func, static_arg, extra_args = ( | |
| (jax.jit, 'static_argnums', {}) if n_devices == 1 else | |
| (jax.pmap, 'static_broadcasted_argnums', pmap_args)) | |
| pmean = functools.partial(jax.lax.pmean, axis_name='batch') | |
| self._maybe_pmean = pmean if n_devices > 1 else lambda x: x | |
| extra_args[static_arg] = 3 | |
| self.jitted_grad = func(self._compute_grad, **extra_args) | |
| extra_args[static_arg] = 4 | |
| self.jitted_feedback = func(self._feedback, donate_argnums=[0, 3], | |
| **extra_args) | |
| extra_args[static_arg] = [3, 4, 5] | |
| self.jitted_predict = func(self._predict, **extra_args) | |
| extra_args[static_arg] = [3, 4] | |
| self.jitted_accum_opt_update = func(accum_opt_update, donate_argnums=[0, 2], | |
| **extra_args) | |
| def init(self, features: Union[_Features, List[_Features]], seed: _Seed): | |
| if not isinstance(features, list): | |
| assert len(self._spec) == 1 | |
| features = [features] | |
| self.params = self.net_fn.init(jax.random.PRNGKey(seed), features, True, # pytype: disable=wrong-arg-types # jax-ndarray | |
| algorithm_index=-1, | |
| return_hints=False, | |
| return_all_outputs=False) | |
| self.opt_state = self.opt.init(self.params) | |
| # We will use the optimizer state skeleton for traversal when we | |
| # want to avoid updating the state of params of untrained algorithms. | |
| self.opt_state_skeleton = self.opt.init(jnp.zeros(1)) | |
| def params(self): | |
| if self._device_params is None: | |
| return None | |
| return jax.device_get(_maybe_pick_first_pmapped(self._device_params)) | |
| def params(self, params): | |
| self._device_params = _maybe_put_replicated(params) | |
| def opt_state(self): | |
| if self._device_opt_state is None: | |
| return None | |
| return jax.device_get(_maybe_pick_first_pmapped(self._device_opt_state)) | |
| def opt_state(self, opt_state): | |
| self._device_opt_state = _maybe_put_replicated(opt_state) | |
| def _compute_grad(self, params, rng_key, feedback, algorithm_index): | |
| lss, grads = jax.value_and_grad(self._loss)( | |
| params, rng_key, feedback, algorithm_index) | |
| return self._maybe_pmean(lss), self._maybe_pmean(grads) | |
| def _feedback(self, params, rng_key, feedback, opt_state, algorithm_index): | |
| lss, grads = jax.value_and_grad(self._loss)( | |
| params, rng_key, feedback, algorithm_index) | |
| grads = self._maybe_pmean(grads) | |
| params, opt_state = self._update_params(params, grads, opt_state, | |
| algorithm_index) | |
| lss = self._maybe_pmean(lss) | |
| return lss, params, opt_state | |
| def _predict(self, params, rng_key: hk.PRNGSequence, features: _Features, | |
| algorithm_index: int, return_hints: bool, | |
| return_all_outputs: bool): | |
| outs, hint_preds = self.net_fn.apply( | |
| params, rng_key, [features], | |
| repred=True, algorithm_index=algorithm_index, | |
| return_hints=return_hints, | |
| return_all_outputs=return_all_outputs) | |
| outs = decoders.postprocess(self._spec[algorithm_index], | |
| outs, | |
| sinkhorn_temperature=0.1, | |
| sinkhorn_steps=50, | |
| hard=True, | |
| ) | |
| return outs, hint_preds | |
| def compute_grad( | |
| self, | |
| rng_key: hk.PRNGSequence, | |
| feedback: _Feedback, | |
| algorithm_index: Optional[int] = None, | |
| ) -> Tuple[float, _Array]: | |
| """Compute gradients.""" | |
| if algorithm_index is None: | |
| assert len(self._spec) == 1 | |
| algorithm_index = 0 | |
| assert algorithm_index >= 0 | |
| # Calculate gradients. | |
| rng_keys = _maybe_pmap_rng_key(rng_key) # pytype: disable=wrong-arg-types # numpy-scalars | |
| feedback = _maybe_pmap_data(feedback) | |
| loss, grads = self.jitted_grad( | |
| self._device_params, rng_keys, feedback, algorithm_index) | |
| loss = _maybe_pick_first_pmapped(loss) | |
| grads = _maybe_pick_first_pmapped(grads) | |
| return loss, grads | |
| def feedback(self, rng_key: hk.PRNGSequence, feedback: _Feedback, | |
| algorithm_index=None) -> float: | |
| if algorithm_index is None: | |
| assert len(self._spec) == 1 | |
| algorithm_index = 0 | |
| # Calculate and apply gradients. | |
| rng_keys = _maybe_pmap_rng_key(rng_key) # pytype: disable=wrong-arg-types # numpy-scalars | |
| feedback = _maybe_pmap_data(feedback) | |
| loss, self._device_params, self._device_opt_state = self.jitted_feedback( | |
| self._device_params, rng_keys, feedback, | |
| self._device_opt_state, algorithm_index) | |
| loss = _maybe_pick_first_pmapped(loss) | |
| return loss | |
| def predict(self, rng_key: hk.PRNGSequence, features: _Features, | |
| algorithm_index: Optional[int] = None, | |
| return_hints: bool = False, | |
| return_all_outputs: bool = False): | |
| """Model inference step.""" | |
| if algorithm_index is None: | |
| assert len(self._spec) == 1 | |
| algorithm_index = 0 | |
| rng_keys = _maybe_pmap_rng_key(rng_key) # pytype: disable=wrong-arg-types # numpy-scalars | |
| features = _maybe_pmap_data(features) | |
| return _maybe_restack_from_pmap( | |
| self.jitted_predict( | |
| self._device_params, rng_keys, features, | |
| algorithm_index, | |
| return_hints, | |
| return_all_outputs)) | |
| def _loss(self, params, rng_key, feedback, algorithm_index): | |
| """Calculates model loss f(feedback; params).""" | |
| output_preds, hint_preds = self.net_fn.apply( | |
| params, rng_key, [feedback.features], | |
| repred=False, | |
| algorithm_index=algorithm_index, | |
| return_hints=True, | |
| return_all_outputs=False) | |
| nb_nodes = _nb_nodes(feedback, is_chunked=False) | |
| lengths = feedback.features.lengths | |
| total_loss = 0.0 | |
| # Calculate output loss. | |
| for truth in feedback.outputs: | |
| total_loss += losses.output_loss( | |
| truth=truth, | |
| pred=output_preds[truth.name], | |
| nb_nodes=nb_nodes, | |
| ) | |
| # Optionally accumulate hint losses. | |
| if self.decode_hints: | |
| for truth in feedback.features.hints: | |
| total_loss += losses.hint_loss( | |
| truth=truth, | |
| preds=[x[truth.name] for x in hint_preds], | |
| lengths=lengths, | |
| nb_nodes=nb_nodes, | |
| ) | |
| return total_loss | |
| def _update_params(self, params, grads, opt_state, algorithm_index): | |
| updates, opt_state = filter_null_grads( | |
| grads, self.opt, opt_state, self.opt_state_skeleton, algorithm_index) | |
| if self._freeze_processor: | |
| params_subset = _filter_out_processor(params) | |
| updates_subset = _filter_out_processor(updates) | |
| assert len(params) > len(params_subset) | |
| assert params_subset | |
| new_params = optax.apply_updates(params_subset, updates_subset) | |
| new_params = hk.data_structures.merge(params, new_params) | |
| else: | |
| new_params = optax.apply_updates(params, updates) | |
| return new_params, opt_state | |
| def update_model_params_accum(self, grads) -> None: | |
| grads = _maybe_put_replicated(grads) | |
| self._device_params, self._device_opt_state = self.jitted_accum_opt_update( | |
| self._device_params, grads, self._device_opt_state, self.opt, | |
| self._freeze_processor) | |
| def verbose_loss(self, feedback: _Feedback, extra_info) -> Dict[str, _Array]: | |
| """Gets verbose loss information.""" | |
| hint_preds = extra_info | |
| nb_nodes = _nb_nodes(feedback, is_chunked=False) | |
| lengths = feedback.features.lengths | |
| losses_ = {} | |
| # Optionally accumulate hint losses. | |
| if self.decode_hints: | |
| for truth in feedback.features.hints: | |
| losses_.update( | |
| losses.hint_loss( | |
| truth=truth, | |
| preds=[x[truth.name] for x in hint_preds], | |
| lengths=lengths, | |
| nb_nodes=nb_nodes, | |
| verbose=True, | |
| )) | |
| return losses_ | |
| def restore_model(self, file_name: str, only_load_processor: bool = False): | |
| """Restore model from `file_name`.""" | |
| path = os.path.join(self.checkpoint_path, file_name) | |
| with open(path, 'rb') as f: | |
| restored_state = pickle.load(f) | |
| if only_load_processor: | |
| restored_params = _filter_in_processor(restored_state['params']) | |
| else: | |
| restored_params = restored_state['params'] | |
| self.params = hk.data_structures.merge(self.params, restored_params) | |
| self.opt_state = restored_state['opt_state'] | |
| def save_model(self, file_name: str): | |
| """Save model (processor weights only) to `file_name`.""" | |
| os.makedirs(self.checkpoint_path, exist_ok=True) | |
| to_save = {'params': self.params, 'opt_state': self.opt_state} | |
| path = os.path.join(self.checkpoint_path, file_name) | |
| with open(path, 'wb') as f: | |
| pickle.dump(to_save, f) | |
| class BaselineModelChunked(BaselineModel): | |
| """Model that processes time-chunked data. | |
| Unlike `BaselineModel`, which processes full samples, `BaselineModelChunked` | |
| processes fixed-timelength chunks of data. Each tensor of inputs and hints | |
| has dimensions chunk_length x batch_size x ... The beginning of a new | |
| sample withing the chunk is signalled by a tensor called `is_first` of | |
| dimensions chunk_length x batch_size. | |
| The chunked model is intended for training. For validation and test, use | |
| `BaselineModel`. | |
| """ | |
| mp_states: List[List[nets.MessagePassingStateChunked]] | |
| init_mp_states: List[List[nets.MessagePassingStateChunked]] | |
| def _create_net_fns(self, hidden_dim, encode_hints, processor_factory, | |
| use_lstm, encoder_init, dropout_prob, | |
| hint_teacher_forcing, hint_repred_mode): | |
| def _use_net(*args, **kwargs): | |
| return nets.NetChunked( | |
| self._spec, hidden_dim, encode_hints, self.decode_hints, | |
| processor_factory, use_lstm, encoder_init, dropout_prob, | |
| hint_teacher_forcing, hint_repred_mode, | |
| self.nb_dims, self.nb_msg_passing_steps)(*args, **kwargs) | |
| self.net_fn = hk.transform(_use_net) | |
| pmap_args = dict(axis_name='batch', devices=jax.local_devices()) | |
| n_devices = jax.local_device_count() | |
| func, static_arg, extra_args = ( | |
| (jax.jit, 'static_argnums', {}) if n_devices == 1 else | |
| (jax.pmap, 'static_broadcasted_argnums', pmap_args)) | |
| pmean = functools.partial(jax.lax.pmean, axis_name='batch') | |
| self._maybe_pmean = pmean if n_devices > 1 else lambda x: x | |
| extra_args[static_arg] = 4 | |
| self.jitted_grad = func(self._compute_grad, **extra_args) | |
| extra_args[static_arg] = 5 | |
| self.jitted_feedback = func(self._feedback, donate_argnums=[0, 4], | |
| **extra_args) | |
| extra_args[static_arg] = [3, 4] | |
| self.jitted_accum_opt_update = func(accum_opt_update, donate_argnums=[0, 2], | |
| **extra_args) | |
| def _init_mp_state(self, features_list: List[List[_FeaturesChunked]], | |
| rng_key: _Array): | |
| def _empty_mp_state(): | |
| return nets.MessagePassingStateChunked( # pytype: disable=wrong-arg-types # numpy-scalars | |
| inputs=None, hints=None, is_first=None, | |
| hint_preds=None, hiddens=None, lstm_state=None) | |
| empty_mp_states = [[_empty_mp_state() for _ in f] for f in features_list] | |
| dummy_params = [self.net_fn.init(rng_key, f, e, False, | |
| init_mp_state=True, algorithm_index=-1) | |
| for (f, e) in zip(features_list, empty_mp_states)] | |
| mp_states = [ | |
| self.net_fn.apply(d, rng_key, f, e, False, | |
| init_mp_state=True, algorithm_index=-1)[1] | |
| for (d, f, e) in zip(dummy_params, features_list, empty_mp_states)] | |
| return mp_states | |
| def init(self, | |
| features: List[List[_FeaturesChunked]], | |
| seed: _Seed): | |
| self.mp_states = self._init_mp_state(features, | |
| jax.random.PRNGKey(seed)) # pytype: disable=wrong-arg-types # jax-ndarray | |
| self.init_mp_states = [list(x) for x in self.mp_states] | |
| self.params = self.net_fn.init( | |
| jax.random.PRNGKey(seed), features[0], self.mp_states[0], # pytype: disable=wrong-arg-types # jax-ndarray | |
| True, init_mp_state=False, algorithm_index=-1) | |
| self.opt_state = self.opt.init(self.params) | |
| # We will use the optimizer state skeleton for traversal when we | |
| # want to avoid updating the state of params of untrained algorithms. | |
| self.opt_state_skeleton = self.opt.init(jnp.zeros(1)) | |
| def predict(self, rng_key: hk.PRNGSequence, features: _FeaturesChunked, | |
| algorithm_index: Optional[int] = None): | |
| """Inference not implemented. Chunked model intended for training only.""" | |
| raise NotImplementedError | |
| def _loss(self, params, rng_key, feedback, mp_state, algorithm_index): | |
| (output_preds, hint_preds), mp_state = self.net_fn.apply( | |
| params, rng_key, [feedback.features], | |
| [mp_state], | |
| repred=False, | |
| init_mp_state=False, | |
| algorithm_index=algorithm_index) | |
| nb_nodes = _nb_nodes(feedback, is_chunked=True) | |
| total_loss = 0.0 | |
| is_first = feedback.features.is_first | |
| is_last = feedback.features.is_last | |
| # Calculate output loss. | |
| for truth in feedback.outputs: | |
| total_loss += losses.output_loss_chunked( | |
| truth=truth, | |
| pred=output_preds[truth.name], | |
| is_last=is_last, | |
| nb_nodes=nb_nodes, | |
| ) | |
| # Optionally accumulate hint losses. | |
| if self.decode_hints: | |
| for truth in feedback.features.hints: | |
| loss = losses.hint_loss_chunked( | |
| truth=truth, | |
| pred=hint_preds[truth.name], | |
| is_first=is_first, | |
| nb_nodes=nb_nodes, | |
| ) | |
| total_loss += loss | |
| return total_loss, (mp_state,) | |
| def _compute_grad(self, params, rng_key, feedback, mp_state, algorithm_index): | |
| (lss, (mp_state,)), grads = jax.value_and_grad(self._loss, has_aux=True)( | |
| params, rng_key, feedback, mp_state, algorithm_index) | |
| return self._maybe_pmean(lss), mp_state, self._maybe_pmean(grads) | |
| def _feedback(self, params, rng_key, feedback, mp_state, opt_state, | |
| algorithm_index): | |
| (lss, (mp_state,)), grads = jax.value_and_grad(self._loss, has_aux=True)( | |
| params, rng_key, feedback, mp_state, algorithm_index) | |
| grads = self._maybe_pmean(grads) | |
| params, opt_state = self._update_params(params, grads, opt_state, | |
| algorithm_index) | |
| lss = self._maybe_pmean(lss) | |
| return lss, params, opt_state, mp_state | |
| def compute_grad( | |
| self, | |
| rng_key: hk.PRNGSequence, | |
| feedback: _Feedback, | |
| algorithm_index: Optional[Tuple[int, int]] = None, | |
| ) -> Tuple[float, _Array]: | |
| """Compute gradients.""" | |
| if algorithm_index is None: | |
| assert len(self._spec) == 1 | |
| algorithm_index = (0, 0) | |
| length_index, algorithm_index = algorithm_index | |
| # Reusing init_mp_state improves performance. | |
| # The next, commented out line, should be used for proper state keeping. | |
| # mp_state = self.mp_states[length_index][algorithm_index] | |
| mp_state = self.init_mp_states[length_index][algorithm_index] | |
| rng_keys = _maybe_pmap_rng_key(rng_key) # pytype: disable=wrong-arg-types # numpy-scalars | |
| feedback = _maybe_pmap_reshape(feedback, split_axis=1) | |
| mp_state = _maybe_pmap_reshape(mp_state, split_axis=0) | |
| loss, mp_state, grads = self.jitted_grad( | |
| self._device_params, rng_keys, feedback, mp_state, algorithm_index) | |
| loss = _maybe_pick_first_pmapped(loss) | |
| grads = _maybe_pick_first_pmapped(grads) | |
| mp_state = _maybe_restack_from_pmap(mp_state) | |
| self.mp_states[length_index][algorithm_index] = mp_state | |
| return loss, grads | |
| def feedback(self, rng_key: hk.PRNGSequence, feedback: _Feedback, | |
| algorithm_index=None) -> float: | |
| if algorithm_index is None: | |
| assert len(self._spec) == 1 | |
| algorithm_index = (0, 0) | |
| length_index, algorithm_index = algorithm_index | |
| # Reusing init_mp_state improves performance. | |
| # The next, commented out line, should be used for proper state keeping. | |
| # mp_state = self.mp_states[length_index][algorithm_index] | |
| mp_state = self.init_mp_states[length_index][algorithm_index] | |
| rng_keys = _maybe_pmap_rng_key(rng_key) # pytype: disable=wrong-arg-types # numpy-scalars | |
| feedback = _maybe_pmap_reshape(feedback, split_axis=1) | |
| mp_state = _maybe_pmap_reshape(mp_state, split_axis=0) | |
| loss, self._device_params, self._device_opt_state, mp_state = ( | |
| self.jitted_feedback( | |
| self._device_params, rng_keys, feedback, | |
| mp_state, self._device_opt_state, algorithm_index)) | |
| loss = _maybe_pick_first_pmapped(loss) | |
| mp_state = _maybe_restack_from_pmap(mp_state) | |
| self.mp_states[length_index][algorithm_index] = mp_state | |
| return loss | |
| def verbose_loss(self, *args, **kwargs): | |
| raise NotImplementedError | |
| def _nb_nodes(feedback: _Feedback, is_chunked) -> int: | |
| for inp in feedback.features.inputs: | |
| if inp.location in [_Location.NODE, _Location.EDGE]: | |
| if is_chunked: | |
| return inp.data.shape[2] # inputs are time x batch x nodes x ... | |
| else: | |
| return inp.data.shape[1] # inputs are batch x nodes x ... | |
| assert False | |
| def _param_in_processor(module_name): | |
| return processors.PROCESSOR_TAG in module_name | |
| def _filter_out_processor(params: hk.Params) -> hk.Params: | |
| return hk.data_structures.filter( | |
| lambda module_name, n, v: not _param_in_processor(module_name), params) | |
| def _filter_in_processor(params: hk.Params) -> hk.Params: | |
| return hk.data_structures.filter( | |
| lambda module_name, n, v: _param_in_processor(module_name), params) | |
| def _is_not_done_broadcast(lengths, i, tensor): | |
| is_not_done = (lengths > i + 1) * 1.0 | |
| while len(is_not_done.shape) < len(tensor.shape): | |
| is_not_done = jnp.expand_dims(is_not_done, -1) | |
| return is_not_done | |
| def accum_opt_update(params, grads, opt_state, opt, freeze_processor): | |
| """Update params from gradients collected from several algorithms.""" | |
| # Average the gradients over all algos | |
| grads = jax.tree_util.tree_map( | |
| lambda *x: sum(x) / (sum([jnp.any(k) for k in x]) + 1e-12), *grads) | |
| updates, opt_state = opt.update(grads, opt_state) | |
| if freeze_processor: | |
| params_subset = _filter_out_processor(params) | |
| assert len(params) > len(params_subset) | |
| assert params_subset | |
| updates_subset = _filter_out_processor(updates) | |
| new_params = optax.apply_updates(params_subset, updates_subset) | |
| new_params = hk.data_structures.merge(params, new_params) | |
| else: | |
| new_params = optax.apply_updates(params, updates) | |
| return new_params, opt_state | |
| def opt_update(opt, flat_grads, flat_opt_state): | |
| return opt.update(flat_grads, flat_opt_state) | |
| def filter_null_grads(grads, opt, opt_state, opt_state_skeleton, algo_idx): | |
| """Compute updates ignoring params that have no gradients. | |
| This prevents untrained params (e.g., encoders/decoders for algorithms | |
| that are not being trained) to accumulate, e.g., momentum from spurious | |
| zero gradients. | |
| Note: this works as intended for "per-parameter" optimizer state, such as | |
| momentum. However, when the optimizer has some global state (such as the | |
| step counts in Adam), the global state will be updated every time, | |
| affecting also future updates of parameters that had null gradients in the | |
| current step. | |
| Args: | |
| grads: Gradients for all parameters. | |
| opt: Optax optimizer. | |
| opt_state: Optimizer state. | |
| opt_state_skeleton: A "skeleton" of optimizer state that has been | |
| initialized with scalar parameters. This serves to traverse each parameter | |
| of the otpimizer state during the opt state update. | |
| algo_idx: Index of algorithm, to filter out unused encoders/decoders. | |
| If None, no filtering happens. | |
| Returns: | |
| Updates and new optimizer state, where the parameters with null gradient | |
| have not been taken into account. | |
| """ | |
| def _keep_in_algo(k, v): | |
| """Ignore params of encoders/decoders irrelevant for this algo.""" | |
| # Note: in shared pointer decoder modes, we should exclude shared params | |
| # for algos that do not have pointer outputs. | |
| if ((processors.PROCESSOR_TAG in k) or | |
| (f'algo_{algo_idx}_' in k)): | |
| return v | |
| return jax.tree_util.tree_map(lambda x: None, v) | |
| if algo_idx is None: | |
| masked_grads = grads | |
| else: | |
| masked_grads = {k: _keep_in_algo(k, v) for k, v in grads.items()} | |
| flat_grads, treedef = jax.tree_util.tree_flatten(masked_grads) | |
| flat_opt_state = jax.tree_util.tree_map( | |
| lambda _, x: x # pylint:disable=g-long-lambda | |
| if isinstance(x, (np.ndarray, jax.Array)) | |
| else treedef.flatten_up_to(x), | |
| opt_state_skeleton, | |
| opt_state, | |
| ) | |
| # Compute updates only for the params with gradient. | |
| flat_updates, flat_opt_state = opt_update(opt, flat_grads, flat_opt_state) | |
| def unflatten(flat, original): | |
| """Restore tree structure, filling missing (None) leaves with original.""" | |
| if isinstance(flat, (np.ndarray, jax.Array)): | |
| return flat | |
| return jax.tree_util.tree_map(lambda x, y: x if y is None else y, original, | |
| treedef.unflatten(flat)) | |
| # Restore the state and updates tree structure. | |
| new_opt_state = jax.tree_util.tree_map(lambda _, x, y: unflatten(x, y), | |
| opt_state_skeleton, flat_opt_state, | |
| opt_state) | |
| updates = unflatten(flat_updates, | |
| jax.tree_util.tree_map(lambda x: 0., grads)) | |
| return updates, new_opt_state | |