Spaces:
Running
on
Zero
Running
on
Zero
import torch.nn as nn | |
import torch | |
import math | |
import numpy as np | |
import matplotlib.pyplot as plt | |
def masking(img, mask): | |
# img [B, C, H, W] | |
# mask [B, 1, H, W] [0,1] | |
img_masked = img * mask.expand((-1, img.shape[1], -1, -1)) | |
return img_masked | |
def imshow(img, mask = None, vmax = None, axis = None): | |
if mask is None: | |
img = img.data.cpu().numpy()[0,:,:,:] | |
else: | |
img = masking(img,mask).data.cpu().numpy()[0,:,:,:] | |
c = img.shape[0] | |
h = img.shape[1] | |
w = img.shape[2] | |
if c == 3: | |
img = np.reshape(img, (3,h,w)).transpose(1,2,0) | |
if c == 1: | |
img = img[0,:,:] | |
# plt.figure(figsize = (8,8)) | |
if vmax is None: | |
if axis is None: | |
plt.imshow(img) | |
else: | |
axis.imshow(img) | |
else: | |
if axis is None: | |
plt.imshow(img, vmax = vmax) | |
else: | |
axis.imshow(img, vmax = vmax) | |
def nmlshow(nml, mask = None, axis = None): | |
if mask is None: | |
nml = nml.data.cpu().numpy()[0,:,:,:] | |
else: | |
nml = masking(nml,mask).data.cpu().numpy()[0,:,:,:] | |
nml = np.transpose(nml, (1,2,0)) | |
# plt.figure(figsize = (8,8)) | |
if axis is None: | |
plt.imshow(-0.5 * nml + 0.5) | |
else: | |
axis.imshow(-0.5 * nml + 0.5) | |