IntrinsicAnything / operators.py
burningdust
Initial commit
d72c37e
raw
history blame contribute delete
No virus
2.93 kB
import numpy as np
import torch
import scipy
import torch.nn.functional as F
from torch import nn
from torch.autograd import Variable
class Blurkernel(nn.Module):
def __init__(self, blur_type='gaussian', kernel_size=31, std=3.0, device=None):
super().__init__()
self.blur_type = blur_type
self.kernel_size = kernel_size
self.std = std
self.device = device
self.seq = nn.Sequential(
nn.ReflectionPad2d(self.kernel_size//2),
nn.Conv2d(3, 3, self.kernel_size, stride=1, padding=0, bias=False, groups=3)
)
self.weights_init()
def forward(self, x):
return self.seq(x)
def weights_init(self):
if self.blur_type == "gaussian":
n = np.zeros((self.kernel_size, self.kernel_size))
n[self.kernel_size // 2,self.kernel_size // 2] = 1
k = scipy.ndimage.gaussian_filter(n, sigma=self.std)
k = torch.from_numpy(k)
self.k = k
for name, f in self.named_parameters():
f.data.copy_(k)
def update_weights(self, k):
if not torch.is_tensor(k):
k = torch.from_numpy(k).to(self.device)
for name, f in self.named_parameters():
f.data.copy_(k)
def get_kernel(self):
return self.k
class GaussialBlurOperator():
def __init__(self, kernel_size, intensity, device):
self.device = device
self.kernel_size = kernel_size
self.conv = Blurkernel(blur_type='gaussian',
kernel_size=kernel_size,
std=intensity,
device=device).to(device)
self.kernel = self.conv.get_kernel()
self.conv.update_weights(self.kernel.type(torch.float32))
def forward(self, data, **kwargs):
return self.conv(data)
def transpose(self, data, **kwargs):
return data
def get_kernel(self):
return self.kernel.view(1, 1, self.kernel_size, self.kernel_size)
def read_img(img_path, read_alpha=False):
img = imageio.imread(img_path)
img = Image.fromarray(img)
img = np.array(img)
if len(img.shape) == 2:
img = img[:, :, np.newaxis]
if read_alpha:
img = img[:, :, 3:] / 255.0
else:
img = img[:, :, :3] / 255.0
img = torch.from_numpy(img).to(0).float()
return img
if __name__=="__main__":
from PIL import Image
import imageio
operator = GaussialBlurOperator(33, 3.0, 0)
img = read_img("/home/chenxi/code/ml-hypersim/downloads/ai_001_001/images/scene_cam_00_final_preview/frame.0000.diffuse_reflectance.jpg")
img = img[:256, :256]
img_blurred = operator.forward(img[None].permute(0,3,1,2).cuda())[0].permute(1,2,0)
img_out = torch.cat([img, img_blurred], dim=1)
Image.fromarray((img_out.detach().cpu().numpy()*255).astype(np.uint8)).save("dbg/blurred.png")