reach-vb's picture
reach-vb HF staff
ce304fafe19161978ad512b385c65426bad519e5a0b8fb3f0659eace3d2ea3cc
f14e74e
raw
history blame
4.14 kB
# Copyright © 2023 Apple Inc.
import math
from typing import Any
import mlx.core as mx
from mlx.nn.layers.base import Module
class Identity(Module):
r"""A placeholder identity operator that is argument-insensitive.
Args:
args: any argument (unused)
kwargs: any keyword argument (unused)
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__()
def __call__(self, x: mx.array) -> mx.array:
return x
class Linear(Module):
r"""Applies an affine transformation to the input.
Concretely:
.. math::
y = x W^\top + b
where:
where :math:`W` has shape ``[output_dims, input_dims]`` and :math:`b` has shape ``[output_dims]``.
The values are initialized from the uniform distribution :math:`\mathcal{U}(-{k}, {k})`,
where :math:`k = \frac{1}{\sqrt{D_i}}` and :math:`D_i` is equal to ``input_dims``.
Args:
input_dims (int): The dimensionality of the input features
output_dims (int): The dimensionality of the output features
bias (bool, optional): If set to ``False`` then the layer will
not use a bias. Default is ``True``.
"""
def __init__(self, input_dims: int, output_dims: int, bias: bool = True) -> None:
super().__init__()
scale = math.sqrt(1.0 / input_dims)
self.weight = mx.random.uniform(
low=-scale,
high=scale,
shape=(output_dims, input_dims),
)
if bias:
self.bias = mx.random.uniform(
low=-scale,
high=scale,
shape=(output_dims,),
)
def _extra_repr(self) -> str:
return f"input_dims={self.weight.shape[1]}, output_dims={self.weight.shape[0]}, bias={'bias' in self}"
def __call__(self, x: mx.array) -> mx.array:
x = x @ self.weight.T
if "bias" in self:
x = x + self.bias
return x
class Bilinear(Module):
r"""Applies a bilinear transformation to the inputs.
Concretely:
.. math::
y_i = x_1^\top W_i x_2 + b_i
where:
:math:`W` has shape ``[output_dims, input1_dims, input2_dims]``, :math:`b` has shape ``[output_dims ]``,
and :math:`i` indexes the output dimension.
The values are initialized from the uniform distribution :math:`\mathcal{U}(-{k}, {k})`,
where :math:`k = \frac{1}{\sqrt{D_1}}` and :math:`D_1` is ``input1_dims``.
Args:
input1_dims (int): The dimensionality of the input1 features
input2_dims (int): The dimensionality of the input2 features
output_dims (int): The dimensionality of the output features
bias (bool, optional): If set to ``False`` then the layer will
not use a bias. Default is ``True``.
"""
def __init__(
self, input1_dims: int, input2_dims: int, output_dims: int, bias: bool = True
) -> None:
super().__init__()
scale = math.sqrt(1.0 / input1_dims)
self.weight = mx.random.uniform(
low=-scale,
high=scale,
shape=(output_dims, input2_dims, input1_dims),
)
if bias:
self.bias = mx.random.uniform(
low=-scale,
high=scale,
shape=(output_dims,),
)
def _extra_repr(self) -> str:
out, in2, in1 = self.weight.shape
return (
f"input1_dims={in1}, input2_dims={in2}, output_dims={out}, "
f"bias={'bias' in self}"
)
def __call__(self, x1: mx.array, x2: mx.array) -> mx.array:
# Normalize shapes
out, in2, in1 = self.weight.shape
xshape = x1.shape[:-1]
x1 = x1.reshape(-1, in1)
x2 = x2.reshape(-1, 1, in2)
# Perform the bilinear transformation
w = self.weight.reshape(out * in2, in1)
y = x1 @ w.T
y = y.reshape(-1, out, in2).swapaxes(-2, -1)
y = x2 @ y
y = y.squeeze(1)
# Reset the shape
y = y.reshape(*xshape, out)
# Apply the bias
if "bias" in self:
y = y + self.bias
return y