|
import numpy as np |
|
import torch |
|
|
|
|
|
def sharpen_prob(p, temperature=2): |
|
"""Sharpening probability with a temperature. |
|
|
|
Args: |
|
p (torch.Tensor): probability matrix (batch_size, n_classes) |
|
temperature (float): temperature. |
|
""" |
|
p = p.pow(temperature) |
|
return p / p.sum(1, keepdim=True) |
|
|
|
|
|
def reverse_index(data, label): |
|
"""Reverse order.""" |
|
inv_idx = torch.arange(data.size(0) - 1, -1, -1).long() |
|
return data[inv_idx], label[inv_idx] |
|
|
|
|
|
def shuffle_index(data, label): |
|
"""Shuffle order.""" |
|
rnd_idx = torch.randperm(data.shape[0]) |
|
return data[rnd_idx], label[rnd_idx] |
|
|
|
|
|
def create_onehot(label, num_classes): |
|
"""Create one-hot tensor. |
|
|
|
We suggest using nn.functional.one_hot. |
|
|
|
Args: |
|
label (torch.Tensor): 1-D tensor. |
|
num_classes (int): number of classes. |
|
""" |
|
onehot = torch.zeros(label.shape[0], num_classes) |
|
onehot = onehot.scatter(1, label.unsqueeze(1).data.cpu(), 1) |
|
onehot = onehot.to(label.device) |
|
return onehot |
|
|
|
|
|
def sigmoid_rampup(current, rampup_length): |
|
"""Exponential rampup. |
|
|
|
Args: |
|
current (int): current step. |
|
rampup_length (int): maximum step. |
|
""" |
|
assert rampup_length > 0 |
|
current = np.clip(current, 0.0, rampup_length) |
|
phase = 1.0 - current/rampup_length |
|
return float(np.exp(-5.0 * phase * phase)) |
|
|
|
|
|
def linear_rampup(current, rampup_length): |
|
"""Linear rampup. |
|
|
|
Args: |
|
current (int): current step. |
|
rampup_length (int): maximum step. |
|
""" |
|
assert rampup_length > 0 |
|
ratio = np.clip(current / rampup_length, 0.0, 1.0) |
|
return float(ratio) |
|
|
|
|
|
def ema_model_update(model, ema_model, alpha): |
|
"""Exponential moving average of model parameters. |
|
|
|
Args: |
|
model (nn.Module): model being trained. |
|
ema_model (nn.Module): ema of the model. |
|
alpha (float): ema decay rate. |
|
""" |
|
for ema_param, param in zip(ema_model.parameters(), model.parameters()): |
|
ema_param.data.mul_(alpha).add_(param.data, alpha=1 - alpha) |
|
|