tiny_doodle_embedding / tiny_doodle_embedding_model.py
JosephCatrambone's picture
Add py file for architecture of embedding model.
70bfb14 verified
# Build Model:
import torch
import torch.nn as nn
EMBEDDING_SIZE = 64
class EmbedDoodle(nn.Module):
def __init__(self, embedding_size: int):
# Inputs: 32x32 binary image
# Outputs: An embedding of said image.
super().__init__()
latent_size = 256
embed_depth = 5
#self.input_conv = nn.Conv2d(kernel_size=3, in_channels=1, out_channels=16)
def make_cell(in_size: int, hidden_size: int, out_size: int, add_dropout: bool):
cell = nn.Sequential()
cell.append(nn.Linear(in_size, hidden_size))
cell.append(nn.SELU())
cell.append(nn.Linear(hidden_size, hidden_size))
if add_dropout:
cell.append(nn.Dropout())
cell.append(nn.SELU())
cell.append(nn.Linear(hidden_size, out_size))
return cell
self.preprocess = nn.Sequential(
nn.Conv2d(kernel_size=3, in_channels=1, out_channels=64),
nn.Conv2d(kernel_size=3, in_channels=64, out_channels=64),
nn.SELU(),
nn.Conv2d(kernel_size=3, in_channels=64, out_channels=64),
nn.Conv2d(kernel_size=3, in_channels=64, out_channels=64),
nn.Dropout(),
nn.SELU(),
#nn.AvgPool2d(kernel_size=3), # bx4097
nn.Flatten(),
nn.Linear(36864, latent_size),
nn.SELU(),
)
self.embedding_path = nn.ModuleList()
for i in range(0, embed_depth):
self.embedding_path.append(make_cell(latent_size, latent_size, latent_size, add_dropout=True))
self.embedding_head = nn.Linear(latent_size, embedding_size)
def forward(self, x):
x = x.view(-1, 1, 32, 32)
x = self.preprocess(x)
# We should do this with a dot product to combine these to really get the effects of a highway/resnet.
for c in self.embedding_path:
x = x + c(x)
x = self.embedding_head(x)
embedding = nn.functional.normalize(x, dim=-1)
return embedding