|
from arch.hourglass import image_transformer_v2 as itv2 |
|
from arch.hourglass.image_transformer_v2 import ImageTransformerDenoiserModelV2 |
|
from arch.swinir.swinir import SwinIR |
|
|
|
|
|
def create_arch(arch, condition_channels=0): |
|
|
|
arch_name, arch_size = arch.split('_') |
|
arch_config = arch_configs[arch_name][arch_size].copy() |
|
arch_config['in_channels'] += condition_channels |
|
return arch_name_to_object[arch_name](**arch_config) |
|
|
|
|
|
arch_configs = { |
|
'hdit': { |
|
"ImageNet256Sp4": { |
|
'in_channels': 3, |
|
'out_channels': 3, |
|
'widths': [256, 512, 1024], |
|
'depths': [2, 2, 8], |
|
'patch_size': [4, 4], |
|
'self_attns': [ |
|
{"type": "neighborhood", "d_head": 64, "kernel_size": 7}, |
|
{"type": "neighborhood", "d_head": 64, "kernel_size": 7}, |
|
{"type": "global", "d_head": 64} |
|
], |
|
'mapping_depth': 2, |
|
'mapping_width': 768, |
|
'dropout_rate': [0, 0, 0], |
|
'mapping_dropout_rate': 0.0 |
|
}, |
|
"XL2": { |
|
'in_channels': 3, |
|
'out_channels': 3, |
|
'widths': [384, 768], |
|
'depths': [2, 11], |
|
'patch_size': [4, 4], |
|
'self_attns': [ |
|
{"type": "neighborhood", "d_head": 64, "kernel_size": 7}, |
|
{"type": "global", "d_head": 64} |
|
], |
|
'mapping_depth': 2, |
|
'mapping_width': 768, |
|
'dropout_rate': [0, 0], |
|
'mapping_dropout_rate': 0.0 |
|
} |
|
|
|
}, |
|
'swinir': { |
|
"M": { |
|
'in_channels': 3, |
|
'out_channels': 3, |
|
'embed_dim': 120, |
|
'depths': [6, 6, 6, 6, 6], |
|
'num_heads': [6, 6, 6, 6, 6], |
|
'resi_connection': '1conv', |
|
'sf': 8 |
|
|
|
}, |
|
"L": { |
|
'in_channels': 3, |
|
'out_channels': 3, |
|
'embed_dim': 180, |
|
'depths': [6, 6, 6, 6, 6, 6, 6, 6], |
|
'num_heads': [6, 6, 6, 6, 6, 6, 6, 6], |
|
'resi_connection': '1conv', |
|
'sf': 8 |
|
}, |
|
}, |
|
} |
|
|
|
|
|
def create_swinir_model(in_channels, out_channels, embed_dim, depths, num_heads, resi_connection, |
|
sf): |
|
return SwinIR( |
|
img_size=64, |
|
patch_size=1, |
|
in_chans=in_channels, |
|
num_out_ch=out_channels, |
|
embed_dim=embed_dim, |
|
depths=depths, |
|
num_heads=num_heads, |
|
window_size=8, |
|
mlp_ratio=2, |
|
sf=sf, |
|
img_range=1.0, |
|
upsampler="nearest+conv", |
|
resi_connection=resi_connection, |
|
unshuffle=True, |
|
unshuffle_scale=8 |
|
) |
|
|
|
|
|
def create_hdit_model(widths, |
|
depths, |
|
self_attns, |
|
dropout_rate, |
|
mapping_depth, |
|
mapping_width, |
|
mapping_dropout_rate, |
|
in_channels, |
|
out_channels, |
|
patch_size |
|
): |
|
assert len(widths) == len(depths) |
|
assert len(widths) == len(self_attns) |
|
assert len(widths) == len(dropout_rate) |
|
mapping_d_ff = mapping_width * 3 |
|
d_ffs = [] |
|
for width in widths: |
|
d_ffs.append(width * 3) |
|
|
|
levels = [] |
|
for depth, width, d_ff, self_attn, dropout in zip(depths, widths, d_ffs, self_attns, dropout_rate): |
|
if self_attn['type'] == 'global': |
|
self_attn = itv2.GlobalAttentionSpec(self_attn.get('d_head', 64)) |
|
elif self_attn['type'] == 'neighborhood': |
|
self_attn = itv2.NeighborhoodAttentionSpec(self_attn.get('d_head', 64), self_attn.get('kernel_size', 7)) |
|
elif self_attn['type'] == 'shifted-window': |
|
self_attn = itv2.ShiftedWindowAttentionSpec(self_attn.get('d_head', 64), self_attn['window_size']) |
|
elif self_attn['type'] == 'none': |
|
self_attn = itv2.NoAttentionSpec() |
|
else: |
|
raise ValueError(f'unsupported self attention type {self_attn["type"]}') |
|
levels.append(itv2.LevelSpec(depth, width, d_ff, self_attn, dropout)) |
|
mapping = itv2.MappingSpec(mapping_depth, mapping_width, mapping_d_ff, mapping_dropout_rate) |
|
model = ImageTransformerDenoiserModelV2( |
|
levels=levels, |
|
mapping=mapping, |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
patch_size=patch_size, |
|
num_classes=0, |
|
mapping_cond_dim=0, |
|
) |
|
|
|
return model |
|
|
|
|
|
arch_name_to_object = { |
|
'hdit': create_hdit_model, |
|
'swinir': create_swinir_model, |
|
} |
|
|