|
import os, torch, torch.nn as nn, torch.utils.data as data, torchvision as tv |
|
import lightning as L |
|
import numpy as np, pandas as pd, matplotlib.pyplot as plt |
|
from pytorch_lightning.loggers import WandbLogger |
|
import wandb |
|
import pytorch_lightning as pl |
|
|
|
torch.set_float32_matmul_precision('medium') |
|
|
|
|
|
train_voxels_path = '/fsx/proj-fmri/ckadirt/b2m/data/sub-001_Resp_Training.npy' |
|
test_voxels_path = '/fsx/proj-fmri/ckadirt/b2m/data/sub-001_Resp_Test_Mean.npy' |
|
|
|
train_embeddings_path = '/fsx/proj-fmri/ckadirt/b2m/data/encodec_training_embeds_sorted.npy' |
|
test_embeddings_path = '/fsx/proj-fmri/ckadirt/b2m/data/encodec_testing_embeds_sorted.npy' |
|
|
|
class VoxelsDataset(data.Dataset): |
|
def __init__(self, voxels_path, embeddings_path): |
|
|
|
self.voxels = torch.from_numpy(np.load(voxels_path)).float().transpose(0, 1) |
|
self.embeddings = torch.from_numpy(np.load(embeddings_path)) |
|
|
|
self.len = len(self.voxels) // 10 |
|
print("The len is ", self.len ) |
|
|
|
def __getitem__(self, index): |
|
|
|
voxels = self.voxels[index*10:(index+1)*10] |
|
embeddings = self.embeddings[index] |
|
return voxels, embeddings |
|
|
|
def __len__(self): |
|
return self.len |
|
|
|
class VoxelsEmbeddinsEncodecDataModule(pl.LightningDataModule): |
|
def __init__(self, train_voxels_path, train_embeddings_path, test_voxels_path, test_embeddings_path, batch_size=8): |
|
super().__init__() |
|
self.train_voxels_path = train_voxels_path |
|
self.train_embeddings_path = train_embeddings_path |
|
self.test_voxels_path = test_voxels_path |
|
self.test_embeddings_path = test_embeddings_path |
|
self.batch_size = batch_size |
|
|
|
def setup(self, stage=None): |
|
self.train_dataset = VoxelsDataset(self.train_voxels_path, self.train_embeddings_path) |
|
self.test_dataset = VoxelsDataset(self.test_voxels_path, self.test_embeddings_path) |
|
|
|
def train_dataloader(self): |
|
return data.DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True) |
|
|
|
def val_dataloader(self): |
|
return data.DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False) |
|
|
|
|
|
class MLP(pl.LightningModule): |
|
def __init__(self, sizes, residual_conections, dropout): |
|
|
|
|
|
|
|
super().__init__() |
|
self.sizes = sizes |
|
self.residual_conections = residual_conections |
|
self.dropout = dropout |
|
self.layers = nn.Sequential() |
|
for i in range(len(sizes)-1): |
|
self.layers.add_module('linear'+str(i), nn.Linear(sizes[i], sizes[i+1])) |
|
self.layers.add_module('relu'+str(i), nn.ReLU()) |
|
self.layers.add_module('dropout'+str(i), nn.Dropout(dropout[i])) |
|
|
|
self.loss = nn.CrossEntropyLoss(reduction='mean') |
|
|
|
def forward(self, x): |
|
return self.layers(x) |
|
|
|
def training_step(self, batch, batch_idx): |
|
voxels, embeddings = batch |
|
|
|
embeddings = embeddings.flatten(start_dim=1).long() |
|
|
|
|
|
|
|
voxels = voxels.mean(dim=1) |
|
voxels = voxels.flatten(start_dim=1) |
|
outputs = self(voxels) |
|
|
|
outputs = outputs.reshape(-1, 1024, 1125*2) |
|
|
|
outputs = outputs + 1e-6 |
|
|
|
|
|
loss = self.loss(outputs, embeddings) |
|
|
|
acuracy = self.tokens_accuracy(outputs, embeddings) |
|
self.log('train_loss', loss) |
|
self.log('train_accuracy', acuracy) |
|
return loss |
|
|
|
def tokens_accuracy(self, outputs, embeddings): |
|
|
|
|
|
|
|
outputs = outputs.argmax(dim=1) |
|
|
|
return (outputs == embeddings).float().mean() |
|
|
|
|
|
def validation_step(self, batch, batch_idx): |
|
voxels, embeddings = batch |
|
embeddings = embeddings.flatten(start_dim=1).long() |
|
|
|
voxels = voxels.mean(dim=1) |
|
voxels = voxels.flatten(start_dim=1) |
|
outputs = self(voxels) |
|
outputs = outputs.reshape(-1, 1024, 1125*2) |
|
loss = self.loss(outputs, embeddings) |
|
accuracy = self.tokens_accuracy(outputs, embeddings) |
|
self.log('val_loss', loss) |
|
self.log('val_accuracy', accuracy) |
|
return loss |
|
|
|
|
|
def configure_optimizers(self): |
|
return torch.optim.Adam(self.trainer.model.parameters(), lr=2e-5, weight_decay=3e-3) |
|
|
|
|
|
|
|
sizes = [60784, 1000, 1000, 1125*2*1024] |
|
residual_conections = [[0], [1], [2], [3]] |
|
dropout = [0.5, 0.5, 0.5, 0.5] |
|
model = MLP(sizes, residual_conections, dropout) |
|
|
|
|
|
|
|
data_module = VoxelsEmbeddinsEncodecDataModule(train_voxels_path, train_embeddings_path, test_voxels_path, test_embeddings_path, batch_size=4) |
|
|
|
wandb.finish() |
|
from pytorch_lightning.strategies import DeepSpeedStrategy |
|
|
|
wandb_logger = WandbLogger(project='brain2music', entity='ckadirt') |
|
|
|
|
|
trainer = pl.Trainer(accelerator="gpu", devices = [0,1,2,3,4,5,6,7], max_epochs=1000, logger=wandb_logger, precision='32', strategy=DeepSpeedStrategy(stage=3, logging_batch_size_per_gpu=8), enable_checkpointing=False, log_every_n_steps=10) |
|
|
|
|
|
trainer.fit(model, datamodule=data_module) |
|
|
|
|