|
from transformers import PreTrainedModel |
|
from .AUNet import AUNet |
|
from .AUNetConfig import AUNetConfig |
|
import torch |
|
|
|
class s2l8hModel(PreTrainedModel): |
|
config_class=AUNetConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.model = AUNet( |
|
in_channels = config.in_channels, out_channels = config.out_channels, |
|
depth = config.depth, spatial_attention = config.spatial_attention, |
|
growth_factor = config.growth_factor, interp_mode = config.interp_mode, |
|
up_mode = config.up_mode, ca_layer = config.ca_layer |
|
) |
|
|
|
|
|
def forward(self, MS, PAN): |
|
return self.model.forward(MS, PAN) |