File size: 440 Bytes
ec9a6bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch
from torch import nn
from einops import rearrange

from GHA.lib.network.Upsampler import Upsampler

class SuperResolutionModule(nn.Module):
    def __init__(self, cfg):
        super(SuperResolutionModule, self).__init__()
        
        self.upsampler = Upsampler(cfg.input_dim, cfg.output_dim, cfg.network_capacity)

    def forward(self, input):
        output = self.upsampler(input)
        return output