ZHIJI_cv_web_ui / NTED /base_module.py
zejunyang
update
9667e74
raw
history blame
3.79 kB
import math
import functools
import sys
import torch
import torch.nn as nn
from NTED.base_function import EncoderLayer, DecoderLayer, ToRGB
from NTED.edge_attention_layer import Edge_Attn
class Encoder(nn.Module):
def __init__(
self,
size,
input_dim,
channels,
num_labels=None,
match_kernels=None,
blur_kernel=[1, 3, 3, 1],
):
super().__init__()
self.first = EncoderLayer(input_dim, channels[size], 1)
self.convs = nn.ModuleList()
log_size = int(math.log(size, 2))
self.log_size = log_size
in_channel = channels[size]
for i in range(log_size-1, 3, -1):
out_channel = channels[2 ** i]
num_label = num_labels[2 ** i] if num_labels is not None else None
match_kernel = match_kernels[2 ** i] if match_kernels is not None else None
use_extraction = num_label and match_kernel
conv = EncoderLayer(
in_channel,
out_channel,
kernel_size=3,
downsample=True,
blur_kernel=blur_kernel,
use_extraction=use_extraction,
num_label=num_label,
match_kernel=match_kernel
)
self.convs.append(conv)
in_channel = out_channel
def forward(self, input, recoder=None):
out = self.first(input)
for idx, layer in enumerate(self.convs):
out = layer(out, recoder)
return out
class Decoder(nn.Module):
def __init__(
self,
size,
channels,
num_labels,
match_kernels,
blur_kernel=[1, 3, 3, 1],
):
super().__init__()
self.convs = nn.ModuleList()
# input at resolution 16*16
in_channel = channels[16]
self.log_size = int(math.log(size, 2))
for i in range(4, self.log_size + 1):
out_channel = channels[2 ** i]
num_label, match_kernel = num_labels[2 ** i], match_kernels[2 ** i]
use_distribution = num_label and match_kernel
upsample = (i != 4)
base_layer = functools.partial(
DecoderLayer,
out_channel=out_channel,
kernel_size=3,
blur_kernel=blur_kernel,
use_distribution=use_distribution,
num_label=num_label,
match_kernel=match_kernel
)
up = nn.Module()
up.conv0 = base_layer(in_channel=in_channel, upsample=upsample)
up.conv1 = base_layer(in_channel=out_channel, upsample=False)
up.to_rgb = ToRGB(out_channel, upsample=upsample)
self.convs.append(up)
in_channel = out_channel
self.num_labels, self.match_kernels = num_labels, match_kernels
self.edge_attn_block = Edge_Attn(in_channels=3)
def forward(self, input, neural_textures, recoder):
counter = 0
out, skip = input, None
for i, up in enumerate(self.convs):
if self.num_labels[2**(i+4)] and self.match_kernels[2**(i+4)]:
neural_texture_conv0 = neural_textures[counter]
neural_texture_conv1 = neural_textures[counter+1]
counter += 2
else:
neural_texture_conv0, neural_texture_conv1 = None, None
out = up.conv0(out, neural_texture=neural_texture_conv0, recoder=recoder)
out = up.conv1(out, neural_texture=neural_texture_conv1, recoder=recoder)
skip = up.to_rgb(out, skip)
image = self.edge_attn_block(skip)
# image = skip
return image