|
""" Linear layer (alternate definition) |
|
""" |
|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn as nn |
|
|
|
|
|
class Linear(nn.Linear): |
|
r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b` |
|
|
|
Wraps torch.nn.Linear to support AMP + torchscript usage by manually casting |
|
weight & bias to input.dtype to work around an issue w/ torch.addmm in this use case. |
|
""" |
|
def forward(self, input: torch.Tensor) -> torch.Tensor: |
|
if torch.jit.is_scripting(): |
|
bias = self.bias.to(dtype=input.dtype) if self.bias is not None else None |
|
return F.linear(input, self.weight.to(dtype=input.dtype), bias=bias) |
|
else: |
|
return F.linear(input, self.weight, self.bias) |
|
|