Spaces:
Build error
Build error
import cv2 | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from torchvision import transforms | |
from esrgan import ESRGAN | |
esrgan = ESRGAN() | |
def viz(module, input): | |
x = input[0][0] | |
#最多显示4张图 | |
min_num = np.minimum(4, x.size()[0]) | |
for i in range(min_num): | |
plt.subplot(1, 4, i+1) | |
plt.imshow(x[i].cpu()) | |
plt.xticks([]) #去掉横坐标值 | |
plt.yticks([]) #去掉纵坐标值 | |
plt.show() | |
def main(): | |
t = transforms.Compose([transforms.ToPILImage(), | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
]) | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
model = esrgan.net | |
for name, m in model.named_modules(): | |
# if not isinstance(m, torch.nn.ModuleList) and \ | |
# not isinstance(m, torch.nn.Sequential) and \ | |
# type(m) in torch.nn.__dict__.values(): | |
# 这里只对卷积层的feature map进行显示 | |
if isinstance(m, torch.nn.Conv2d): | |
m.register_forward_pre_hook(viz) | |
img = cv2.imread('image.png') | |
img = t(img).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
model(img) | |
if __name__ == '__main__': | |
main() | |