nico-x's picture
codebase withouth model
b54146b
import torch
from torch.utils.data import Dataset
from torchvision import datasets, transforms
from utils.tokenizer import prepare_decoder_labels, encode, decode
class MNIST_2x2(Dataset):
def __init__(self, base_dataset, transform=None, seed=42):
self.base_dataset = base_dataset
self.transform = transform
self.length = len(base_dataset)
torch.manual_seed(seed)
self.index_map = [
torch.randint(0, self.length, (4,))
for _ in range(self.length)
]
def __len__(self):
return self.length
def __getitem__(self, idx):
indices = self.index_map[idx]
images = [self.base_dataset[i][0] for i in indices]
top_row = torch.cat([images[0], images[1]], dim=2)
bottom_row = torch.cat([images[2], images[3]], dim=2)
grid_image = torch.cat([top_row, bottom_row], dim=1)
labels = [self.base_dataset[i][1] for i in indices]
decoder_input_ids, decoder_target_ids = prepare_decoder_labels(labels)
decoder_input = torch.tensor(decoder_input_ids, dtype=torch.long)
decoder_target = torch.tensor(decoder_target_ids, dtype=torch.long)
return grid_image, decoder_input, decoder_target
# test the dataset and visualize a few samples
if __name__ == "__main__":
import matplotlib.pyplot as plt
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
mnist_train = datasets.MNIST('./data', train=True, download=True, transform=transform)
mnist_test = datasets.MNIST('./data', train=False, download=True, transform=transform)
train_dataset = MNIST_2x2(mnist_train, seed=42)
test_dataset = MNIST_2x2(mnist_test, seed=42)
def show_grid_image(grid_tensor, decoder_target):
# Undo normalization for visualization
img = grid_tensor.clone()
img = img * 0.3081 + 0.1307
img = img.squeeze().numpy()
# Decode token IDs into digit strings
digits = decode(decoder_target.tolist()[:-1]) # Remove <finish> for display
label_str = " ".join(digits)
plt.imshow(img, cmap="gray")
plt.title(f"Digits: {label_str}")
plt.axis("off")
plt.show()
# Visualize a few samples
for i in range(3):
grid_image, decoder_input, decoder_target = train_dataset[i]
show_grid_image(grid_image, decoder_target)