Image_Denoising_Demo / utils /utils_model.py
sunder-ali's picture
Upload 2 files
944ec77
raw
history blame
No virus
3.7 kB
import numpy as np
import torch
from utils import utils_image as util
def infer(model, L):
E = model(L)
return E
def inferp(model, L, modulo=16):
h, w = L.size()[-2:]
paddingBottom = int(np.ceil(h/modulo)*modulo-h)
paddingRight = int(np.ceil(w/modulo)*modulo-w)
L = torch.nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(L)
E = model(L)
E = E[..., :h, :w]
return E
def inferspfn(model, L, refield=32, min_size=256, sf=1, modulo=1):
h, w = L.size()[-2:]
if h*w <= min_size**2:
L = torch.nn.ReplicationPad2d((0, int(np.ceil(w/modulo)*modulo-w), 0, int(np.ceil(h/modulo)*modulo-h)))(L)
E = model(L)
E = E[..., :h*sf, :w*sf]
else:
top = slice(0, (h//2//refield+1)*refield)
bottom = slice(h - (h//2//refield+1)*refield, h)
left = slice(0, (w//2//refield+1)*refield)
right = slice(w - (w//2//refield+1)*refield, w)
Ls = [L[..., top, left], L[..., top, right], L[..., bottom, left], L[..., bottom, right]]
if h * w <= 4*(min_size**2):
Es = [model(Ls[i]) for i in range(4)]
else:
Es = [inferspfn(model, Ls[i], refield=refield, min_size=min_size, sf=sf, modulo=modulo) for i in range(4)]
b, c = Es[0].size()[:2]
E = torch.zeros(b, c, sf * h, sf * w).type_as(L)
E[..., :h//2*sf, :w//2*sf] = Es[0][..., :h//2*sf, :w//2*sf]
E[..., :h//2*sf, w//2*sf:w*sf] = Es[1][..., :h//2*sf, (-w + w//2)*sf:]
E[..., h//2*sf:h*sf, :w//2*sf] = Es[2][..., (-h + h//2)*sf:, :w//2*sf]
E[..., h//2*sf:h*sf, w//2*sf:w*sf] = Es[3][..., (-h + h//2)*sf:, (-w + w//2)*sf:]
return E
def infersp(model, L, refield=32, min_size=256, sf=1, modulo=1):
E = inferspfn(model, L, refield=refield, min_size=min_size, sf=sf, modulo=modulo)
return E
def inferosp(model, L, refield=32, min_size=256, sf=1, modulo=1):
h, w = L.size()[-2:]
top = slice(0, (h//2//refield+1)*refield)
bottom = slice(h - (h//2//refield+1)*refield, h)
left = slice(0, (w//2//refield+1)*refield)
right = slice(w - (w//2//refield+1)*refield, w)
Ls = [L[..., top, left], L[..., top, right], L[..., bottom, left], L[..., bottom, right]]
Es = [model(Ls[i]) for i in range(4)]
b, c = Es[0].size()[:2]
E = torch.zeros(b, c, sf * h, sf * w).type_as(L)
E[..., :h//2*sf, :w//2*sf] = Es[0][..., :h//2*sf, :w//2*sf]
E[..., :h//2*sf, w//2*sf:w*sf] = Es[1][..., :h//2*sf, (-w + w//2)*sf:]
E[..., h//2*sf:h*sf, :w//2*sf] = Es[2][..., (-h + h//2)*sf:, :w//2*sf]
E[..., h//2*sf:h*sf, w//2*sf:w*sf] = Es[3][..., (-h + h//2)*sf:, (-w + w//2)*sf:]
return E
def inference(model, L, mode=0, refield=128, min_size=256, sf=1, modulo=1):
if mode == 0:
E = infer(model, L)
elif mode == 1:
E = inferp(model, L, modulo)
elif mode == 2:
E = infersp(model, L, refield, min_size, sf, modulo)
elif mode == 3:
E = inferosp(model, L, refield, min_size, sf, modulo)
return E
if __name__ == '__main__':
class Net(torch.nn.Module):
def __init__(self, in_channels=3, out_channels=3):
super(Net, self).__init__()
self.conv = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1)
def forward(self, x):
x = self.conv(x)
return x
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
model = Net()
model = model.eval()
x = torch.randn((2,3,400,400))
torch.cuda.empty_cache()
with torch.no_grad():
for mode in range(5):
y = inference(model, x, mode)
print(y.shape)