Spaces:
Runtime error
Runtime error
dkoshman
commited on
Commit
•
c29b35f
1
Parent(s):
6e82d4a
image embedding and encoding
Browse files
model.py
CHANGED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from einops.layers.torch import Rearrange
|
2 |
+
import einops
|
3 |
+
import math
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
class ImageEmbedding(nn.Module):
|
9 |
+
"""Reshape image into patches and project into given dimension"""
|
10 |
+
|
11 |
+
def __init__(self, d_model, input_height, input_width, patch_size=16):
|
12 |
+
super().__init__()
|
13 |
+
assert input_height % patch_size == 0 and input_width % patch_size == 0, \
|
14 |
+
"Cannot split image in patches"
|
15 |
+
|
16 |
+
self.tokenize = Rearrange(
|
17 |
+
'b c (h1 h2) (w1 w2) -> b (c h1 w1) (h2 w2)',
|
18 |
+
h2=patch_size,
|
19 |
+
w2=patch_size
|
20 |
+
)
|
21 |
+
self.projection = nn.Linear(patch_size ** 2, d_model)
|
22 |
+
|
23 |
+
def forward(self, image_batch):
|
24 |
+
image_batch = self.tokenize(image_batch)
|
25 |
+
image_batch = self.projection(image_batch)
|
26 |
+
return image_batch
|
27 |
+
|
28 |
+
|
29 |
+
class PositionalEncoding(nn.Module):
|
30 |
+
|
31 |
+
def __init__(self, d_model, max_sequence_len=5000):
|
32 |
+
super().__init__()
|
33 |
+
|
34 |
+
# pos - position in sequence, i - index of element embedding
|
35 |
+
# PE(pos, 2i) = sin(pos / 10000**(2i / d_model)) = sin(pos * e**(2i * (-log(10000))/d_model))
|
36 |
+
# PE(pos, 2i+1) = cos(pos / 10000**(2i / d_model)) = cos(pos * e**(2i * (-log(10000))/d_model))
|
37 |
+
|
38 |
+
positions = torch.arange(max_sequence_len)
|
39 |
+
even_embedding_indices = torch.arange(0, d_model, 2)
|
40 |
+
|
41 |
+
expression = torch.exp(even_embedding_indices * (-math.log(10000.0) / d_model))
|
42 |
+
expression = torch.einsum("i, j -> ij", positions, expression)
|
43 |
+
|
44 |
+
even_encodings = torch.sin(expression)
|
45 |
+
odd_encodings = torch.cos(expression)
|
46 |
+
|
47 |
+
positional_encodings = einops.rearrange(
|
48 |
+
[even_encodings, odd_encodings],
|
49 |
+
'even_odd pos embed -> pos 1 (embed even_odd)'
|
50 |
+
)
|
51 |
+
|
52 |
+
self.register_buffer('positional_encodings', positional_encodings)
|
53 |
+
|
54 |
+
def forward(self, image_batch):
|
55 |
+
batch_size = image_batch.size(0)
|
56 |
+
positional_encodings = self.positional_encodings[:batch_size]
|
57 |
+
return positional_encodings
|