File size: 1,026 Bytes
baa8e90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
import torch.nn as nn
import numpy as np

class LatentUpscaler(nn.Module):
	def head(self):
		return [
			nn.Conv2d(self.chan, self.size, kernel_size=self.krn, padding=self.pad),
			nn.ReLU(),
			nn.Upsample(scale_factor=self.fac, mode="nearest"),
			nn.ReLU(),
		]
	def core(self):
		layers = []
		for _ in range(self.depth):
			layers += [
				nn.Conv2d(self.size, self.size, kernel_size=self.krn, padding=self.pad),
				nn.ReLU(),
			]
		return layers
	def tail(self):
		return [
			nn.Conv2d(self.size, self.chan, kernel_size=self.krn, padding=self.pad),
		]

	def __init__(self, fac, depth=16):
		super().__init__()
		self.size = 64      # Conv2d size
		self.chan = 4       # in/out channels
		self.depth = depth  # no. of layers
		self.fac = fac      # scale factor
		self.krn = 3        # kernel size
		self.pad = 1        # padding

		self.sequential = nn.Sequential(
			*self.head(),
			*self.core(),
			*self.tail(),
		)

	def forward(self, x: torch.Tensor) -> torch.Tensor:
		return self.sequential(x)