|
import torch |
|
from torchvision import transforms |
|
from math import pi |
|
import torchvision.transforms.functional as TF |
|
|
|
|
|
|
|
def exists(val): |
|
"""Check if a variable exists""" |
|
return val is not None |
|
|
|
|
|
def uniq(arr): |
|
return {el: True for el in arr}.keys() |
|
|
|
|
|
def default(val, d): |
|
"""If a value exists, return it; otherwise, return a default value""" |
|
return val if exists(val) else d |
|
|
|
|
|
def max_neg_value(t): |
|
return -torch.finfo(t.dtype).max |
|
|
|
|
|
def cast_tuple(val, depth=1): |
|
if isinstance(val, list): |
|
val = tuple(val) |
|
return val if isinstance(val, tuple) else (val,) * depth |
|
|
|
|
|
def is_empty(t): |
|
"""Check if a tensor is empty""" |
|
|
|
return t.nelement() == 0 |
|
|
|
|
|
def masked_mean(t, mask, dim=1): |
|
""" |
|
Compute the mean of a tensor, masked by a given mask |
|
|
|
Args: |
|
t (torch.Tensor): input tensor of shape (batch_size, seq_len, hidden_dim) |
|
mask (torch.Tensor): mask tensor of shape (batch_size, seq_len) |
|
dim (int): dimension along which to compute the mean (default=1) |
|
|
|
Returns: |
|
torch.Tensor: masked mean tensor of shape (batch_size, hidden_dim) |
|
""" |
|
t = t.masked_fill(~mask[:, :, None], 0.0) |
|
return t.sum(dim=1) / mask.sum(dim=1)[..., None] |
|
|
|
|
|
def set_requires_grad(model, value): |
|
""" |
|
Set whether or not the model's parameters require gradients |
|
|
|
Args: |
|
model (torch.nn.Module): the PyTorch model to modify |
|
value (bool): whether or not to require gradients |
|
""" |
|
for param in model.parameters(): |
|
param.requires_grad = value |
|
|
|
|
|
def eval_decorator(fn): |
|
""" |
|
Decorator function to evaluate a given function |
|
|
|
Args: |
|
fn (callable): function to evaluate |
|
|
|
Returns: |
|
callable: the decorated function |
|
""" |
|
|
|
def inner(model, *args, **kwargs): |
|
was_training = model.training |
|
model.eval() |
|
out = fn(model, *args, **kwargs) |
|
model.train(was_training) |
|
return out |
|
|
|
return inner |
|
|
|
|
|
def log(t, eps=1e-20): |
|
""" |
|
Compute the natural logarithm of a tensor |
|
|
|
Args: |
|
t (torch.Tensor): input tensor |
|
eps (float): small value to add to prevent taking the log of 0 (default=1e-20) |
|
|
|
Returns: |
|
torch.Tensor: the natural logarithm of the input tensor |
|
""" |
|
return torch.log(t + eps) |
|
|
|
|
|
def gumbel_noise(t): |
|
""" |
|
Generate Gumbel noise |
|
|
|
Args: |
|
t (torch.Tensor): input tensor |
|
|
|
Returns: |
|
torch.Tensor: a tensor of Gumbel noise with the same shape as the input tensor |
|
""" |
|
noise = torch.zeros_like(t).uniform_(0, 1) |
|
return -log(-log(noise)) |
|
|
|
|
|
def gumbel_sample(t, temperature=0.9, dim=-1): |
|
""" |
|
Sample from a Gumbel-softmax distribution |
|
|
|
Args: |
|
t (torch.Tensor): input tensor of shape (batch_size, num_classes) |
|
temperature (float): temperature for the Gumbel-softmax distribution (default=0.9) |
|
dim (int): dimension along which to sample (default=-1) |
|
|
|
Returns: |
|
torch.Tensor: a tensor of samples from the Gumbel-softmax distribution with the same shape as the input tensor |
|
""" |
|
return (t / max(temperature, 1e-10)) + gumbel_noise(t) |
|
|
|
|
|
def top_k(logits, thres=0.5): |
|
""" |
|
Return a tensor where all but the top k values are set to negative infinity |
|
|
|
Args: |
|
logits (torch.Tensor): input tensor of shape (batch_size, num_classes) |
|
thres (float): threshold for the top k values (default=0.5) |
|
|
|
Returns: |
|
torch.Tensor: a tensor with the same shape as the input tensor, where all but the top k values are set to negative infinity |
|
""" |
|
num_logits = logits.shape[-1] |
|
k = max(int((1 - thres) * num_logits), 1) |
|
val, ind = torch.topk(logits, k) |
|
probs = torch.full_like(logits, float("-inf")) |
|
probs.scatter_(-1, ind, val) |
|
return probs |
|
|
|
|
|
def gamma_func(mode="cosine", scale=0.15): |
|
"""Return a function that takes a single input r and returns a value based on the selected mode""" |
|
|
|
|
|
if mode == "linear": |
|
return lambda r: 1 - r |
|
elif mode == "cosine": |
|
return lambda r: torch.cos(r * pi / 2) |
|
elif mode == "square": |
|
return lambda r: 1 - r**2 |
|
elif mode == "cubic": |
|
return lambda r: 1 - r**3 |
|
elif mode == "scaled-cosine": |
|
return lambda r: scale * (torch.cos(r * pi / 2)) |
|
else: |
|
|
|
raise NotImplementedError |
|
|
|
|
|
class always: |
|
"""Helper class to always return a given value""" |
|
|
|
def __init__(self, val): |
|
self.val = val |
|
|
|
def __call__(self, x, *args, **kwargs): |
|
return self.val |
|
|
|
|
|
class DivideMax(torch.nn.Module): |
|
def __init__(self, dim): |
|
super().__init__() |
|
self.dim = dim |
|
|
|
def forward(self, x): |
|
maxes = x.amax(dim=self.dim, keepdim=True).detach() |
|
return x / maxes |
|
|
|
def replace_outliers(image, percentile=0.0001): |
|
|
|
lower_bound, upper_bound = torch.quantile(image, percentile), torch.quantile( |
|
image, 1 - percentile |
|
) |
|
mask = (image <= upper_bound) & (image >= lower_bound) |
|
|
|
valid_pixels = image[mask] |
|
|
|
image[~mask] = torch.clip(image[~mask], min(valid_pixels), max(valid_pixels)) |
|
|
|
return image |
|
|
|
|
|
def process_image(image, dataset, image_type=None): |
|
image = TF.to_tensor(image)[0].unsqueeze(0).unsqueeze(0) |
|
image /= image.max() |
|
|
|
if dataset == "HPA": |
|
if image_type == 'nucleus': |
|
normalize = (0.0655, 0.0650) |
|
|
|
elif image_type == 'protein': |
|
normalize = (0.1732, 0.1208) |
|
|
|
elif dataset == "OpenCell": |
|
|
|
if image_type == 'nucleus': |
|
normalize = (0.0272, 0.0244) |
|
|
|
elif image_type == 'protein': |
|
normalize = (0.0486, 0.0671) |
|
|
|
t_forms = [] |
|
|
|
t_forms.append(transforms.RandomCrop(256)) |
|
|
|
|
|
|
|
|
|
image = transforms.Compose(t_forms)(image) |
|
|
|
return image |
|
|