mindeyev2old / src /utils.py
ckadirt's picture
Upload folder using huggingface_hub
626cbe5 verified
import numpy as np
from torchvision import transforms
import torch
import torch.nn as nn
import torch.nn.functional as F
import PIL
import random
import os
import matplotlib.pyplot as plt
import pandas as pd
import math
import webdataset as wds
import tempfile
from torchvision.utils import make_grid
from diffusers.utils import randn_tensor
import json
from torchmetrics.image.fid import FrechetInceptionDistance
from PIL import Image
import requests
import io
import time
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def is_interactive():
import __main__ as main
return not hasattr(main, '__file__')
def seed_everything(seed=0, cudnn_deterministic=True):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if cudnn_deterministic:
torch.backends.cudnn.deterministic = True
else:
## needs to be False to use conv3D
print('Note: not using cudnn.deterministic')
def np_to_Image(x):
if x.ndim==4:
x=x[0]
return PIL.Image.fromarray((x.transpose(1, 2, 0)*127.5+128).clip(0,255).astype('uint8'))
def torch_to_Image(x):
if x.ndim==4:
x=x[0]
return transforms.ToPILImage()(x)
def Image_to_torch(x):
try:
x = (transforms.ToTensor()(x)[:3].unsqueeze(0)-.5)/.5
except:
x = (transforms.ToTensor()(x[0])[:3].unsqueeze(0)-.5)/.5
return x
def torch_to_matplotlib(x,device=device):
if torch.mean(x)>10:
x = (x.permute(0, 2, 3, 1)).clamp(0, 255).to(torch.uint8)
else:
x = (x.permute(0, 2, 3, 1) * 255).clamp(0, 255).to(torch.uint8)
if device=='cpu':
return x[0]
else:
return x.cpu().numpy()[0]
def pairwise_cosine_similarity(A, B, dim=1, eps=1e-8):
#https://stackoverflow.com/questions/67199317/pytorch-cosine-similarity-nxn-elements
numerator = A @ B.T
A_l2 = torch.mul(A, A).sum(axis=dim)
B_l2 = torch.mul(B, B).sum(axis=dim)
denominator = torch.max(torch.sqrt(torch.outer(A_l2, B_l2)), torch.tensor(eps))
return torch.div(numerator, denominator)
def batchwise_pearson_correlation(Z, B):
# Calculate means
Z_mean = torch.mean(Z, dim=1, keepdim=True)
B_mean = torch.mean(B, dim=1, keepdim=True)
# Subtract means
Z_centered = Z - Z_mean
B_centered = B - B_mean
# Calculate Pearson correlation coefficient
numerator = Z_centered @ B_centered.T
Z_centered_norm = torch.linalg.norm(Z_centered, dim=1, keepdim=True)
B_centered_norm = torch.linalg.norm(B_centered, dim=1, keepdim=True)
denominator = Z_centered_norm @ B_centered_norm.T
pearson_correlation = (numerator / denominator)
return pearson_correlation
def batchwise_cosine_similarity(Z,B):
# https://www.h4pz.co/blog/2021/4/2/batch-cosine-similarity-in-pytorch-or-numpy-jax-cupy-etc
B = B.T
Z_norm = torch.linalg.norm(Z, dim=1, keepdim=True) # Size (n, 1).
B_norm = torch.linalg.norm(B, dim=0, keepdim=True) # Size (1, b).
cosine_similarity = ((Z @ B) / (Z_norm @ B_norm)).T
return cosine_similarity
def topk(similarities,labels,k=5):
if k > similarities.shape[0]:
k = similarities.shape[0]
topsum=0
for i in range(k):
topsum += torch.sum(torch.argsort(similarities,axis=1)[:,-(i+1)] == labels)/len(labels)
return topsum
def get_non_diagonals(a):
a = torch.triu(a,diagonal=1)+torch.tril(a,diagonal=-1)
# make diagonals -1
a=a.fill_diagonal_(-1)
return a
def gather_features(image_features, voxel_features, accelerator):
all_image_features = accelerator.gather(image_features.contiguous())
if voxel_features is not None:
all_voxel_features = accelerator.gather(voxel_features.contiguous())
return all_image_features, all_voxel_features
return all_image_features
def soft_clip_loss(preds, targs, temp=0.125): #, distributed=False, accelerator=None):
# if not distributed:
clip_clip = (targs @ targs.T)/temp
brain_clip = (preds @ targs.T)/temp
# else:
# all_targs = gather_features(targs, None, accelerator)
# clip_clip = (targs @ all_targs.T)/temp
# brain_clip = (preds @ all_targs.T)/temp
loss1 = -(brain_clip.log_softmax(-1) * clip_clip.softmax(-1)).sum(-1).mean()
loss2 = -(brain_clip.T.log_softmax(-1) * clip_clip.softmax(-1)).sum(-1).mean()
loss = (loss1 + loss2)/2
return loss
def mixco(voxels, beta=0.15, s_thresh=0.5):
perm = torch.randperm(voxels.shape[0])
voxels_shuffle = voxels[perm].to(voxels.device,dtype=voxels.dtype)
betas = torch.distributions.Beta(beta, beta).sample([voxels.shape[0]]).to(voxels.device,dtype=voxels.dtype)
select = (torch.rand(voxels.shape[0]) <= s_thresh).to(voxels.device)
betas_shape = [-1] + [1]*(len(voxels.shape)-1)
voxels[select] = voxels[select] * betas[select].reshape(*betas_shape) + \
voxels_shuffle[select] * (1 - betas[select]).reshape(*betas_shape)
betas[~select] = 1
return voxels, perm, betas, select
def mixco_clip_target(clip_target, perm, select, betas):
clip_target_shuffle = clip_target[perm]
clip_target[select] = clip_target[select] * betas[select].reshape(-1, 1) + \
clip_target_shuffle[select] * (1 - betas[select]).reshape(-1, 1)
return clip_target
def mixco_nce(preds, targs, temp=0.1, perm=None, betas=None, select=None, distributed=False,
accelerator=None, local_rank=None, bidirectional=True):
brain_clip = (preds @ targs.T)/temp
if perm is not None and betas is not None and select is not None:
probs = torch.diag(betas)
probs[torch.arange(preds.shape[0]).to(preds.device), perm] = 1 - betas
loss = -(brain_clip.log_softmax(-1) * probs).sum(-1).mean()
if bidirectional:
loss2 = -(brain_clip.T.log_softmax(-1) * probs.T).sum(-1).mean()
loss = (loss + loss2)/2
return loss
else:
loss = F.cross_entropy(brain_clip, torch.arange(brain_clip.shape[0]).to(brain_clip.device))
if bidirectional:
loss2 = F.cross_entropy(brain_clip.T, torch.arange(brain_clip.shape[0]).to(brain_clip.device))
loss = (loss + loss2)/2
return loss
def count_params(model):
total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('param counts:\n{:,} total\n{:,} trainable'.format(total, trainable))
def image_grid(imgs, rows, cols):
w, h = imgs[0].size
grid = PIL.Image.new('RGB', size=(cols*w, rows*h))
for i, img in enumerate(imgs):
grid.paste(img, box=(i%cols*w, i//cols*h))
return grid
def check_loss(loss):
if loss.isnan().any():
raise ValueError('NaN loss')
def cosine_anneal(start, end, steps):
return end + (start - end)/2 * (1 + torch.cos(torch.pi*torch.arange(steps)/(steps-1)))
import braceexpand
def get_dataloaders(
batch_size,
image_var='images',
num_devices=None,
num_workers=None,
train_url=None,
val_url=None,
meta_url=None,
num_train=None,
num_val=None,
cache_dir="/scratch/tmp/wds-cache",
seed=0,
voxels_key="nsdgeneral.npy",
val_batch_size=None,
to_tuple=["voxels", "images", "trial"],
local_rank=0,
world_size=1,
):
print("Getting dataloaders...")
assert image_var == 'images'
def my_split_by_node(urls):
return urls
train_url = list(braceexpand.braceexpand(train_url))
val_url = list(braceexpand.braceexpand(val_url))
if num_devices is None:
num_devices = torch.cuda.device_count()
if num_workers is None:
num_workers = num_devices
if num_train is None:
metadata = json.load(open(meta_url))
num_train = metadata['totals']['train']
if num_val is None:
metadata = json.load(open(meta_url))
num_val = metadata['totals']['val']
if val_batch_size is None:
val_batch_size = batch_size
global_batch_size = batch_size * num_devices
num_batches = math.floor(num_train / global_batch_size)
num_worker_batches = math.floor(num_batches / num_workers)
if num_worker_batches == 0: num_worker_batches = 1
print("\nnum_train",num_train)
print("global_batch_size",global_batch_size)
print("batch_size",batch_size)
print("num_workers",num_workers)
print("num_batches",num_batches)
print("num_worker_batches", num_worker_batches)
# train_url = train_url[local_rank:world_size]
train_data = wds.WebDataset(train_url, resampled=False, cache_dir=cache_dir, nodesplitter=my_split_by_node)\
.shuffle(500, initial=500, rng=random.Random(42))\
.decode("torch")\
.rename(images="jpg;png", voxels=voxels_key, trial="trial.npy", coco="coco73k.npy", reps="num_uniques.npy")\
.to_tuple(*to_tuple)#\
# .batched(batch_size, partial=True)#\
# .with_epoch(num_worker_batches)
# BATCH SIZE SHOULD BE NONE!!! FOR TRAIN AND VAL | resampled=True for train | .batched(val_batch_size, partial=False)
train_dl = torch.utils.data.DataLoader(train_data, batch_size=batch_size, num_workers=1, shuffle=False)
# Validation
print("val_batch_size",val_batch_size)
val_data = wds.WebDataset(val_url, resampled=False, cache_dir=cache_dir, nodesplitter=my_split_by_node)\
.shuffle(500, initial=500, rng=random.Random(42))\
.decode("torch")\
.rename(images="jpg;png", voxels=voxels_key, trial="trial.npy", coco="coco73k.npy", reps="num_uniques.npy")\
.to_tuple(*to_tuple)#\
# .batched(val_batch_size, partial=True)
val_dl = torch.utils.data.DataLoader(val_data, batch_size=val_batch_size, num_workers=1, shuffle=False, drop_last=True)
return train_dl, val_dl, num_train, num_val