SceneDreamer / activation.py
gabrielsemiceki9's picture
Upload 125 files
b72e09b verified
raw
history blame contribute delete
464 Bytes
import torch
from torch.autograd import Function
from torch.cuda.amp import custom_bwd, custom_fwd
class _trunc_exp(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32) # cast to float32
def forward(ctx, x):
ctx.save_for_backward(x)
return torch.exp(x)
@staticmethod
@custom_bwd
def backward(ctx, g):
x = ctx.saved_tensors[0]
return g * torch.exp(x.clamp(-15, 15))
trunc_exp = _trunc_exp.apply