LicenseGAN / plate.py
白鹭先生
修改测试图片
7e1fccb
raw history blame
No virus
1.47 kB
import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from torchvision import transforms
from esrgan import ESRGAN
esrgan = ESRGAN()
def viz(module, input):
x = input[0][0]
#最多显示4张图
min_num = np.minimum(4, x.size()[0])
for i in range(min_num):
plt.subplot(1, 4, i+1)
plt.imshow(x[i].cpu())
plt.xticks([]) #去掉横坐标值
plt.yticks([]) #去掉纵坐标值
plt.show()
def main():
t = transforms.Compose([transforms.ToPILImage(),
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = esrgan.net
for name, m in model.named_modules():
# if not isinstance(m, torch.nn.ModuleList) and \
# not isinstance(m, torch.nn.Sequential) and \
# type(m) in torch.nn.__dict__.values():
# 这里只对卷积层的feature map进行显示
if isinstance(m, torch.nn.Conv2d):
m.register_forward_pre_hook(viz)
img = cv2.imread('image.png')
img = t(img).unsqueeze(0).to(device)
with torch.no_grad():
model(img)
if __name__ == '__main__':
main()