File size: 4,137 Bytes
f14e74e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# Copyright © 2023 Apple Inc.

import math
from typing import Any

import mlx.core as mx
from mlx.nn.layers.base import Module


class Identity(Module):
    r"""A placeholder identity operator that is argument-insensitive.

    Args:
        args: any argument (unused)
        kwargs: any keyword argument (unused)
    """

    def __init__(self, *args: Any, **kwargs: Any) -> None:
        super().__init__()

    def __call__(self, x: mx.array) -> mx.array:
        return x


class Linear(Module):
    r"""Applies an affine transformation to the input.

    Concretely:

    .. math::

        y = x W^\top + b

    where:
    where :math:`W` has shape ``[output_dims, input_dims]`` and :math:`b` has shape ``[output_dims]``.

    The values are initialized from the uniform distribution :math:`\mathcal{U}(-{k}, {k})`,
    where :math:`k = \frac{1}{\sqrt{D_i}}` and :math:`D_i` is equal to ``input_dims``.

    Args:
        input_dims (int): The dimensionality of the input features
        output_dims (int): The dimensionality of the output features
        bias (bool, optional): If set to ``False`` then the layer will
          not use a bias. Default is ``True``.
    """

    def __init__(self, input_dims: int, output_dims: int, bias: bool = True) -> None:
        super().__init__()
        scale = math.sqrt(1.0 / input_dims)
        self.weight = mx.random.uniform(
            low=-scale,
            high=scale,
            shape=(output_dims, input_dims),
        )
        if bias:
            self.bias = mx.random.uniform(
                low=-scale,
                high=scale,
                shape=(output_dims,),
            )

    def _extra_repr(self) -> str:
        return f"input_dims={self.weight.shape[1]}, output_dims={self.weight.shape[0]}, bias={'bias' in self}"

    def __call__(self, x: mx.array) -> mx.array:
        x = x @ self.weight.T
        if "bias" in self:
            x = x + self.bias
        return x


class Bilinear(Module):
    r"""Applies a bilinear transformation to the inputs.

    Concretely:

    .. math::

        y_i = x_1^\top W_i x_2 + b_i

    where:
    :math:`W` has shape ``[output_dims, input1_dims, input2_dims]``, :math:`b` has shape ``[output_dims ]``,
    and :math:`i` indexes the output dimension.

    The values are initialized from the uniform distribution :math:`\mathcal{U}(-{k}, {k})`,
    where :math:`k = \frac{1}{\sqrt{D_1}}` and :math:`D_1` is ``input1_dims``.

    Args:
        input1_dims (int): The dimensionality of the input1 features
        input2_dims (int): The dimensionality of the input2 features
        output_dims (int): The dimensionality of the output features
        bias (bool, optional): If set to ``False`` then the layer will
          not use a bias. Default is ``True``.
    """

    def __init__(
        self, input1_dims: int, input2_dims: int, output_dims: int, bias: bool = True
    ) -> None:
        super().__init__()
        scale = math.sqrt(1.0 / input1_dims)
        self.weight = mx.random.uniform(
            low=-scale,
            high=scale,
            shape=(output_dims, input2_dims, input1_dims),
        )
        if bias:
            self.bias = mx.random.uniform(
                low=-scale,
                high=scale,
                shape=(output_dims,),
            )

    def _extra_repr(self) -> str:
        out, in2, in1 = self.weight.shape
        return (
            f"input1_dims={in1}, input2_dims={in2}, output_dims={out}, "
            f"bias={'bias' in self}"
        )

    def __call__(self, x1: mx.array, x2: mx.array) -> mx.array:
        # Normalize shapes
        out, in2, in1 = self.weight.shape
        xshape = x1.shape[:-1]
        x1 = x1.reshape(-1, in1)
        x2 = x2.reshape(-1, 1, in2)

        # Perform the bilinear transformation
        w = self.weight.reshape(out * in2, in1)
        y = x1 @ w.T
        y = y.reshape(-1, out, in2).swapaxes(-2, -1)
        y = x2 @ y
        y = y.squeeze(1)

        # Reset the shape
        y = y.reshape(*xshape, out)

        # Apply the bias
        if "bias" in self:
            y = y + self.bias

        return y