Spaces:
Runtime error
Runtime error
import datetime | |
import math | |
import os | |
import torch | |
import time | |
import skimage.io | |
import skimage.transform | |
import matplotlib.pyplot as plt | |
import glob | |
import torch.optim as optim | |
import torchvision | |
import torchvision.transforms as transforms | |
from skimage import exposure | |
toTensor = transforms.ToTensor() | |
toPIL = transforms.ToPILImage() | |
import numpy as np | |
from PIL import Image | |
from models import * | |
os.environ["CUDA_VISIBLE_DEVICES"] = "0" | |
def remove_dataparallel_wrapper(state_dict): | |
r"""Converts a DataParallel model to a normal one by removing the "module." | |
wrapper in the module dictionary | |
Args: | |
state_dict: a torch.nn.DataParallel state dictionary | |
""" | |
from collections import OrderedDict | |
new_state_dict = OrderedDict() | |
for k, vl in state_dict.items(): | |
name = k[7:] # remove 'module.' of DataParallel | |
new_state_dict[name] = vl | |
return new_state_dict | |
from argparse import Namespace | |
def GetOptions(): | |
# training options | |
opt = Namespace() | |
opt.model = "rcan" | |
opt.n_resgroups = 3 | |
opt.n_resblocks = 10 | |
opt.n_feats = 96 | |
opt.reduction = 16 | |
opt.narch = 0 | |
opt.norm = "minmax" | |
opt.cpu = False | |
opt.multigpu = False | |
opt.undomulti = False | |
opt.device = torch.device( | |
"cuda" if torch.cuda.is_available() and not opt.cpu else "cpu" | |
) | |
opt.imageSize = 512 | |
opt.weights = "model/simrec_simin_gtout_rcan_512_2_ntrain790-final.pth" | |
opt.root = "model/0080.jpg" | |
opt.out = "model/myout" | |
opt.task = "simin_gtout" | |
opt.scale = 1 | |
opt.nch_in = 9 | |
opt.nch_out = 1 | |
return opt | |
def GetOptions_allRnd_0215(): | |
# training options | |
opt = Namespace() | |
opt.model = "rcan" | |
opt.n_resgroups = 3 | |
opt.n_resblocks = 10 | |
opt.n_feats = 48 | |
opt.reduction = 16 | |
opt.narch = 0 | |
opt.norm = "adapthist" | |
opt.cpu = False | |
opt.multigpu = False | |
opt.undomulti = False | |
opt.device = torch.device( | |
"cuda" if torch.cuda.is_available() and not opt.cpu else "cpu" | |
) | |
opt.imageSize = 512 | |
opt.weights = "model/0216_SIMRec_0214_rndAll_rcan_continued.pth" | |
opt.root = "model/0080.jpg" | |
opt.out = "model/myout" | |
opt.task = "simin_gtout" | |
opt.scale = 1 | |
opt.nch_in = 9 | |
opt.nch_out = 1 | |
return opt | |
def GetOptions_allRnd_0317(): | |
# training options | |
opt = Namespace() | |
opt.model = "rcan" | |
opt.n_resgroups = 3 | |
opt.n_resblocks = 10 | |
opt.n_feats = 96 | |
opt.reduction = 16 | |
opt.narch = 0 | |
opt.norm = "minmax" | |
opt.cpu = False | |
opt.multigpu = False | |
opt.undomulti = False | |
opt.device = torch.device( | |
"cuda" if torch.cuda.is_available() and not opt.cpu else "cpu" | |
) | |
opt.imageSize = 512 | |
opt.weights = "model/DIV2K_randomised_3x3_20200317.pth" | |
opt.root = "model/0080.jpg" | |
opt.out = "model/myout" | |
opt.task = "simin_gtout" | |
opt.scale = 1 | |
opt.nch_in = 9 | |
opt.nch_out = 1 | |
return opt | |
def LoadModel(opt): | |
print("Loading model") | |
print(opt) | |
net = GetModel(opt) | |
print("loading checkpoint", opt.weights) | |
checkpoint = torch.load(opt.weights, map_location=opt.device) | |
if type(checkpoint) is dict: | |
state_dict = checkpoint["state_dict"] | |
else: | |
state_dict = checkpoint | |
if opt.undomulti: | |
state_dict = remove_dataparallel_wrapper(state_dict) | |
net.load_state_dict(state_dict) | |
return net | |
def prepimg(stack, self): | |
inputimg = stack[:9] | |
if self.nch_in == 6: | |
inputimg = inputimg[[0, 1, 3, 4, 6, 7]] | |
elif self.nch_in == 3: | |
inputimg = inputimg[[0, 4, 8]] | |
if inputimg.shape[1] > 512 or inputimg.shape[2] > 512: | |
print("Over 512x512! Cropping") | |
inputimg = inputimg[:, :512, :512] | |
if ( | |
self.norm == "convert" | |
): # raw img from microscope, needs normalisation and correct frame ordering | |
print("Raw input assumed - converting") | |
# NCHW | |
# I = np.zeros((9,opt.imageSize,opt.imageSize),dtype='uint16') | |
# for t in range(9): | |
# frame = inputimg[t] | |
# frame = 120 / np.max(frame) * frame | |
# frame = np.rot90(np.rot90(np.rot90(frame))) | |
# I[t,:,:] = frame | |
# inputimg = I | |
inputimg = np.rot90(inputimg, axes=(1, 2)) | |
inputimg = inputimg[ | |
[6, 7, 8, 3, 4, 5, 0, 1, 2] | |
] # could also do [8,7,6,5,4,3,2,1,0] | |
for i in range(len(inputimg)): | |
inputimg[i] = 100 / np.max(inputimg[i]) * inputimg[i] | |
elif "convert" in self.norm: | |
fac = float(self.norm[7:]) | |
inputimg = np.rot90(inputimg, axes=(1, 2)) | |
inputimg = inputimg[ | |
[6, 7, 8, 3, 4, 5, 0, 1, 2] | |
] # could also do [8,7,6,5,4,3,2,1,0] | |
for i in range(len(inputimg)): | |
inputimg[i] = fac * 255 / np.max(inputimg[i]) * inputimg[i] | |
inputimg = inputimg.astype("float") / np.max(inputimg) # used to be /255 | |
widefield = np.mean(inputimg, 0) | |
if self.norm == "adapthist": | |
for i in range(len(inputimg)): | |
inputimg[i] = exposure.equalize_adapthist(inputimg[i], clip_limit=0.001) | |
widefield = exposure.equalize_adapthist(widefield, clip_limit=0.001) | |
else: | |
# normalise | |
inputimg = torch.tensor(inputimg).float() | |
widefield = torch.tensor(widefield).float() | |
widefield = (widefield - torch.min(widefield)) / ( | |
torch.max(widefield) - torch.min(widefield) | |
) | |
if self.norm == "minmax": | |
for i in range(len(inputimg)): | |
inputimg[i] = (inputimg[i] - torch.min(inputimg[i])) / ( | |
torch.max(inputimg[i]) - torch.min(inputimg[i]) | |
) | |
elif "minmax" in self.norm: | |
fac = float(self.norm[6:]) | |
for i in range(len(inputimg)): | |
inputimg[i] = ( | |
fac | |
* (inputimg[i] - torch.min(inputimg[i])) | |
/ (torch.max(inputimg[i]) - torch.min(inputimg[i])) | |
) | |
# otf = torch.tensor(otf.astype('float') / np.max(otf)).unsqueeze(0).float() | |
# gt = torch.tensor(gt.astype('float') / 255).unsqueeze(0).float() | |
# simimg = torch.tensor(simimg.astype('float') / 255).unsqueeze(0).float() | |
# widefield = torch.mean(inputimg,0).unsqueeze(0) | |
# normalise | |
# gt = (gt - torch.min(gt)) / (torch.max(gt) - torch.min(gt)) | |
# simimg = (simimg - torch.min(simimg)) / (torch.max(simimg) - torch.min(simimg)) | |
# widefield = (widefield - torch.min(widefield)) / (torch.max(widefield) - torch.min(widefield)) | |
inputimg = torch.tensor(inputimg).float() | |
widefield = torch.tensor(widefield).float() | |
return inputimg, widefield | |
def save_image(data, filename, cmap): | |
sizes = np.shape(data) | |
fig = plt.figure() | |
fig.set_size_inches(1.0 * sizes[0] / sizes[1], 1, forward=False) | |
ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0]) | |
ax.set_axis_off() | |
fig.add_axes(ax) | |
ax.imshow(data, cmap=cmap) | |
plt.savefig(filename, dpi=sizes[0]) | |
plt.close() | |
def EvaluateModel(net, opt, stack): | |
outfile = datetime.datetime.utcnow().strftime("%H-%M-%S") | |
outfile = "ML-SIM_%s" % outfile | |
os.makedirs(opt.out, exist_ok=True) | |
print(stack.shape) | |
inputimg, widefield = prepimg(stack, opt) | |
if opt.norm == "convert" or "minmax" in opt.norm or "adapthist" in opt.norm: | |
cmap = "viridis" | |
else: | |
cmap = "gray" | |
# skimage.io.imsave('%s_wf.png' % outfile,(255*widefield.numpy()).astype('uint8')) | |
wf = (255 * widefield.numpy()).astype("uint8") | |
wf_upscaled = skimage.transform.rescale( | |
wf, 1.5, order=3 | |
) # should ideally be done by drawing on client side, in javascript | |
save_image(wf_upscaled, "%s_wf.png" % outfile, cmap) | |
# skimage.io.imsave('%s.tif' % outfile, inputimg.numpy()) | |
inputimg = inputimg.unsqueeze(0) | |
with torch.no_grad(): | |
sr = net(inputimg.to(opt.device)) | |
sr = sr.cpu() | |
sr = torch.clamp(sr, min=0, max=1) | |
print("min max", inputimg.min(), inputimg.max()) | |
pil_sr_img = toPIL(sr[0]) | |
if opt.norm == "convert": | |
pil_sr_img = transforms.functional.rotate(pil_sr_img, -90) | |
# pil_sr_img.save('%s.png' % outfile) # true output for downloading, no LUT | |
sr_img = np.array(pil_sr_img) | |
# sr_img = exposure.equalize_adapthist(sr_img,clip_limit=0.01) | |
skimage.io.imsave("%s.png" % outfile, sr_img) # true out for downloading, no LUT | |
sr_img = skimage.transform.rescale( | |
sr_img, 1.5, order=3 | |
) # should ideally be done by drawing on client side, in javascript | |
save_image(sr_img, "%s_sr.png" % outfile, cmap) | |
return outfile + "_sr.png", outfile + "_wf.png", outfile + ".png" | |
# return wf, sr_img, outfile | |