Spaces:
Running
on
T4
Running
on
T4
File size: 2,358 Bytes
85bd48b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A collection of JAX utility functions for use in protein folding."""
import collections
import numbers
from typing import Mapping
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
def final_init(config):
if config.zero_init:
return 'zeros'
else:
return 'linear'
def batched_gather(params, indices, axis=0, batch_dims=0):
"""Implements a JAX equivalent of `tf.gather` with `axis` and `batch_dims`."""
take_fn = lambda p, i: jnp.take(p, i, axis=axis)
for _ in range(batch_dims):
take_fn = jax.vmap(take_fn)
return take_fn(params, indices)
def mask_mean(mask, value, axis=None, drop_mask_channel=False, eps=1e-10):
"""Masked mean."""
if drop_mask_channel:
mask = mask[..., 0]
mask_shape = mask.shape
value_shape = value.shape
assert len(mask_shape) == len(value_shape)
if isinstance(axis, numbers.Integral):
axis = [axis]
elif axis is None:
axis = list(range(len(mask_shape)))
assert isinstance(axis, collections.Iterable), (
'axis needs to be either an iterable, integer or "None"')
broadcast_factor = 1.
for axis_ in axis:
value_size = value_shape[axis_]
mask_size = mask_shape[axis_]
if mask_size == 1:
broadcast_factor *= value_size
else:
assert mask_size == value_size
return (jnp.sum(mask * value, axis=axis) /
(jnp.sum(mask, axis=axis) * broadcast_factor + eps))
def flat_params_to_haiku(params: Mapping[str, np.ndarray]) -> hk.Params:
"""Convert a dictionary of NumPy arrays to Haiku parameters."""
hk_params = {}
for path, array in params.items():
scope, name = path.split('//')
if scope not in hk_params:
hk_params[scope] = {}
hk_params[scope][name] = jnp.array(array)
return hk_params
|