b2m / src /mlpdummy.py
ckadirt's picture
Upload folder using huggingface_hub
65ffd92 verified
import os, torch, torch.nn as nn, torch.utils.data as data, torchvision as tv
import lightning as L
import numpy as np, pandas as pd, matplotlib.pyplot as plt
from pytorch_lightning.loggers import WandbLogger
import wandb
import pytorch_lightning as pl
torch.set_float32_matmul_precision('medium')
# create the datasets and dataloaders
train_voxels_path = '/fsx/proj-fmri/ckadirt/b2m/data/sub-001_Resp_Training.npy' # path to training voxels 65000 * 4800
test_voxels_path = '/fsx/proj-fmri/ckadirt/b2m/data/sub-001_Resp_Test_Mean.npy' # path to test voxels 65000 * 600
train_embeddings_path = '/fsx/proj-fmri/ckadirt/b2m/data/encodec_training_embeds_sorted.npy' # path to training embeddings 480 * 2 * 1125
test_embeddings_path = '/fsx/proj-fmri/ckadirt/b2m/data/encodec_testing_embeds_sorted.npy' # path to test embeddings 600 * 2 * 1125
class VoxelsDataset(data.Dataset):
def __init__(self, voxels_path, embeddings_path):
# transpose the two dimensions of the voxels data to match the embeddings data
self.voxels = torch.from_numpy(np.load(voxels_path)).float().transpose(0, 1)
self.embeddings = torch.from_numpy(np.load(embeddings_path))
# as each stimulus has been exposed for 15 seconds and the fMRI data is sampled every 1.5 seconds, we take 10 samples per stimulus
self.len = len(self.voxels) // 10
print("The len is ", self.len )
def __getitem__(self, index):
# as each stimulus has been exposed for 15 seconds and the fMRI data is sampled every 1.5 seconds, we take 10 samples per stimulus
voxels = self.voxels[index*10:(index+1)*10]
embeddings = self.embeddings[index]
return voxels, embeddings
def __len__(self):
return self.len
class VoxelsEmbeddinsEncodecDataModule(pl.LightningDataModule):
def __init__(self, train_voxels_path, train_embeddings_path, test_voxels_path, test_embeddings_path, batch_size=8):
super().__init__()
self.train_voxels_path = train_voxels_path
self.train_embeddings_path = train_embeddings_path
self.test_voxels_path = test_voxels_path
self.test_embeddings_path = test_embeddings_path
self.batch_size = batch_size
def setup(self, stage=None):
self.train_dataset = VoxelsDataset(self.train_voxels_path, self.train_embeddings_path)
self.test_dataset = VoxelsDataset(self.test_voxels_path, self.test_embeddings_path)
def train_dataloader(self):
return data.DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)
def val_dataloader(self):
return data.DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False)
class MLP(pl.LightningModule):
def __init__(self, sizes, residual_conections, dropout):
# sizes is a list of the sizes of the layers ej: [4800, 1000, 1000, 1000, 1000, 1000, 1000, 600]
# residual_conections is a list with the same length as sizes, each element is a list of the indexes of the layers that will recieve the output of the layer as input, 0 means that the layer will recieve the x inputs ej. [[0], [1], [2,1], [3], [4,3], [5], [6,5], [7]]
# dropout is a list with the same length as sizes, each element is the dropout probability of the layer ej. [0.5, 0.5, 0.5, 0.5, 0.5, 0.5]
super().__init__()
self.sizes = sizes
self.residual_conections = residual_conections
self.dropout = dropout
self.layers = nn.Sequential()
for i in range(len(sizes)-1):
self.layers.add_module('linear'+str(i), nn.Linear(sizes[i], sizes[i+1]))
self.layers.add_module('relu'+str(i), nn.ReLU())
self.layers.add_module('dropout'+str(i), nn.Dropout(dropout[i]))
self.loss = nn.CrossEntropyLoss(reduction='mean')
def forward(self, x):
return self.layers(x)
def training_step(self, batch, batch_idx):
voxels, embeddings = batch # the sizes are [batch_size, 10, 65000] and [batch_size, 2, 1125]
# flatten the voxels to [batch_size, rest of the dimensions]
embeddings = embeddings.flatten(start_dim=1).long() # the size is [batch_size, 2250]
#take just the first 200 embeddings
# embeddings = embeddings[:, :200]
# take the mean of the second dimension of the voxels to get the mean of the 10 samples per stimulus
voxels = voxels.mean(dim=1)
voxels = voxels.flatten(start_dim=1) # the size is [batch_size, 65000]
outputs = self(voxels)
# the outputs are [batch_size, 200*1024], we need to reshape them to [batch_size, 200, 1024]
outputs = outputs.reshape(-1, 1024, 1125*2)
# avoid division by zero
outputs = outputs + 1e-6
#print(outputs.shape, embeddings.shape)
#print(outputs[0,0,:10], embeddings[0,:10])
loss = self.loss(outputs, embeddings)
#print(loss)
acuracy = self.tokens_accuracy(outputs, embeddings)
self.log('train_loss', loss)
self.log('train_accuracy', acuracy)
return loss
def tokens_accuracy(self, outputs, embeddings):
# outputs is [batch_size, 1024, 200]
# embeddings is [batch_size, 200]
# we need to get the index of the maximum value of each token
outputs = outputs.argmax(dim=1)
# now we need to compare the outputs with the embeddings
return (outputs == embeddings).float().mean()
def validation_step(self, batch, batch_idx):
voxels, embeddings = batch
embeddings = embeddings.flatten(start_dim=1).long()
#embeddings = embeddings[:, :200]
voxels = voxels.mean(dim=1)
voxels = voxels.flatten(start_dim=1)
outputs = self(voxels)
outputs = outputs.reshape(-1, 1024, 1125*2)
loss = self.loss(outputs, embeddings)
accuracy = self.tokens_accuracy(outputs, embeddings)
self.log('val_loss', loss)
self.log('val_accuracy', accuracy)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.trainer.model.parameters(), lr=2e-5, weight_decay=3e-3)
# create the model
sizes = [60784, 1000, 1000, 1125*2*1024]
residual_conections = [[0], [1], [2], [3]]
dropout = [0.5, 0.5, 0.5, 0.5]
model = MLP(sizes, residual_conections, dropout)
# create the data module
data_module = VoxelsEmbeddinsEncodecDataModule(train_voxels_path, train_embeddings_path, test_voxels_path, test_embeddings_path, batch_size=4)
wandb.finish()
from pytorch_lightning.strategies import DeepSpeedStrategy
wandb_logger = WandbLogger(project='brain2music', entity='ckadirt')
# define the trainer
trainer = pl.Trainer(accelerator="gpu", devices = [0,1,2,3,4,5,6,7], max_epochs=1000, logger=wandb_logger, precision='32', strategy=DeepSpeedStrategy(stage=3, logging_batch_size_per_gpu=8), enable_checkpointing=False, log_every_n_steps=10)
#trainer = pl.Trainer(accelerator="gpu", devices = [0,1,2,3], max_epochs=1000, logger=wandb_logger, precision='bf16', strategy='fsdp', enable_checkpointing=False, log_every_n_steps=10)
# train the model
trainer.fit(model, datamodule=data_module)