File size: 979 Bytes
513c07c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
import torch
import torch.nn as nn
from super_image import RcanModel, RcanConfig
class CustomRcan(RcanModel):
"""
RCAN variant without sub_mean / add_mean normalization.
Useful for physical variables like wind components (u, v),
where image normalization is not applicable.
"""
def forward(self, x):
# Skip sub_mean and add_mean
x = self.head(x)
res = self.body(x)
res += x
x = self.tail(res)
return x
def load_rcan(pretrained_repo="lschmidt/rcan-dsc", config_file="config.json", weight_file="pytorch_model_4x.pt"):
from huggingface_hub import hf_hub_download
config, _ = RcanConfig.from_pretrained(pretrained_repo, config_filename=config_file)
model = CustomRcan(config)
state_dict_path = hf_hub_download(repo_id=pretrained_repo, filename=weight_file)
state_dict = torch.load(state_dict_path, map_location="cpu")
model.load_state_dict(state_dict, strict=False)
return model
|