Spaces:
Runtime error
Runtime error
from torch.utils.data import DataLoader | |
import pandas as pd | |
from torch import nn | |
from torch.utils.data import SubsetRandomSampler | |
from sklearn.model_selection import KFold | |
import pytorch_lightning as pl | |
from pytorch_lightning import callbacks as cb | |
from models.utils import LabelWeightedBCELoss | |
from preprocessing.dataset import SongDataset | |
from preprocessing.preprocess import get_examples | |
from models.residual import ResidualDancer, TrainingEnvironment | |
import yaml | |
from preprocessing.dataset import DanceDataModule | |
from wakepy import keepawake | |
def get_config(filepath:str) -> dict: | |
with open(filepath, "r") as f: | |
config = yaml.safe_load(f) | |
return config | |
def cross_validation(config, k=5): | |
df = pd.read_csv("data/songs.csv") | |
g_config = config["global"] | |
batch_size = config["data_module"]["batch_size"] | |
x,y = get_examples(df, "data/samples",class_list=g_config["dance_ids"]) | |
dataset = SongDataset(x,y) | |
splits=KFold(n_splits=k,shuffle=True,random_state=g_config["seed"]) | |
trainer = pl.Trainer(accelerator=g_config["device"]) | |
for fold, (train_idx,val_idx) in enumerate(splits.split(x,y)): | |
print(f"Fold {fold+1}") | |
model = ResidualDancer(n_classes=len(g_config["dance_ids"])) | |
train_env = TrainingEnvironment(model,nn.BCELoss()) | |
train_sampler = SubsetRandomSampler(train_idx) | |
test_sampler = SubsetRandomSampler(val_idx) | |
train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler) | |
test_loader = DataLoader(dataset, batch_size=batch_size, sampler=test_sampler) | |
trainer.fit(train_env, train_loader) | |
trainer.test(train_env, test_loader) | |
def train_model(config:dict): | |
TARGET_CLASSES = config["global"]["dance_ids"] | |
DEVICE = config["global"]["device"] | |
SEED = config["global"]["seed"] | |
pl.seed_everything(SEED, workers=True) | |
data = DanceDataModule(target_classes=TARGET_CLASSES, **config['data_module']) | |
model = ResidualDancer(n_classes=len(TARGET_CLASSES), **config['model']) | |
label_weights = data.get_label_weights().to(DEVICE) | |
criterion = LabelWeightedBCELoss(label_weights) #nn.CrossEntropyLoss(label_weights) | |
train_env = TrainingEnvironment(model, criterion, config) | |
callbacks = [ | |
# cb.LearningRateFinder(update_attr=True), | |
cb.EarlyStopping("val/loss", patience=5), | |
cb.StochasticWeightAveraging(1e-2), | |
cb.RichProgressBar() | |
] | |
trainer = pl.Trainer( | |
callbacks=callbacks, | |
**config["trainer"] | |
) | |
trainer.fit(train_env, datamodule=data) | |
trainer.test(train_env, datamodule=data) | |
if __name__ == "__main__": | |
config = get_config("models/config/train.yaml") | |
with keepawake(): | |
train_model(config) |