| import torch | |
| import torch.nn as nn | |
| from PIL import Image | |
| import torchvision.transforms as T | |
| class VisionTokenizer(nn.Module): | |
| def __init__(self, patch_size=16, in_channels=3, embed_dim=512): | |
| super().__init__() | |
| self.patch_size = patch_size | |
| # A convolution that chops the image into patches and projects them | |
| # to the same 'embed_dim' as your text tokens | |
| self.projection = nn.Conv2d( | |
| in_channels, embed_dim, | |
| kernel_size=patch_size, stride=patch_size | |
| ) | |
| def forward(self, x): | |
| # x shape: [Batch, 3, Height, Width] | |
| x = self.projection(x) # [Batch, embed_dim, H/P, W/P] | |
| x = x.flatten(2).transpose(1, 2) # [Batch, num_patches, embed_dim] | |
| return x | |
| # --- Example of how to use it with your Repo --- | |
| # 1. Load an image | |
| img = Image.open("cat.jpg").resize((224, 224)) | |
| transform = T.Compose([T.ToTensor()]) | |
| img_tensor = transform(img).unsqueeze(0) # [1, 3, 224, 224] | |
| # 2. Tokenize it | |
| v_tokenizer = VisionTokenizer(patch_size=16, embed_dim=512) | |
| visual_tokens = v_tokenizer(img_tensor) | |
| print(f"Image transformed into {visual_tokens.shape[1]} visual tokens.") |