File size: 444 Bytes
904ef7d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
import torch
from torch.autograd import Function
from torch.cuda.amp import custom_bwd, custom_fwd
class _trunc_exp(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float)
def forward(ctx, x):
ctx.save_for_backward(x)
return torch.exp(x)
@staticmethod
@custom_bwd
def backward(ctx, g):
x = ctx.saved_tensors[0]
return g * torch.exp(x.clamp(-15, 15))
trunc_exp = _trunc_exp.apply |