import cv2 import matplotlib.pyplot as plt import numpy as np import torch import torch.nn as nn from torchvision import transforms from esrgan import ESRGAN esrgan = ESRGAN() def viz(module, input): x = input[0][0] #最多显示4张图 min_num = np.minimum(4, x.size()[0]) for i in range(min_num): plt.subplot(1, 4, i+1) plt.imshow(x[i].cpu()) plt.xticks([]) #去掉横坐标值 plt.yticks([]) #去掉纵坐标值 plt.show() def main(): t = transforms.Compose([transforms.ToPILImage(), transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = esrgan.net for name, m in model.named_modules(): # if not isinstance(m, torch.nn.ModuleList) and \ # not isinstance(m, torch.nn.Sequential) and \ # type(m) in torch.nn.__dict__.values(): # 这里只对卷积层的feature map进行显示 if isinstance(m, torch.nn.Conv2d): m.register_forward_pre_hook(viz) img = cv2.imread('image.png') img = t(img).unsqueeze(0).to(device) with torch.no_grad(): model(img) if __name__ == '__main__': main()