Spaces:
Running
on
Zero
Running
on
Zero
from arch.hourglass import image_transformer_v2 as itv2 | |
from arch.hourglass.image_transformer_v2 import ImageTransformerDenoiserModelV2 | |
from arch.swinir.swinir import SwinIR | |
def create_arch(arch, condition_channels=0): | |
# arch should be, e.g., swinir_XL, or hdit_XL | |
arch_name, arch_size = arch.split('_') | |
arch_config = arch_configs[arch_name][arch_size].copy() | |
arch_config['in_channels'] += condition_channels | |
return arch_name_to_object[arch_name](**arch_config) | |
arch_configs = { | |
'hdit': { | |
"ImageNet256Sp4": { | |
'in_channels': 3, | |
'out_channels': 3, | |
'widths': [256, 512, 1024], | |
'depths': [2, 2, 8], | |
'patch_size': [4, 4], | |
'self_attns': [ | |
{"type": "neighborhood", "d_head": 64, "kernel_size": 7}, | |
{"type": "neighborhood", "d_head": 64, "kernel_size": 7}, | |
{"type": "global", "d_head": 64} | |
], | |
'mapping_depth': 2, | |
'mapping_width': 768, | |
'dropout_rate': [0, 0, 0], | |
'mapping_dropout_rate': 0.0 | |
}, | |
"XL2": { | |
'in_channels': 3, | |
'out_channels': 3, | |
'widths': [384, 768], | |
'depths': [2, 11], | |
'patch_size': [4, 4], | |
'self_attns': [ | |
{"type": "neighborhood", "d_head": 64, "kernel_size": 7}, | |
{"type": "global", "d_head": 64} | |
], | |
'mapping_depth': 2, | |
'mapping_width': 768, | |
'dropout_rate': [0, 0], | |
'mapping_dropout_rate': 0.0 | |
} | |
}, | |
'swinir': { | |
"M": { | |
'in_channels': 3, | |
'out_channels': 3, | |
'embed_dim': 120, | |
'depths': [6, 6, 6, 6, 6], | |
'num_heads': [6, 6, 6, 6, 6], | |
'resi_connection': '1conv', | |
'sf': 8 | |
}, | |
"L": { | |
'in_channels': 3, | |
'out_channels': 3, | |
'embed_dim': 180, | |
'depths': [6, 6, 6, 6, 6, 6, 6, 6], | |
'num_heads': [6, 6, 6, 6, 6, 6, 6, 6], | |
'resi_connection': '1conv', | |
'sf': 8 | |
}, | |
}, | |
} | |
def create_swinir_model(in_channels, out_channels, embed_dim, depths, num_heads, resi_connection, | |
sf): | |
return SwinIR( | |
img_size=64, | |
patch_size=1, | |
in_chans=in_channels, | |
num_out_ch=out_channels, | |
embed_dim=embed_dim, | |
depths=depths, | |
num_heads=num_heads, | |
window_size=8, | |
mlp_ratio=2, | |
sf=sf, | |
img_range=1.0, | |
upsampler="nearest+conv", | |
resi_connection=resi_connection, | |
unshuffle=True, | |
unshuffle_scale=8 | |
) | |
def create_hdit_model(widths, | |
depths, | |
self_attns, | |
dropout_rate, | |
mapping_depth, | |
mapping_width, | |
mapping_dropout_rate, | |
in_channels, | |
out_channels, | |
patch_size | |
): | |
assert len(widths) == len(depths) | |
assert len(widths) == len(self_attns) | |
assert len(widths) == len(dropout_rate) | |
mapping_d_ff = mapping_width * 3 | |
d_ffs = [] | |
for width in widths: | |
d_ffs.append(width * 3) | |
levels = [] | |
for depth, width, d_ff, self_attn, dropout in zip(depths, widths, d_ffs, self_attns, dropout_rate): | |
if self_attn['type'] == 'global': | |
self_attn = itv2.GlobalAttentionSpec(self_attn.get('d_head', 64)) | |
elif self_attn['type'] == 'neighborhood': | |
self_attn = itv2.NeighborhoodAttentionSpec(self_attn.get('d_head', 64), self_attn.get('kernel_size', 7)) | |
elif self_attn['type'] == 'shifted-window': | |
self_attn = itv2.ShiftedWindowAttentionSpec(self_attn.get('d_head', 64), self_attn['window_size']) | |
elif self_attn['type'] == 'none': | |
self_attn = itv2.NoAttentionSpec() | |
else: | |
raise ValueError(f'unsupported self attention type {self_attn["type"]}') | |
levels.append(itv2.LevelSpec(depth, width, d_ff, self_attn, dropout)) | |
mapping = itv2.MappingSpec(mapping_depth, mapping_width, mapping_d_ff, mapping_dropout_rate) | |
model = ImageTransformerDenoiserModelV2( | |
levels=levels, | |
mapping=mapping, | |
in_channels=in_channels, | |
out_channels=out_channels, | |
patch_size=patch_size, | |
num_classes=0, | |
mapping_cond_dim=0, | |
) | |
return model | |
arch_name_to_object = { | |
'hdit': create_hdit_model, | |
'swinir': create_swinir_model, | |
} | |