deepkyu's picture
initial commit
1ba3df3
raw
history blame
No virus
1.02 kB
import torch
import torch.nn as nn
from . import encoder, decoder
class Generator(nn.Module):
def __init__(self, hp, in_channels=1):
super().__init__()
self.hp = hp
_ngf = 64
hidden_dim = _ngf * 4
self.content_encoder = getattr(encoder, self.hp.encoder.content.type)(self.hp, in_channels, hidden_dim)
self.style_encoder = getattr(encoder, self.hp.encoder.style.type)(self.hp, in_channels, hidden_dim)
self.decoder = getattr(decoder, self.hp.decoder.type)(self.hp, hidden_dim * 2, in_channels)
def forward(self, images):
content_images, style_images = images
content_feature = self.content_encoder(content_images)
style_images = style_images * 2 - 1 # pixel value range -1 to 1
style_feature = self.style_encoder(style_images) # K-shot as batch
_, _, H, W = content_feature.size()
out = self.decoder(torch.cat([content_feature, style_feature.expand(-1, -1, H, W)], dim=1))
return out