from arch.hourglass import image_transformer_v2 as itv2
from arch.hourglass.image_transformer_v2 import ImageTransformerDenoiserModelV2
from arch.swinir.swinir import SwinIR


def create_arch(arch, condition_channels=0):
    # arch should be, e.g., swinir_XL, or hdit_XL
    arch_name, arch_size = arch.split('_')
    arch_config = arch_configs[arch_name][arch_size].copy()
    arch_config['in_channels'] += condition_channels
    return arch_name_to_object[arch_name](**arch_config)


arch_configs = {
    'hdit': {
        "ImageNet256Sp4": {
            'in_channels': 3,
            'out_channels': 3,
            'widths': [256, 512, 1024],
            'depths': [2, 2, 8],
            'patch_size': [4, 4],
            'self_attns': [
                {"type": "neighborhood", "d_head": 64, "kernel_size": 7},
                {"type": "neighborhood", "d_head": 64, "kernel_size": 7},
                {"type": "global", "d_head": 64}
            ],
            'mapping_depth': 2,
            'mapping_width': 768,
            'dropout_rate': [0, 0, 0],
            'mapping_dropout_rate': 0.0
        },
        "XL2": {
            'in_channels': 3,
            'out_channels': 3,
            'widths': [384, 768],
            'depths': [2, 11],
            'patch_size': [4, 4],
            'self_attns': [
                {"type": "neighborhood", "d_head": 64, "kernel_size": 7},
                {"type": "global", "d_head": 64}
            ],
            'mapping_depth': 2,
            'mapping_width': 768,
            'dropout_rate': [0, 0],
            'mapping_dropout_rate': 0.0
        }

    },
    'swinir': {
        "M": {
            'in_channels': 3,
            'out_channels': 3,
            'embed_dim': 120,
            'depths': [6, 6, 6, 6, 6],
            'num_heads': [6, 6, 6, 6, 6],
            'resi_connection': '1conv',
            'sf': 8

        },
        "L": {
            'in_channels': 3,
            'out_channels': 3,
            'embed_dim': 180,
            'depths': [6, 6, 6, 6, 6, 6, 6, 6],
            'num_heads': [6, 6, 6, 6, 6, 6, 6, 6],
            'resi_connection': '1conv',
            'sf': 8
        },
    },
}


def create_swinir_model(in_channels, out_channels, embed_dim, depths, num_heads, resi_connection,
                        sf):
    return SwinIR(
        img_size=64,
        patch_size=1,
        in_chans=in_channels,
        num_out_ch=out_channels,
        embed_dim=embed_dim,
        depths=depths,
        num_heads=num_heads,
        window_size=8,
        mlp_ratio=2,
        sf=sf,
        img_range=1.0,
        upsampler="nearest+conv",
        resi_connection=resi_connection,
        unshuffle=True,
        unshuffle_scale=8
    )


def create_hdit_model(widths,
                      depths,
                      self_attns,
                      dropout_rate,
                      mapping_depth,
                      mapping_width,
                      mapping_dropout_rate,
                      in_channels,
                      out_channels,
                      patch_size
                      ):
    assert len(widths) == len(depths)
    assert len(widths) == len(self_attns)
    assert len(widths) == len(dropout_rate)
    mapping_d_ff = mapping_width * 3
    d_ffs = []
    for width in widths:
        d_ffs.append(width * 3)

    levels = []
    for depth, width, d_ff, self_attn, dropout in zip(depths, widths, d_ffs, self_attns, dropout_rate):
        if self_attn['type'] == 'global':
            self_attn = itv2.GlobalAttentionSpec(self_attn.get('d_head', 64))
        elif self_attn['type'] == 'neighborhood':
            self_attn = itv2.NeighborhoodAttentionSpec(self_attn.get('d_head', 64), self_attn.get('kernel_size', 7))
        elif self_attn['type'] == 'shifted-window':
            self_attn = itv2.ShiftedWindowAttentionSpec(self_attn.get('d_head', 64), self_attn['window_size'])
        elif self_attn['type'] == 'none':
            self_attn = itv2.NoAttentionSpec()
        else:
            raise ValueError(f'unsupported self attention type {self_attn["type"]}')
        levels.append(itv2.LevelSpec(depth, width, d_ff, self_attn, dropout))
    mapping = itv2.MappingSpec(mapping_depth, mapping_width, mapping_d_ff, mapping_dropout_rate)
    model = ImageTransformerDenoiserModelV2(
        levels=levels,
        mapping=mapping,
        in_channels=in_channels,
        out_channels=out_channels,
        patch_size=patch_size,
        num_classes=0,
        mapping_cond_dim=0,
    )

    return model


arch_name_to_object = {
    'hdit': create_hdit_model,
    'swinir': create_swinir_model,
}