print("Importing standard...") import subprocess import shutil from pathlib import Path print("Importing external...") import torch import numpy as np from PIL import Image REDUCTION = "pca" if REDUCTION == "umap": from umap import UMAP elif REDUCTION == "tsne": from sklearn.manifold import TSNE elif REDUCTION == "pca": from sklearn.decomposition import PCA def symlog(x): return torch.sign(x) * torch.log(torch.abs(x) + 1) def preprocess_masks_features(masks, features): # Get shapes right B, M, H, W = masks.shape Bf, F, Hf, Wf = features.shape masks = masks.reshape(B, M, 1, H * W) # # the following assertions should work, remove due to speed # assert H == Hf and W == Wf and B == Bf # assert masks.dtype == torch.bool # assert (mask_areas > 0).all(), "you shouldn't have empty masks" # Reduce M if there are empty masks mask_areas = masks.sum(dim=3) # B, M, 1 features = features.reshape(B, 1, F, H * W) # output shapes # features: B, 1, F, H*W # masks: B, M, 1, H*W return masks, features, M, B, H, W, F def get_row_col(H, W, device): # get position of pixels in [0, 1] row = torch.linspace(0, 1, H, device=device) col = torch.linspace(0, 1, W, device=device) return row, col def get_current_git_commit(): try: # Run the git command to get the current commit hash commit_hash = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip() # Decode from bytes to a string return commit_hash.decode("utf-8") except subprocess.CalledProcessError: # Handle the case where the command fails (e.g., not a Git repository) print("An error occurred while trying to retrieve the git commit hash.") return None def clean_dir(dirname): """Removes all directories in dirname that don't have a done.txt file""" dstdir = Path(dirname) dstdir.mkdir(exist_ok=True, parents=True) for f in dstdir.iterdir(): # if the directory doesn't have a done.txt file remove it if f.is_dir() and not (f / "done.txt").exists(): shutil.rmtree(f) def save_tensor_as_image(tensor, dstfile, global_step): dstfile = Path(dstfile) dstfile = (dstfile.parent / (dstfile.stem + "_" + str(global_step))).with_suffix( ".jpg" ) save(tensor, str(dstfile)) def minmaxnorm(x): return (x - x.min()) / (x.max() - x.min()) def save(tensor, name, channel_offset=0): tensor = to_img(tensor, channel_offset=channel_offset) Image.fromarray(tensor).save(name) def to_img(tensor, channel_offset=0): tensor = minmaxnorm(tensor) tensor = (tensor * 255).to(torch.uint8) C, H, W = tensor.shape if tensor.shape[0] == 1: tensor = tensor[0] elif tensor.shape[0] == 2: tensor = torch.stack([tensor[0], torch.zeros_like(tensor[0]), tensor[1]], dim=0) tensor = tensor.permute(1, 2, 0) elif tensor.shape[0] >= 3: tensor = tensor[channel_offset : channel_offset + 3] tensor = tensor.permute(1, 2, 0) tensor = tensor.cpu().numpy() return tensor def log_input_output( name, x, y_hat, global_step, img_dstdir, out_dstdir, reduce_dim=True, reduction=REDUCTION, resample_size=20000, ): y_hat = y_hat.reshape( y_hat.shape[0], y_hat.shape[2], y_hat.shape[3], y_hat.shape[4] ) if reduce_dim and y_hat.shape[1] >= 3: reducer = ( UMAP(n_components=3) if (reduction == "umap") else ( TSNE(n_components=3) if reduction == "tsne" else PCA(n_components=3) if reduction == "pca" else None ) ) np_y_hat = y_hat.detach().cpu().permute(1, 0, 2, 3).numpy() # F, 1, B, H, W np_y_hat = np_y_hat.reshape(np_y_hat.shape[0], -1) # F, BHW np_y_hat = np_y_hat.T # BHW, F sampled_pixels = np_y_hat[:: np_y_hat.shape[0] // resample_size] print("dim reduction fit..." + " " * 30, end="\r") reducer = reducer.fit(sampled_pixels) print("dim reduction transform..." + " " * 30, end="\r") reducer.transform(np_y_hat[:10]) # to numba compile the function np_y_hat = reducer.transform(np_y_hat) # BHW, 3 # revert back to original shape y_hat2 = ( torch.from_numpy( np_y_hat.T.reshape(3, y_hat.shape[0], y_hat.shape[2], y_hat.shape[3]) ) .to(y_hat.device) .permute(1, 0, 2, 3) ) print("done" + " " * 30, end="\r") else: y_hat2 = y_hat for i in range(min(len(x), 8)): save_tensor_as_image( x[i], img_dstdir / f"input_{name}_{str(i).zfill(2)}", global_step=global_step, ) for c in range(y_hat.shape[1]): save_tensor_as_image( y_hat[i, c : c + 1], out_dstdir / f"pred_channel_{name}_{str(i).zfill(2)}_{c}", global_step=global_step, ) # log color image assert len(y_hat2.shape) == 4, "should be B, F, H, W" if reduce_dim: save_tensor_as_image( y_hat2[i][:3], out_dstdir / f"pred_reduced_{name}_{str(i).zfill(2)}", global_step=global_step, ) save_tensor_as_image( y_hat[i][:3], out_dstdir / f"pred_colorchs_{name}_{str(i).zfill(2)}", global_step=global_step, ) def check_for_nan(loss, model, batch): try: assert torch.isnan(loss) == False except Exception as e: # print things useful to debug # does the batch contain nan? print("img batch contains nan?", torch.isnan(batch[0]).any()) print("mask batch contains nan?", torch.isnan(batch[1]).any()) # does the model weights contain nan? for name, param in model.named_parameters(): if torch.isnan(param).any(): print(name, "contains nan") # does the output contain nan? print("output contains nan?", torch.isnan(model(batch[0])).any()) # now raise the error raise e def calculate_iou(pred, label): intersection = ((label == 1) & (pred == 1)).sum() union = ((label == 1) | (pred == 1)).sum() if not union: return 0 else: iou = intersection.item() / union.item() return iou def load_from_ckpt(net, ckpt_path, strict=True): """Load network weights""" if ckpt_path and Path(ckpt_path).exists(): ckpt = torch.load(ckpt_path, map_location="cpu") if "MODEL_STATE" in ckpt: ckpt = ckpt["MODEL_STATE"] elif "state_dict" in ckpt: ckpt = ckpt["state_dict"] net.load_state_dict(ckpt, strict=strict) print("Loaded checkpoint from", ckpt_path) return net