nico-x's picture
codebase withouth model
b54146b
import torch
import torch.nn as nn
from .feature_extractor import FeatureExtractor
from .encoder import TransformerEncoder
from .decoder import TransformerDecoder
class ImageToDigitTransformer(nn.Module):
def __init__(self, vocab_size=13, d_model=64, n_heads=4, ff_dim=128,
encoder_depth=4, decoder_depth=2, num_patches=16, max_seq_len=5):
super().__init__()
self.feature_extractor = FeatureExtractor(patch_size=14, emb_dim=d_model)
self.encoder = TransformerEncoder(
depth=encoder_depth,
d_model=d_model,
n_heads=n_heads,
ff_dim=ff_dim,
num_patches=num_patches
)
self.decoder = TransformerDecoder(
vocab_size=vocab_size,
max_len=max_seq_len,
d_model=d_model,
n_heads=n_heads,
ff_dim=ff_dim,
depth=decoder_depth
)
def forward(self, image_tensor, decoder_input_ids):
"""
image_tensor: (B, 1, 56, 56)
decoder_input_ids: (B, 5)
Returns:
logits: (B, 5, vocab_size)
"""
patch_embeddings = self.feature_extractor(image_tensor) # (B, 16, 64)
encoder_output = self.encoder(patch_embeddings) # (B, 16, 64)
logits = self.decoder(decoder_input_ids, encoder_output) # (B, 5, 13)
return logits
if __name__ == '__main__':
model = ImageToDigitTransformer()
img = torch.randn(4, 1, 56, 56)
tokens = torch.randint(0, 13, (4, 5))
logits = model(img, tokens)
print(logits.shape) # Expected: (4, 5, 13)