Mini-gpt-0.000001 / vision_tokenizer.py
AIencoder's picture
Create vision_tokenizer.py
052ff39 verified
import torch
import torch.nn as nn
from PIL import Image
import torchvision.transforms as T
class VisionTokenizer(nn.Module):
def __init__(self, patch_size=16, in_channels=3, embed_dim=512):
super().__init__()
self.patch_size = patch_size
# A convolution that chops the image into patches and projects them
# to the same 'embed_dim' as your text tokens
self.projection = nn.Conv2d(
in_channels, embed_dim,
kernel_size=patch_size, stride=patch_size
)
def forward(self, x):
# x shape: [Batch, 3, Height, Width]
x = self.projection(x) # [Batch, embed_dim, H/P, W/P]
x = x.flatten(2).transpose(1, 2) # [Batch, num_patches, embed_dim]
return x
# --- Example of how to use it with your Repo ---
# 1. Load an image
img = Image.open("cat.jpg").resize((224, 224))
transform = T.Compose([T.ToTensor()])
img_tensor = transform(img).unsqueeze(0) # [1, 3, 224, 224]
# 2. Tokenize it
v_tokenizer = VisionTokenizer(patch_size=16, embed_dim=512)
visual_tokens = v_tokenizer(img_tensor)
print(f"Image transformed into {visual_tokens.shape[1]} visual tokens.")