test / modules /xadapter /adapter.py
bilegentile's picture
Upload folder using huggingface_hub
c19ca42 verified
import torch
import torch.nn as nn
from collections import OrderedDict
from diffusers.models.embeddings import (
TimestepEmbedding,
Timesteps,
)
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def avg_pool_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if dims == 1:
return nn.AvgPool1d(*args, **kwargs)
elif dims == 2:
return nn.AvgPool2d(*args, **kwargs)
elif dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def get_parameter_dtype(parameter: torch.nn.Module):
try:
params = tuple(parameter.parameters())
if len(params) > 0:
return params[0].dtype
buffers = tuple(parameter.buffers())
if len(buffers) > 0:
return buffers[0].dtype
except StopIteration:
# For torch.nn.DataParallel compatibility in PyTorch 1.5
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
first_tuple = next(gen)
return first_tuple[1].dtype
class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
stride = 2 if dims != 3 else (1, 2, 2)
if use_conv:
self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding)
else:
assert self.channels == self.out_channels
from torch.nn import MaxUnpool2d
self.op = MaxUnpool2d(dims, kernel_size=stride, stride=stride)
def forward(self, x):
assert x.shape[1] == self.channels
return self.op(x)
class Upsample(nn.Module):
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
stride = 2 if dims != 3 else (1, 2, 2)
if use_conv:
self.op = nn.ConvTranspose2d(self.channels, self.out_channels, 3, stride=stride, padding=1)
else:
assert self.channels == self.out_channels
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
def forward(self, x, output_size):
assert x.shape[1] == self.channels
return self.op(x, output_size)
class Linear(nn.Module):
def __init__(self, temb_channels, out_channels):
super(Linear, self).__init__()
self.linear = nn.Linear(temb_channels, out_channels)
def forward(self, x):
return self.linear(x)
class ResnetBlock(nn.Module):
def __init__(self, in_c, out_c, down, up, ksize=3, sk=False, use_conv=True, enable_timestep=False, temb_channels=None, use_norm=False):
super().__init__()
self.use_norm = use_norm
self.enable_timestep = enable_timestep
ps = ksize // 2
if in_c != out_c or sk == False:
self.in_conv = nn.Conv2d(in_c, out_c, ksize, 1, ps)
else:
self.in_conv = None
self.block1 = nn.Conv2d(out_c, out_c, 3, 1, 1)
self.act = nn.ReLU()
if use_norm:
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=out_c, eps=1e-6, affine=True)
self.block2 = nn.Conv2d(out_c, out_c, ksize, 1, ps)
if sk == False:
self.skep = nn.Conv2d(in_c, out_c, ksize, 1, ps)
else:
self.skep = None
self.down = down
self.up = up
if self.down:
self.down_opt = Downsample(in_c, use_conv=use_conv)
if self.up:
self.up_opt = Upsample(in_c, use_conv=use_conv)
if enable_timestep:
self.timestep_proj = Linear(temb_channels, out_c)
def forward(self, x, output_size=None, temb=None):
if self.down == True:
x = self.down_opt(x)
if self.up == True:
x = self.up_opt(x, output_size)
if self.in_conv is not None: # edit
x = self.in_conv(x)
h = self.block1(x)
if temb is not None:
temb = self.timestep_proj(temb)[:, :, None, None]
h = h + temb
if self.use_norm:
h = self.norm1(h)
h = self.act(h)
h = self.block2(h)
if self.skep is not None:
return h + self.skep(x)
else:
return h + x
class Adapter_XL(nn.Module):
def __init__(self, in_channels=[1280, 640, 320], out_channels=[1280, 1280, 640], nums_rb=3, ksize=3, sk=True, use_conv=False, use_zero_conv=True,
enable_timestep=False, use_norm=False, temb_channels=None, fusion_type='ADD'):
super(Adapter_XL, self).__init__()
self.channels = in_channels
self.nums_rb = nums_rb
self.body = []
self.out = []
self.use_zero_conv = use_zero_conv
self.fusion_type = fusion_type
self.gamma = []
self.beta = []
self.norm = []
if fusion_type == "SPADE":
self.use_zero_conv = False
for i in range(len(self.channels)):
if self.fusion_type == 'SPADE':
# Corresponding to SPADE <Semantic Image Synthesis with Spatially-Adaptive Normalization>
self.gamma.append(nn.Conv2d(out_channels[i], out_channels[i], 1, padding=0))
self.beta.append(nn.Conv2d(out_channels[i], out_channels[i], 1, padding=0))
self.norm.append(nn.BatchNorm2d(out_channels[i]))
elif use_zero_conv:
self.out.append(self.make_zero_conv(out_channels[i]))
else:
self.out.append(nn.Conv2d(out_channels[i], out_channels[i], 1, padding=0))
for j in range(nums_rb):
if i==0:
# 1280, 32, 32 -> 1280, 32, 32
self.body.append(
ResnetBlock(in_channels[i], out_channels[i], down=False, up=False, ksize=ksize, sk=sk, use_conv=use_conv,
enable_timestep=enable_timestep, temb_channels=temb_channels, use_norm=use_norm))
# 1280, 32, 32 -> 1280, 32, 32
elif i==1:
# 640, 64, 64 -> 1280, 64, 64
if j==0:
self.body.append(
ResnetBlock(in_channels[i], out_channels[i], down=False, up=False, ksize=ksize, sk=sk,
use_conv=use_conv, enable_timestep=enable_timestep, temb_channels=temb_channels, use_norm=use_norm))
else:
self.body.append(
ResnetBlock(out_channels[i], out_channels[i], down=False, up=False, ksize=ksize,sk=sk,
use_conv=use_conv, enable_timestep=enable_timestep, temb_channels=temb_channels, use_norm=use_norm))
else:
# 320, 64, 64 -> 640, 128, 128
if j==0:
self.body.append(
ResnetBlock(in_channels[i], out_channels[i], down=False, up=True, ksize=ksize, sk=sk,
use_conv=True, enable_timestep=enable_timestep, temb_channels=temb_channels, use_norm=use_norm))
# use convtranspose2d
else:
self.body.append(
ResnetBlock(out_channels[i], out_channels[i], down=False, up=False, ksize=ksize, sk=sk,
use_conv=use_conv, enable_timestep=enable_timestep, temb_channels=temb_channels, use_norm=use_norm))
self.body = nn.ModuleList(self.body)
if self.use_zero_conv:
self.zero_out = nn.ModuleList(self.out)
# if self.fusion_type == 'SPADE':
# self.norm = nn.ModuleList(self.norm)
# self.gamma = nn.ModuleList(self.gamma)
# self.beta = nn.ModuleList(self.beta)
# else:
# self.zero_out = nn.ModuleList(self.out)
# if enable_timestep:
# a = 320
#
# time_embed_dim = a * 4
# self.time_proj = Timesteps(a, True, 0)
# timestep_input_dim = a
#
# self.time_embedding = TimestepEmbedding(
# timestep_input_dim,
# time_embed_dim,
# act_fn='silu',
# post_act_fn=None,
# cond_proj_dim=None,
# )
def make_zero_conv(self, channels):
return zero_module(nn.Conv2d(channels, channels, 1, padding=0))
@property
def dtype(self) -> torch.dtype:
"""
`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
"""
return get_parameter_dtype(self)
def forward(self, x, t=None):
# extract features
features = []
b, c, _, _ = x[-1].shape
if t is not None:
if not torch.is_tensor(t):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = x[0].device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
t = torch.tensor([t], dtype=dtype, device=x[0].device)
elif len(t.shape) == 0:
t = t[None].to(x[0].device)
t = t.expand(b)
t = self.time_proj(t) # b, 320
t = t.to(dtype=x[0].dtype)
t = self.time_embedding(t) # b, 1280
output_size = (b, 640, 128, 128) # last CA layer output
for i in range(len(self.channels)):
for j in range(self.nums_rb):
idx = i * self.nums_rb + j
if j == 0:
if i < 2:
out = self.body[idx](x[i], temb=t)
else:
out = self.body[idx](x[i], output_size=output_size, temb=t)
else:
out = self.body[idx](out, temb=t)
if self.fusion_type == 'SPADE':
out_gamma = self.gamma[i](out)
out_beta = self.beta[i](out)
out = [out_gamma, out_beta]
else:
out = self.zero_out[i](out)
features.append(out)
return features
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module