Spaces:
Running
on
L40S
Running
on
L40S
import torch | |
import torch.nn as nn | |
from . import SparseTensor | |
__all__ = [ | |
'SparseLinear' | |
] | |
class SparseLinear(nn.Linear): | |
def __init__(self, in_features, out_features, bias=True): | |
super(SparseLinear, self).__init__(in_features, out_features, bias) | |
def forward(self, input: SparseTensor) -> SparseTensor: | |
return input.replace(super().forward(input.feats)) | |