#!/usr/bin/python3 # -*- coding: utf-8 -*- """ https://github.com/boreas-l/zipEnhancer/blob/main/models/layers/scaling.py """ import logging import random from typing import Optional, Tuple, Union import torch import torch.nn as nn def logaddexp_onnx(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: max_value = torch.max(x, y) diff = torch.abs(x - y) return max_value + torch.log1p(torch.exp(-diff)) # RuntimeError: Exporting the operator logaddexp to ONNX opset version # 14 is not supported. Please feel free to request support or submit # a pull request on PyTorch GitHub. # # The following function is to solve the above error when exporting # models to ONNX via torch.jit.trace() def logaddexp(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: if torch.jit.is_scripting(): # Note: We cannot use torch.jit.is_tracing() here as it also # matches torch.onnx.export(). return torch.logaddexp(x, y) elif torch.onnx.is_in_onnx_export(): return logaddexp_onnx(x, y) else: # for torch.jit.trace() return torch.logaddexp(x, y) class PiecewiseLinear(object): """ Piecewise linear function, from float to float, specified as nonempty list of (x,y) pairs with the x values in order. x values <[initial x] or >[final x] are map to [initial y], [final y] respectively. """ def __init__(self, *args): assert len(args) >= 1, len(args) if len(args) == 1 and isinstance(args[0], PiecewiseLinear): self.pairs = list(args[0].pairs) else: self.pairs = [(float(x), float(y)) for x, y in args] for x, y in self.pairs: assert isinstance(x, (float, int)), type(x) assert isinstance(y, (float, int)), type(y) for i in range(len(self.pairs) - 1): assert self.pairs[i + 1][0] > self.pairs[i][0], ( i, self.pairs[i], self.pairs[i + 1], ) def __str__(self): # e.g. 'PiecewiseLinear((0., 10.), (100., 0.))' return f'PiecewiseLinear({str(self.pairs)[1:-1]})' def __call__(self, x): if x <= self.pairs[0][0]: return self.pairs[0][1] elif x >= self.pairs[-1][0]: return self.pairs[-1][1] else: cur_x, cur_y = self.pairs[0] for i in range(1, len(self.pairs)): next_x, next_y = self.pairs[i] if cur_x <= x <= next_x: return cur_y + (next_y - cur_y) * (x - cur_x) / ( next_x - cur_x) cur_x, cur_y = next_x, next_y assert False def __mul__(self, alpha): return PiecewiseLinear(*[(x, y * alpha) for x, y in self.pairs]) def __add__(self, x): if isinstance(x, (float, int)): return PiecewiseLinear(*[(p[0], p[1] + x) for p in self.pairs]) s, x = self.get_common_basis(x) return PiecewiseLinear(*[(sp[0], sp[1] + xp[1]) for sp, xp in zip(s.pairs, x.pairs)]) def max(self, x): if isinstance(x, (float, int)): x = PiecewiseLinear((0, x)) s, x = self.get_common_basis(x, include_crossings=True) return PiecewiseLinear(*[(sp[0], max(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)]) def min(self, x): if isinstance(x, float) or isinstance(x, int): x = PiecewiseLinear((0, x)) s, x = self.get_common_basis(x, include_crossings=True) return PiecewiseLinear(*[(sp[0], min(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)]) def __eq__(self, other): return self.pairs == other.pairs def get_common_basis(self, p: 'PiecewiseLinear', include_crossings: bool = False): """ Returns (self_mod, p_mod) which are equivalent piecewise linear functions to self and p, but with the same x values. p: the other piecewise linear function include_crossings: if true, include in the x values positions where the functions indicate by this and p cross. """ assert isinstance(p, PiecewiseLinear), type(p) # get sorted x-values without repetition. x_vals = sorted( set([x for x, _ in self.pairs] + [x for x, _ in p.pairs])) y_vals1 = [self(x) for x in x_vals] y_vals2 = [p(x) for x in x_vals] if include_crossings: extra_x_vals = [] for i in range(len(x_vals) - 1): _compare_results1 = (y_vals1[i] > y_vals2[i]) _compare_results2 = (y_vals1[i + 1] > y_vals2[i + 1]) if _compare_results1 != _compare_results2: # if ((y_vals1[i] > y_vals2[i]) != # (y_vals1[i + 1] > y_vals2[i + 1])): # if the two lines in this subsegment potentially cross each other. diff_cur = abs(y_vals1[i] - y_vals2[i]) diff_next = abs(y_vals1[i + 1] - y_vals2[i + 1]) # `pos`, between 0 and 1, gives the relative x position, # with 0 being x_vals[i] and 1 being x_vals[i+1]. pos = diff_cur / (diff_cur + diff_next) extra_x_val = x_vals[i] + pos * (x_vals[i + 1] - x_vals[i]) extra_x_vals.append(extra_x_val) if len(extra_x_vals) > 0: x_vals = sorted(set(x_vals + extra_x_vals)) y_vals1 = [self(x) for x in x_vals] y_vals2 = [p(x) for x in x_vals] return ( PiecewiseLinear(*zip(x_vals, y_vals1)), PiecewiseLinear(*zip(x_vals, y_vals2)), ) class ScheduledFloat(torch.nn.Module): """ This object is a torch.nn.Module only because we want it to show up in [top_level module].modules(); it does not have a working forward() function. You are supposed to cast it to float, as in, float(parent_module.whatever), and use it as something like a dropout prob. It is a floating point value whose value changes depending on the batch count of the training loop. It is a piecewise linear function where you specify the (x,y) pairs in sorted order on x; x corresponds to the batch index. For batch-index values before the first x or after the last x, we just use the first or last y value. Example: self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0) `default` is used when self.batch_count is not set or not in training mode or in torch.jit scripting mode. """ def __init__(self, *args, default: float = 0.0): super().__init__() # self.batch_count and self.name will be written to in the training loop. self.batch_count = None self.name = None self.default = default self.schedule = PiecewiseLinear(*args) def extra_repr(self) -> str: return ( f'batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}' ) def __float__(self): batch_count = self.batch_count if (batch_count is None or not self.training or torch.jit.is_scripting() or torch.jit.is_tracing()): return float(self.default) else: ans = self.schedule(self.batch_count) if random.random() < 0.0002: logging.info( f'ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}' ) return ans def __add__(self, x): if isinstance(x, float) or isinstance(x, int): return ScheduledFloat(self.schedule + x, default=self.default) else: return ScheduledFloat( self.schedule + x.schedule, default=self.default + x.default) def max(self, x): if isinstance(x, float) or isinstance(x, int): return ScheduledFloat(self.schedule.max(x), default=self.default) else: return ScheduledFloat( self.schedule.max(x.schedule), default=max(self.default, x.default)) FloatLike = Union[float, ScheduledFloat] class SoftmaxFunction(torch.autograd.Function): """ Tries to handle half-precision derivatives in a randomized way that should be more accurate for training than the default behavior. """ @staticmethod def forward(ctx, x: torch.Tensor, dim: int): ans = x.softmax(dim=dim) # if x dtype is float16, x.softmax() returns a float32 because # (presumably) that op does not support float16, and autocast # is enabled. if torch.is_autocast_enabled(): ans = ans.to(torch.float16) ctx.save_for_backward(ans) ctx.x_dtype = x.dtype ctx.dim = dim return ans @staticmethod def backward(ctx, ans_grad: torch.Tensor): (ans,) = ctx.saved_tensors with torch.cuda.amp.autocast(enabled=False): ans_grad = ans_grad.to(torch.float32) ans = ans.to(torch.float32) x_grad = ans_grad * ans x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True) return x_grad, None if __name__ == "__main__": pass