|
import torch as th |
|
from einops import rearrange |
|
|
|
from .generic_sampler import SimpleWork |
|
from .w_img import split_wimg, avg_merge_wimg |
|
|
|
class CondIndCircle(SimpleWork): |
|
def __init__(self, shape, eps_scalar_t_fn, num_img, overlap_size=32): |
|
c, h, w = shape |
|
assert overlap_size == w // 2 |
|
self.overlap_size = overlap_size |
|
self.num_img = num_img |
|
final_img_w = w * num_img - self.overlap_size * num_img |
|
super().__init__((c, h, final_img_w), self.get_eps_t_fn(eps_scalar_t_fn)) |
|
|
|
def circle_split(self, in_x): |
|
long_x = th.cat( |
|
[ |
|
in_x, |
|
in_x[:,:,:,:self.overlap_size], |
|
], |
|
dim=-1 |
|
) |
|
xs = split_wimg(long_x, self.num_img, rtn_overlap=False) |
|
return xs |
|
|
|
def circle_merge(self, xs, overlap_size=None): |
|
if overlap_size is None: |
|
overlap_size = self.overlap_size |
|
long_xs = avg_merge_wimg(xs, overlap_size, n=self.num_img, is_avg=True) |
|
return th.cat( |
|
[ |
|
( |
|
long_xs[:,:,:,:overlap_size] + long_xs[:,:,:,-overlap_size:] |
|
) / 2.0, |
|
long_xs[:,:,:,overlap_size:-overlap_size] |
|
], |
|
dim=-1 |
|
) |
|
|
|
def get_eps_t_fn(self, eps_scalar_t_fn): |
|
def eps_t_fn(in_x, scalar_t, y=None): |
|
long_x = th.cat( |
|
[ |
|
in_x, |
|
in_x[:,:,:,:self.overlap_size], |
|
], |
|
dim=-1 |
|
) |
|
xs = split_wimg(long_x, self.num_img, rtn_overlap=False) |
|
if y is not None: |
|
y = y.repeat_interleave(self.num_img) |
|
scalar_t = scalar_t.repeat_interleave(self.num_img) |
|
full_eps = eps_scalar_t_fn(xs, scalar_t, y=y) |
|
full_eps = rearrange( |
|
full_eps, |
|
"(b n) c h w -> n b c h w", n = self.num_img |
|
) |
|
|
|
|
|
half_eps = eps_scalar_t_fn(xs[:,:,:,-self.overlap_size:], scalar_t, y=y) |
|
half_eps = rearrange( |
|
half_eps, |
|
"(b n) c h w -> n b c h w", n = self.num_img |
|
) |
|
|
|
half_eps[-1]=0 |
|
|
|
full_eps[:,:,:,:,-self.overlap_size:] = full_eps[:,:,:,:,-self.overlap_size:] - half_eps |
|
whole_eps = rearrange( |
|
full_eps, |
|
"n b c h w -> (b n) c h w" |
|
) |
|
long_eps = avg_merge_wimg(whole_eps, self.overlap_size, n=self.num_img, is_avg=False) |
|
return th.cat( |
|
[ |
|
( |
|
long_eps[:,:,:,:self.overlap_size] + long_eps[:,:,:,-self.overlap_size:] |
|
) / 2.0, |
|
long_eps[:,:,:,self.overlap_size:-self.overlap_size] |
|
], |
|
dim=-1 |
|
) |
|
return eps_t_fn |
|
|
|
|
|
class CondIndCircleSR(SimpleWork): |
|
def __init__(self, shape, eps_scalar_t_fn, num_img, low_res, overlap_size=32): |
|
c, h, w = shape |
|
assert overlap_size == w // 2 |
|
self.overlap_size = overlap_size |
|
self.low_overlap_size = low_res.shape[-2] // 2 |
|
self.num_img = num_img |
|
final_img_w = w * num_img - self.overlap_size * num_img |
|
assert low_res.shape[-1] == self.low_overlap_size * num_img |
|
|
|
self.square_fn = self.get_square_sr_fn(eps_scalar_t_fn, low_res) |
|
self.half_fn = self.get_half_sr_fn(eps_scalar_t_fn, low_res) |
|
|
|
super().__init__((c, h, final_img_w), self.get_eps_t_fn()) |
|
|
|
def circle_split(self, in_x, overlap_size=None): |
|
if overlap_size is None: |
|
overlap_size = self.overlap_size |
|
long_x = th.cat( |
|
[ |
|
in_x, |
|
in_x[:,:,:,:overlap_size], |
|
], |
|
dim=-1 |
|
) |
|
xs = split_wimg(long_x, self.num_img, rtn_overlap=False) |
|
return xs |
|
|
|
def circle_merge(self, xs, overlap_size=None): |
|
if overlap_size is None: |
|
overlap_size = self.overlap_size |
|
long_xs = avg_merge_wimg(xs, overlap_size, n=self.num_img, is_avg=True) |
|
return th.cat( |
|
[ |
|
( |
|
long_xs[:,:,:,:overlap_size] + long_xs[:,:,:,-overlap_size:] |
|
) / 2.0, |
|
long_xs[:,:,:,overlap_size:-overlap_size] |
|
], |
|
dim=-1 |
|
) |
|
|
|
def get_square_sr_fn(self, eps_fn, low_res): |
|
low_res = self.circle_split(low_res, self.low_overlap_size) |
|
def _fn(_x, _t, enable_grad): |
|
context = th.enable_grad if enable_grad else th.no_grad |
|
with context(): |
|
vec_t = th.ones(_x.shape[0]).cuda() * _t |
|
rtn = eps_fn(_x, vec_t, low_res) |
|
rtn = rearrange( |
|
rtn, |
|
"(b n) c h w -> n b c h w", n = self.num_img |
|
) |
|
return rtn |
|
return _fn |
|
|
|
def get_half_sr_fn(self, eps_fn, low_res): |
|
low_res = self.circle_split(low_res, self.low_overlap_size) |
|
def _fn(_x, _t, enable_grad): |
|
context = th.enable_grad if enable_grad else th.no_grad |
|
with context(): |
|
vec_t = th.ones(_x.shape[0]).cuda() * _t |
|
half_eps = eps_fn(_x[:,:,:,-self.overlap_size:], vec_t, low_res[:,:,:,-self.low_overlap_size:]) |
|
half_eps = rearrange( |
|
half_eps, |
|
"(b n) c h w -> n b c h w", n = self.num_img |
|
) |
|
|
|
half_eps[-1]=0 |
|
return half_eps |
|
return _fn |
|
|
|
def get_eps_t_fn(self): |
|
def eps_t_fn(in_x, scalar_t, enable_grad=False): |
|
long_x = th.cat( |
|
[ |
|
in_x, |
|
in_x[:,:,:,:self.overlap_size], |
|
], |
|
dim=-1 |
|
) |
|
xs = split_wimg(long_x, self.num_img, rtn_overlap=False) |
|
|
|
|
|
full_eps = self.square_fn(xs, scalar_t, enable_grad) |
|
|
|
half_eps = self.half_fn(xs, scalar_t, enable_grad) |
|
|
|
full_eps[:,:,:,:,-self.overlap_size:] = full_eps[:,:,:,:,-self.overlap_size:] - half_eps |
|
whole_eps = rearrange( |
|
full_eps, |
|
"n b c h w -> (b n) c h w" |
|
) |
|
long_eps = avg_merge_wimg(whole_eps, self.overlap_size, n=self.num_img, is_avg=False) |
|
return th.cat( |
|
[ |
|
( |
|
long_eps[:,:,:,:self.overlap_size] + long_eps[:,:,:,-self.overlap_size:] |
|
) / 2.0, |
|
long_eps[:,:,:,self.overlap_size:-self.overlap_size] |
|
], |
|
dim=-1 |
|
) |
|
return eps_t_fn |