Spaces:
Paused
Paused
| import math | |
| import torch | |
| from torch.optim import Optimizer | |
| import numpy as np | |
| class ClampOptimizer(Optimizer): | |
| def __init__(self, optimizer, params, **kwargs): | |
| self.opt = optimizer(params, **kwargs) | |
| self.params = params | |
| def step(self, closure=None): | |
| loss = self.opt.step(closure) | |
| for param in self.params: | |
| tmp_latent_norm = torch.clamp(param.data, 0, 1) | |
| param.data.add_(tmp_latent_norm - param.data) | |
| return loss | |
| def zero_grad(self): | |
| self.opt.zero_grad() | |