| import argparse |
| import sys |
| import unittest.mock |
| from pathlib import Path |
| import torch |
| import torch.nn as nn |
| import numpy as np |
| from omegaconf import OmegaConf |
|
|
| |
| ROOT = Path(__file__).resolve().parent.parent |
| sys.path.append(str(ROOT / "Matcha-TTS")) |
|
|
| sys.path.append(str(ROOT)) |
| import src.training |
| from src.stage1.medarc_architecture import MultiSubjectConvLinearEncoder |
| from src.stage2.CFM import CFM |
| from torch.utils.data import DataLoader, Dataset |
|
|
|
|
| class MockDataset(Dataset): |
| def __init__( |
| self, num_samples, num_subjects=4, time_steps=10, voxels=100, feat_dims=(32, 64) |
| ): |
| self.num_samples = num_samples |
| self.num_subjects = num_subjects |
| self.time_steps = time_steps |
| self.voxels = voxels |
| self.feat_dims = feat_dims |
|
|
| def __len__(self): |
| return self.num_samples |
|
|
| def __getitem__(self, idx): |
| |
| features = [torch.randn(self.time_steps, dim) for dim in self.feat_dims] |
| |
| fmri = torch.randn(self.num_subjects, self.time_steps, self.voxels) |
|
|
| return {"features": features, "fmri": fmri} |
|
|
|
|
| def mock_make_data_loaders(cfg): |
| print("MOCKING DATA LOADERS FOR DEBUG") |
| |
| num_samples = 4 |
| batch_size = cfg.batch_size |
|
|
| |
| voxels = 1000 |
| feat_dims = (32, 64) |
|
|
| ds = MockDataset(num_samples=num_samples, voxels=voxels, feat_dims=feat_dims) |
| loader = DataLoader(ds, batch_size=batch_size) |
|
|
| return {"train": loader, "val_debug": loader} |
|
|
|
|
| def main(): |
| |
| with unittest.mock.patch( |
| "src.training.make_data_loaders", side_effect=mock_make_data_loaders |
| ): |
| |
| |
| |
|
|
| |
| sys.argv = ["training.py", "--cfg-path", "test/debug_config.yml"] |
|
|
| |
| try: |
| src.training.main() |
| except Exception as e: |
| print(f"Caught exception during debug run: {e}") |
| import traceback |
|
|
| traceback.print_exc() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|