|
import torch.nn as nn |
|
|
|
|
|
encoder = nn.Sequential( |
|
nn.Conv2d(3, 3, (1, 1)), |
|
nn.ReflectionPad2d((1, 1, 1, 1)), |
|
nn.Conv2d(3, 64, (3, 3)), |
|
nn.ReLU(), |
|
nn.ReflectionPad2d((1, 1, 1, 1)), |
|
nn.Conv2d(64, 64, (3, 3)), |
|
nn.ReLU(), |
|
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), |
|
nn.ReflectionPad2d((1, 1, 1, 1)), |
|
nn.Conv2d(64, 128, (3, 3)), |
|
nn.ReLU(), |
|
nn.ReflectionPad2d((1, 1, 1, 1)), |
|
nn.Conv2d(128, 128, (3, 3)), |
|
nn.ReLU(), |
|
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), |
|
nn.ReflectionPad2d((1, 1, 1, 1)), |
|
nn.Conv2d(128, 256, (3, 3)), |
|
nn.ReLU(), |
|
nn.ReflectionPad2d((1, 1, 1, 1)), |
|
nn.Conv2d(256, 256, (3, 3)), |
|
nn.ReLU(), |
|
nn.ReflectionPad2d((1, 1, 1, 1)), |
|
nn.Conv2d(256, 256, (3, 3)), |
|
nn.ReLU(), |
|
nn.ReflectionPad2d((1, 1, 1, 1)), |
|
nn.Conv2d(256, 256, (3, 3)), |
|
nn.ReLU(), |
|
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), |
|
nn.ReflectionPad2d((1, 1, 1, 1)), |
|
nn.Conv2d(256, 512, (3, 3)), |
|
nn.ReLU(), |
|
nn.ReflectionPad2d((1, 1, 1, 1)), |
|
nn.Conv2d(512, 512, (3, 3)), |
|
nn.ReLU(), |
|
nn.ReflectionPad2d((1, 1, 1, 1)), |
|
nn.Conv2d(512, 512, (3, 3)), |
|
nn.ReLU(), |
|
nn.ReflectionPad2d((1, 1, 1, 1)), |
|
nn.Conv2d(512, 512, (3, 3)), |
|
nn.ReLU(), |
|
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), |
|
nn.ReflectionPad2d((1, 1, 1, 1)), |
|
nn.Conv2d(512, 512, (3, 3)), |
|
nn.ReLU(), |
|
nn.ReflectionPad2d((1, 1, 1, 1)), |
|
nn.Conv2d(512, 512, (3, 3)), |
|
nn.ReLU(), |
|
nn.ReflectionPad2d((1, 1, 1, 1)), |
|
nn.Conv2d(512, 512, (3, 3)), |
|
nn.ReLU(), |
|
nn.ReflectionPad2d((1, 1, 1, 1)), |
|
nn.Conv2d(512, 512, (3, 3)), |
|
nn.ReLU() |
|
) |