File size: 2,104 Bytes
70bfb14 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
# Build Model:
import torch
import torch.nn as nn
EMBEDDING_SIZE = 64
class EmbedDoodle(nn.Module):
def __init__(self, embedding_size: int):
# Inputs: 32x32 binary image
# Outputs: An embedding of said image.
super().__init__()
latent_size = 256
embed_depth = 5
#self.input_conv = nn.Conv2d(kernel_size=3, in_channels=1, out_channels=16)
def make_cell(in_size: int, hidden_size: int, out_size: int, add_dropout: bool):
cell = nn.Sequential()
cell.append(nn.Linear(in_size, hidden_size))
cell.append(nn.SELU())
cell.append(nn.Linear(hidden_size, hidden_size))
if add_dropout:
cell.append(nn.Dropout())
cell.append(nn.SELU())
cell.append(nn.Linear(hidden_size, out_size))
return cell
self.preprocess = nn.Sequential(
nn.Conv2d(kernel_size=3, in_channels=1, out_channels=64),
nn.Conv2d(kernel_size=3, in_channels=64, out_channels=64),
nn.SELU(),
nn.Conv2d(kernel_size=3, in_channels=64, out_channels=64),
nn.Conv2d(kernel_size=3, in_channels=64, out_channels=64),
nn.Dropout(),
nn.SELU(),
#nn.AvgPool2d(kernel_size=3), # bx4097
nn.Flatten(),
nn.Linear(36864, latent_size),
nn.SELU(),
)
self.embedding_path = nn.ModuleList()
for i in range(0, embed_depth):
self.embedding_path.append(make_cell(latent_size, latent_size, latent_size, add_dropout=True))
self.embedding_head = nn.Linear(latent_size, embedding_size)
def forward(self, x):
x = x.view(-1, 1, 32, 32)
x = self.preprocess(x)
# We should do this with a dot product to combine these to really get the effects of a highway/resnet.
for c in self.embedding_path:
x = x + c(x)
x = self.embedding_head(x)
embedding = nn.functional.normalize(x, dim=-1)
return embedding
|