File size: 1,828 Bytes
617d388
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: OSAG.py
# Created Date: Tuesday April 28th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified:  Sunday, 23rd April 2023 3:08:49 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2020 Shanghai Jiao Tong University
#############################################################


import torch.nn as nn

from .esa import ESA
from .OSA import OSA_Block


class OSAG(nn.Module):
    def __init__(
        self,
        channel_num=64,
        bias=True,
        block_num=4,
        ffn_bias=False,
        window_size=0,
        pe=False,
    ):
        super(OSAG, self).__init__()

        # print("window_size: %d" % (window_size))
        # print("with_pe", pe)
        # print("ffn_bias: %d" % (ffn_bias))

        # block_script_name = kwargs.get("block_script_name", "OSA")
        # block_class_name = kwargs.get("block_class_name", "OSA_Block")

        # script_name = "." + block_script_name
        # package = __import__(script_name, fromlist=True)
        block_class = OSA_Block  # getattr(package, block_class_name)
        group_list = []
        for _ in range(block_num):
            temp_res = block_class(
                channel_num,
                bias,
                ffn_bias=ffn_bias,
                window_size=window_size,
                with_pe=pe,
            )
            group_list.append(temp_res)
        group_list.append(nn.Conv2d(channel_num, channel_num, 1, 1, 0, bias=bias))
        self.residual_layer = nn.Sequential(*group_list)
        esa_channel = max(channel_num // 4, 16)
        self.esa = ESA(esa_channel, channel_num)

    def forward(self, x):
        out = self.residual_layer(x)
        out = out + x
        return self.esa(out)