noisy_human / cnn.py
santiviquez's picture
add everything
29457c0
raw
history blame
No virus
2.25 kB
import torch.nn as nn
import torch.nn.functional as F
class CNN(nn.Module):
def __init__(self, input_channels):
super(CNN, self).__init__()
self.input_channels = input_channels
self.conv1 = nn.Conv2d(self.input_channels, 32, kernel_size=(3, 3))
self.batchnorm1 = nn.BatchNorm2d(32)
self.pool1 = nn.MaxPool2d(kernel_size=(3, 3))
self.dropout1 = nn.Dropout(0.3)
self.conv2 = nn.Conv2d(32, 64, kernel_size=(3, 3))
self.batchnorm2 = nn.BatchNorm2d(64)
self.pool2 = nn.MaxPool2d(kernel_size=(1, 3))
self.conv3 = nn.Conv2d(64, 128, kernel_size=(3, 3))
self.batchnorm3 = nn.BatchNorm2d(128)
self.pool3 = nn.MaxPool2d(kernel_size=2)
self.dropout2 = nn.Dropout(0.3)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc1 = nn.Linear(128, 256)
self.fc2 = nn.Linear(256, 512)
self.dropout3 = nn.Dropout(0.5)
self.fc3 = nn.Linear(512, 10)
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Conv2d):
nn.init.xavier_normal_(module.weight.data)
if module.bias is not None:
nn.init.constant_(module.bias.data, 0)
elif isinstance(module, nn.BatchNorm2d):
nn.init.constant_(module.weight.data, 1)
nn.init.constant_(module.bias.data, 0)
elif isinstance(module, nn.Linear):
n = module.in_features
y = 1.0 / n ** (1/2)
module.weight.data.uniform_(-y, y)
module.bias.data.fill_(0)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.batchnorm1(x)
x = self.pool1(x)
x = self.dropout1(x)
x = self.conv2(x)
x = F.relu(x)
x = self.batchnorm2(x)
x = self.pool2(x)
x = self.conv3(x)
x = F.relu(x)
x = self.batchnorm3(x)
x = self.pool3(x)
x = self.dropout2(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout3(x)
x = self.fc2(x)
x = F.relu(x)
x = self.dropout3(x)
x = self.fc3(x)
return x