import datetime import math import os import torch import time import skimage.io import skimage.transform import matplotlib.pyplot as plt import glob import torch.optim as optim import torchvision import torchvision.transforms as transforms from skimage import exposure toTensor = transforms.ToTensor() toPIL = transforms.ToPILImage() import numpy as np from PIL import Image from models import * os.environ["CUDA_VISIBLE_DEVICES"] = "0" def remove_dataparallel_wrapper(state_dict): r"""Converts a DataParallel model to a normal one by removing the "module." wrapper in the module dictionary Args: state_dict: a torch.nn.DataParallel state dictionary """ from collections import OrderedDict new_state_dict = OrderedDict() for k, vl in state_dict.items(): name = k[7:] # remove 'module.' of DataParallel new_state_dict[name] = vl return new_state_dict from argparse import Namespace def GetOptions(): # training options opt = Namespace() opt.model = 'rcan' opt.n_resgroups = 3 opt.n_resblocks = 10 opt.n_feats = 96 opt.reduction = 16 opt.narch = 0 opt.norm = 'minmax' opt.cpu = False opt.multigpu = False opt.undomulti = False opt.device = torch.device('cuda' if torch.cuda.is_available() and not opt.cpu else 'cpu') opt.imageSize = 512 opt.weights = "model/simrec_simin_gtout_rcan_512_2_ntrain790-final.pth" opt.root = "model/0080.jpg" opt.out = "model/myout" opt.task = 'simin_gtout' opt.scale = 1 opt.nch_in = 9 opt.nch_out = 1 return opt def GetOptions_allRnd_0215(): # training options opt = Namespace() opt.model = 'rcan' opt.n_resgroups = 3 opt.n_resblocks = 10 opt.n_feats = 48 opt.reduction = 16 opt.narch = 0 opt.norm = 'adapthist' opt.cpu = False opt.multigpu = False opt.undomulti = False opt.device = torch.device('cuda' if torch.cuda.is_available() and not opt.cpu else 'cpu') opt.imageSize = 512 opt.weights = "model/0216_SIMRec_0214_rndAll_rcan_continued.pth" opt.root = "model/0080.jpg" opt.out = "model/myout" opt.task = 'simin_gtout' opt.scale = 1 opt.nch_in = 9 opt.nch_out = 1 return opt def GetOptions_allRnd_0317(): # training options opt = Namespace() opt.model = 'rcan' opt.n_resgroups = 3 opt.n_resblocks = 10 opt.n_feats = 96 opt.reduction = 16 opt.narch = 0 opt.norm = 'minmax' opt.cpu = False opt.multigpu = False opt.undomulti = False opt.device = torch.device('cuda' if torch.cuda.is_available() and not opt.cpu else 'cpu') opt.imageSize = 512 opt.weights = "model/DIV2K_randomised_3x3_20200317.pth" opt.root = "model/0080.jpg" opt.out = "model/myout" opt.task = 'simin_gtout' opt.scale = 1 opt.nch_in = 9 opt.nch_out = 1 return opt def LoadModel(opt): print('Loading model') print(opt) net = GetModel(opt) print('loading checkpoint',opt.weights) checkpoint = torch.load(opt.weights,map_location=opt.device) if type(checkpoint) is dict: state_dict = checkpoint['state_dict'] else: state_dict = checkpoint if opt.undomulti: state_dict = remove_dataparallel_wrapper(state_dict) net.load_state_dict(state_dict) return net def prepimg(stack,self): inputimg = stack[:9] if self.nch_in == 6: inputimg = inputimg[[0,1,3,4,6,7]] elif self.nch_in == 3: inputimg = inputimg[[0,4,8]] if inputimg.shape[1] > 512 or inputimg.shape[2] > 512: print('Over 512x512! Cropping') inputimg = inputimg[:,:512,:512] if self.norm == 'convert': # raw img from microscope, needs normalisation and correct frame ordering print('Raw input assumed - converting') # NCHW # I = np.zeros((9,opt.imageSize,opt.imageSize),dtype='uint16') # for t in range(9): # frame = inputimg[t] # frame = 120 / np.max(frame) * frame # frame = np.rot90(np.rot90(np.rot90(frame))) # I[t,:,:] = frame # inputimg = I inputimg = np.rot90(inputimg,axes=(1,2)) inputimg = inputimg[[6,7,8,3,4,5,0,1,2]] # could also do [8,7,6,5,4,3,2,1,0] for i in range(len(inputimg)): inputimg[i] = 100 / np.max(inputimg[i]) * inputimg[i] elif 'convert' in self.norm: fac = float(self.norm[7:]) inputimg = np.rot90(inputimg,axes=(1,2)) inputimg = inputimg[[6,7,8,3,4,5,0,1,2]] # could also do [8,7,6,5,4,3,2,1,0] for i in range(len(inputimg)): inputimg[i] = fac * 255 / np.max(inputimg[i]) * inputimg[i] inputimg = inputimg.astype('float') / np.max(inputimg) # used to be /255 widefield = np.mean(inputimg,0) if self.norm == 'adapthist': for i in range(len(inputimg)): inputimg[i] = exposure.equalize_adapthist(inputimg[i],clip_limit=0.001) widefield = exposure.equalize_adapthist(widefield,clip_limit=0.001) else: # normalise inputimg = torch.tensor(inputimg).float() widefield = torch.tensor(widefield).float() widefield = (widefield - torch.min(widefield)) / (torch.max(widefield) - torch.min(widefield)) if self.norm == 'minmax': for i in range(len(inputimg)): inputimg[i] = (inputimg[i] - torch.min(inputimg[i])) / (torch.max(inputimg[i]) - torch.min(inputimg[i])) elif 'minmax' in self.norm: fac = float(self.norm[6:]) for i in range(len(inputimg)): inputimg[i] = fac * (inputimg[i] - torch.min(inputimg[i])) / (torch.max(inputimg[i]) - torch.min(inputimg[i])) # otf = torch.tensor(otf.astype('float') / np.max(otf)).unsqueeze(0).float() # gt = torch.tensor(gt.astype('float') / 255).unsqueeze(0).float() # simimg = torch.tensor(simimg.astype('float') / 255).unsqueeze(0).float() # widefield = torch.mean(inputimg,0).unsqueeze(0) # normalise # gt = (gt - torch.min(gt)) / (torch.max(gt) - torch.min(gt)) # simimg = (simimg - torch.min(simimg)) / (torch.max(simimg) - torch.min(simimg)) # widefield = (widefield - torch.min(widefield)) / (torch.max(widefield) - torch.min(widefield)) inputimg = torch.tensor(inputimg).float() widefield = torch.tensor(widefield).float() return inputimg,widefield def save_image(data, filename,cmap): sizes = np.shape(data) fig = plt.figure() fig.set_size_inches(1. * sizes[0] / sizes[1], 1, forward = False) ax = plt.Axes(fig, [0., 0., 1., 1.]) ax.set_axis_off() fig.add_axes(ax) ax.imshow(data, cmap=cmap) plt.savefig(filename, dpi = sizes[0]) plt.close() def EvaluateModel(net,opt,stack): outfile = datetime.datetime.utcnow().strftime('%H-%M-%S') outfile = 'ML-SIM_%s' % outfile os.makedirs(opt.out, exist_ok=True) print(stack.shape) inputimg, widefield = prepimg(stack, opt) if opt.norm == 'convert' or 'minmax' in opt.norm or 'adapthist' in opt.norm: cmap = 'viridis' else: cmap = 'gray' # skimage.io.imsave('%s_wf.png' % outfile,(255*widefield.numpy()).astype('uint8')) wf = (255*widefield.numpy()).astype('uint8') wf_upscaled = skimage.transform.rescale(wf,1.5,order=3,multichannel=False) # should ideally be done by drawing on client side, in javascript save_image(wf_upscaled,'%s_wf.png' % outfile,cmap) # skimage.io.imsave('%s.tif' % outfile, inputimg.numpy()) inputimg = inputimg.unsqueeze(0) with torch.no_grad(): sr = net(inputimg.to(opt.device)) sr = sr.cpu() sr = torch.clamp(sr,min=0,max=1) print('min max',inputimg.min(),inputimg.max()) pil_sr_img = toPIL(sr[0]) if opt.norm == 'convert': pil_sr_img = transforms.functional.rotate(pil_sr_img,-90) # pil_sr_img.save('%s.png' % outfile) # true output for downloading, no LUT sr_img = np.array(pil_sr_img) # sr_img = exposure.equalize_adapthist(sr_img,clip_limit=0.01) skimage.io.imsave('%s.png' % outfile, sr_img) # true out for downloading, no LUT sr_img = skimage.transform.rescale(sr_img,1.5,order=3,multichannel=False) # should ideally be done by drawing on client side, in javascript save_image(sr_img,'%s_sr.png' % outfile,cmap) return outfile + '_sr.png', outfile + '_wf.png', outfile + '.png' # return wf, sr_img, outfile