|
import torch
|
|
import torch.nn as nn
|
|
from . import SparseTensor
|
|
|
|
__all__ = [
|
|
'SparseReLU',
|
|
'SparseSiLU',
|
|
'SparseGELU',
|
|
'SparseActivation'
|
|
]
|
|
|
|
|
|
class SparseReLU(nn.ReLU):
|
|
def forward(self, input: SparseTensor) -> SparseTensor:
|
|
return input.replace(super().forward(input.feats))
|
|
|
|
|
|
class SparseSiLU(nn.SiLU):
|
|
def forward(self, input: SparseTensor) -> SparseTensor:
|
|
return input.replace(super().forward(input.feats))
|
|
|
|
|
|
class SparseGELU(nn.GELU):
|
|
def forward(self, input: SparseTensor) -> SparseTensor:
|
|
return input.replace(super().forward(input.feats))
|
|
|
|
|
|
class SparseActivation(nn.Module):
|
|
def __init__(self, activation: nn.Module):
|
|
super().__init__()
|
|
self.activation = activation
|
|
|
|
def forward(self, input: SparseTensor) -> SparseTensor:
|
|
return input.replace(self.activation(input.feats))
|
|
|
|
|