ML-SIM / NNfunctions.py
charlesnchr's picture
Fox for Skimage deprecation of multichannel
c5a3315
import datetime
import math
import os
import torch
import time
import skimage.io
import skimage.transform
import matplotlib.pyplot as plt
import glob
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from skimage import exposure
toTensor = transforms.ToTensor()
toPIL = transforms.ToPILImage()
import numpy as np
from PIL import Image
from models import *
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
def remove_dataparallel_wrapper(state_dict):
r"""Converts a DataParallel model to a normal one by removing the "module."
wrapper in the module dictionary
Args:
state_dict: a torch.nn.DataParallel state dictionary
"""
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, vl in state_dict.items():
name = k[7:] # remove 'module.' of DataParallel
new_state_dict[name] = vl
return new_state_dict
from argparse import Namespace
def GetOptions():
# training options
opt = Namespace()
opt.model = "rcan"
opt.n_resgroups = 3
opt.n_resblocks = 10
opt.n_feats = 96
opt.reduction = 16
opt.narch = 0
opt.norm = "minmax"
opt.cpu = False
opt.multigpu = False
opt.undomulti = False
opt.device = torch.device(
"cuda" if torch.cuda.is_available() and not opt.cpu else "cpu"
)
opt.imageSize = 512
opt.weights = "model/simrec_simin_gtout_rcan_512_2_ntrain790-final.pth"
opt.root = "model/0080.jpg"
opt.out = "model/myout"
opt.task = "simin_gtout"
opt.scale = 1
opt.nch_in = 9
opt.nch_out = 1
return opt
def GetOptions_allRnd_0215():
# training options
opt = Namespace()
opt.model = "rcan"
opt.n_resgroups = 3
opt.n_resblocks = 10
opt.n_feats = 48
opt.reduction = 16
opt.narch = 0
opt.norm = "adapthist"
opt.cpu = False
opt.multigpu = False
opt.undomulti = False
opt.device = torch.device(
"cuda" if torch.cuda.is_available() and not opt.cpu else "cpu"
)
opt.imageSize = 512
opt.weights = "model/0216_SIMRec_0214_rndAll_rcan_continued.pth"
opt.root = "model/0080.jpg"
opt.out = "model/myout"
opt.task = "simin_gtout"
opt.scale = 1
opt.nch_in = 9
opt.nch_out = 1
return opt
def GetOptions_allRnd_0317():
# training options
opt = Namespace()
opt.model = "rcan"
opt.n_resgroups = 3
opt.n_resblocks = 10
opt.n_feats = 96
opt.reduction = 16
opt.narch = 0
opt.norm = "minmax"
opt.cpu = False
opt.multigpu = False
opt.undomulti = False
opt.device = torch.device(
"cuda" if torch.cuda.is_available() and not opt.cpu else "cpu"
)
opt.imageSize = 512
opt.weights = "model/DIV2K_randomised_3x3_20200317.pth"
opt.root = "model/0080.jpg"
opt.out = "model/myout"
opt.task = "simin_gtout"
opt.scale = 1
opt.nch_in = 9
opt.nch_out = 1
return opt
def LoadModel(opt):
print("Loading model")
print(opt)
net = GetModel(opt)
print("loading checkpoint", opt.weights)
checkpoint = torch.load(opt.weights, map_location=opt.device)
if type(checkpoint) is dict:
state_dict = checkpoint["state_dict"]
else:
state_dict = checkpoint
if opt.undomulti:
state_dict = remove_dataparallel_wrapper(state_dict)
net.load_state_dict(state_dict)
return net
def prepimg(stack, self):
inputimg = stack[:9]
if self.nch_in == 6:
inputimg = inputimg[[0, 1, 3, 4, 6, 7]]
elif self.nch_in == 3:
inputimg = inputimg[[0, 4, 8]]
if inputimg.shape[1] > 512 or inputimg.shape[2] > 512:
print("Over 512x512! Cropping")
inputimg = inputimg[:, :512, :512]
if (
self.norm == "convert"
): # raw img from microscope, needs normalisation and correct frame ordering
print("Raw input assumed - converting")
# NCHW
# I = np.zeros((9,opt.imageSize,opt.imageSize),dtype='uint16')
# for t in range(9):
# frame = inputimg[t]
# frame = 120 / np.max(frame) * frame
# frame = np.rot90(np.rot90(np.rot90(frame)))
# I[t,:,:] = frame
# inputimg = I
inputimg = np.rot90(inputimg, axes=(1, 2))
inputimg = inputimg[
[6, 7, 8, 3, 4, 5, 0, 1, 2]
] # could also do [8,7,6,5,4,3,2,1,0]
for i in range(len(inputimg)):
inputimg[i] = 100 / np.max(inputimg[i]) * inputimg[i]
elif "convert" in self.norm:
fac = float(self.norm[7:])
inputimg = np.rot90(inputimg, axes=(1, 2))
inputimg = inputimg[
[6, 7, 8, 3, 4, 5, 0, 1, 2]
] # could also do [8,7,6,5,4,3,2,1,0]
for i in range(len(inputimg)):
inputimg[i] = fac * 255 / np.max(inputimg[i]) * inputimg[i]
inputimg = inputimg.astype("float") / np.max(inputimg) # used to be /255
widefield = np.mean(inputimg, 0)
if self.norm == "adapthist":
for i in range(len(inputimg)):
inputimg[i] = exposure.equalize_adapthist(inputimg[i], clip_limit=0.001)
widefield = exposure.equalize_adapthist(widefield, clip_limit=0.001)
else:
# normalise
inputimg = torch.tensor(inputimg).float()
widefield = torch.tensor(widefield).float()
widefield = (widefield - torch.min(widefield)) / (
torch.max(widefield) - torch.min(widefield)
)
if self.norm == "minmax":
for i in range(len(inputimg)):
inputimg[i] = (inputimg[i] - torch.min(inputimg[i])) / (
torch.max(inputimg[i]) - torch.min(inputimg[i])
)
elif "minmax" in self.norm:
fac = float(self.norm[6:])
for i in range(len(inputimg)):
inputimg[i] = (
fac
* (inputimg[i] - torch.min(inputimg[i]))
/ (torch.max(inputimg[i]) - torch.min(inputimg[i]))
)
# otf = torch.tensor(otf.astype('float') / np.max(otf)).unsqueeze(0).float()
# gt = torch.tensor(gt.astype('float') / 255).unsqueeze(0).float()
# simimg = torch.tensor(simimg.astype('float') / 255).unsqueeze(0).float()
# widefield = torch.mean(inputimg,0).unsqueeze(0)
# normalise
# gt = (gt - torch.min(gt)) / (torch.max(gt) - torch.min(gt))
# simimg = (simimg - torch.min(simimg)) / (torch.max(simimg) - torch.min(simimg))
# widefield = (widefield - torch.min(widefield)) / (torch.max(widefield) - torch.min(widefield))
inputimg = torch.tensor(inputimg).float()
widefield = torch.tensor(widefield).float()
return inputimg, widefield
def save_image(data, filename, cmap):
sizes = np.shape(data)
fig = plt.figure()
fig.set_size_inches(1.0 * sizes[0] / sizes[1], 1, forward=False)
ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0])
ax.set_axis_off()
fig.add_axes(ax)
ax.imshow(data, cmap=cmap)
plt.savefig(filename, dpi=sizes[0])
plt.close()
def EvaluateModel(net, opt, stack):
outfile = datetime.datetime.utcnow().strftime("%H-%M-%S")
outfile = "ML-SIM_%s" % outfile
os.makedirs(opt.out, exist_ok=True)
print(stack.shape)
inputimg, widefield = prepimg(stack, opt)
if opt.norm == "convert" or "minmax" in opt.norm or "adapthist" in opt.norm:
cmap = "viridis"
else:
cmap = "gray"
# skimage.io.imsave('%s_wf.png' % outfile,(255*widefield.numpy()).astype('uint8'))
wf = (255 * widefield.numpy()).astype("uint8")
wf_upscaled = skimage.transform.rescale(
wf, 1.5, order=3
) # should ideally be done by drawing on client side, in javascript
save_image(wf_upscaled, "%s_wf.png" % outfile, cmap)
# skimage.io.imsave('%s.tif' % outfile, inputimg.numpy())
inputimg = inputimg.unsqueeze(0)
with torch.no_grad():
sr = net(inputimg.to(opt.device))
sr = sr.cpu()
sr = torch.clamp(sr, min=0, max=1)
print("min max", inputimg.min(), inputimg.max())
pil_sr_img = toPIL(sr[0])
if opt.norm == "convert":
pil_sr_img = transforms.functional.rotate(pil_sr_img, -90)
# pil_sr_img.save('%s.png' % outfile) # true output for downloading, no LUT
sr_img = np.array(pil_sr_img)
# sr_img = exposure.equalize_adapthist(sr_img,clip_limit=0.01)
skimage.io.imsave("%s.png" % outfile, sr_img) # true out for downloading, no LUT
sr_img = skimage.transform.rescale(
sr_img, 1.5, order=3
) # should ideally be done by drawing on client side, in javascript
save_image(sr_img, "%s_sr.png" % outfile, cmap)
return outfile + "_sr.png", outfile + "_wf.png", outfile + ".png"
# return wf, sr_img, outfile