| import torch.nn as nn | |
| decoder = nn.Sequential( | |
| nn.ReflectionPad2d((1, 1, 1, 1)), | |
| nn.Conv2d(512, 256, (3, 3)), | |
| nn.ReLU(), | |
| nn.Upsample(scale_factor=2, mode='nearest'), | |
| nn.ReflectionPad2d((1, 1, 1, 1)), | |
| nn.Conv2d(256, 256, (3, 3)), | |
| nn.ReLU(), | |
| nn.ReflectionPad2d((1, 1, 1, 1)), | |
| nn.Conv2d(256, 256, (3, 3)), | |
| nn.ReLU(), | |
| nn.ReflectionPad2d((1, 1, 1, 1)), | |
| nn.Conv2d(256, 256, (3, 3)), | |
| nn.ReLU(), | |
| nn.ReflectionPad2d((1, 1, 1, 1)), | |
| nn.Conv2d(256, 128, (3, 3)), | |
| nn.ReLU(), | |
| nn.Upsample(scale_factor=2, mode='nearest'), | |
| nn.ReflectionPad2d((1, 1, 1, 1)), | |
| nn.Conv2d(128, 128, (3, 3)), | |
| nn.ReLU(), | |
| nn.ReflectionPad2d((1, 1, 1, 1)), | |
| nn.Conv2d(128, 64, (3, 3)), | |
| nn.ReLU(), | |
| nn.Upsample(scale_factor=2, mode='nearest'), | |
| nn.ReflectionPad2d((1, 1, 1, 1)), | |
| nn.Conv2d(64, 64, (3, 3)), | |
| nn.ReLU(), | |
| nn.ReflectionPad2d((1, 1, 1, 1)), | |
| nn.Conv2d(64, 3, (3, 3)), | |
| ) |