dkoshman commited on
Commit
c29b35f
1 Parent(s): 6e82d4a

image embedding and encoding

Browse files
Files changed (1) hide show
  1. model.py +57 -0
model.py CHANGED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from einops.layers.torch import Rearrange
2
+ import einops
3
+ import math
4
+ import torch.nn as nn
5
+ import torch
6
+
7
+
8
+ class ImageEmbedding(nn.Module):
9
+ """Reshape image into patches and project into given dimension"""
10
+
11
+ def __init__(self, d_model, input_height, input_width, patch_size=16):
12
+ super().__init__()
13
+ assert input_height % patch_size == 0 and input_width % patch_size == 0, \
14
+ "Cannot split image in patches"
15
+
16
+ self.tokenize = Rearrange(
17
+ 'b c (h1 h2) (w1 w2) -> b (c h1 w1) (h2 w2)',
18
+ h2=patch_size,
19
+ w2=patch_size
20
+ )
21
+ self.projection = nn.Linear(patch_size ** 2, d_model)
22
+
23
+ def forward(self, image_batch):
24
+ image_batch = self.tokenize(image_batch)
25
+ image_batch = self.projection(image_batch)
26
+ return image_batch
27
+
28
+
29
+ class PositionalEncoding(nn.Module):
30
+
31
+ def __init__(self, d_model, max_sequence_len=5000):
32
+ super().__init__()
33
+
34
+ # pos - position in sequence, i - index of element embedding
35
+ # PE(pos, 2i) = sin(pos / 10000**(2i / d_model)) = sin(pos * e**(2i * (-log(10000))/d_model))
36
+ # PE(pos, 2i+1) = cos(pos / 10000**(2i / d_model)) = cos(pos * e**(2i * (-log(10000))/d_model))
37
+
38
+ positions = torch.arange(max_sequence_len)
39
+ even_embedding_indices = torch.arange(0, d_model, 2)
40
+
41
+ expression = torch.exp(even_embedding_indices * (-math.log(10000.0) / d_model))
42
+ expression = torch.einsum("i, j -> ij", positions, expression)
43
+
44
+ even_encodings = torch.sin(expression)
45
+ odd_encodings = torch.cos(expression)
46
+
47
+ positional_encodings = einops.rearrange(
48
+ [even_encodings, odd_encodings],
49
+ 'even_odd pos embed -> pos 1 (embed even_odd)'
50
+ )
51
+
52
+ self.register_buffer('positional_encodings', positional_encodings)
53
+
54
+ def forward(self, image_batch):
55
+ batch_size = image_batch.size(0)
56
+ positional_encodings = self.positional_encodings[:batch_size]
57
+ return positional_encodings