import os, torch, torch.nn as nn, torch.utils.data as data, torchvision as tv import lightning as L import numpy as np, pandas as pd, matplotlib.pyplot as plt from pytorch_lightning.loggers import WandbLogger import wandb import pytorch_lightning as pl torch.set_float32_matmul_precision('medium') # create the datasets and dataloaders train_voxels_path = '/fsx/proj-fmri/ckadirt/b2m/data/sub-001_Resp_Training.npy' # path to training voxels 65000 * 4800 test_voxels_path = '/fsx/proj-fmri/ckadirt/b2m/data/sub-001_Resp_Test_Mean.npy' # path to test voxels 65000 * 600 train_embeddings_path = '/fsx/proj-fmri/ckadirt/b2m/data/encodec_training_embeds_sorted.npy' # path to training embeddings 480 * 2 * 1125 test_embeddings_path = '/fsx/proj-fmri/ckadirt/b2m/data/encodec_testing_embeds_sorted.npy' # path to test embeddings 600 * 2 * 1125 class VoxelsDataset(data.Dataset): def __init__(self, voxels_path, embeddings_path): # transpose the two dimensions of the voxels data to match the embeddings data self.voxels = torch.from_numpy(np.load(voxels_path)).float().transpose(0, 1) self.embeddings = torch.from_numpy(np.load(embeddings_path)) # as each stimulus has been exposed for 15 seconds and the fMRI data is sampled every 1.5 seconds, we take 10 samples per stimulus self.len = len(self.voxels) // 10 print("The len is ", self.len ) def __getitem__(self, index): # as each stimulus has been exposed for 15 seconds and the fMRI data is sampled every 1.5 seconds, we take 10 samples per stimulus voxels = self.voxels[index*10:(index+1)*10] embeddings = self.embeddings[index] return voxels, embeddings def __len__(self): return self.len class VoxelsEmbeddinsEncodecDataModule(pl.LightningDataModule): def __init__(self, train_voxels_path, train_embeddings_path, test_voxels_path, test_embeddings_path, batch_size=8): super().__init__() self.train_voxels_path = train_voxels_path self.train_embeddings_path = train_embeddings_path self.test_voxels_path = test_voxels_path self.test_embeddings_path = test_embeddings_path self.batch_size = batch_size def setup(self, stage=None): self.train_dataset = VoxelsDataset(self.train_voxels_path, self.train_embeddings_path) self.test_dataset = VoxelsDataset(self.test_voxels_path, self.test_embeddings_path) def train_dataloader(self): return data.DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True) def val_dataloader(self): return data.DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False) class MLP(pl.LightningModule): def __init__(self, sizes, residual_conections, dropout): # sizes is a list of the sizes of the layers ej: [4800, 1000, 1000, 1000, 1000, 1000, 1000, 600] # residual_conections is a list with the same length as sizes, each element is a list of the indexes of the layers that will recieve the output of the layer as input, 0 means that the layer will recieve the x inputs ej. [[0], [1], [2,1], [3], [4,3], [5], [6,5], [7]] # dropout is a list with the same length as sizes, each element is the dropout probability of the layer ej. [0.5, 0.5, 0.5, 0.5, 0.5, 0.5] super().__init__() self.sizes = sizes self.residual_conections = residual_conections self.dropout = dropout self.layers = nn.Sequential() for i in range(len(sizes)-1): self.layers.add_module('linear'+str(i), nn.Linear(sizes[i], sizes[i+1])) self.layers.add_module('relu'+str(i), nn.ReLU()) self.layers.add_module('dropout'+str(i), nn.Dropout(dropout[i])) self.loss = nn.CrossEntropyLoss(reduction='mean') def forward(self, x): return self.layers(x) def training_step(self, batch, batch_idx): voxels, embeddings = batch # the sizes are [batch_size, 10, 65000] and [batch_size, 2, 1125] # flatten the voxels to [batch_size, rest of the dimensions] embeddings = embeddings.flatten(start_dim=1).long() # the size is [batch_size, 2250] #take just the first 200 embeddings # embeddings = embeddings[:, :200] # take the mean of the second dimension of the voxels to get the mean of the 10 samples per stimulus voxels = voxels.mean(dim=1) voxels = voxels.flatten(start_dim=1) # the size is [batch_size, 65000] outputs = self(voxels) # the outputs are [batch_size, 200*1024], we need to reshape them to [batch_size, 200, 1024] outputs = outputs.reshape(-1, 1024, 1125*2) # avoid division by zero outputs = outputs + 1e-6 #print(outputs.shape, embeddings.shape) #print(outputs[0,0,:10], embeddings[0,:10]) loss = self.loss(outputs, embeddings) #print(loss) acuracy = self.tokens_accuracy(outputs, embeddings) self.log('train_loss', loss) self.log('train_accuracy', acuracy) return loss def tokens_accuracy(self, outputs, embeddings): # outputs is [batch_size, 1024, 200] # embeddings is [batch_size, 200] # we need to get the index of the maximum value of each token outputs = outputs.argmax(dim=1) # now we need to compare the outputs with the embeddings return (outputs == embeddings).float().mean() def validation_step(self, batch, batch_idx): voxels, embeddings = batch embeddings = embeddings.flatten(start_dim=1).long() #embeddings = embeddings[:, :200] voxels = voxels.mean(dim=1) voxels = voxels.flatten(start_dim=1) outputs = self(voxels) outputs = outputs.reshape(-1, 1024, 1125*2) loss = self.loss(outputs, embeddings) accuracy = self.tokens_accuracy(outputs, embeddings) self.log('val_loss', loss) self.log('val_accuracy', accuracy) return loss def configure_optimizers(self): return torch.optim.Adam(self.trainer.model.parameters(), lr=2e-5, weight_decay=3e-3) # create the model sizes = [60784, 1000, 1000, 1125*2*1024] residual_conections = [[0], [1], [2], [3]] dropout = [0.5, 0.5, 0.5, 0.5] model = MLP(sizes, residual_conections, dropout) # create the data module data_module = VoxelsEmbeddinsEncodecDataModule(train_voxels_path, train_embeddings_path, test_voxels_path, test_embeddings_path, batch_size=4) wandb.finish() from pytorch_lightning.strategies import DeepSpeedStrategy wandb_logger = WandbLogger(project='brain2music', entity='ckadirt') # define the trainer trainer = pl.Trainer(accelerator="gpu", devices = [0,1,2,3,4,5,6,7], max_epochs=1000, logger=wandb_logger, precision='32', strategy=DeepSpeedStrategy(stage=3, logging_batch_size_per_gpu=8), enable_checkpointing=False, log_every_n_steps=10) #trainer = pl.Trainer(accelerator="gpu", devices = [0,1,2,3], max_epochs=1000, logger=wandb_logger, precision='bf16', strategy='fsdp', enable_checkpointing=False, log_every_n_steps=10) # train the model trainer.fit(model, datamodule=data_module)