|
|
|
|
|
import math |
|
from typing import Any |
|
|
|
import mlx.core as mx |
|
from mlx.nn.layers.base import Module |
|
|
|
|
|
class Identity(Module): |
|
r"""A placeholder identity operator that is argument-insensitive. |
|
|
|
Args: |
|
args: any argument (unused) |
|
kwargs: any keyword argument (unused) |
|
""" |
|
|
|
def __init__(self, *args: Any, **kwargs: Any) -> None: |
|
super().__init__() |
|
|
|
def __call__(self, x: mx.array) -> mx.array: |
|
return x |
|
|
|
|
|
class Linear(Module): |
|
r"""Applies an affine transformation to the input. |
|
|
|
Concretely: |
|
|
|
.. math:: |
|
|
|
y = x W^\top + b |
|
|
|
where: |
|
where :math:`W` has shape ``[output_dims, input_dims]`` and :math:`b` has shape ``[output_dims]``. |
|
|
|
The values are initialized from the uniform distribution :math:`\mathcal{U}(-{k}, {k})`, |
|
where :math:`k = \frac{1}{\sqrt{D_i}}` and :math:`D_i` is equal to ``input_dims``. |
|
|
|
Args: |
|
input_dims (int): The dimensionality of the input features |
|
output_dims (int): The dimensionality of the output features |
|
bias (bool, optional): If set to ``False`` then the layer will |
|
not use a bias. Default is ``True``. |
|
""" |
|
|
|
def __init__(self, input_dims: int, output_dims: int, bias: bool = True) -> None: |
|
super().__init__() |
|
scale = math.sqrt(1.0 / input_dims) |
|
self.weight = mx.random.uniform( |
|
low=-scale, |
|
high=scale, |
|
shape=(output_dims, input_dims), |
|
) |
|
if bias: |
|
self.bias = mx.random.uniform( |
|
low=-scale, |
|
high=scale, |
|
shape=(output_dims,), |
|
) |
|
|
|
def _extra_repr(self) -> str: |
|
return f"input_dims={self.weight.shape[1]}, output_dims={self.weight.shape[0]}, bias={'bias' in self}" |
|
|
|
def __call__(self, x: mx.array) -> mx.array: |
|
x = x @ self.weight.T |
|
if "bias" in self: |
|
x = x + self.bias |
|
return x |
|
|
|
|
|
class Bilinear(Module): |
|
r"""Applies a bilinear transformation to the inputs. |
|
|
|
Concretely: |
|
|
|
.. math:: |
|
|
|
y_i = x_1^\top W_i x_2 + b_i |
|
|
|
where: |
|
:math:`W` has shape ``[output_dims, input1_dims, input2_dims]``, :math:`b` has shape ``[output_dims ]``, |
|
and :math:`i` indexes the output dimension. |
|
|
|
The values are initialized from the uniform distribution :math:`\mathcal{U}(-{k}, {k})`, |
|
where :math:`k = \frac{1}{\sqrt{D_1}}` and :math:`D_1` is ``input1_dims``. |
|
|
|
Args: |
|
input1_dims (int): The dimensionality of the input1 features |
|
input2_dims (int): The dimensionality of the input2 features |
|
output_dims (int): The dimensionality of the output features |
|
bias (bool, optional): If set to ``False`` then the layer will |
|
not use a bias. Default is ``True``. |
|
""" |
|
|
|
def __init__( |
|
self, input1_dims: int, input2_dims: int, output_dims: int, bias: bool = True |
|
) -> None: |
|
super().__init__() |
|
scale = math.sqrt(1.0 / input1_dims) |
|
self.weight = mx.random.uniform( |
|
low=-scale, |
|
high=scale, |
|
shape=(output_dims, input2_dims, input1_dims), |
|
) |
|
if bias: |
|
self.bias = mx.random.uniform( |
|
low=-scale, |
|
high=scale, |
|
shape=(output_dims,), |
|
) |
|
|
|
def _extra_repr(self) -> str: |
|
out, in2, in1 = self.weight.shape |
|
return ( |
|
f"input1_dims={in1}, input2_dims={in2}, output_dims={out}, " |
|
f"bias={'bias' in self}" |
|
) |
|
|
|
def __call__(self, x1: mx.array, x2: mx.array) -> mx.array: |
|
|
|
out, in2, in1 = self.weight.shape |
|
xshape = x1.shape[:-1] |
|
x1 = x1.reshape(-1, in1) |
|
x2 = x2.reshape(-1, 1, in2) |
|
|
|
|
|
w = self.weight.reshape(out * in2, in1) |
|
y = x1 @ w.T |
|
y = y.reshape(-1, out, in2).swapaxes(-2, -1) |
|
y = x2 @ y |
|
y = y.squeeze(1) |
|
|
|
|
|
y = y.reshape(*xshape, out) |
|
|
|
|
|
if "bias" in self: |
|
y = y + self.bias |
|
|
|
return y |
|
|