lino / src /models /utils /show_tensor.py
algohunt
initial_commit
c295391
import torch.nn as nn
import torch
import math
import numpy as np
import matplotlib.pyplot as plt
def masking(img, mask):
# img [B, C, H, W]
# mask [B, 1, H, W] [0,1]
img_masked = img * mask.expand((-1, img.shape[1], -1, -1))
return img_masked
def imshow(img, mask = None, vmax = None, axis = None):
if mask is None:
img = img.data.cpu().numpy()[0,:,:,:]
else:
img = masking(img,mask).data.cpu().numpy()[0,:,:,:]
c = img.shape[0]
h = img.shape[1]
w = img.shape[2]
if c == 3:
img = np.reshape(img, (3,h,w)).transpose(1,2,0)
if c == 1:
img = img[0,:,:]
# plt.figure(figsize = (8,8))
if vmax is None:
if axis is None:
plt.imshow(img)
else:
axis.imshow(img)
else:
if axis is None:
plt.imshow(img, vmax = vmax)
else:
axis.imshow(img, vmax = vmax)
def nmlshow(nml, mask = None, axis = None):
if mask is None:
nml = nml.data.cpu().numpy()[0,:,:,:]
else:
nml = masking(nml,mask).data.cpu().numpy()[0,:,:,:]
nml = np.transpose(nml, (1,2,0))
# plt.figure(figsize = (8,8))
if axis is None:
plt.imshow(-0.5 * nml + 0.5)
else:
axis.imshow(-0.5 * nml + 0.5)