File size: 4,137 Bytes
f14e74e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
# Copyright © 2023 Apple Inc.
import math
from typing import Any
import mlx.core as mx
from mlx.nn.layers.base import Module
class Identity(Module):
r"""A placeholder identity operator that is argument-insensitive.
Args:
args: any argument (unused)
kwargs: any keyword argument (unused)
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__()
def __call__(self, x: mx.array) -> mx.array:
return x
class Linear(Module):
r"""Applies an affine transformation to the input.
Concretely:
.. math::
y = x W^\top + b
where:
where :math:`W` has shape ``[output_dims, input_dims]`` and :math:`b` has shape ``[output_dims]``.
The values are initialized from the uniform distribution :math:`\mathcal{U}(-{k}, {k})`,
where :math:`k = \frac{1}{\sqrt{D_i}}` and :math:`D_i` is equal to ``input_dims``.
Args:
input_dims (int): The dimensionality of the input features
output_dims (int): The dimensionality of the output features
bias (bool, optional): If set to ``False`` then the layer will
not use a bias. Default is ``True``.
"""
def __init__(self, input_dims: int, output_dims: int, bias: bool = True) -> None:
super().__init__()
scale = math.sqrt(1.0 / input_dims)
self.weight = mx.random.uniform(
low=-scale,
high=scale,
shape=(output_dims, input_dims),
)
if bias:
self.bias = mx.random.uniform(
low=-scale,
high=scale,
shape=(output_dims,),
)
def _extra_repr(self) -> str:
return f"input_dims={self.weight.shape[1]}, output_dims={self.weight.shape[0]}, bias={'bias' in self}"
def __call__(self, x: mx.array) -> mx.array:
x = x @ self.weight.T
if "bias" in self:
x = x + self.bias
return x
class Bilinear(Module):
r"""Applies a bilinear transformation to the inputs.
Concretely:
.. math::
y_i = x_1^\top W_i x_2 + b_i
where:
:math:`W` has shape ``[output_dims, input1_dims, input2_dims]``, :math:`b` has shape ``[output_dims ]``,
and :math:`i` indexes the output dimension.
The values are initialized from the uniform distribution :math:`\mathcal{U}(-{k}, {k})`,
where :math:`k = \frac{1}{\sqrt{D_1}}` and :math:`D_1` is ``input1_dims``.
Args:
input1_dims (int): The dimensionality of the input1 features
input2_dims (int): The dimensionality of the input2 features
output_dims (int): The dimensionality of the output features
bias (bool, optional): If set to ``False`` then the layer will
not use a bias. Default is ``True``.
"""
def __init__(
self, input1_dims: int, input2_dims: int, output_dims: int, bias: bool = True
) -> None:
super().__init__()
scale = math.sqrt(1.0 / input1_dims)
self.weight = mx.random.uniform(
low=-scale,
high=scale,
shape=(output_dims, input2_dims, input1_dims),
)
if bias:
self.bias = mx.random.uniform(
low=-scale,
high=scale,
shape=(output_dims,),
)
def _extra_repr(self) -> str:
out, in2, in1 = self.weight.shape
return (
f"input1_dims={in1}, input2_dims={in2}, output_dims={out}, "
f"bias={'bias' in self}"
)
def __call__(self, x1: mx.array, x2: mx.array) -> mx.array:
# Normalize shapes
out, in2, in1 = self.weight.shape
xshape = x1.shape[:-1]
x1 = x1.reshape(-1, in1)
x2 = x2.reshape(-1, 1, in2)
# Perform the bilinear transformation
w = self.weight.reshape(out * in2, in1)
y = x1 @ w.T
y = y.reshape(-1, out, in2).swapaxes(-2, -1)
y = x2 @ y
y = y.squeeze(1)
# Reset the shape
y = y.reshape(*xshape, out)
# Apply the bias
if "bias" in self:
y = y + self.bias
return y
|