AI-NERD / ModelDefinitions.py
jhorwath's picture
Add model definitions and pretrained weights
1cd3e6c
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchsummary import summary
from torch.utils.data import TensorDataset, DataLoader
class recon_encoder(nn.Module):
def __init__(self, latent_size, nconv=16, pool=4, drop=0.05):
super(recon_encoder, self).__init__()
self.encoder = nn.Sequential( # Appears sequential has similar functionality as TF avoiding need for separate model definition and activ
nn.Conv2d(in_channels=1, out_channels=nconv, kernel_size=3, stride=1, padding=(1,1)),
nn.Dropout(drop),
nn.ReLU(),
nn.Conv2d(nconv, nconv, 3, stride=1, padding=(1,1)),
nn.Dropout(drop),
nn.ReLU(),
nn.MaxPool2d((pool,pool)),
nn.Conv2d(nconv, nconv*2, 3, stride=1, padding=(1,1)),
nn.Dropout(drop),
nn.ReLU(),
nn.Conv2d(nconv*2, nconv*2, 3, stride=1, padding=(1,1)),
nn.Dropout(drop),
nn.ReLU(),
nn.MaxPool2d((pool,pool)),
nn.Conv2d(nconv*2, nconv*4, 3, stride=1, padding=(1,1)),
nn.Dropout(drop),
nn.ReLU(),
nn.Conv2d(nconv*4, nconv*4, 3, stride=1, padding=(1,1)),
nn.Dropout(drop),
nn.ReLU(),
nn.MaxPool2d((pool,pool)),
#nn.Conv2d(nconv*4, nconv*4, 3, stride=1, padding=(1,1)),
#nn.Dropout(drop),
#nn.ReLU(),
#nn.Conv2d(nconv*4, nconv*4, 3, stride=1, padding=(1,1)),
#nn.Dropout(drop),
#nn.ReLU(),
#nn.MaxPool2d((pool,pool)),
)
self.bottleneck = nn.Sequential(
# FC layer at bottleneck -- dropout might not make sense here
nn.Flatten(),
nn.Linear(1024, latent_size),
#nn.Dropout(drop),
nn.ReLU(),
# nn.Linear(latent_size, 1024),
# #nn.Dropout(drop),
# nn.ReLU(),
# nn.Unflatten(1,(64,4,4))# 0 is batch dimension
)
self.decoder1 = nn.Sequential(
nn.Conv2d(nconv*4, nconv*4, 3, stride=1, padding=(1,1)),
nn.Dropout(drop),
nn.ReLU(),
nn.Conv2d(nconv*4, nconv*4, 3, stride=1, padding=(1,1)),
nn.Dropout(drop),
nn.ReLU(),
nn.Upsample(scale_factor=pool, mode='bilinear'),
nn.Conv2d(nconv*4, nconv*4, 3, stride=1, padding=(1,1)),
nn.Dropout(drop),
nn.ReLU(),
nn.Conv2d(nconv*4, nconv*4, 3, stride=1, padding=(1,1)),
nn.Dropout(drop),
nn.ReLU(),
nn.Upsample(scale_factor=pool, mode='bilinear'),
nn.Conv2d(nconv*4, nconv*2, 3, stride=1, padding=(1,1)),
nn.Dropout(drop),
nn.ReLU(),
nn.Conv2d(nconv*2, nconv*2, 3, stride=1, padding=(1,1)),
nn.Dropout(drop),
nn.ReLU(),
nn.Upsample(scale_factor=pool, mode='bilinear'),
#nn.Conv2d(nconv*2, nconv*2, 3, stride=1, padding=(1,1)),
#nn.Dropout(drop),
#nn.ReLU(),
#nn.Conv2d(nconv*2, nconv*2, 3, stride=1, padding=(1,1)),
#nn.Dropout(drop),
#nn.ReLU(),
#nn.Upsample(scale_factor=pool, mode='bilinear'),
nn.Conv2d(nconv*2, 1, 3, stride=1, padding=(1,1)), #Output conv layer has 2 for mu and sigma
nn.Sigmoid() #Amplitude mode
)
def forward(self,x):
with torch.cuda.amp.autocast():
x1 = self.encoder(x)
x1 = self.bottleneck(x1)
#print(x1.shape)
return x1
#Helper function to calculate size of flattened array from conv layer shapes
def calc_fc_shape(self):
x0 = torch.zeros([256,256]).unsqueeze(0)
x0 = self.encoder(x0)
self.conv_bock_output_shape = x0.shape
#print ("Output of conv block shape is", self.conv_bock_output_shape)
self.flattened_size = x0.flatten().shape[0]
#print ("Flattened layer size is", self.flattened_size)
return self.flattened_size
class recon_model(nn.Module):
def __init__(self, latent_size, nconv=16, pool=4, drop=0.05):
super(recon_model, self).__init__()
self.encoder = nn.Sequential( # Appears sequential has similar functionality as TF avoiding need for separate model definition and activ
nn.Conv2d(in_channels=1, out_channels=nconv, kernel_size=3, stride=1, padding=(1,1)),
nn.Dropout(drop),
nn.ReLU(),
nn.Conv2d(nconv, nconv, 3, stride=1, padding=(1,1)),
nn.Dropout(drop),
nn.ReLU(),
nn.MaxPool2d((pool,pool)),
nn.Conv2d(nconv, nconv*2, 3, stride=1, padding=(1,1)),
nn.Dropout(drop),
nn.ReLU(),
nn.Conv2d(nconv*2, nconv*2, 3, stride=1, padding=(1,1)),
nn.Dropout(drop),
nn.ReLU(),
nn.MaxPool2d((pool,pool)),
nn.Conv2d(nconv*2, nconv*4, 3, stride=1, padding=(1,1)),
nn.Dropout(drop),
nn.ReLU(),
nn.Conv2d(nconv*4, nconv*4, 3, stride=1, padding=(1,1)),
nn.Dropout(drop),
nn.ReLU(),
nn.MaxPool2d((pool,pool)),
#nn.Conv2d(nconv*4, nconv*4, 3, stride=1, padding=(1,1)),
#nn.Dropout(drop),
#nn.ReLU(),
#nn.Conv2d(nconv*4, nconv*4, 3, stride=1, padding=(1,1)),
#nn.Dropout(drop),
#nn.ReLU(),
#nn.MaxPool2d((pool,pool)),
)
self.bottleneck = nn.Sequential(
# FC layer at bottleneck -- dropout might not make sense here
nn.Flatten(),
nn.Linear(1024, latent_size),
#nn.Dropout(drop),
nn.ReLU(),
nn.Linear(latent_size, 1024),
#nn.Dropout(drop),
nn.ReLU(),
nn.Unflatten(1,(64,4,4))# 0 is batch dimension
)
self.decoder1 = nn.Sequential(
nn.Conv2d(nconv*4, nconv*4, 3, stride=1, padding=(1,1)),
nn.Dropout(drop),
nn.ReLU(),
nn.Conv2d(nconv*4, nconv*4, 3, stride=1, padding=(1,1)),
nn.Dropout(drop),
nn.ReLU(),
nn.Upsample(scale_factor=pool, mode='bilinear'),
nn.Conv2d(nconv*4, nconv*4, 3, stride=1, padding=(1,1)),
nn.Dropout(drop),
nn.ReLU(),
nn.Conv2d(nconv*4, nconv*4, 3, stride=1, padding=(1,1)),
nn.Dropout(drop),
nn.ReLU(),
nn.Upsample(scale_factor=pool, mode='bilinear'),
nn.Conv2d(nconv*4, nconv*2, 3, stride=1, padding=(1,1)),
nn.Dropout(drop),
nn.ReLU(),
nn.Conv2d(nconv*2, nconv*2, 3, stride=1, padding=(1,1)),
nn.Dropout(drop),
nn.ReLU(),
nn.Upsample(scale_factor=pool, mode='bilinear'),
#nn.Conv2d(nconv*2, nconv*2, 3, stride=1, padding=(1,1)),
#nn.Dropout(drop),
#nn.ReLU(),
#nn.Conv2d(nconv*2, nconv*2, 3, stride=1, padding=(1,1)),
#nn.Dropout(drop),
#nn.ReLU(),
#nn.Upsample(scale_factor=pool, mode='bilinear'),
nn.Conv2d(nconv*2, 1, 3, stride=1, padding=(1,1)), #Output conv layer has 2 for mu and sigma
nn.Sigmoid() #Amplitude mode
)
def forward(self,x):
with torch.cuda.amp.autocast():
x1 = self.encoder(x)
x1 = self.bottleneck(x1)
#print(x1.shape)
return self.decoder1(x1)
#Helper function to calculate size of flattened array from conv layer shapes
def calc_fc_shape(self):
x0 = torch.zeros([256,256]).unsqueeze(0)
x0 = self.encoder(x0)
self.conv_bock_output_shape = x0.shape
#print ("Output of conv block shape is", self.conv_bock_output_shape)
self.flattened_size = x0.flatten().shape[0]
#print ("Flattened layer size is", self.flattened_size)
return self.flattened_size