diff --git a/conditional-flow-matching/.github/workflows/code-quality.yaml b/conditional-flow-matching/.github/workflows/code-quality.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3d1a26e85f9205630824cd583a321b3f2817d0af --- /dev/null +++ b/conditional-flow-matching/.github/workflows/code-quality.yaml @@ -0,0 +1,26 @@ +# Same as `code-quality-pr.yaml` but triggered on commit to main branch +# and runs on all files (instead of only the changed ones) + +name: Code Quality Main + +on: + push: + branches: [main] + pull_request: + branches: [main, "release/*"] + +jobs: + code-quality: + runs-on: ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@v5 + + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: "3.13" + + - name: Run pre-commits + uses: pre-commit/action@v3.0.1 diff --git a/conditional-flow-matching/.github/workflows/python-publish.yml b/conditional-flow-matching/.github/workflows/python-publish.yml new file mode 100644 index 0000000000000000000000000000000000000000..e129ad3986e12627d9155f70c6a803ca5d1f9c9b --- /dev/null +++ b/conditional-flow-matching/.github/workflows/python-publish.yml @@ -0,0 +1,38 @@ +# This workflow will upload a Python Package using Twine when a release is created +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries + +# This workflow uses actions that are not certified by GitHub. +# They are provided by a third-party and are governed by +# separate terms of service, privacy policy, and support +# documentation. + +name: Upload Python Package + +on: + release: + types: [published] + +permissions: + contents: read + +jobs: + deploy: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v3 + with: + python-version: "3.x" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install build + - name: Build package + run: python -m build + - name: Publish package + uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 + with: + user: __token__ + password: ${{ secrets.PYPI_API_TOKEN }} diff --git a/conditional-flow-matching/.github/workflows/test.yaml b/conditional-flow-matching/.github/workflows/test.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7c08b012da037e2656d58f2076ab6358847e6816 --- /dev/null +++ b/conditional-flow-matching/.github/workflows/test.yaml @@ -0,0 +1,71 @@ +name: TorchCFM Tests + +on: + push: + branches: [main] + pull_request: + branches: [main, "release/*"] + +jobs: + run_tests_ubuntu: + runs-on: ${{ matrix.os }} + + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + + steps: + - name: Checkout + uses: actions/checkout@v5 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v6 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install pytest + pip install sh + pip install -e . + + - name: List dependencies + run: | + python -m pip list + + - name: Run pytest + run: | + pytest -v --ignore=examples --ignore=runner + + # upload code coverage report + code-coverage-torchcfm: + runs-on: ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@v5 + + - name: Set up Python 3.10 + uses: actions/setup-python@v6 + with: + python-version: "3.10" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install pytest + pip install pytest-cov[toml] + pip install sh + pip install -e . + + - name: Run tests and collect coverage + run: pytest . --cov torchcfm --ignore=runner --ignore=examples --ignore=torchcfm/models/ --cov-fail-under=30 + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v3 + with: + name: codecov-torchcfm + verbose: true diff --git a/conditional-flow-matching/.github/workflows/test_runner.yaml b/conditional-flow-matching/.github/workflows/test_runner.yaml new file mode 100644 index 0000000000000000000000000000000000000000..39961a3fc3ce96fa4d8fadfc6b2e12d50365c2ef --- /dev/null +++ b/conditional-flow-matching/.github/workflows/test_runner.yaml @@ -0,0 +1,75 @@ +name: Runner Tests + +#on: +# push: +# branches: [main] +# pull_request: +# branches: [main, "release/*"] + +jobs: + run_tests_ubuntu: + runs-on: ${{ matrix.os }} + + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, ubuntu-20.04, macos-latest, windows-latest] + python-version: ["3.9", "3.10"] + + steps: + - name: Checkout + uses: actions/checkout@v5 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v6 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + # Fix pip version < 24.1 due to lightning incomaptibility + python -m pip install pip==23.2.1 + pip install -r runner-requirements.txt + pip install pytest + pip install sh + pip install -e . + + - name: List dependencies + run: | + python -m pip list + + - name: Run pytest + run: | + pytest -v runner + + # upload code coverage report + code-coverage-runner: + runs-on: ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@v5 + + - name: Set up Python 3.10 + uses: actions/setup-python@v6 + with: + python-version: "3.10" + + - name: Install dependencies + run: | + # Fix pip version < 24.1 due to lightning incomaptibility + python -m pip install pip==23.2.1 + pip install -r runner-requirements.txt + pip install pytest + pip install pytest-cov[toml] + pip install sh + pip install -e . + + - name: Run tests and collect coverage + run: pytest runner --cov runner --cov-fail-under=30 # NEEDS TO BE UPDATED WHEN CHANGING THE NAME OF "src" FOLDER + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v3 + with: + name: codecov-runner + verbose: true diff --git a/conditional-flow-matching/runner/configs/callbacks/default.yaml b/conditional-flow-matching/runner/configs/callbacks/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4da3e5d139303155782632b3bfce697a0d32ccc0 --- /dev/null +++ b/conditional-flow-matching/runner/configs/callbacks/default.yaml @@ -0,0 +1,22 @@ +defaults: + - model_checkpoint.yaml + - early_stopping.yaml + - model_summary.yaml + - rich_progress_bar.yaml + - _self_ + +model_checkpoint: + dirpath: ${paths.output_dir}/checkpoints + filename: "epoch_{epoch:04d}" + monitor: "val/loss" + mode: "min" + save_last: True + auto_insert_metric_name: False + +early_stopping: + monitor: "val/loss" + patience: 100 + mode: "min" + +model_summary: + max_depth: -1 diff --git a/conditional-flow-matching/runner/configs/callbacks/early_stopping.yaml b/conditional-flow-matching/runner/configs/callbacks/early_stopping.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c2988a5f3653bd2ae407fa47a61b82fb29bf3ce9 --- /dev/null +++ b/conditional-flow-matching/runner/configs/callbacks/early_stopping.yaml @@ -0,0 +1,17 @@ +# https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.EarlyStopping.html + +# Monitor a metric and stop training when it stops improving. +# Look at the above link for more detailed information. +early_stopping: + _target_: pytorch_lightning.callbacks.EarlyStopping + monitor: ??? # quantity to be monitored, must be specified !!! + min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement + patience: 3 # number of checks with no improvement after which training will be stopped + verbose: False # verbosity mode + mode: "min" # "max" means higher metric value is better, can be also "min" + strict: True # whether to crash the training if monitor is not found in the validation metrics + check_finite: True # when set True, stops training when the monitor becomes NaN or infinite + stopping_threshold: null # stop training immediately once the monitored quantity reaches this threshold + divergence_threshold: null # stop training as soon as the monitored quantity becomes worse than this threshold + check_on_train_epoch_end: null # whether to run early stopping at the end of the training epoch + # log_rank_zero_only: False # this keyword argument isn't available in stable version diff --git a/conditional-flow-matching/runner/configs/callbacks/model_checkpoint.yaml b/conditional-flow-matching/runner/configs/callbacks/model_checkpoint.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b08630b60d0ddfe5ec735fc0ec76ed45ec17210c --- /dev/null +++ b/conditional-flow-matching/runner/configs/callbacks/model_checkpoint.yaml @@ -0,0 +1,19 @@ +# https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.ModelCheckpoint.html + +# Save the model periodically by monitoring a quantity. +# Look at the above link for more detailed information. +model_checkpoint: + _target_: pytorch_lightning.callbacks.ModelCheckpoint + dirpath: null # directory to save the model file + filename: null # checkpoint filename + monitor: null # name of the logged metric which determines when model is improving + verbose: False # verbosity mode + save_last: null # additionally always save an exact copy of the last checkpoint to a file last.ckpt + save_top_k: 1 # save k best models (determined by above metric) + mode: "min" # "max" means higher metric value is better, can be also "min" + auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name + save_weights_only: False # if True, then only the model’s weights will be saved + every_n_train_steps: null # number of training steps between checkpoints + train_time_interval: null # checkpoints are monitored at the specified time interval + every_n_epochs: null # number of epochs between checkpoints + save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation diff --git a/conditional-flow-matching/runner/configs/callbacks/model_summary.yaml b/conditional-flow-matching/runner/configs/callbacks/model_summary.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e46e38c08ab91f58942de1870df55407edd32b5e --- /dev/null +++ b/conditional-flow-matching/runner/configs/callbacks/model_summary.yaml @@ -0,0 +1,7 @@ +# https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.RichModelSummary.html + +# Generates a summary of all layers in a LightningModule with rich text formatting. +# Look at the above link for more detailed information. +model_summary: + _target_: pytorch_lightning.callbacks.RichModelSummary + max_depth: 1 # the maximum depth of layer nesting that the summary will include diff --git a/conditional-flow-matching/runner/configs/callbacks/no_stopping.yaml b/conditional-flow-matching/runner/configs/callbacks/no_stopping.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b1aa4c3cbd2b6a196e47471cb1b2b6351d619a1f --- /dev/null +++ b/conditional-flow-matching/runner/configs/callbacks/no_stopping.yaml @@ -0,0 +1,15 @@ +defaults: + - model_checkpoint.yaml + - model_summary.yaml + - rich_progress_bar.yaml + - _self_ + +model_checkpoint: + dirpath: ${paths.output_dir}/checkpoints + filename: "epoch_{epoch:04d}" + save_last: True + every_n_epochs: 100 # number of epochs between checkpoints + auto_insert_metric_name: False + +model_summary: + max_depth: 3 diff --git a/conditional-flow-matching/runner/configs/callbacks/none.yaml b/conditional-flow-matching/runner/configs/callbacks/none.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/conditional-flow-matching/runner/configs/callbacks/rich_progress_bar.yaml b/conditional-flow-matching/runner/configs/callbacks/rich_progress_bar.yaml new file mode 100644 index 0000000000000000000000000000000000000000..25b6750b81cd60563eb8fcb0740d911a2e3323d3 --- /dev/null +++ b/conditional-flow-matching/runner/configs/callbacks/rich_progress_bar.yaml @@ -0,0 +1,6 @@ +# https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.RichProgressBar.html + +# Create a progress bar with rich text formatting. +# Look at the above link for more detailed information. +rich_progress_bar: + _target_: pytorch_lightning.callbacks.RichProgressBar diff --git a/conditional-flow-matching/runner/configs/datamodule/cifar.yaml b/conditional-flow-matching/runner/configs/datamodule/cifar.yaml new file mode 100644 index 0000000000000000000000000000000000000000..368058fd5c60f270ab50c312045a196482506a52 --- /dev/null +++ b/conditional-flow-matching/runner/configs/datamodule/cifar.yaml @@ -0,0 +1,11 @@ +_target_: src.datamodules.cifar10_datamodule.CIFAR10DataModule +#_target_: pl_bolts.datamodules.cifar10_datamodule.CIFAR10DataModule +data_dir: ${paths.data_dir} +batch_size: 128 +val_split: 0.0 +num_workers: 0 +normalize: True +seed: 42 +shuffle: True +pin_memory: True +drop_last: False diff --git a/conditional-flow-matching/runner/configs/datamodule/custom_dist.yaml b/conditional-flow-matching/runner/configs/datamodule/custom_dist.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0e956c8fa726dd3f726795673936aaf003cb0c18 --- /dev/null +++ b/conditional-flow-matching/runner/configs/datamodule/custom_dist.yaml @@ -0,0 +1,17 @@ +_target_: src.datamodules.distribution_datamodule.TrajectoryNetDistributionTrajectoryDataModule +#_target_: src.datamodules.distribution_datamodule.DistributionDataModule + +data_dir: ${paths.data_dir} # data_dir is specified in config.yaml +train_val_test_split: 1000 +batch_size: 100 +num_workers: 0 +pin_memory: False + +system: ${paths.data_dir}/embryoid_anndata_small_v2.h5ad + +system_kwargs: + max_dim: 1e10 + embedding_name: "phate" + #embedding_name: "highly_variable" + whiten: True + #whiten: False diff --git a/conditional-flow-matching/runner/configs/datamodule/eb_full.yaml b/conditional-flow-matching/runner/configs/datamodule/eb_full.yaml new file mode 100644 index 0000000000000000000000000000000000000000..17f3455f286339aa9a4e4b2b19e5ad9aac1fb6c9 --- /dev/null +++ b/conditional-flow-matching/runner/configs/datamodule/eb_full.yaml @@ -0,0 +1,14 @@ +# https://github.com/DiffEqML/torchdyn/blob/master/torchdyn/datasets/static_datasets.py +_target_: src.datamodules.distribution_datamodule.TorchDynDataModule +#_target_: src.datamodules.distribution_datamodule.DistributionDataModule + +train_val_test_split: [0.8, 0.1, 0.1] +batch_size: 128 +num_workers: 0 +pin_memory: False + +system: ${paths.data_dir}/eb_velocity_v5.npz + +system_kwargs: + max_dim: 100 + whiten: False diff --git a/conditional-flow-matching/runner/configs/datamodule/funnel.yaml b/conditional-flow-matching/runner/configs/datamodule/funnel.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b8e44d17a9253e15b8ac42952cc479ae36d93c20 --- /dev/null +++ b/conditional-flow-matching/runner/configs/datamodule/funnel.yaml @@ -0,0 +1,12 @@ +_target_: src.datamodules.distribution_datamodule.TorchDynDataModule +#_target_: src.datamodules.distribution_datamodule.DistributionDataModule + +train_val_test_split: [10000, 1000, 1000] +batch_size: 128 +num_workers: 0 +pin_memory: False + +system: "funnel" + +system_kwargs: + dim: 10 diff --git a/conditional-flow-matching/runner/configs/datamodule/gaussians.yaml b/conditional-flow-matching/runner/configs/datamodule/gaussians.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f22df6bf03ff3334efd99ddf8083bd0812d40a3b --- /dev/null +++ b/conditional-flow-matching/runner/configs/datamodule/gaussians.yaml @@ -0,0 +1,13 @@ +# https://github.com/DiffEqML/torchdyn/blob/master/torchdyn/datasets/static_datasets.py +_target_: src.datamodules.distribution_datamodule.TorchDynDataModule +#_target_: src.datamodules.distribution_datamodule.DistributionDataModule + +train_val_test_split: [10000, 1000, 1000] +batch_size: 128 +num_workers: 0 +pin_memory: False + +system: "gaussians" + +system_kwargs: + noise: 1e-4 diff --git a/conditional-flow-matching/runner/configs/datamodule/moons.yaml b/conditional-flow-matching/runner/configs/datamodule/moons.yaml new file mode 100644 index 0000000000000000000000000000000000000000..801ba172e049b1375f6085704190a409f86cd2f4 --- /dev/null +++ b/conditional-flow-matching/runner/configs/datamodule/moons.yaml @@ -0,0 +1,10 @@ +# https://github.com/DiffEqML/torchdyn/blob/master/torchdyn/datasets/static_datasets.py +_target_: src.datamodules.distribution_datamodule.SKLearnDataModule +#_target_: src.datamodules.distribution_datamodule.DistributionDataModule + +train_val_test_split: [10000, 1000, 1000] +batch_size: 128 +num_workers: 0 +pin_memory: False + +system: "moons" diff --git a/conditional-flow-matching/runner/configs/datamodule/scurve.yaml b/conditional-flow-matching/runner/configs/datamodule/scurve.yaml new file mode 100644 index 0000000000000000000000000000000000000000..35951770f6eae397aef2160157f08b08db68db07 --- /dev/null +++ b/conditional-flow-matching/runner/configs/datamodule/scurve.yaml @@ -0,0 +1,10 @@ +# https://github.com/DiffEqML/torchdyn/blob/master/torchdyn/datasets/static_datasets.py +_target_: src.datamodules.distribution_datamodule.SKLearnDataModule +#_target_: src.datamodules.distribution_datamodule.DistributionDataModule + +train_val_test_split: [10000, 1000, 1000] +batch_size: 128 +num_workers: 0 +pin_memory: False + +system: "scurve" diff --git a/conditional-flow-matching/runner/configs/datamodule/sklearn.yaml b/conditional-flow-matching/runner/configs/datamodule/sklearn.yaml new file mode 100644 index 0000000000000000000000000000000000000000..35951770f6eae397aef2160157f08b08db68db07 --- /dev/null +++ b/conditional-flow-matching/runner/configs/datamodule/sklearn.yaml @@ -0,0 +1,10 @@ +# https://github.com/DiffEqML/torchdyn/blob/master/torchdyn/datasets/static_datasets.py +_target_: src.datamodules.distribution_datamodule.SKLearnDataModule +#_target_: src.datamodules.distribution_datamodule.DistributionDataModule + +train_val_test_split: [10000, 1000, 1000] +batch_size: 128 +num_workers: 0 +pin_memory: False + +system: "scurve" diff --git a/conditional-flow-matching/runner/configs/datamodule/time_dist.yaml b/conditional-flow-matching/runner/configs/datamodule/time_dist.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5d4273fda4c5ce03040b9dbb7f9f0b943903c261 --- /dev/null +++ b/conditional-flow-matching/runner/configs/datamodule/time_dist.yaml @@ -0,0 +1,12 @@ +_target_: src.datamodules.distribution_datamodule.CustomTrajectoryDataModule +#_target_: src.datamodules.distribution_datamodule.DistributionDataModule + +data_dir: ${paths.data_dir} # data_dir is specified in config.yaml +train_val_test_split: [0.8, 0.1, 0.1] +batch_size: 128 +num_workers: 0 +pin_memory: False +max_dim: 5 +whiten: True + +system: ${paths.data_dir}/eb_velocity_v5.npz diff --git a/conditional-flow-matching/runner/configs/datamodule/torchdyn.yaml b/conditional-flow-matching/runner/configs/datamodule/torchdyn.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f6c0f82c07dfced65662cda4603af92aca83fffc --- /dev/null +++ b/conditional-flow-matching/runner/configs/datamodule/torchdyn.yaml @@ -0,0 +1,13 @@ +# https://github.com/DiffEqML/torchdyn/blob/master/torchdyn/datasets/static_datasets.py +_target_: src.datamodules.distribution_datamodule.TorchDynDataModule +#_target_: src.datamodules.distribution_datamodule.DistributionDataModule + +train_val_test_split: [10000, 1000, 1000] +batch_size: 128 +num_workers: 0 +pin_memory: False + +system: "moons" + +system_kwargs: + noise: 1e-4 diff --git a/conditional-flow-matching/runner/configs/datamodule/tree.yaml b/conditional-flow-matching/runner/configs/datamodule/tree.yaml new file mode 100644 index 0000000000000000000000000000000000000000..df9f6627e0b022d61edada085538a29b74d91257 --- /dev/null +++ b/conditional-flow-matching/runner/configs/datamodule/tree.yaml @@ -0,0 +1,8 @@ +_target_: src.datamodules.distribution_datamodule.DistributionDataModule + +data_dir: ${data_dir} # data_dir is specified in config.yaml +train_val_test_split: 1000 +batch_size: 100 +num_workers: 0 +pin_memory: False +p: 2 diff --git a/conditional-flow-matching/runner/configs/datamodule/twodim.yaml b/conditional-flow-matching/runner/configs/datamodule/twodim.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ecba0d4da732e070c021113ffb2a8004b2355a3a --- /dev/null +++ b/conditional-flow-matching/runner/configs/datamodule/twodim.yaml @@ -0,0 +1,8 @@ +_target_: src.datamodules.distribution_datamodule.TwoDimDataModule + +train_val_test_split: [10000, 1000, 1000] +batch_size: 128 +num_workers: 0 +pin_memory: False + +system: "moon-8gaussians" diff --git a/conditional-flow-matching/runner/configs/debug/default.yaml b/conditional-flow-matching/runner/configs/debug/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..52c2f869f4afc7339ae068f7c9046ed8b3238672 --- /dev/null +++ b/conditional-flow-matching/runner/configs/debug/default.yaml @@ -0,0 +1,35 @@ +# @package _global_ + +# default debugging setup, runs 1 full epoch +# other debugging configs can inherit from this one + +# overwrite task name so debugging logs are stored in separate folder +task_name: "debug" + +# disable callbacks and loggers during debugging +callbacks: null +logger: null + +extras: + ignore_warnings: False + enforce_tags: False + +# sets level of all command line loggers to 'DEBUG' +# https://hydra.cc/docs/tutorials/basic/running_your_app/logging/ +hydra: + job_logging: + root: + level: DEBUG + + # use this to also set hydra loggers to 'DEBUG' + # verbose: True + +trainer: + max_epochs: 1 + accelerator: cpu # debuggers don't like gpus + devices: 1 # debuggers don't like multiprocessing + detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor + +datamodule: + num_workers: 0 # debuggers don't like multiprocessing + pin_memory: False # disable gpu memory pin diff --git a/conditional-flow-matching/runner/configs/debug/fdr.yaml b/conditional-flow-matching/runner/configs/debug/fdr.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ce22e6ccf0190c07e5e21e3e51b70ed89b5272c4 --- /dev/null +++ b/conditional-flow-matching/runner/configs/debug/fdr.yaml @@ -0,0 +1,9 @@ +# @package _global_ + +# runs 1 train, 1 validation and 1 test step + +defaults: + - default.yaml + +trainer: + fast_dev_run: true diff --git a/conditional-flow-matching/runner/configs/debug/limit.yaml b/conditional-flow-matching/runner/configs/debug/limit.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9d426119483df55168837a6ebc3358ff8784ab06 --- /dev/null +++ b/conditional-flow-matching/runner/configs/debug/limit.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +# uses only 1% of the training data and 5% of validation/test data + +defaults: + - default.yaml + +trainer: + max_epochs: 3 + limit_train_batches: 0.01 + limit_val_batches: 0.05 + limit_test_batches: 0.05 diff --git a/conditional-flow-matching/runner/configs/debug/overfit.yaml b/conditional-flow-matching/runner/configs/debug/overfit.yaml new file mode 100644 index 0000000000000000000000000000000000000000..22cab5467f76feccd7b3e5721f3cac369b385ba0 --- /dev/null +++ b/conditional-flow-matching/runner/configs/debug/overfit.yaml @@ -0,0 +1,13 @@ +# @package _global_ + +# overfits to 3 batches + +defaults: + - default.yaml + +trainer: + max_epochs: 20 + overfit_batches: 3 + +# model ckpt and early stopping need to be disabled during overfitting +callbacks: null diff --git a/conditional-flow-matching/runner/configs/debug/profiler.yaml b/conditional-flow-matching/runner/configs/debug/profiler.yaml new file mode 100644 index 0000000000000000000000000000000000000000..72e70cccf778f01532b3f05f36188bc5575bbecf --- /dev/null +++ b/conditional-flow-matching/runner/configs/debug/profiler.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +# runs with execution time profiling + +defaults: + - default.yaml + +trainer: + max_epochs: 1 + profiler: "simple" + # profiler: "advanced" + # profiler: "pytorch" diff --git a/conditional-flow-matching/runner/configs/eval.yaml b/conditional-flow-matching/runner/configs/eval.yaml new file mode 100644 index 0000000000000000000000000000000000000000..13655e452917cd8b476ae216069499251ef9f59a --- /dev/null +++ b/conditional-flow-matching/runner/configs/eval.yaml @@ -0,0 +1,18 @@ +# @package _global_ + +defaults: + - _self_ + - datamodule: sklearn # choose datamodule with `test_dataloader()` for evaluation + - model: cfm + - logger: null + - trainer: default.yaml + - paths: default.yaml + - extras: default.yaml + - hydra: default.yaml + +task_name: "eval" + +tags: ["dev"] + +# passing checkpoint path is necessary for evaluation +ckpt_path: ??? diff --git a/conditional-flow-matching/runner/configs/experiment/cfm.yaml b/conditional-flow-matching/runner/configs/experiment/cfm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8e6e6194b0aaf0a9fd5e966849f495d42cd7618d --- /dev/null +++ b/conditional-flow-matching/runner/configs/experiment/cfm.yaml @@ -0,0 +1,22 @@ +# @package _global_ + +defaults: + - override /model: cfm.yaml + - override /logger: + - csv.yaml + - wandb.yaml + - override /datamodule: sklearn.yaml + +name: "cfm" +seed: 42 + +datamodule: + batch_size: 512 + +model: + optimizer: + weight_decay: 1e-5 + +trainer: + max_epochs: 1000 + check_val_every_n_epoch: 10 diff --git a/conditional-flow-matching/runner/configs/experiment/cnf.yaml b/conditional-flow-matching/runner/configs/experiment/cnf.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c0731cd05853eace52971f5a1af03a47d915775c --- /dev/null +++ b/conditional-flow-matching/runner/configs/experiment/cnf.yaml @@ -0,0 +1,22 @@ +# @package _global_ + +defaults: + - override /model: cnf.yaml + - override /logger: + - csv.yaml + - wandb.yaml + - override /datamodule: sklearn.yaml + +name: "cnf" +seed: 42 + +datamodule: + batch_size: 1024 + +model: + optimizer: + weight_decay: 1e-5 + +trainer: + max_epochs: 1000 + check_val_every_n_epoch: 10 diff --git a/conditional-flow-matching/runner/configs/experiment/icnn.yaml b/conditional-flow-matching/runner/configs/experiment/icnn.yaml new file mode 100644 index 0000000000000000000000000000000000000000..84c07367c767b70fa700cbd4486d89069f9a7739 --- /dev/null +++ b/conditional-flow-matching/runner/configs/experiment/icnn.yaml @@ -0,0 +1,18 @@ +# @package _global_ + +defaults: + - override /model: icnn + - override /logger: + - csv + - wandb + - override /datamodule: sklearn + +name: "icnn" +seed: 42 + +datamodule: + batch_size: 256 + +trainer: + max_epochs: 10000 + check_val_every_n_epoch: 100 diff --git a/conditional-flow-matching/runner/configs/experiment/image_cfm.yaml b/conditional-flow-matching/runner/configs/experiment/image_cfm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5459c9bd4e11fd61acf64bbc6a730e7ca8486b39 --- /dev/null +++ b/conditional-flow-matching/runner/configs/experiment/image_cfm.yaml @@ -0,0 +1,33 @@ +# @package _global_ + +defaults: + - override /model: image_cfm.yaml + - override /callbacks: no_stopping + - override /logger: + - csv.yaml + - wandb.yaml + - override /datamodule: cifar.yaml + - override /trainer: ddp.yaml + +name: "cfm" +seed: 42 + +datamodule: + batch_size: 128 + +model: + _target_: src.models.cfm_module.CFMLitModule + sigma_min: 1e-4 + + scheduler: + _target_: timm.scheduler.PolyLRScheduler + _partial_: True + warmup_t: 200 + warmup_lr_init: 1e-8 + t_initial: 2000 + +trainer: + devices: 2 + max_epochs: 2000 + check_val_every_n_epoch: 10 + limit_val_batches: 0.01 diff --git a/conditional-flow-matching/runner/configs/experiment/image_fm.yaml b/conditional-flow-matching/runner/configs/experiment/image_fm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a55c8585c06f5d0ab8e0238a9e3bf33748572728 --- /dev/null +++ b/conditional-flow-matching/runner/configs/experiment/image_fm.yaml @@ -0,0 +1,33 @@ +# @package _global_ + +defaults: + - override /model: image_cfm.yaml + - override /callbacks: no_stopping + - override /logger: + - csv.yaml + - wandb.yaml + - override /datamodule: cifar.yaml + - override /trainer: ddp.yaml + +name: "cfm" +seed: 42 + +datamodule: + batch_size: 128 + +model: + _target_: src.models.cfm_module.FMLitModule + sigma_min: 1e-4 + + scheduler: + _target_: timm.scheduler.PolyLRScheduler + _partial_: True + warmup_t: 200 + warmup_lr_init: 1e-8 + t_initial: 2000 + +trainer: + devices: 2 + max_epochs: 2000 + check_val_every_n_epoch: 10 + limit_val_batches: 0.01 diff --git a/conditional-flow-matching/runner/configs/experiment/image_otcfm.yaml b/conditional-flow-matching/runner/configs/experiment/image_otcfm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1e75b40545eb4f2055c735f9ca3743a36086ca67 --- /dev/null +++ b/conditional-flow-matching/runner/configs/experiment/image_otcfm.yaml @@ -0,0 +1,34 @@ +# @package _global_ + +defaults: + - override /model: image_cfm.yaml + - override /callbacks: no_stopping + - override /logger: + - csv.yaml + - wandb.yaml + - override /datamodule: cifar.yaml + - override /trainer: ddp.yaml + +name: "cfm" +seed: 42 + +datamodule: + batch_size: 128 + +model: + _target_: src.models.cfm_module.CFMLitModule + sigma_min: 1e-4 + + scheduler: + _target_: timm.scheduler.PolyLRScheduler + _partial_: True + warmup_t: 200 + warmup_lr_init: 1e-8 + t_initial: 2000 + ot_sampler: "exact" + +trainer: + devices: 2 + max_epochs: 2000 + check_val_every_n_epoch: 10 + limit_val_batches: 0.01 diff --git a/conditional-flow-matching/runner/configs/experiment/trajectorynet.yaml b/conditional-flow-matching/runner/configs/experiment/trajectorynet.yaml new file mode 100644 index 0000000000000000000000000000000000000000..143b24848141dddeac5ef490ecc207af9f3c507f --- /dev/null +++ b/conditional-flow-matching/runner/configs/experiment/trajectorynet.yaml @@ -0,0 +1,22 @@ +# @package _global_ + +defaults: + - override /model: trajectorynet.yaml + - override /logger: + - csv.yaml + - wandb.yaml + - override /datamodule: twodim.yaml + +name: "cnf" +seed: 42 + +datamodule: + batch_size: 1024 + +model: + optimizer: + weight_decay: 1e-5 + +trainer: + max_epochs: 1000 + check_val_every_n_epoch: 10 diff --git a/conditional-flow-matching/runner/configs/extras/default.yaml b/conditional-flow-matching/runner/configs/extras/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cbb617cbbcd68fba87e84126e9f59fef8acc6e41 --- /dev/null +++ b/conditional-flow-matching/runner/configs/extras/default.yaml @@ -0,0 +1,8 @@ +# disable python warnings if they annoy you +ignore_warnings: False + +# ask user for tags if none are provided in the config +enforce_tags: True + +# pretty print config tree at the start of the run using Rich library +print_config: True diff --git a/conditional-flow-matching/runner/configs/hparams_search/optuna.yaml b/conditional-flow-matching/runner/configs/hparams_search/optuna.yaml new file mode 100644 index 0000000000000000000000000000000000000000..401e2771c43894e5fc983a6e8cfa2edcc5a8a2e9 --- /dev/null +++ b/conditional-flow-matching/runner/configs/hparams_search/optuna.yaml @@ -0,0 +1,49 @@ +# @package _global_ + +# example hyperparameter optimization of some experiment with Optuna: +# python train.py -m hparams_search=mnist_optuna experiment=example + +defaults: + - override /hydra/sweeper: optuna + +# choose metric which will be optimized by Optuna +# make sure this is the correct name of some metric logged in lightning module! +optimized_metric: "val/2-Wasserstein" + +# here we define Optuna hyperparameter search +# it optimizes for value returned from function with @hydra.main decorator +# docs: https://hydra.cc/docs/next/plugins/optuna_sweeper +hydra: + mode: "MULTIRUN" # set hydra to multirun by default if this config is attached + + sweeper: + _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper + + # storage URL to persist optimization results + # for example, you can use SQLite if you set 'sqlite:///example.db' + storage: null + + # name of the study to persist optimization results + study_name: null + + # number of parallel workers + n_jobs: 1 + + # 'minimize' or 'maximize' the objective + direction: maximize + + # total number of runs that will be executed + n_trials: 20 + + # choose Optuna hyperparameter sampler + # you can choose bayesian sampler (tpe), random search (without optimization), grid sampler, and others + # docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html + sampler: + _target_: optuna.samplers.TPESampler + seed: 1234 + n_startup_trials: 10 # number of random sampling runs before optimization starts + + # define hyperparameter search space + params: + model.optimizer.lr: interval(0.0001, 0.1) + datamodule.batch_size: choice(32, 64, 128, 256) diff --git a/conditional-flow-matching/runner/configs/hydra/default.yaml b/conditional-flow-matching/runner/configs/hydra/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a564fe48714494e59f92d369327812e3ebcfa8a8 --- /dev/null +++ b/conditional-flow-matching/runner/configs/hydra/default.yaml @@ -0,0 +1,15 @@ +# https://hydra.cc/docs/configure_hydra/intro/ + +# enable color logging +defaults: + - override hydra_logging: colorlog + - override job_logging: colorlog + +# output directory, generated dynamically on each run +run: + dir: ${paths.log_dir}/${task_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S} +sweep: + dir: ${paths.log_dir}/${task_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S} + subdir: ${hydra.job.num} +job: + chdir: true diff --git a/conditional-flow-matching/runner/configs/launcher/mila_cluster.yaml b/conditional-flow-matching/runner/configs/launcher/mila_cluster.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f4a422083817be9bf212c20bdfbedb1f83506f75 --- /dev/null +++ b/conditional-flow-matching/runner/configs/launcher/mila_cluster.yaml @@ -0,0 +1,18 @@ +# @package _global_ +# +defaults: + - override /hydra/launcher: submitit_slurm + +hydra: + launcher: + partition: long + cpus_per_task: 2 + mem_gb: 20 + gres: gpu:1 + timeout_min: 1440 + array_parallelism: 10 # max num of tasks to run in parallel (via job array) + setup: + - "module purge" + - "module load miniconda/3" + - "conda activate myenv" + - "unset CUDA_VISIBLE_DEVICES" diff --git a/conditional-flow-matching/runner/configs/launcher/mila_cpu_cluster.yaml b/conditional-flow-matching/runner/configs/launcher/mila_cpu_cluster.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4954a085a059325d30c06dc7e11a8715dfc9d0f3 --- /dev/null +++ b/conditional-flow-matching/runner/configs/launcher/mila_cpu_cluster.yaml @@ -0,0 +1,16 @@ +# @package _global_ +# +defaults: + - override /hydra/launcher: submitit_slurm + +hydra: + launcher: + partition: long-cpu + cpus_per_task: 1 + mem_gb: 5 + timeout_min: 100 + array_parallelism: 64 + setup: + - "module purge" + - "module load miniconda/3" + - "conda activate myenv" diff --git a/conditional-flow-matching/runner/configs/local/.gitkeep b/conditional-flow-matching/runner/configs/local/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/conditional-flow-matching/runner/configs/local/default.yaml b/conditional-flow-matching/runner/configs/local/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..195a2e8c2d346e5f8fac3514f538fb0a6594238b --- /dev/null +++ b/conditional-flow-matching/runner/configs/local/default.yaml @@ -0,0 +1,12 @@ +# path to root directory +# this requires PROJECT_ROOT environment variable to exist +# PROJECT_ROOT is inferred and set by pyrootutils package in `train.py` and `eval.py` +root_dir: ${oc.env:PROJECT_ROOT} + +scratch_dir: ${oc.env:PROJECT_ROOT} + +# path to data directory +data_dir: ${local.scratch_dir}/data/ + +# path to logging directory +log_dir: ${local.scratch_dir}/logs/ diff --git a/conditional-flow-matching/runner/configs/logger/comet.yaml b/conditional-flow-matching/runner/configs/logger/comet.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ea5438adfd6aea2f431b47dee31adee8c9460f8e --- /dev/null +++ b/conditional-flow-matching/runner/configs/logger/comet.yaml @@ -0,0 +1,12 @@ +# https://www.comet.ml + +comet: + _target_: pytorch_lightning.loggers.comet.CometLogger + api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable + save_dir: "${paths.output_dir}" + project_name: "lightning-hydra-template" + rest_api_key: null + # experiment_name: "" + experiment_key: null # set to resume experiment + offline: False + prefix: "" diff --git a/conditional-flow-matching/runner/configs/logger/csv.yaml b/conditional-flow-matching/runner/configs/logger/csv.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8e48628f1415d6a02e6507cdf23639fe18218bec --- /dev/null +++ b/conditional-flow-matching/runner/configs/logger/csv.yaml @@ -0,0 +1,7 @@ +# csv logger built in lightning + +csv: + _target_: pytorch_lightning.loggers.csv_logs.CSVLogger + save_dir: "${paths.output_dir}" + name: "csv/" + prefix: "" diff --git a/conditional-flow-matching/runner/configs/logger/many_loggers.yaml b/conditional-flow-matching/runner/configs/logger/many_loggers.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5d1843f70dfdc40e43ff19127022dc08f050b812 --- /dev/null +++ b/conditional-flow-matching/runner/configs/logger/many_loggers.yaml @@ -0,0 +1,9 @@ +# train with many loggers at once + +defaults: + # - comet.yaml + - csv.yaml + # - mlflow.yaml + # - neptune.yaml + - tensorboard.yaml + - wandb.yaml diff --git a/conditional-flow-matching/runner/configs/logger/mlflow.yaml b/conditional-flow-matching/runner/configs/logger/mlflow.yaml new file mode 100644 index 0000000000000000000000000000000000000000..58ed49956c42032c2758910e15b3b07f1434fd5d --- /dev/null +++ b/conditional-flow-matching/runner/configs/logger/mlflow.yaml @@ -0,0 +1,12 @@ +# https://mlflow.org + +mlflow: + _target_: pytorch_lightning.loggers.mlflow.MLFlowLogger + # experiment_name: "" + # run_name: "" + tracking_uri: ${paths.log_dir}/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI + tags: null + # save_dir: "./mlruns" + prefix: "" + artifact_location: null + # run_id: "" diff --git a/conditional-flow-matching/runner/configs/logger/neptune.yaml b/conditional-flow-matching/runner/configs/logger/neptune.yaml new file mode 100644 index 0000000000000000000000000000000000000000..32b98896d48e3575f99af903aeafbdb68bc0a7db --- /dev/null +++ b/conditional-flow-matching/runner/configs/logger/neptune.yaml @@ -0,0 +1,9 @@ +# https://neptune.ai + +neptune: + _target_: pytorch_lightning.loggers.neptune.NeptuneLogger + api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable + project: username/lightning-hydra-template + # name: "" + log_model_checkpoints: True + prefix: "" diff --git a/conditional-flow-matching/runner/configs/logger/tensorboard.yaml b/conditional-flow-matching/runner/configs/logger/tensorboard.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1888882481a40ddb61e19f4c50211304701e41ca --- /dev/null +++ b/conditional-flow-matching/runner/configs/logger/tensorboard.yaml @@ -0,0 +1,10 @@ +# https://www.tensorflow.org/tensorboard/ + +tensorboard: + _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger + save_dir: "${paths.output_dir}/tensorboard/" + name: null + log_graph: False + default_hp_metric: True + prefix: "" + # version: "" diff --git a/conditional-flow-matching/runner/configs/logger/wandb.yaml b/conditional-flow-matching/runner/configs/logger/wandb.yaml new file mode 100644 index 0000000000000000000000000000000000000000..45bc2719bce022a5910cc7ea5d29aaeef2c7a7b4 --- /dev/null +++ b/conditional-flow-matching/runner/configs/logger/wandb.yaml @@ -0,0 +1,16 @@ +# https://wandb.ai + +wandb: + _target_: pytorch_lightning.loggers.wandb.WandbLogger + # name: "" # name of the run (normally generated by wandb) + save_dir: "${paths.output_dir}" + offline: False + id: null # pass correct id to resume experiment! + anonymous: null # enable anonymous logging + project: "conditional-flow-model" + log_model: False # upload lightning ckpts + prefix: "" # a string to put at the beginning of metric keys + # entity: "" # set to name of your wandb team + group: "" + tags: [] + job_type: "" diff --git a/conditional-flow-matching/runner/configs/model/cfm.yaml b/conditional-flow-matching/runner/configs/model/cfm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..33ba521f223ef914fe4eaf3dd67a4eff882c4408 --- /dev/null +++ b/conditional-flow-matching/runner/configs/model/cfm.yaml @@ -0,0 +1,41 @@ +_target_: src.models.cfm_module.CFMLitModule +_partial_: true + +optimizer: + _target_: torch.optim.AdamW + _partial_: true + lr: 0.001 + weight_decay: 1e-5 + +net: + _target_: src.models.components.simple_mlp.VelocityNet + _partial_: true + hidden_dims: [64, 64, 64] + batch_norm: False + activation: "selu" + +augmentations: + _target_: src.models.components.augmentation.AugmentationModule + cnf_estimator: null + l1_reg: 0. + l2_reg: 0. + squared_l2_reg: 0. + jacobian_frobenius_reg: 0. + jacobian_diag_frobenius_reg: 0. + jacobian_off_diag_frobenius_reg: 0. + +partial_solver: + _target_: src.models.components.solver.FlowSolver + _partial_: true + ode_solver: "euler" + atol: 1e-5 + rtol: 1e-5 + +ot_sampler: null + +sigma_min: 0.1 + +# Set to integer if want to train with left out timepoint +leaveout_timepoint: -1 + +plot: False diff --git a/conditional-flow-matching/runner/configs/model/cfm_v2.yaml b/conditional-flow-matching/runner/configs/model/cfm_v2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d6e9ec440738acf1e2cf67cca12204e1d7971b66 --- /dev/null +++ b/conditional-flow-matching/runner/configs/model/cfm_v2.yaml @@ -0,0 +1,28 @@ +_target_: src.models.runner.CFMLitModule +_partial_: true + +optimizer: + _target_: torch.optim.AdamW + _partial_: true + lr: 0.001 + weight_decay: 1e-5 + +net: + _target_: src.models.components.simple_mlp.VelocityNet + _partial_: true + hidden_dims: [64, 64, 64] + batch_norm: False + activation: "selu" + +flow_matcher: + _target_: torchcfm.ConditionalFlowMatcher + sigma: 0.0 + +solver: + _target_: src.models.components.solver.FlowSolver + _partial_: true + ode_solver: "euler" + atol: 1e-5 + rtol: 1e-5 + +plot: True diff --git a/conditional-flow-matching/runner/configs/model/cnf.yaml b/conditional-flow-matching/runner/configs/model/cnf.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9271b8fa33a22dce2465a121da7e8207a2c1f2d5 --- /dev/null +++ b/conditional-flow-matching/runner/configs/model/cnf.yaml @@ -0,0 +1,28 @@ +_target_: src.models.cfm_module.CNFLitModule +_partial_: true + +optimizer: + _target_: torch.optim.AdamW + _partial_: true + lr: 0.001 + weight_decay: 0.01 + +net: + _target_: src.models.components.simple_mlp.VelocityNet + _partial_: true + hidden_dims: [64, 64, 64] + batch_norm: False + activation: "selu" + +augmentations: + _target_: src.models.components.augmentation.AugmentationModule + cnf_estimator: "exact" + l1_reg: 0. + l2_reg: 0. + squared_l2_reg: 0. + jacobian_frobenius_reg: 0. + jacobian_diag_frobenius_reg: 0. + jacobian_off_diag_frobenius_reg: 0. + +# Set to integer if want to train with left out timepoint +leaveout_timepoint: -1 diff --git a/conditional-flow-matching/runner/configs/model/fm.yaml b/conditional-flow-matching/runner/configs/model/fm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4fc33baf1045b3ab9dc033851dad26b53ad33209 --- /dev/null +++ b/conditional-flow-matching/runner/configs/model/fm.yaml @@ -0,0 +1,37 @@ +_target_: src.models.cfm_module.FMLitModule +_partial_: true + +optimizer: + _target_: torch.optim.AdamW + _partial_: true + lr: 0.001 + weight_decay: 1e-5 + +net: + _target_: src.models.components.simple_mlp.VelocityNet + _partial_: true + hidden_dims: [64, 64, 64] + batch_norm: False + activation: "selu" + +augmentations: + _target_: src.models.components.augmentation.AugmentationModule + cnf_estimator: null + l1_reg: 0. + l2_reg: 0. + squared_l2_reg: 0. + jacobian_frobenius_reg: 0. + jacobian_diag_frobenius_reg: 0. + jacobian_off_diag_frobenius_reg: 0. + +partial_solver: + _target_: src.models.components.solver.FlowSolver + _partial_: true + ode_solver: "euler" + atol: 1e-5 + rtol: 1e-5 + +sigma_min: 0.1 + +# Set to integer if want to train with left out timepoint +leaveout_timepoint: -1 diff --git a/conditional-flow-matching/runner/configs/model/icnn.yaml b/conditional-flow-matching/runner/configs/model/icnn.yaml new file mode 100644 index 0000000000000000000000000000000000000000..de1cdf3d0b2d7d46365eafcf3d3cf5205ab9a859 --- /dev/null +++ b/conditional-flow-matching/runner/configs/model/icnn.yaml @@ -0,0 +1,25 @@ +_target_: src.models.icnn_module.ICNNLitModule +_partial_: true + +f_net: + _target_: src.models.components.icnn_model.ICNN + _partial_: true + dimh: 64 + num_hidden_layers: 4 + +g_net: + _target_: src.models.components.icnn_model.ICNN + _partial_: true + dimh: 64 + num_hidden_layers: 4 + +optimizer: + _target_: torch.optim.Adam + _partial_: true + lr: 0.0001 + betas: [0.5, 0.9] + +reg: 0.1 + +# Set to integer if want to train with left out timepoint +leaveout_timepoint: -1 diff --git a/conditional-flow-matching/runner/configs/model/image_cfm.yaml b/conditional-flow-matching/runner/configs/model/image_cfm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..42a7313fa3d42e9d6f10978bc09f94ba46c954b2 --- /dev/null +++ b/conditional-flow-matching/runner/configs/model/image_cfm.yaml @@ -0,0 +1,46 @@ +_target_: src.models.cfm_module.CFMLitModule +_partial_: true + +optimizer: + _target_: torch.optim.Adam + _partial_: true + lr: 0.0005 + +net: + _target_: src.models.components.unet.UNetModelWrapper + _partial_: true + num_res_blocks: 2 + num_channels: 256 + channel_mult: [1, 2, 2, 2] + num_heads: 4 + num_head_channels: 64 + attention_resolutions: "16" + dropout: 0 + +augmentations: + _target_: src.models.components.augmentation.AugmentationModule + cnf_estimator: null + l1_reg: 0. + l2_reg: 0. + squared_l2_reg: 0. + jacobian_frobenius_reg: 0. + jacobian_diag_frobenius_reg: 0. + jacobian_off_diag_frobenius_reg: 0. + +partial_solver: + _target_: src.models.components.solver.FlowSolver + _partial_: true + ode_solver: "euler" + atol: 1e-5 + rtol: 1e-5 + +test_nfe: 100 + +ot_sampler: null + +sigma_min: 0.1 + +# Set to integer if want to train with left out timepoint +leaveout_timepoint: -1 + +plot: True diff --git a/conditional-flow-matching/runner/configs/model/otcfm.yaml b/conditional-flow-matching/runner/configs/model/otcfm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..587206f74793c212dbc88b955cc9b6edd9bf521f --- /dev/null +++ b/conditional-flow-matching/runner/configs/model/otcfm.yaml @@ -0,0 +1,39 @@ +_target_: src.models.cfm_module.CFMLitModule +_partial_: true + +optimizer: + _target_: torch.optim.AdamW + _partial_: true + lr: 0.001 + weight_decay: 1e-5 + +net: + _target_: src.models.components.simple_mlp.VelocityNet + _partial_: true + hidden_dims: [64, 64, 64] + batch_norm: False + activation: "selu" + +augmentations: + _target_: src.models.components.augmentation.AugmentationModule + cnf_estimator: null + l1_reg: 0. + l2_reg: 0. + squared_l2_reg: 0. + jacobian_frobenius_reg: 0. + jacobian_diag_frobenius_reg: 0. + jacobian_off_diag_frobenius_reg: 0. + +partial_solver: + _target_: src.models.components.solver.FlowSolver + _partial_: true + ode_solver: "euler" + atol: 1e-5 + rtol: 1e-5 + +ot_sampler: "exact" + +sigma_min: 0.1 + +# Set to integer if want to train with left out timepoint +leaveout_timepoint: -1 diff --git a/conditional-flow-matching/runner/configs/model/sbcfm.yaml b/conditional-flow-matching/runner/configs/model/sbcfm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b5ac5e60ca60b7ef12ff74300104062df6861384 --- /dev/null +++ b/conditional-flow-matching/runner/configs/model/sbcfm.yaml @@ -0,0 +1,39 @@ +_target_: src.models.cfm_module.SBCFMLitModule +_partial_: true + +optimizer: + _target_: torch.optim.AdamW + _partial_: true + lr: 0.001 + weight_decay: 1e-5 + +net: + _target_: src.models.components.simple_mlp.VelocityNet + _partial_: true + hidden_dims: [64, 64, 64] + batch_norm: False + activation: "selu" + +augmentations: + _target_: src.models.components.augmentation.AugmentationModule + cnf_estimator: null + l1_reg: 0. + l2_reg: 0. + squared_l2_reg: 0. + jacobian_frobenius_reg: 0. + jacobian_diag_frobenius_reg: 0. + jacobian_off_diag_frobenius_reg: 0. + +partial_solver: + _target_: src.models.components.solver.FlowSolver + _partial_: true + ode_solver: "euler" + atol: 1e-5 + rtol: 1e-5 + +ot_sampler: "sinkhorn" + +sigma_min: 1.0 + +# Set to integer if want to train with left out timepoint +leaveout_timepoint: -1 diff --git a/conditional-flow-matching/runner/configs/model/trajectorynet.yaml b/conditional-flow-matching/runner/configs/model/trajectorynet.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8aaf68882ae5ef44a0d69e118ab80c0a30cd2bc5 --- /dev/null +++ b/conditional-flow-matching/runner/configs/model/trajectorynet.yaml @@ -0,0 +1,28 @@ +_target_: src.models.cfm_module.CNFLitModule +_partial_: true + +optimizer: + _target_: torch.optim.AdamW + _partial_: true + lr: 0.001 + weight_decay: 0.01 + +net: + _target_: src.models.components.simple_mlp.VelocityNet + _partial_: true + hidden_dims: [64, 64, 64] + batch_norm: False + activation: "tanh" + +augmentations: + _target_: src.models.components.augmentation.AugmentationModule + cnf_estimator: "exact" + l1_reg: 0. + l2_reg: 0. + squared_l2_reg: 1e-4 + jacobian_frobenius_reg: 1e-4 + jacobian_diag_frobenius_reg: 0. + jacobian_off_diag_frobenius_reg: 0. + +# Set to integer if want to train with left out timepoint +leaveout_timepoint: -1 diff --git a/conditional-flow-matching/runner/configs/paths/default.yaml b/conditional-flow-matching/runner/configs/paths/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d47c129be278a8715e2ae0f75a8380095d6b17c1 --- /dev/null +++ b/conditional-flow-matching/runner/configs/paths/default.yaml @@ -0,0 +1,18 @@ +# path to root directory +# this requires PROJECT_ROOT environment variable to exist +# PROJECT_ROOT is inferred and set by pyrootutils package in `train.py` and `eval.py` +root_dir: ${oc.env:PROJECT_ROOT} + +# path to data directory +data_dir: ${local.data_dir} + +# path to logging directory +log_dir: ${local.log_dir} + +# path to output directory, created dynamically by hydra +# path generation pattern is specified in `configs/hydra/default.yaml` +# use it to store all files generated during the run, like ckpts and metrics +output_dir: ${hydra:runtime.output_dir} + +# path to working directory +work_dir: ${hydra:runtime.cwd} diff --git a/conditional-flow-matching/runner/configs/train.yaml b/conditional-flow-matching/runner/configs/train.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a60b5ae82e71bc67d0530e5fae2f41d6bfd5da63 --- /dev/null +++ b/conditional-flow-matching/runner/configs/train.yaml @@ -0,0 +1,52 @@ +# @package _global_ + +# specify here default configuration +# order of defaults determines the order in which configs override each other +defaults: + - _self_ + - datamodule: sklearn + - model: cfm + - callbacks: default + - logger: csv # set logger here or use command line (e.g. `python train.py logger=tensorboard`) + - trainer: default + - paths: default + - extras: default + - hydra: default + - launcher: null + + # experiment configs allow for version control of specific hyperparameters + # e.g. best hyperparameters for given model and datamodule + - experiment: null + + # config for hyperparameter optimization + - hparams_search: null + + # optional local config for machine/user specific settings + # it's optional since it doesn't need to exist and is excluded from version control + - optional local: default + + # debugging config (enable through command line, e.g. `python train.py debug=default) + - debug: null + +# task name, determines output directory path +task_name: "train" + +# tags to help you identify your experiments +# you can overwrite this in experiment configs +# overwrite from command line with `python train.py tags="[first_tag, second_tag]"` +# appending lists from command line is currently not supported :( +# https://github.com/facebookresearch/hydra/issues/1547 +tags: ["dev"] + +# set False to skip model training +train: True + +# evaluate on test set, using best model weights achieved during training +# lightning chooses best weights based on the metric specified in checkpoint callback +test: True + +# simply provide checkpoint path to resume training +ckpt_path: null + +# seed for random number generators in pytorch, numpy and python.random +seed: null diff --git a/conditional-flow-matching/runner/configs/trainer/cpu.yaml b/conditional-flow-matching/runner/configs/trainer/cpu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9291239a68cb0ad29346fe37057046824dbbb160 --- /dev/null +++ b/conditional-flow-matching/runner/configs/trainer/cpu.yaml @@ -0,0 +1,5 @@ +defaults: + - default.yaml + +accelerator: cpu +devices: 1 diff --git a/conditional-flow-matching/runner/configs/trainer/ddp.yaml b/conditional-flow-matching/runner/configs/trainer/ddp.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fcf6e1053eaacff34a9481fa5e804191297b3cfa --- /dev/null +++ b/conditional-flow-matching/runner/configs/trainer/ddp.yaml @@ -0,0 +1,13 @@ +defaults: + - default.yaml + +# use "ddp_spawn" instead of "ddp", +# it's slower but normal "ddp" currently doesn't work ideally with hydra +# https://github.com/facebookresearch/hydra/issues/2070 +# https://pytorch-lightning.readthedocs.io/en/latest/accelerators/gpu_intermediate.html#distributed-data-parallel-spawn +strategy: ddp_spawn + +accelerator: gpu +devices: 4 +num_nodes: 1 +sync_batchnorm: True diff --git a/conditional-flow-matching/runner/configs/trainer/ddp_sim.yaml b/conditional-flow-matching/runner/configs/trainer/ddp_sim.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a93c111ae129087a06a10a95b4951963d9dad881 --- /dev/null +++ b/conditional-flow-matching/runner/configs/trainer/ddp_sim.yaml @@ -0,0 +1,7 @@ +defaults: + - default.yaml + +# simulate DDP on CPU, useful for debugging +accelerator: cpu +devices: 2 +strategy: ddp_spawn diff --git a/conditional-flow-matching/runner/configs/trainer/default.yaml b/conditional-flow-matching/runner/configs/trainer/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9c3d17b51f06c1dd368a26e3fa1d66c8349a5b37 --- /dev/null +++ b/conditional-flow-matching/runner/configs/trainer/default.yaml @@ -0,0 +1,19 @@ +_target_: pytorch_lightning.Trainer + +default_root_dir: ${paths.output_dir} + +min_epochs: 1 # prevents early stopping +max_epochs: 10 + +accelerator: cpu +devices: 1 + +# mixed precision for extra speed-up +# precision: 16 + +# perform a validation loop every N training epochs +check_val_every_n_epoch: 1 + +# set True to to ensure deterministic results +# makes training slower but gives more reproducibility than just setting seeds +deterministic: False diff --git a/conditional-flow-matching/runner/configs/trainer/gpu.yaml b/conditional-flow-matching/runner/configs/trainer/gpu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..657796cf8ef6903e2196ada03c2bb1262de3a5c1 --- /dev/null +++ b/conditional-flow-matching/runner/configs/trainer/gpu.yaml @@ -0,0 +1,5 @@ +defaults: + - default.yaml + +accelerator: gpu +devices: 1 diff --git a/conditional-flow-matching/runner/configs/trainer/mps.yaml b/conditional-flow-matching/runner/configs/trainer/mps.yaml new file mode 100644 index 0000000000000000000000000000000000000000..13f46d9ddd26814321f4ed221c1dd2c0f3f83acc --- /dev/null +++ b/conditional-flow-matching/runner/configs/trainer/mps.yaml @@ -0,0 +1,5 @@ +defaults: + - default.yaml + +accelerator: mps +devices: 1 diff --git a/conditional-flow-matching/runner/data/.gitkeep b/conditional-flow-matching/runner/data/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/conditional-flow-matching/runner/logs/.gitkeep b/conditional-flow-matching/runner/logs/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/conditional-flow-matching/runner/scripts/schedule.sh b/conditional-flow-matching/runner/scripts/schedule.sh new file mode 100644 index 0000000000000000000000000000000000000000..ea7e6626181ada3736a9883f32402df56d4668ae --- /dev/null +++ b/conditional-flow-matching/runner/scripts/schedule.sh @@ -0,0 +1,7 @@ +#!/bin/bash +# Schedule execution of many runs +# Run from root folder with: bash scripts/schedule.sh + +python src/train.py trainer.max_epochs=5 logger=csv + +python src/train.py trainer.max_epochs=10 logger=csv diff --git a/conditional-flow-matching/runner/scripts/two-dim-cfm.sh b/conditional-flow-matching/runner/scripts/two-dim-cfm.sh new file mode 100644 index 0000000000000000000000000000000000000000..34bdb2ce041a74c70879f74717dda9c0a0a080c7 --- /dev/null +++ b/conditional-flow-matching/runner/scripts/two-dim-cfm.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +# Compares flow matching (FM) conditional flow matching (CFM) and optimal +# transport conditional flow matching on four datasets. twodim is not possible +# for the flow matching algorithm as it has a non-gaussian source distribution. +# FM is therefore only run on three datasets. +python src/train.py -m experiment=cfm \ + model=cfm,otcfm \ + launcher=mila_cpu_cluster \ + model.sigma_min=0.1 \ + datamodule=scurve,moons,twodim,gaussians \ + seed=42,43,44,45,46 & + +# Sleep to avoid launching jobs at the same time +sleep 1 +python src/train.py -m experiment=cfm \ + model=fm \ + launcher=mila_cpu_cluster \ + model.sigma_min=0.1 \ + datamodule=scurve,moons,gaussians \ + seed=42,43,44,45,46 & diff --git a/conditional-flow-matching/runner/src/__init__.py b/conditional-flow-matching/runner/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/conditional-flow-matching/runner/src/datamodules/__init__.py b/conditional-flow-matching/runner/src/datamodules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/conditional-flow-matching/runner/src/datamodules/cifar10_datamodule.py b/conditional-flow-matching/runner/src/datamodules/cifar10_datamodule.py new file mode 100644 index 0000000000000000000000000000000000000000..dd2992c44c635e788cef75abdfe3fdd279265bb5 --- /dev/null +++ b/conditional-flow-matching/runner/src/datamodules/cifar10_datamodule.py @@ -0,0 +1,15 @@ +from typing import Any, List, Union + +import pl_bolts +from torch.utils.data import DataLoader +from torchvision import transforms as transform_lib + + +class CIFAR10DataModule(pl_bolts.datamodules.cifar10_datamodule.CIFAR10DataModule): + def __init__(self, *args, **kwargs): + test_transforms = transform_lib.ToTensor() + super().__init__(*args, test_transforms=test_transforms, **kwargs) + + def val_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]: + """The val dataloader.""" + return self._data_loader(self.dataset_train) diff --git a/conditional-flow-matching/runner/src/datamodules/components/time_dataset.py b/conditional-flow-matching/runner/src/datamodules/components/time_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a727d14faba7fda268cf189ede4485ce6d92c981 --- /dev/null +++ b/conditional-flow-matching/runner/src/datamodules/components/time_dataset.py @@ -0,0 +1,22 @@ +import numpy as np +import scanpy as sc + + +def adata_dataset(path, embed_name="X_pca", label_name="day", max_dim=100): + adata = sc.read_h5ad(path) + labels = adata.obs[label_name].astype("category") + ulabels = labels.cat.categories + return adata.obsm[embed_name][:, :max_dim], labels, ulabels + + +def tnet_dataset(path, embed_name="pcs", label_name="sample_labels", max_dim=100): + a = np.load(path, allow_pickle=True) + return a[embed_name][:, :max_dim], a[label_name], np.unique(a[label_name]) + + +def load_dataset(path, max_dim=100): + if path.endswith("h5ad"): + return adata_dataset(path, max_dim=max_dim) + if path.endswith("npz"): + return tnet_dataset(path, max_dim=max_dim) + raise NotImplementedError() diff --git a/conditional-flow-matching/runner/src/datamodules/distribution_datamodule.py b/conditional-flow-matching/runner/src/datamodules/distribution_datamodule.py new file mode 100644 index 0000000000000000000000000000000000000000..b5a5ace52c799e5cd9cf457e1d43aa38eb170533 --- /dev/null +++ b/conditional-flow-matching/runner/src/datamodules/distribution_datamodule.py @@ -0,0 +1,715 @@ +import math +from functools import partial +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +from pytorch_lightning import LightningDataModule +from pytorch_lightning.trainer.supporters import CombinedLoader +from sklearn.preprocessing import StandardScaler +from torch.utils.data import DataLoader, TensorDataset, random_split +from torchdyn.datasets import ToyDataset + +from src import utils + +from .components.base import BaseLightningDataModule +from .components.time_dataset import load_dataset +from .components.tnet_dataset import SCData +from .components.two_dim import data_distrib + +log = utils.get_pylogger(__name__) + + +class TrajectoryNetDistributionTrajectoryDataModule(LightningDataModule): + pass_to_model = True + IS_TRAJECTORY = True + + def __init__( + self, + data_dir: str = "data/", + train_val_test_split: Union[int, Tuple[int, int, int]] = 1, + system: str = "TREE", + system_kwargs: dict = {}, + batch_size: int = 64, + num_workers: int = 0, + pin_memory: bool = False, + seed=None, + ): + super().__init__() + self.save_hyperparameters(logger=True) + dataset = SCData.factory(system, system_kwargs) + self.data = dataset.get_data() + self.dim = self.data.shape[-1] + self.labels = dataset.get_times() + self.system = system + self.ulabels = dataset.get_unique_times() + + self.timepoint_data = [ + self.data[self.labels == lab].astype(np.float32) for lab in self.ulabels + ] + self.split() + log.info( + f"Loaded {self.system} with timepoints {self.ulabels} of sizes {[len(d) for d in self.timepoint_data]}." + ) + + def split(self): + """Split requires self.hparams.train_val_test_split, timepoint_data, system, ulabels.""" + train_val_test_split = self.hparams.train_val_test_split + if isinstance(train_val_test_split, int): + self.split_timepoint_data = [(x, x, x) for x in self.timepoint_data] + return + splitter = partial( + random_split, + lengths=self.hparams.train_val_test_split, + generator=torch.Generator().manual_seed(42), + ) + self.split_timepoint_data = list(map(splitter, self.timepoint_data)) + + def combined_loader(self, index, shuffle=False): + tp_dataloaders = [ + DataLoader( + dataset=datasets[index], + batch_size=self.hparams.batch_size, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + shuffle=shuffle, + drop_last=True, + ) + for datasets in self.split_timepoint_data + ] + return CombinedLoader(tp_dataloaders, mode="min_size") + + def train_dataloader(self): + return self.combined_loader(0, shuffle=True) + + def val_dataloader(self): + return self.combined_loader(1, shuffle=False) + + def test_dataloader(self): + return self.combined_loader(2, shuffle=False) + + +class CustomTrajectoryDataModule(LightningDataModule): + pass_to_model = True + IS_TRAJECTORY = True + # TODO Code copied from above, doesn't like inheritance with init. + + def __init__( + self, + data_dir: str = "data/", + train_val_test_split: Union[Tuple[int, int, int], Tuple[float, float, float]] = ( + 0.8, + 0.1, + 0.1, + ), + max_dim: Optional[int] = None, + system: str = "", + batch_size: int = 64, + whiten: bool = False, + hvg: bool = False, + num_workers: int = 0, + pin_memory: bool = False, + seed=None, + ): + super().__init__() + self.save_hyperparameters(logger=True) + self.data, self.labels, self.ulabels = load_dataset(system) + if hvg: + import scanpy as sc + + adata = sc.read_h5ad(system) + sc.pp.highly_variable_genes(adata, n_top_genes=max_dim) + self.data = adata.X[:, adata.var["highly_variable"]].toarray() + if max_dim: + self.data = self.data[:, :max_dim] + if whiten: + self.scaler = StandardScaler() + self.scaler.fit(self.data) + self.data = self.scaler.transform(self.data) + self.dim = self.data.shape[-1] + self.system = system + + self.timepoint_data = [ + self.data[self.labels == lab].astype(np.float32) for lab in self.ulabels + ] + self.split() + log.info( + f"Loaded {self.system} with timepoints {self.ulabels} of sizes {[len(d) for d in self.timepoint_data]} with dim {self.dim}." + ) + + def split(self): + """Split requires self.hparams.train_val_test_split, timepoint_data, system, ulabels.""" + train_val_test_split = self.hparams.train_val_test_split + if isinstance(train_val_test_split, int): + self.split_timepoint_data = [(x, x, x) for x in self.timepoint_data] + return + splitter = partial( + random_split, + lengths=self.hparams.train_val_test_split, + generator=torch.Generator().manual_seed(42), + ) + self.split_timepoint_data = list(map(splitter, self.timepoint_data)) + + def combined_loader(self, index, shuffle=False, load_full=False): + if load_full: + tp_dataloaders = [ + DataLoader( + dataset=datasets, + batch_size=1000 * self.hparams.batch_size, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + shuffle=False, + drop_last=False, + ) + for datasets in self.timepoint_data + ] + else: + tp_dataloaders = [ + DataLoader( + dataset=datasets[index], + batch_size=self.hparams.batch_size, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + shuffle=shuffle, + drop_last=True, + ) + for datasets in self.split_timepoint_data + ] + return CombinedLoader(tp_dataloaders, mode="min_size") + + def train_dataloader(self): + return self.combined_loader(0, shuffle=True) + + def val_dataloader(self): + return self.combined_loader(1, shuffle=False, load_full=False) + + def test_dataloader(self): + return self.combined_loader(2, shuffle=False, load_full=True) + + +class CustomGeodesicTrajectoryDataModule(LightningDataModule): + HAS_JOINT_PLANS = True + pass_to_model = True + IS_TRAJECTORY = True + # TODO Code copied from above, doesn't like inheritance with init. + + def __init__( + self, + data_dir: str = "data/", + train_val_test_split: Union[Tuple[int, int, int], Tuple[float, float, float]] = ( + 0.8, + 0.1, + 0.1, + ), + max_dim: Optional[int] = None, + system: str = "", + batch_size: int = 64, + whiten: bool = False, + hvg: bool = False, + num_workers: int = 0, + pin_memory: bool = False, + seed=None, + ): + super().__init__() + self.save_hyperparameters(logger=True) + assert system.endswith("h5ad") + import scanpy as sc + + adata = sc.read_h5ad(system) + self.data, self.labels, self.ulabels = load_dataset(system) + if hvg: + sc.pp.highly_variable_genes(adata, n_top_genes=max_dim) + self.data = adata.X[:, adata.var["highly_variable"]].toarray() + + if max_dim: + self.data = self.data[:, :max_dim] + if whiten: + self.scaler = StandardScaler() + self.scaler.fit(self.data) + self.data = self.scaler.transform(self.data) + self.dim = self.data.shape[-1] + self.system = system + print(self.ulabels.unique()) + + self.pi = [adata.uns[f"pi_{t}_{t+1}"] for t in range(len(self.ulabels.unique()) - 1)] + self.pi_leaveout = [adata.uns[f"pi_{t+1}"] for t in range(len(self.ulabels.unique()) - 2)] + + self.timepoint_data = [ + self.data[self.labels == lab].astype(np.float32) for lab in self.ulabels + ] + log.info( + f"Loaded {self.system} with timepoints {self.ulabels} of sizes {[len(d) for d in self.timepoint_data]} with dim {self.dim}." + ) + log.info(f"time datasets of shape {[t.shape for t in self.timepoint_data]}") + + def combined_loader(self, index, shuffle=False, load_full=False): + if load_full: + tp_dataloaders = [ + DataLoader( + dataset=datasets, + batch_size=1000 * self.hparams.batch_size, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + shuffle=False, + drop_last=False, + ) + for datasets in self.timepoint_data + ] + else: + tp_dataloaders = [ + DataLoader( + dataset=torch.arange(datasets.shape[0])[:, None], + batch_size=self.hparams.batch_size, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + shuffle=shuffle, + drop_last=True, + ) + for datasets in self.timepoint_data + ] + + return CombinedLoader(tp_dataloaders, mode="min_size") + + def train_dataloader(self): + return self.combined_loader(0, shuffle=True) + + def val_dataloader(self): + """Use training set for validation assuming [1,0,0] train val test split.""" + return self.combined_loader(0, shuffle=False, load_full=False) + + def test_dataloader(self): + return self.combined_loader(2, shuffle=False, load_full=True) + + +class DiffusionSchrodingerBridgeGaussians(LightningDataModule): + pass_to_model = True + IS_TRAJECTORY = True + GAUSSIAN_CLOSED_FORM = True # Has closed form SB solution + # TODO Code copied from above, doesn't like inheritance with init. + + def __init__( + self, + train_val_test_split: Union[Tuple[int, int, int], Tuple[float, float, float]] = ( + 0.8, + 0.1, + 0.1, + ), + dim=2, + a=0.1, + batch_size: int = 64, + num_workers: int = 0, + pin_memory: bool = False, + seed=None, + ): + super().__init__() + self.save_hyperparameters(logger=True) + np.random.seed(seed) + + N = ( + train_val_test_split + if isinstance(train_val_test_split, int) + else sum(train_val_test_split) + ) + self.timepoint_data = [ + torch.from_numpy(np.random.randn(N, dim) - a).type(torch.float32), + torch.from_numpy(np.random.randn(N, dim) + a).type(torch.float32), + ] + self.split() + self.dim = dim + self.a = a + + def split(self): + """Split requires self.hparams.train_val_test_split, timepoint_data, system, ulabels.""" + train_val_test_split = self.hparams.train_val_test_split + if isinstance(train_val_test_split, int): + self.split_timepoint_data = [(x, x, x) for x in self.timepoint_data] + return + splitter = partial( + random_split, + lengths=self.hparams.train_val_test_split, + generator=torch.Generator().manual_seed(42), + ) + self.split_timepoint_data = list(map(splitter, self.timepoint_data)) + + def closed_form_marginal(self, sigma, t): + """Simplified closed form marginal for the SB Gaussians. + + Derived from Mallasto et al. 2020 https://arxiv.org/pdf/2006.03416.pdf + """ + a = self.a + mean = (2 * a * t - a) * torch.ones(self.dim) + cov = (math.sqrt(4 + sigma**4) * t * (1 - t) + (1 - t) ** 2 + t**2) * torch.eye(self.dim) + return mean, cov + + def detailed_evaluation(self, xt, sigma, t): + est_mean = torch.mean(xt, dim=0) + est_cov = torch.cov(xt.T) + mean, cov = self.closed_form_marginal(sigma, t) + mean_diff = torch.mean(est_mean - mean) + off_diag_frob = torch.linalg.matrix_norm(est_cov - torch.diag_embed(torch.diag(est_cov))) + diag_diff = torch.mean(torch.diag(est_cov - cov)) + return torch.stack(mean_diff, off_diag_frob, diag_diff) + + def KL(self, xt, sigma, t): + """KL divergence between the ground truth SB marginal and the estimated marginal.""" + est_mean = torch.mean(xt, dim=0) + est_cov = torch.cov(xt.T) + mean, cov = self.closed_form_marginal(sigma, t) + return torch.distributions.kl.kl_divergence( + torch.distributions.MultivariateNormal(est_mean, est_cov), + torch.distributions.MultivariateNormal(mean, cov), + ) + + def combined_loader(self, index, shuffle=False, drop_last=False): + tp_dataloaders = [ + DataLoader( + dataset=datasets[index], + batch_size=self.hparams.batch_size, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + shuffle=shuffle, + drop_last=drop_last, + ) + for datasets in self.split_timepoint_data + ] + return CombinedLoader(tp_dataloaders, mode="min_size") + + def train_dataloader(self): + return self.combined_loader(0, shuffle=True, drop_last=True) + + def val_dataloader(self): + return self.combined_loader(1, shuffle=False, drop_last=False) + + def test_dataloader(self): + return self.combined_loader(2, shuffle=False, drop_last=False) + + +class TwoDimDataModule(LightningDataModule): + pass_to_model = True + IS_TRAJECTORY = True + # TODO Code copied from above, doesn't like inheritance with init. + + def __init__( + self, + train_val_test_split: Union[Tuple[int, int, int], Tuple[float, float, float]] = ( + 0.8, + 0.1, + 0.1, + ), + system: str = "", + batch_size: int = 64, + num_workers: int = 0, + pin_memory: bool = False, + seed=None, + ): + super().__init__() + self.save_hyperparameters(logger=True) + self.system = system + np.random.seed(seed) + + N = ( + train_val_test_split + if isinstance(train_val_test_split, int) + else sum(train_val_test_split) + ) + + systems = system.split("-") + self.timepoint_data = [data_distrib(N, s, seed) for s in systems] + self.split() + self.dim = 2 + + def split(self): + """Split requires self.hparams.train_val_test_split, timepoint_data, system, ulabels.""" + train_val_test_split = self.hparams.train_val_test_split + if isinstance(train_val_test_split, int): + self.split_timepoint_data = [(x, x, x) for x in self.timepoint_data] + return + splitter = partial( + random_split, + lengths=self.hparams.train_val_test_split, + generator=torch.Generator().manual_seed(42), + ) + self.split_timepoint_data = list(map(splitter, self.timepoint_data)) + + def combined_loader(self, index, shuffle=False): + tp_dataloaders = [ + DataLoader( + dataset=datasets[index], + batch_size=self.hparams.batch_size, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + shuffle=shuffle, + drop_last=True, + ) + for datasets in self.split_timepoint_data + ] + return CombinedLoader(tp_dataloaders, mode="min_size") + + def train_dataloader(self): + return self.combined_loader(0, shuffle=True) + + def val_dataloader(self): + return self.combined_loader(1, shuffle=False) + + def test_dataloader(self): + return self.combined_loader(2, shuffle=False) + + +class TorchDynDataModule(BaseLightningDataModule): + # https://github.com/DiffEqML/torchdyn/blob/master/torchdyn/datasets/static_datasets.py + pass_to_model = True + IS_TRAJECTORY = False + + def __init__( + self, + system: str, + system_kwargs: dict = {}, + train_val_test_split: Union[int, Tuple[int, int, int]] = 1, + batch_size: int = 64, + num_workers: int = 0, + pin_memory: bool = False, + seed=None, + ) -> None: + super().__init__() + self.save_hyperparameters(logger=True) + self.system = system + + np.random.seed(seed) + N = ( + train_val_test_split + if isinstance(train_val_test_split, int) + else sum(train_val_test_split) + ) + if system == "gaussians": + N = N // 8 + system_kwargs["n_gaussians"] = 8 + system_kwargs["radius"] = 5 + system_kwargs["std_gaussians"] = 1 + if system == "funnel": + x = torch.randn((N, system_kwargs["dim"])) + x[:, 1:] *= (x[:, :1] / 2).exp() + dataset = x + if system.endswith(".npz") or system.endswith(".h5ad"): + # Load single cell data + dataset, self.labels, self.ulabels = load_dataset(system) + if "max_dim" in system_kwargs: + max_dim = system_kwargs["max_dim"] + dataset = dataset[:, :max_dim] + if "whiten" in system_kwargs: + self.scaler = StandardScaler() + self.scaler.fit(dataset) + dataset = self.scaler.transform(dataset) + else: + dataset, self.labels = ToyDataset().generate(N, dataset_type=system, **system_kwargs) + if isinstance(self.hparams.train_val_test_split, int): + self.data_train, self.data_val, self.data_test = dataset, dataset, dataset + else: + self.data_train, self.data_val, self.data_test = random_split( + dataset=dataset, + lengths=self.hparams.train_val_test_split, + generator=torch.Generator().manual_seed(42), + ) + self.dim = dataset.shape[1] + + +class ICNNDataModule(LightningDataModule): + # DEPRECATED + pass_to_model = True + IS_TRAJECTORY = True + # TODO Code copied from above, doesn't like inheritance with init. + + def __init__( + self, + system: List[str], + train_val_test_split: Union[int, Tuple[int, int, int], Tuple[float, float, float]] = ( + 0.8, + 0.1, + 0.1, + ), + batch_size: int = 64, + num_workers: int = 0, + pin_memory: bool = False, + seed=None, + ): + super().__init__() + self.save_hyperparameters(logger=True) + self.data, self.labels, self.ulabels = load_dataset(system) + self.dim = self.data.shape[-1] + self.system = system + + self.timepoint_data = [ + self.data[self.labels == lab].astype(np.float32) for lab in self.ulabels + ] + self.split() + log.info( + f"Loaded {self.system} with timepoints {self.ulabels} of sizes {[len(d) for d in self.timepoint_data]}." + ) + + def split(self): + """Split requires self.hparams.train_val_test_split, timepoint_data, system, ulabels.""" + train_val_test_split = self.hparams.train_val_test_split + if isinstance(train_val_test_split, int): + self.split_timepoint_data = [(x, x, x) for x in self.timepoint_data] + return + splitter = partial( + random_split, + lengths=self.hparams.train_val_test_split, + generator=torch.Generator().manual_seed(42), + ) + self.split_timepoint_data = list(map(splitter, self.timepoint_data)) + + def combined_loader(self, index, shuffle=False): + tp_dataloaders = [ + DataLoader( + dataset=datasets[index], + batch_size=self.hparams.batch_size, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + shuffle=shuffle, + drop_last=True, + ) + for datasets in self.split_timepoint_data + ] + return CombinedLoader(tp_dataloaders, mode="min_size") + + def train_dataloader(self): + return self.combined_loader(0, shuffle=True) + + def val_dataloader(self): + return self.combined_loader(1, shuffle=False) + + def test_dataloader(self): + return self.combined_loader(2, shuffle=False) + + +class SKLearnDataModule(BaseLightningDataModule): + pass_to_model = True + IS_TRAJECTORY = False + + def __init__( + self, + system: str, + train_val_test_split: Union[int, Tuple[int, int, int]] = 1, + batch_size: int = 64, + num_workers: int = 0, + pin_memory: bool = False, + seed=42, + ) -> None: + import sklearn.datasets + + super().__init__() + self.save_hyperparameters(logger=True) + self.system = system + + np.random.seed(seed) + N = ( + train_val_test_split + if isinstance(train_val_test_split, int) + else sum(train_val_test_split) + ) + np.random.seed(seed) + if system == "circles": + self.data, _ = sklearn.datasets.make_circles( + n_samples=N, factor=0.5, noise=0.05, random_state=seed + ) + self.data *= 3.5 + elif system == "moons": + self.data, _ = sklearn.datasets.make_moons(n_samples=N, noise=0.05, random_state=seed) + self.data *= 2 + self.data[:, 0] -= 1 + elif system == "blobs": + self.data, _ = sklearn.datasets.make_blobs(n_samples=N) + elif system == "scurve": + self.data, _ = sklearn.datasets.make_s_curve( + n_samples=N, noise=0.05, random_state=seed + ) + self.data = np.vstack([self.data[:, 0], self.data[:, 2]]).T + self.data *= 1.5 + else: + raise NotImplementedError("Unknown dataset name %s" % system) + + dataset = torch.from_numpy(self.data.astype(np.float32)) + + if isinstance(self.hparams.train_val_test_split, int): + self.data_train, self.data_val, self.data_test = dataset, dataset, dataset + else: + self.data_train, self.data_val, self.data_test = random_split( + dataset=dataset, + lengths=self.hparams.train_val_test_split, + generator=torch.Generator().manual_seed(42), + ) + self.dim = dataset.shape[1] + + +class DistributionDataModule(BaseLightningDataModule): + """DEPRECATED: Implements loader for datasets taking the form of a sequence of distributions + over time. + + Each batch is a 3-tuple of data (data, time, causal graph) ([b x d], [b], [b x d x d]). + """ + + pass_to_model = True + HAS_GRAPH = True + + def __init__( + self, + data_dir: str = "data/", + system: str = "TREE", + train_val_test_split: Union[int, Tuple[int, int, int]] = 1, + batch_size: int = 64, + num_workers: int = 0, + pin_memory: bool = False, + T: int = 100, + system_kwargs: dict = {}, + seed=None, + ): + super().__init__() + self.save_hyperparameters(logger=True) + dataset = SCData.factory(system, system_kwargs) + self.data = dataset.get_data() + self.labels = dataset.get_times() + self.system = system + + self.grn = torch.zeros((1, 1), dtype=torch.float32) + if hasattr(dataset, "grn"): + self.grn = torch.tensor(dataset.grn, dtype=torch.float32) + else: + log.info("No network found, using dummy") + + self.timepoint_data = [self.data[self.labels == lab] for lab in dataset.get_unique_times()] + self.min_count = min(len(d) for d in self.timepoint_data) + self.nice_data = np.array([d[: self.min_count] for d in self.timepoint_data]) + self.nice_data = torch.tensor(self.nice_data, dtype=torch.float32).transpose(0, 1) + # TODO add support for jagged + self.times = torch.tensor(dataset.get_unique_times(), dtype=torch.float32).repeat( + self.min_count, 1 + ) + self.grn = self.grn.repeat(self.min_count, 1, 1) + t = len(dataset.get_unique_times()) + self.even_times = torch.linspace(1, t, t).repeat(self.min_count, 1) + dataset = TensorDataset(self.nice_data) # , self.even_times, self.grn) + # dataset = TensorDataset(self.nice_data, self.even_times, self.grn) + + if isinstance(self.hparams.train_val_test_split, int): + self.data_train, self.data_val, self.data_test = dataset, dataset, dataset + else: + self.data_train, self.data_val, self.data_test = random_split( + dataset=dataset, + lengths=self.hparams.train_val_test_split, + generator=torch.Generator().manual_seed(42), + ) + self.dim = self.nice_data.shape[-1] + + +if __name__ == "__main__": + import hydra + import omegaconf + import pyrootutils + + root = pyrootutils.setup_root(__file__, pythonpath=True) + cfg = omegaconf.OmegaConf.load(root / "configs" / "datamodule" / "torchdyn.yaml") + _ = hydra.utils.instantiate(cfg) + cfg = omegaconf.OmegaConf.load(root / "configs" / "datamodule" / "custom_dist.yaml") + cfg.system = "scurve" + datamodule = hydra.utils.instantiate(cfg) + print(datamodule.data_train.shape) diff --git a/conditional-flow-matching/runner/src/eval.py b/conditional-flow-matching/runner/src/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..76312661c47fa14cf37aaba68b0009e84e8cbd68 --- /dev/null +++ b/conditional-flow-matching/runner/src/eval.py @@ -0,0 +1,111 @@ +import pyrootutils + +root = pyrootutils.setup_root( + search_from=__file__, + indicator=[".git", "pyproject.toml"], + pythonpath=True, + dotenv=True, +) + +# ------------------------------------------------------------------------------------ # +# `pyrootutils.setup_root(...)` above is optional line to make environment more convenient +# should be placed at the top of each entry file +# +# main advantages: +# - allows you to keep all entry files in "src/" without installing project as a package +# - launching python file works no matter where is your current work dir +# - automatically loads environment variables from ".env" if exists +# +# how it works: +# - `setup_root()` above recursively searches for either ".git" or "pyproject.toml" in present +# and parent dirs, to determine the project root dir +# - adds root dir to the PYTHONPATH (if `pythonpath=True`), so this file can be run from +# any place without installing project as a package +# - sets PROJECT_ROOT environment variable which is used in "configs/paths/default.yaml" +# to make all paths always relative to project root +# - loads environment variables from ".env" in root dir (if `dotenv=True`) +# +# you can remove `pyrootutils.setup_root(...)` if you: +# 1. either install project as a package or move each entry file to the project root dir +# 2. remove PROJECT_ROOT variable from paths in "configs/paths/default.yaml" +# +# https://github.com/ashleve/pyrootutils +# ------------------------------------------------------------------------------------ # + +from typing import List, Tuple + +import hydra +from omegaconf import DictConfig +from pytorch_lightning import LightningDataModule, LightningModule, Trainer +from pytorch_lightning.loggers import LightningLoggerBase + +from src import utils + +log = utils.get_pylogger(__name__) + + +@utils.task_wrapper +def evaluate(cfg: DictConfig) -> Tuple[dict, dict]: + """Evaluates given checkpoint on a datamodule testset. + + This method is wrapped in optional @task_wrapper decorator which applies extra utilities + before and after the call. + + Args: + cfg (DictConfig): Configuration composed by Hydra. + + Returns: + Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects. + """ + assert cfg.ckpt_path + + log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>") + datamodule: LightningDataModule = hydra.utils.instantiate(cfg.datamodule) + + log.info(f"Instantiating model <{cfg.model._target_}>") + if hasattr(datamodule, "pass_to_model"): + log.info("Passing full datamodule to model") + model: LightningModule = hydra.utils.instantiate(cfg.model)(datamodule=datamodule) + else: + if hasattr(datamodule, "dim"): + log.info("Passing datamodule.dim to model") + model: LightningModule = hydra.utils.instantiate(cfg.model)(dim=datamodule.dim) + else: + model: LightningModule = hydra.utils.instantiate(cfg.model) + + log.info("Instantiating loggers...") + logger: List[LightningLoggerBase] = utils.instantiate_loggers(cfg.get("logger")) + + log.info(f"Instantiating trainer <{cfg.trainer._target_}>") + trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger) + + object_dict = { + "cfg": cfg, + "datamodule": datamodule, + "model": model, + "logger": logger, + "trainer": trainer, + } + + if logger: + log.info("Logging hyperparameters!") + utils.log_hyperparameters(object_dict) + + log.info("Starting testing!") + trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path) + + # for predictions use trainer.predict(...) + # predictions = trainer.predict(model=model, dataloaders=dataloaders, ckpt_path=cfg.ckpt_path) + + metric_dict = trainer.callback_metrics + + return metric_dict, object_dict + + +@hydra.main(version_base="1.2", config_path=root / "configs", config_name="eval.yaml") +def main(cfg: DictConfig) -> None: + evaluate(cfg) + + +if __name__ == "__main__": + main() diff --git a/conditional-flow-matching/runner/src/models/__init__.py b/conditional-flow-matching/runner/src/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/conditional-flow-matching/runner/src/models/cfm_module.py b/conditional-flow-matching/runner/src/models/cfm_module.py new file mode 100644 index 0000000000000000000000000000000000000000..d5c050e5ea5c7781a17729358e8fe74f874d425e --- /dev/null +++ b/conditional-flow-matching/runner/src/models/cfm_module.py @@ -0,0 +1,1454 @@ +import copy +import math +import os +from typing import Any, List, Optional, Union + +import numpy as np +import torch +from pytorch_lightning import LightningDataModule, LightningModule +from torch.distributions import MultivariateNormal +from torchdyn.core import NeuralODE +from torchvision import transforms + +from .components.augmentation import ( + AugmentationModule, + AugmentedVectorField, + Sequential, +) +from .components.distribution_distances import compute_distribution_distances +from .components.optimal_transport import OTPlanSampler +from .components.plotting import ( + plot_samples, + plot_trajectory, + store_trajectories, +) +from .components.schedule import ConstantNoiseScheduler, NoiseScheduler +from .components.solver import FlowSolver +from .utils import get_wandb_logger + + +class CFMLitModule(LightningModule): + """Conditional Flow Matching Module for training generative models and models over time.""" + + def __init__( + self, + net: Any, + optimizer: Any, + datamodule: LightningDataModule, + augmentations: AugmentationModule, + partial_solver: FlowSolver, + scheduler: Optional[Any] = None, + neural_ode: Optional[Any] = None, + ot_sampler: Optional[Union[str, Any]] = None, + sigma_min: float = 0.1, + avg_size: int = -1, + leaveout_timepoint: int = -1, + test_nfe: int = 100, + plot: bool = False, + nice_name: str = "CFM", + ) -> None: + """Initialize a conditional flow matching network either as a generative model or for a + sequence of timepoints. + + Note: DDP does not currently work with NeuralODE objects from torchdyn + in the init so we initialize them every time we need to do a sampling + step. + + Args: + net: torch module representing dx/dt = f(t, x) for t in [1, T] missing dimension. + optimizer: partial torch.optimizer missing parameters. + datamodule: datamodule object needs to have "dim", "IS_TRAJECTORY" properties. + ot_sampler: ot_sampler specified as an object or string. If none then no OT is used in minibatch. + sigma_min: sigma_min determines the width of the Gaussian smoothing of the data and interpolations. + leaveout_timepoint: which (if any) timepoint to leave out during the training phase + plot: if true, log intermediate plots during validation + """ + super().__init__() + self.save_hyperparameters( + ignore=[ + "net", + "optimizer", + "scheduler", + "datamodule", + "augmentations", + "partial_solver", + ], + logger=False, + ) + self.datamodule = datamodule + self.is_trajectory = False + if hasattr(datamodule, "IS_TRAJECTORY"): + self.is_trajectory = datamodule.IS_TRAJECTORY + # dims is either an integer or a tuple. This helps us to decide whether to process things as + # a vector or as an image. + if hasattr(datamodule, "dim"): + self.dim = datamodule.dim + self.is_image = False + elif hasattr(datamodule, "dims"): + self.dim = datamodule.dims + self.is_image = True + else: + raise NotImplementedError("Datamodule must have either dim or dims") + self.net = net(dim=self.dim) + self.augmentations = augmentations + self.aug_net = AugmentedVectorField(self.net, self.augmentations.regs, self.dim) + self.val_augmentations = AugmentationModule( + # cnf_estimator=None, + l1_reg=1, + l2_reg=1, + squared_l2_reg=1, + ) + self.val_aug_net = AugmentedVectorField(self.net, self.val_augmentations.regs, self.dim) + if neural_ode is not None: + self.aug_node = Sequential( + self.augmentations.augmenter, + neural_ode(self.aug_net), + ) + + self.partial_solver = partial_solver + self.optimizer = optimizer + self.scheduler = scheduler + self.ot_sampler = ot_sampler + if ot_sampler == "None": + self.ot_sampler = None + if isinstance(self.ot_sampler, str): + # regularization taken for optimal Schrodinger bridge relationship + self.ot_sampler = OTPlanSampler(method=ot_sampler, reg=2 * sigma_min**2) + self.criterion = torch.nn.MSELoss() + + def forward_integrate(self, batch: Any, t_span: torch.Tensor): + """Forward pass with integration over t_span intervals. + + (t, x, t_span) -> [x_t_span]. + """ + X = self.unpack_batch(batch) + X_start = X[:, t_span[0], :] + traj = self.node.trajectory(X_start, t_span=t_span) + return traj + + def forward(self, t: torch.Tensor, x: torch.Tensor): + """Forward pass (t, x) -> dx/dt.""" + return self.net(t, x) + + def unpack_batch(self, batch): + """Unpacks a batch of data to a single tensor.""" + if self.is_trajectory: + return torch.stack(batch, dim=1) + if not isinstance(self.dim, int): + # Assume this is an image classification dataset where we need to strip the targets + return batch[0] + return batch + + def preprocess_batch(self, X, training=False): + """Converts a batch of data into matched a random pair of (x0, x1)""" + t_select = torch.zeros(1, device=X.device) + if self.is_trajectory: + batch_size, times, dim = X.shape + if not hasattr(self.datamodule, "HAS_JOINT_PLANS"): + # resample the OT plan + # list of length t of tuples of length 2 of tensors of shape + tmp_ot_list = [] + for t in range(times - 1): + if training and t + 1 == self.hparams.leaveout_timepoint: + tmp_ot = torch.stack((X[:, t], X[:, t + 2])) + else: + tmp_ot = torch.stack((X[:, t], X[:, t + 1])) + if ( + training + and self.ot_sampler is not None + and t != self.hparams.leaveout_timepoint + ): + tmp_ot = torch.stack(self.ot_sampler.sample_plan(tmp_ot[0], tmp_ot[1])) + + tmp_ot_list.append(tmp_ot) + tmp_ot_list = torch.stack(tmp_ot_list) + # randomly sample a batch + + if training and self.hparams.leaveout_timepoint > 0: + # Select random except for the leftout timepoint + t_select = torch.randint(times - 2, size=(batch_size,), device=X.device) + t_select[t_select >= self.hparams.leaveout_timepoint] += 1 + else: + t_select = torch.randint(times - 1, size=(batch_size,)) + x0 = [] + x1 = [] + for i in range(batch_size): + ti = t_select[i] + ti_next = ti + 1 + if training and ti_next == self.hparams.leaveout_timepoint: + ti_next += 1 + if hasattr(self.datamodule, "HAS_JOINT_PLANS"): + x0.append(torch.tensor(self.datamodule.timepoint_data[ti][X[i, ti]])) + pi = self.datamodule.pi[ti] + if training and ti + 1 == self.hparams.leaveout_timepoint: + pi = self.datamodule.pi_leaveout[ti] + index_batch = X[i][ti] + i_next = np.random.choice( + pi.shape[1], p=pi[index_batch] / pi[index_batch].sum() + ) + x1.append(torch.tensor(self.datamodule.timepoint_data[ti_next][i_next])) + else: + x0.append(tmp_ot_list[ti][0][i]) + x1.append(tmp_ot_list[ti][1][i]) + x0, x1 = torch.stack(x0), torch.stack(x1) + else: + batch_size = X.shape[0] + # If no trajectory assume generate from standard normal + x0 = torch.randn_like(X) + x1 = X + return x0, x1, t_select + + def average_ut(self, x, t, mu_t, sigma_t, ut): + pt = torch.exp(-0.5 * (torch.cdist(x, mu_t) ** 2) / (sigma_t**2)) + batch_size = x.shape[0] + ind = torch.randint( + batch_size, size=(batch_size, self.hparams.avg_size - 1) + ) # randomly (non-repreat) sample m-many index + # always include self + ind = torch.cat([ind, torch.arange(batch_size)[:, None]], dim=1) + pt_sub = torch.stack([pt[i, ind[i]] for i in range(batch_size)]) + ut_sub = torch.stack([ut[ind[i]] for i in range(batch_size)]) + p_sum = torch.sum(pt_sub, dim=1, keepdim=True) + ut = torch.sum(pt_sub[:, :, None] * ut_sub, dim=1) / p_sum + # Reduce batch size because they are all the same + return x[:1], ut[:1], t[:1] + + def calc_mu_sigma(self, x0, x1, t): + mu_t = t * x1 + (1 - t) * x0 + sigma_t = self.hparams.sigma_min + return mu_t, sigma_t + + def calc_u(self, x0, x1, x, t, mu_t, sigma_t): + del x, t, mu_t, sigma_t + return x1 - x0 + + def calc_loc_and_target(self, x0, x1, t, t_select, training): + """Computes the loss on a batch of data.""" + t_xshape = t.reshape(-1, *([1] * (x0.dim() - 1))) + mu_t, sigma_t = self.calc_mu_sigma(x0, x1, t_xshape) + eps_t = torch.randn_like(mu_t) + x = mu_t + sigma_t * eps_t + ut = self.calc_u(x0, x1, x, t_xshape, mu_t, sigma_t) + + # if we are starting from right before the leaveout_timepoint then we + # divide the target by 2 + if training and self.hparams.leaveout_timepoint > 0: + ut[t_select + 1 == self.hparams.leaveout_timepoint] /= 2 + t[t_select + 1 == self.hparams.leaveout_timepoint] *= 2 + + # p is the pair-wise conditional probability matrix. Note that this has to be torch.cdist(x, mu) in that order + # t that network sees is incremented by first timepoint + t = t + t_select.reshape(-1, *t.shape[1:]) + return x, ut, t, mu_t, sigma_t, eps_t + + def step(self, batch: Any, training: bool = False): + """Computes the loss on a batch of data.""" + X = self.unpack_batch(batch) + x0, x1, t_select = self.preprocess_batch(X, training) + # Either randomly sample a single T or sample a batch of T's + if self.hparams.avg_size > 0: + t = torch.rand(1).repeat(X.shape[0]).type_as(X) + else: + t = torch.rand(X.shape[0]).type_as(X) + # Resample the plan if we are using optimal transport + if self.ot_sampler is not None and not self.is_trajectory: + x0, x1 = self.ot_sampler.sample_plan(x0, x1) + + x, ut, t, mu_t, sigma_t, eps_t = self.calc_loc_and_target(x0, x1, t, t_select, training) + + if self.hparams.avg_size > 0: + x, ut, t = self.average_ut(x, t, mu_t, sigma_t, ut) + aug_x = self.aug_net(t, x, augmented_input=False) + reg, vt = self.augmentations(aug_x) + return torch.mean(reg), self.criterion(vt, ut) + + def training_step(self, batch: Any, batch_idx: int): + reg, mse = self.step(batch, training=True) + loss = mse + reg + prefix = "train" + self.log_dict( + {f"{prefix}/loss": loss, f"{prefix}/mse": mse, f"{prefix}/reg": reg}, + on_step=True, + on_epoch=False, + prog_bar=True, + ) + return loss + + def image_eval_step(self, batch: Any, batch_idx: int, prefix: str): + import os + + from torchvision.utils import save_image + + # val_augmentations = AugmentationModule( + # cnf_estimator="hutch", + # squared_l2_reg=1, + # ) + # aug_dims = val_augmentations.aug_dims + # val_aug_net = AugmentedVectorField(self.net, val_augmentations.regs, self.dim) + # val_aug_node = Sequential( + # val_augmentations.augmenter, + # NeuralODE(val_aug_net, solver="euler", sensitivity="adjoint"), + # ) + # t_span = torch.linspace(1, 0, 101) + # x = batch[0] + # os.makedirs("regularizations", exist_ok=True) + # for k in range(0): + # x_norm = cifar10_normalization()(x + (torch.rand_like(x) / 255)) + # _, aug_traj = val_aug_node(x_norm, t_span) + # aug, traj = aug_traj[-1, :, :aug_dims], aug_traj[-1, :, aug_dims:] + # mn = MultivariateNormal( + # torch.zeros(prod(self.dim)).type_as(traj), + # torch.eye(prod(self.dim)).type_as(traj), + # ) + # aug[:, 0] += mn.log_prob(traj.reshape(traj.shape[0], -1)) + # np.save( + # f"regularizations/regs_{k}_{batch_idx}.npy", + # aug.detach().cpu().numpy(), + # ) + + solver = self.partial_solver(self.net, self.dim) + if isinstance(self.hparams.test_nfe, int): + t_span = torch.linspace(0, 1, int(self.hparams.test_nfe) + 1) + elif isinstance(self.hparams.test_nfe, str): + solver.ode_solver = "tsit5" + t_span = torch.linspace(0, 1, 2) + else: + raise NotImplementedError(f"Unknown test procedure {self.hparams.test_nfe}") + traj = solver.odeint(torch.randn(batch[0].shape[0], *self.dim).type_as(batch[0]), t_span)[ + -1 + ] + os.makedirs("images", exist_ok=True) + mean = [-x / 255.0 for x in [125.3, 123.0, 113.9]] + std = [255.0 / x for x in [63.0, 62.1, 66.7]] + inv_normalize = transforms.Compose( + [ + transforms.Normalize(mean=[0.0, 0.0, 0.0], std=std), + transforms.Normalize(mean=mean, std=[1.0, 1.0, 1.0]), + ] + ) + traj = inv_normalize(traj) + traj = torch.clip(traj, min=0, max=1.0) + for i, image in enumerate(traj): + save_image(image, fp=f"images/{batch_idx}_{i}.png") + return {"x": batch[0]} + + def eval_step(self, batch: Any, batch_idx: int, prefix: str): + if prefix == "test" and self.is_image: + self.image_eval_step(batch, batch_idx, prefix) + shapes = [b.shape[0] for b in batch] + + if not self.is_image and prefix == "val" and shapes.count(shapes[0]) == len(shapes): + reg, mse = self.step(batch, training=False) + loss = mse + reg + self.log_dict( + {f"{prefix}/loss": loss, f"{prefix}/mse": mse, f"{prefix}/reg": reg}, + on_step=False, + on_epoch=True, + sync_dist=True, + ) + return {"loss": loss, "mse": mse, "reg": reg, "x": self.unpack_batch(batch)} + + return {"x": batch} + + def preprocess_epoch_end(self, outputs: List[Any], prefix: str): + """Preprocess the outputs of the epoch end function.""" + if self.is_trajectory and prefix == "test" and isinstance(outputs[0]["x"], list): + # x is jagged if doing a trajectory + x = outputs[0]["x"] + ts = len(x) + x0 = x[0] + x_rest = x[1:] + elif self.is_trajectory: + if hasattr(self.datamodule, "HAS_JOINT_PLANS"): + x = [torch.tensor(dd) for dd in self.datamodule.timepoint_data] + x0 = x[0] + x_rest = x[1:] + ts = len(x) + else: + v = {k: torch.cat([d[k] for d in outputs]) for k in ["x"]} + x = v["x"] + ts = x.shape[1] + x0 = x[:, 0, :] + x_rest = x[:, 1:] + else: + if isinstance(self.dim, int): + v = {k: torch.cat([d[k] for d in outputs]) for k in ["x"]} + x = v["x"] + else: + x = [d["x"] for d in outputs][0][0][:100] + # Sample some random points for the plotting function + rand = torch.randn_like(x) + # rand = torch.randn_like(x, generator=torch.Generator(device=x.device).manual_seed(42)) + x = torch.stack([rand, x], dim=1) + ts = x.shape[1] + x0 = x[:, 0] + x_rest = x[:, 1:] + return ts, x, x0, x_rest + + def forward_eval_integrate(self, ts, x0, x_rest, outputs, prefix): + # Build a trajectory + t_span = torch.linspace(0, 1, 101) + regs = [] + trajs = [] + full_trajs = [] + solver = self.partial_solver(self.net, self.dim) + nfe = 0 + x0_tmp = x0.clone() + + if self.is_image: + traj = solver.odeint(x0, t_span) + full_trajs.append(traj) + trajs.append(traj[0]) + trajs.append(traj[-1]) + nfe += solver.nfe + + if not self.is_image: + solver.augmentations = self.val_augmentations + for i in range(ts - 1): + traj, aug = solver.odeint(x0_tmp, t_span + i) + full_trajs.append(traj) + traj, aug = traj[-1], aug[-1] + x0_tmp = traj + regs.append(torch.mean(aug, dim=0).detach().cpu().numpy()) + trajs.append(traj) + nfe += solver.nfe + + full_trajs = torch.cat(full_trajs) + + if not self.is_image: + regs = np.stack(regs).mean(axis=0) + names = [f"{prefix}/{name}" for name in self.val_augmentations.names] + self.log_dict(dict(zip(names, regs)), sync_dist=True) + + # Evaluate the fit + if ( + self.is_trajectory + and prefix == "test" + and isinstance(outputs[0]["x"], list) + and not hasattr(self.datamodule, "GAUSSIAN_CLOSED_FORM") + ): + # Redo the solver for each timepoint + trajs = [] + full_trajs = [] + nfe = 0 + x0_tmp = x0 + for i in range(ts - 1): + traj, _ = solver.odeint(x0_tmp, t_span + i) + traj = traj[-1] + x0_tmp = x_rest[i] + trajs.append(traj) + nfe += solver.nfe + names, dists = compute_distribution_distances(trajs[:-1], x_rest[:-1]) + else: + names, dists = compute_distribution_distances(trajs, x_rest) + names = [f"{prefix}/{name}" for name in names] + d = dict(zip(names, dists)) + if self.hparams.leaveout_timepoint >= 0: + to_add = { + f"{prefix}/t_out/{key.split('/')[-1]}": val + for key, val in d.items() + if key.startswith(f"{prefix}/t{self.hparams.leaveout_timepoint}") + } + d.update(to_add) + d[f"{prefix}/nfe"] = nfe + + self.log_dict(d, sync_dist=True) + + if hasattr(self.datamodule, "GAUSSIAN_CLOSED_FORM"): + solver.augmentations = None + # t_span = torch.linspace(0, 1, 101) + # traj = solver.odeint(x0, t_span) + # t_span = t_span[::5] + # traj = traj[::5] + t_span = torch.linspace(0, 1, 21) + traj = solver.odeint(x0, t_span) + assert traj.shape[0] == t_span.shape[0] + kls = [ + self.datamodule.KL(xt, self.hparams.sigma_min, t) for t, xt in zip(t_span, traj) + ] + self.log_dict({f"{prefix}/kl/mean": torch.stack(kls).mean().item()}, sync_dist=True) + self.log_dict({f"{prefix}/kl/tp_{i}": kls[i] for i in range(21)}, sync_dist=True) + + return trajs, full_trajs + + def eval_epoch_end(self, outputs: List[Any], prefix: str): + wandb_logger = get_wandb_logger(self.loggers) + if prefix == "test" and self.is_image: + os.makedirs("images", exist_ok=True) + if len(os.listdir("images")) > 0: + path = "/home/mila/a/alexander.tong/scratch/trajectory-inference/data/fid_stats_cifar10_train.npz" + from pytorch_fid import fid_score + + fid = fid_score.calculate_fid_given_paths(["images", path], 256, "cuda", 2048, 0) + self.log(f"{prefix}/fid", fid) + + ts, x, x0, x_rest = self.preprocess_epoch_end(outputs, prefix) + trajs, full_trajs = self.forward_eval_integrate(ts, x0, x_rest, outputs, prefix) + + if self.hparams.plot: + if isinstance(self.dim, int): + plot_trajectory( + x, + full_trajs, + title=f"{self.current_epoch}_ode", + key="ode_path", + wandb_logger=wandb_logger, + ) + else: + plot_samples( + trajs[-1], + title=f"{self.current_epoch}_samples", + wandb_logger=wandb_logger, + ) + + if prefix == "test" and not self.is_image: + store_trajectories(x, self.net) + + def validation_step(self, batch: Any, batch_idx: int): + return self.eval_step(batch, batch_idx, "val") + + def validation_epoch_end(self, outputs: List[Any]): + self.eval_epoch_end(outputs, "val") + + def test_step(self, batch: Any, batch_idx: int): + return self.eval_step(batch, batch_idx, "test") + + def test_epoch_end(self, outputs: List[Any]): + self.eval_epoch_end(outputs, "test") + + def configure_optimizers(self): + """Pass model parameters to optimizer.""" + optimizer = self.optimizer(params=self.parameters()) + if self.scheduler is None: + return optimizer + + scheduler = self.scheduler(optimizer) + return [optimizer], [{"scheduler": scheduler, "interval": "epoch"}] + + def lr_scheduler_step(self, scheduler, optimizer_idx, metric): + scheduler.step(epoch=self.current_epoch) + + +class RectifiedFlowLitModule(CFMLitModule): + def __init__( + self, + net: Any, + optimizer: Any, + datamodule: LightningDataModule, + augmentations: AugmentationModule, + partial_solver: FlowSolver, + val_augmentations: Optional[AugmentationModule] = None, + scheduler: Optional[Any] = None, + neural_ode: Optional[Any] = None, + ot_sampler: Optional[Union[str, Any]] = None, + sigma_min: float = 0.1, + rectify_epochs: Optional[List[int]] = None, + test_nfe: int = 100, + avg_size: int = -1, + leaveout_timepoint: int = -1, + plot: bool = False, + nice_name: str = "Rect", + ) -> None: + """Initialize a conditional flow matching network either as a generative model or for a + sequence of timepoints. + + Args: + net: torch module representing dx/dt = f(t, x) for t in [1, T] missing dimension. + optimizer: partial torch.optimizer missing parameters. + datamodule: datamodule object needs to have "dim", "IS_TRAJECTORY" properties. + ot_sampler: ot_sampler specified as an object or string. If none then no OT is used in minibatch. + sigma_min: sigma_min determines the width of the Gaussian smoothing of the data and interpolations. + leaveout_timepoint: which (if any) timepoint to leave out during the training phase + plot: if true, log intermediate plots during validation + """ + super(CFMLitModule, self).__init__() + self.save_hyperparameters( + ignore=[ + "net", + "optimizer", + "scheduler", + "datamodule", + "augmentations", + "val_augmentations", + "partial_solver", + ], + logger=False, + ) + self.datamodule = datamodule + self.is_trajectory = False + if hasattr(datamodule, "IS_TRAJECTORY"): + self.is_trajectory = datamodule.IS_TRAJECTORY + if hasattr(datamodule, "dim"): + self.dim = datamodule.dim + self.is_image = False + elif hasattr(datamodule, "dims"): + self.dim = datamodule.dims + self.is_image = True + else: + raise NotImplementedError("Datamodule must have either dim or dims") + self.net = net(dim=self.dim) + self.frozen_net = None + self.augmentations = augmentations + self.aug_net = AugmentedVectorField(self.net, self.augmentations.regs, self.dim) + self.val_augmentations = val_augmentations + if val_augmentations is None: + self.val_augmentations = AugmentationModule( + l1_reg=1, + l2_reg=1, + squared_l2_reg=1, + ) + self.val_aug_net = AugmentedVectorField(self.net, self.val_augmentations.regs, self.dim) + if neural_ode is not None: + self.aug_node = Sequential( + self.augmentations.augmenter, + neural_ode(self.aug_net), + ) + self.partial_solver = partial_solver + self.optimizer = optimizer + self.scheduler = scheduler + self.ot_sampler = ot_sampler + if ot_sampler == "None": + self.ot_sampler = None + if isinstance(self.ot_sampler, str): + # regularization taken for optimal Schrodinger bridge relationship + self.ot_sampler = OTPlanSampler(method=ot_sampler, reg=2 * sigma_min**2) + self.criterion = torch.nn.MSELoss() + + def preprocess_batch(self, X, training=False): + """Converts a batch of data into matched a random pair of (x0, x1)""" + t_select = torch.zeros(1, device=X.device) + if self.is_trajectory: + batch_size, times, dim = X.shape + if training and self.hparams.leaveout_timepoint > 0: + # Select random except for the leftout timepoint + t_select = torch.randint(times - 2, size=(batch_size,), device=X.device) + t_select[t_select >= self.hparams.leaveout_timepoint] += 1 + else: + t_select = torch.randint(times - 1, size=(batch_size,)) + x0 = [] + x1 = [] + for i in range(batch_size): + ti = t_select[i] + ti_next = ti + 1 + if training and ti_next == self.hparams.leaveout_timepoint: + ti_next += 1 + x0.append(X[i, ti]) + x1.append(X[i, ti_next]) + x0, x1 = torch.stack(x0), torch.stack(x1) + else: + batch_size = X.shape[0] + # If no trajectory assume generate from standard normal + x0 = torch.randn_like(X) + x1 = X + + if self.frozen_net is not None: + # Currently only works for 2 distributions + assert t_select[0] == 0 + t_span = torch.linspace(0, 1, 100) + val_node = NeuralODE(self.frozen_net, solver="euler") + with torch.no_grad(): + _, traj = val_node(x0, t_span) + x1 = traj[-1] + return x0, x1, t_select + + def training_epoch_end(self, training_step_outputs): + if ( + self.hparams.rectify_epochs is not None + and self.current_epoch in self.hparams.rectify_epochs + ): + self.frozen_net = copy.deepcopy(self.net) + + +class ActionMatchingLitModule(CFMLitModule): + """Implements Action Matching: Learning Stochastic Dynamics from Samples (Neklyudov et al. + 2022) + + Requires net to have a .energy function where net.energy(t, x): \\mathbb{R}^{d+1} \to + \\mathbb{R} and net.forward is equal to \nabla_x(net.energy). + """ + + def step(self, batch: Any, training: bool = False): + """Computes the loss on a batch of data.""" + assert not self.is_trajectory + energy = self.net.energy + X = self.unpack_batch(batch) + x0, x1, t_select = self.preprocess_batch(X, training) + + if self.ot_sampler is not None: + x0, x1 = self.ot_sampler.sample_plan(x0, x1) + + t = torch.rand(X.shape[0]).type_as(X) + t_xshape = t.reshape(-1, *([1] * (x0.dim() - 1))) + xt = t_xshape * x1 + (1 - t_xshape) * x0 + # t that network sees is incremented by first timepoint + t = t + t_select.reshape(-1, *t.shape[1:]) + + xt.requires_grad, t_xshape.requires_grad = True, True + with torch.set_grad_enabled(True): + st = torch.sum(energy(torch.cat([xt, t_xshape], dim=-1))) + dsdx, dsdt = torch.autograd.grad(st, (xt, t_xshape), create_graph=True) + xt.requires_grad, t_xshape.requires_grad = False, False + a0 = energy(torch.cat([x0, torch.zeros(x0.shape[0], 1)], dim=-1)) + a1 = energy(torch.cat([x1, torch.ones(x1.shape[0], 1)], dim=-1)) + loss = a0 - a1 + 0.5 * (dsdx**2).sum(1, keepdims=True) + dsdt + loss = loss.mean() + aug_x = self.aug_net(t, xt, augmented_input=False) + reg, vt = self.augmentations(aug_x) + return torch.mean(reg), loss + + +class VariancePreservingCFM(CFMLitModule): + """Implements a variance preserving time schedule as suggested in (Albergo et al. + + 2023) here we have an interpolation cos(t pi/2) x_0 + sin(t pi/2) x_1. + """ + + def calc_mu_sigma(self, x0, x1, t): + assert not self.is_trajectory + mu_t = torch.cos(math.pi / 2 * t) * x0 + torch.sin(math.pi / 2 * t) * x1 + sigma_t = self.hparams.sigma_min + return mu_t, sigma_t + + def calc_u(self, x0, x1, x, t, mu_t, sigma_t): + del x, mu_t, sigma_t + return math.pi / 2 * (torch.cos(math.pi / 2 * t) * x1 - torch.sin(math.pi / 2 * t) * x0) + + +class SBCFMLitModule(CFMLitModule): + """Implements a Schrodinger Bridge based conditional flow matching model. + + This is similar to the OTCFM loss, however with the variance varying with t*(1-t). This has + provably equal probability flow to the Schrodinger bridge solution when the transport is + computed with the squared Euclidean distance on R^d. + """ + + def calc_mu_sigma(self, x0, x1, t): + assert not self.is_trajectory + mu_t = t * x1 + (1 - t) * x0 + sigma_t = self.hparams.sigma_min * torch.sqrt(t - t**2) + return mu_t, sigma_t + + def calc_u(self, x0, x1, x, t, mu_t, sigma_t): + del sigma_t + sigma_t_prime_over_sigma_t = (1 - 2 * t) / (2 * t * (1 - t)) + ut = sigma_t_prime_over_sigma_t * (x - mu_t) + x1 - x0 + return ut + + +class SF2MLitModule(CFMLitModule): + def __init__( + self, + net: Any, + optimizer: Any, + datamodule: LightningDataModule, + augmentations: AugmentationModule, + partial_solver: FlowSolver, + score_net: Optional[Any] = None, + scheduler: Optional[Any] = None, + ot_sampler: Optional[Union[str, Any]] = None, + sigma: Optional[NoiseScheduler] = None, + sigma_min: float = 0.1, + outer_loop_epochs: Optional[int] = None, + score_weight: float = 1.0, + avg_size: int = -1, + leaveout_timepoint: int = -1, + test_nfe: int = 100, + test_sde: bool = False, + plot: bool = False, + nice_name: Optional[str] = "SF2M", + ) -> None: + """Initialize a conditional flow matching network either as a generative model or for a + sequence of timepoints. + + Args: + net: torch module representing dx/dt = f(t, x) for t in [1, T] missing dimension. + score_net: torch module representing the score function of the flow. + If not supplied it is assumed that the net contains both flow and + score. + optimizer: partial torch.optimizer missing parameters. + datamodule: datamodule object needs to have "dim", "IS_TRAJECTORY" properties. + ot_sampler: ot_sampler specified as an object or string. If none then no OT is used in minibatch. + sigma: sigma determines the width of the Gaussian smoothing of the data and interpolations. + leaveout_timepoint: which (if any) timepoint to leave out during the training phase + plot: if true, log intermediate plots during validation + """ + super(CFMLitModule, self).__init__() + self.save_hyperparameters( + ignore=[ + "net", + "optimizer", + "scheduler", + "datamodule", + "augmentations", + "sigma_scheduler", + "partial_solver", + ], + logger=False, + ) + self.datamodule = datamodule + self.is_trajectory = False + if hasattr(datamodule, "IS_TRAJECTORY"): + self.is_trajectory = datamodule.IS_TRAJECTORY + # dims is either an integer or a tuple. This helps us to decide whether to process things as + # a vector or as an image. + if hasattr(datamodule, "dim"): + self.dim = datamodule.dim + self.is_image = False + elif hasattr(datamodule, "dims"): + self.dim = datamodule.dims + self.is_image = True + else: + raise NotImplementedError("Datamodule must have either dim or dims") + self.net = net(dim=self.dim) + self.separate_score = score_net is not None + self.score_net = score_net + if self.separate_score: + self.score_net = score_net(dim=self.dim) + self.partial_solver = partial_solver + self.augmentations = augmentations + self.aug_net = AugmentedVectorField(self.net, self.augmentations.regs, self.dim) + self.val_augmentations = AugmentationModule( + # cnf_estimator=None, + l1_reg=1, + l2_reg=1, + squared_l2_reg=1, + ) + self.val_aug_net = AugmentedVectorField(self.net, self.val_augmentations.regs, self.dim) + self.optimizer = optimizer + self.scheduler = scheduler + self.sigma = sigma + if sigma is None: + self.sigma = ConstantNoiseScheduler(sigma_min) + self.ot_sampler = ot_sampler + if ot_sampler == "None": + self.ot_sampler = None + if isinstance(self.ot_sampler, str): + # regularization taken for optimal Schrodinger bridge relationship + self.ot_sampler = OTPlanSampler(method=ot_sampler, reg=2 * self.sigma.F(1)) + self.criterion = torch.nn.MSELoss() + + # If we are doing outer loops holds the current dataset + self.stored_data = None + self.tmp_stored_data = None + + def calc_mu_sigma(self, x0, x1, t): + # assert not self.is_trajectory + ft = self.sigma.F(t) + fone = self.sigma.F(1) + mu_t = x0 + (x1 - x0) * ft / fone + # Note this is slightly different than the notebook. Which is correct? + sigma_t = torch.sqrt(ft - ft**2 / fone) + return mu_t, sigma_t + + def calc_u(self, x0, x1, x, t, mu_t, sigma_t): + ft = self.sigma.F(t) + fone = self.sigma.F(1) + sigma_t_prime = self.sigma(t) ** 2 - 2 * ft * self.sigma(t) ** 2 / fone + sigma_t_prime_over_sigma_t = sigma_t_prime / (sigma_t + 1e-8) + mu_t_prime = (x1 - x0) * self.sigma(t) ** 2 / fone + ut = sigma_t_prime_over_sigma_t * (x - mu_t) + mu_t_prime + return ut + + def calc_loc_and_target(self, x0, x1, t, t_select, training): + t_xshape = t.reshape(-1, *([1] * (x0.dim() - 1))) + mu_t, sigma_t = self.calc_mu_sigma(x0, x1, t_xshape) + eps_t = torch.randn_like(mu_t) + x = mu_t + sigma_t * eps_t + ut = self.calc_u(x0, x1, x, t_xshape, mu_t, sigma_t) + + # if we are starting from right before the leaveout_timepoint then we + # divide the target by 2 + if training and self.hparams.leaveout_timepoint > 0: + ut[t_select + 1 == self.hparams.leaveout_timepoint] /= 2 + t[t_select + 1 == self.hparams.leaveout_timepoint] *= 2 + + # p is the pair-wise conditional probability matrix. Note that this has to be torch.cdist(x, mu) in that order + # t that network sees is incremented by first timepoint + score_target = eps_t + # score_target = -eps_t * self.sigma(t_xshape) ** 2 / 2 + t = t + t_select.reshape(-1, *t.shape[1:]) + return x, ut, t, mu_t, sigma_t, score_target + + def forward_flow_and_score(self, t, x): + if self.separate_score: + reg, vt = self.augmentations(self.aug_net(t, x, augmented_input=False)) + st = self.score_net(t, x) + return reg, vt, st + reg, vtst = self.augmentations(self.aug_net(t, x, augmented_input=False)) + split_idx = vtst.shape[1] // 2 + vt, st = vtst[:, :split_idx], vtst[:, split_idx:] + return reg, vt, st + + def step(self, batch: Any, training: bool = False): + """Computes the loss on a batch of data.""" + X = self.unpack_batch(batch) + x0, x1, t_select = self.preprocess_batch(X, training) + # Either randomly sample a single T or sample a batch of T's + if self.hparams.avg_size > 0: + t = torch.rand(1).repeat(X.shape[0]).type_as(X) + else: + t = torch.rand(X.shape[0]).type_as(X) + # Resample the plan if we are using optimal transport + if self.ot_sampler is not None and self.stored_data is None: + x0, x1 = self.ot_sampler.sample_plan(x0, x1) + t_orig = t.clone() + + x, ut, t, mu_t, sigma_t, score_target = self.calc_loc_and_target( + x0, x1, t, t_select, training + ) + + if self.hparams.avg_size > 0: + x, ut, t = self.average_ut(x, t, mu_t, sigma_t, ut) + + reg, vt, st = self.forward_flow_and_score(t, x) + flow_loss = self.criterion(vt, ut) + score_loss = self.criterion( + -sigma_t * st / (self.sigma(t_orig.reshape(sigma_t.shape)) ** 2) * 2, + score_target, + ) + return torch.mean(reg) + self.hparams.score_weight * score_loss, flow_loss + + def forward_sde_eval(self, ts, x0, x_rest, outputs, prefix): + # Build a trajectory + t_span = torch.linspace(0, 1, 2) + solver = self.partial_solver( + self.net, self.dim, score_field=self.score_net, sigma=self.sigma + ) + if False and self.is_image: + traj = solver.sdeint(x0, t_span, logqp=False) + + trajs = [] + full_trajs = [] + nfe = 0 + kldiv_total = 0 + x0_tmp = x0.clone() + for i in range(ts - 1): + traj, kldiv = solver.sdeint(x0_tmp, t_span + i, logqp=True) + kldiv_total += torch.mean(kldiv[-1]) + x0_tmp = traj[-1] + trajs.append(traj[-1]) + full_trajs.append(traj) + nfe += solver.nfe + full_trajs = torch.cat(full_trajs) + if not self.is_image: + # Evaluate the fit + if ( + self.is_trajectory + and prefix == "test" + and isinstance(outputs[0]["x"], list) + and not hasattr(self.datamodule, "GAUSSIAN_CLOSED_FORM") + ): + trajs = [] + full_trajs = [] + nfe = 0 + kldiv_total = 0 + x0_tmp = x0.clone() + for i in range(ts - 1): + traj, kldiv = solver.sdeint(x0_tmp, t_span + i, logqp=True) + x0_tmp = x_rest[i] + kldiv_total += torch.mean(kldiv[-1]) + trajs.append(traj[-1]) + full_trajs.append(traj) + nfe += solver.nfe + names, dists = compute_distribution_distances(trajs[:-1], x_rest[:-1]) + else: + names, dists = compute_distribution_distances(trajs, x_rest) + names = [f"{prefix}/sde/{name}" for name in names] + d = dict(zip(names, dists)) + if self.hparams.leaveout_timepoint >= 0: + to_add = { + f"{prefix}/sde/t_out/{key.split('/')[-1]}": val + for key, val in d.items() + if key.startswith(f"{prefix}/sde/t{self.hparams.leaveout_timepoint}") + } + d.update(to_add) + d[f"{prefix}/sde/nfe"] = nfe + d[f"{prefix}/sde/kldiv"] = kldiv_total + self.log_dict(d, sync_dist=True) + if hasattr(self.datamodule, "GAUSSIAN_CLOSED_FORM"): + solver.augmentations = None + t_span = torch.linspace(0, 1, 21) + solver.dt = 0.05 + # solver.dt = 0.01 + traj = solver.sdeint(x0, t_span) + assert traj.shape[0] == t_span.shape[0] + kls = [ + self.datamodule.KL(xt, self.hparams.sigma_min, t) for t, xt in zip(t_span, traj) + ] + self.log_dict( + {f"{prefix}/sde/kl/mean": torch.stack(kls).mean().item()}, + sync_dist=True, + ) + self.log_dict({f"{prefix}/sde/kl/tp_{i}": kls[i] for i in range(21)}, sync_dist=True) + return trajs, full_trajs + + def eval_epoch_end(self, outputs: List[Any], prefix: str): + super().eval_epoch_end(outputs, prefix) + wandb_logger = get_wandb_logger(self.loggers) + ts, x, x0, x_rest = self.preprocess_epoch_end(outputs, prefix) + if isinstance(self.dim, int): + traj, sde_traj = self.forward_sde_eval(ts, x0, x_rest, outputs, prefix) + + if self.hparams.plot: + if isinstance(self.dim, int): + plot_trajectory( + x, + sde_traj, + title=f"{self.current_epoch}_sde_traj", + key="sde", + wandb_logger=wandb_logger, + ) + + def preprocess_batch(self, X, training=False): + """Converts a batch of data into matched a random pair of (x0, x1)""" + if self.stored_data is not None and training: + # Randomly sample a batch from the stored data. + idx = torch.randint(self.stored_data.shape[0], size=(X.shape[0],)) + X = self.stored_data[idx] + t_select = torch.zeros(1, device=X.device) + return X[:, 0], X[:, 1], t_select + return super().preprocess_batch(X, training) + + def training_step(self, batch: Any, batch_idx: int): + # If we are doing outerloops we need to resample and store forward and backwards batches. + if ( + self.hparams.outer_loop_epochs is not None + and (self.current_epoch + 1) % self.hparams.outer_loop_epochs == 0 + ): + X = self.unpack_batch(batch) + x0, x1, t_select = self.preprocess_batch(X, training=True) + assert not torch.any(t_select) # resampling outerloop can only handle 2 timepoints + solver = self.partial_solver + t_span = torch.linspace(0, 1, 2) + solver = self.partial_solver( + self.net, self.dim, score_field=self.score_net, sigma=self.sigma + ) + batch_size = x0.shape[0] + with torch.no_grad(): + forward_traj = solver.sdeint(x0[: batch_size // 2], t_span) + backward_traj = torch.flip( + solver.sdeint(x1[batch_size // 2 :], t_span, reverse=True), (0,) + ) + stored_traj = torch.cat([forward_traj, backward_traj], dim=1) + stored_traj = stored_traj.transpose(0, 1) + if batch_idx == 0: + self.tmp_stored_data = [] + self.tmp_stored_data.append(stored_traj) + return super().training_step(batch, batch_idx) + + def training_epoch_end(self, training_step_outputs): + if ( + self.hparams.outer_loop_epochs is not None + and (self.current_epoch + 1) % self.hparams.outer_loop_epochs == 0 + ): + self.stored_data = torch.cat(self.tmp_stored_data, dim=0).detach().clone() + + def image_eval_step(self, batch: Any, batch_idx: int, prefix: str): + import os + + from torchvision.utils import save_image + + solver = self.partial_solver(self.net, self.dim) + if isinstance(self.hparams.test_nfe, int): + t_span = torch.linspace(0, 1, int(self.hparams.test_nfe) + 1) + elif isinstance(self.hparams.test_nfe, str): + solver.ode_solver = "tsit5" + t_span = torch.linspace(0, 1, 2).type_as(batch[0]) + else: + raise NotImplementedError(f"Unknown test procedure {self.hparams.test_nfe}") + if self.hparams.test_sde: + solver = self.partial_solver( + self.net, self.dim, score_field=self.score_net, sigma=self.sigma + ) + solver.dt = 1 / int(self.hparams.test_nfe) + t_span = torch.linspace(0, 1, 2).type_as(batch[0]) + integrator = solver.sdeint + else: + integrator = solver.odeint + x0 = torch.randn(5 * batch[0].shape[0], *self.dim).type_as(batch[0]) + traj = integrator(x0, t_span)[-1] + os.makedirs("images", exist_ok=True) + mean = [-x / 255.0 for x in [125.3, 123.0, 113.9]] + std = [255.0 / x for x in [63.0, 62.1, 66.7]] + inv_normalize = transforms.Compose( + [ + transforms.Normalize(mean=[0.0, 0.0, 0.0], std=std), + transforms.Normalize(mean=mean, std=[1.0, 1.0, 1.0]), + ] + ) + traj = inv_normalize(traj) + traj = torch.clip(traj, min=0, max=1.0) + for i, image in enumerate(traj): + save_image(image, fp=f"images/{batch_idx}_{i}.png") + os.makedirs("compressed_images", exist_ok=True) + torch.save(traj.cpu(), f"compressed_images/{batch_idx}.pt") + return {"x": batch[0]} + + +class OneWaySF2MLitModule(SF2MLitModule): + def calc_loc_and_target(self, x0, x1, t, t_select, training): + x, ut, t, mu_t, sigma_t, score_target = super().calc_loc_and_target( + x0, x1, t, t_select, training + ) + t_xshape = t.reshape(-1, *([1] * (x0.dim() - 1))) + eps_t = -score_target * 2 / (self.sigma(t_xshape) ** 2) + forward_target = ( + x1 - x0 - (self.sigma(t_xshape) * torch.sqrt(t_xshape / (1 - t_xshape + 1e-6))) * eps_t + ) + return x, forward_target, t, mu_t, sigma_t, None + + def step(self, batch: Any, training: bool = False): + """Computes the loss on a batch of data.""" + X = self.unpack_batch(batch) + x0, x1, t_select = self.preprocess_batch(X, training) + # Either randomly sample a single T or sample a batch of T's + if self.hparams.avg_size > 0: + t = torch.rand(1).repeat(X.shape[0]).type_as(X) + else: + t = torch.rand(X.shape[0]).type_as(X) + # Resample the plan if we are using optimal transport + if self.ot_sampler is not None and self.stored_data is None: + x0, x1 = self.ot_sampler.sample_plan(x0, x1) + + x, forward_target, t, _, _, _ = self.calc_loc_and_target(x0, x1, t, t_select, training) + t_xshape = t.reshape(-1, *([1] * (x0.dim() - 1))) + forward_scaling = (1 + self.sigma(t_xshape) ** 2 * t_xshape / (1 - t_xshape + 1e-6)) ** -1 + reg, vt, st = self.forward_flow_and_score(t, x) + forward_flow_loss = torch.mean(forward_scaling * (vt - forward_target) ** 2) + return torch.mean(reg), forward_flow_loss + + def forward_eval_integrate(self, ts, x0, x_rest, outputs, prefix): + # Build a trajectory + t_span = torch.linspace(0, 1, 101).type_as(x0) + regs = [] + trajs = [] + full_trajs = [] + solver = self.partial_solver( + self.net, self.dim, score_field=self.score_net, sigma=self.sigma + ) + nfe = 0 + x0_tmp = x0.clone() + for i in range(ts - 1): + if not self.is_image: + solver.augmentations = self.val_augmentations + traj, aug = solver.sdeint(x0_tmp, t_span + i) + aug = aug[-1] + regs.append(torch.mean(aug, dim=0).detach().cpu().numpy()) + else: + traj = solver.sdeint(x0_tmp, t_span + i) + full_trajs.append(traj) + traj = traj[-1] + x0_tmp = traj + trajs.append(traj) + nfe += solver.nfe + + if not self.is_image: + regs = np.stack(regs).mean(axis=0) + names = [f"{prefix}/{name}" for name in self.val_augmentations.names] + self.log_dict(dict(zip(names, regs)), sync_dist=True) + + # Evaluate the fit + names, dists = compute_distribution_distances(trajs, x_rest) + names = [f"{prefix}/{name}" for name in names] + d = dict(zip(names, dists)) + if self.hparams.leaveout_timepoint >= 0: + to_add = { + f"{prefix}/t_out/{key.split('/')[-1]}": val + for key, val in d.items() + if key.startswith(f"{prefix}/t{self.hparams.leaveout_timepoint}") + } + d.update(to_add) + d[f"{prefix}/nfe"] = nfe + self.log_dict(d, sync_dist=True) + + if hasattr(self.datamodule, "GAUSSIAN_CLOSED_FORM"): + solver.augmentations = None + t_span = torch.linspace(0, 1, 21) # 101 + traj = solver.odeint(x0, t_span) + # t_span = t_span[::5] + # traj = traj[::5] + assert traj.shape[0] == t_span.shape[0] + kls = [ + self.datamodule.KL(xt, self.hparams.sigma_min, t) for t, xt in zip(t_span, traj) + ] + # others = torch.stack([self.datamodule.detailed_evaluation(xt, self.hparams.sigma_min, t) for t, xt in zip(t_span, traj)]) + + self.log_dict({f"{prefix}/kl/mean": torch.stack(kls).mean().item()}, sync_dist=True) + self.log_dict({f"{prefix}/kl/tp_{i}": kls[i] for i in range(21)}, sync_dist=True) + + full_trajs = torch.cat(full_trajs) + return trajs, full_trajs + + +class DSBMLitModule(SF2MLitModule): + """Based on SF2M module except directly regresses against the target SDE drift rather than + separating the ODE and Score components.""" + + def calc_loc_and_target(self, x0, x1, t, t_select, training): + t_xshape = t.reshape(-1, *([1] * (x0.dim() - 1))).clone() + x, ut, t_plus_t_select, mu_t, sigma_t, eps_t = super().calc_loc_and_target( + x0, x1, t, t_select, training + ) + forward_target = ( + x1 - x0 - (self.sigma(t_xshape) * torch.sqrt(t_xshape / (1 - t_xshape + 1e-6))) * eps_t + ) + backward_target = ( + x0 + - x1 + - (self.sigma(t_xshape) * torch.sqrt((1 - t_xshape) / (t_xshape + 1e-6))) * eps_t + ) + return x, forward_target, t_plus_t_select, mu_t, sigma_t, backward_target + + def step(self, batch: Any, training: bool = False): + """Computes the loss on a batch of data.""" + X = self.unpack_batch(batch) + x0, x1, t_select = self.preprocess_batch(X, training) + # Either randomly sample a single T or sample a batch of T's + if self.hparams.avg_size > 0: + t = torch.rand(1).repeat(X.shape[0]).type_as(X) + else: + t = torch.rand(X.shape[0]).type_as(X) + # Resample the plan if we are using optimal transport + if self.ot_sampler is not None and self.stored_data is None: + x0, x1 = self.ot_sampler.sample_plan(x0, x1) + + forward_scaling = (1 + self.sigma(t) ** 2 * t / (1 - t + 1e-6)) ** -1 + backward_scaling = (1 + self.sigma(t) ** 2 * (1 - t) / (t + 1e-6)) ** -1 + x, forward_target, t, _, _, backward_target = self.calc_loc_and_target( + x0, x1, t, t_select, training + ) + # print(forward_target, backward_target, x0, x1, t, t_select) + reg, vt, st = self.forward_flow_and_score(t, x) + forward_flow_loss = torch.mean(forward_scaling[:, None] * (vt - forward_target) ** 2) + backward_flow_loss = torch.mean(backward_scaling[:, None] * (st - backward_target) ** 2) + if not torch.isfinite(forward_flow_loss) or not torch.isfinite(backward_flow_loss): + raise ValueError("Loss Not Finite") + + return torch.mean(reg) + backward_flow_loss, forward_flow_loss + + def forward_eval_integrate(self, ts, x0, x_rest, outputs, prefix): + # Build a trajectory + t_span = torch.linspace(0, 1, 101) + regs = [] + trajs = [] + full_trajs = [] + solver = self.partial_solver( + self.net, self.dim, score_field=self.score_net, sigma=self.sigma + ) + nfe = 0 + x0_tmp = x0.clone() + for i in range(ts - 1): + if not self.is_image: + solver.augmentations = self.val_augmentations + traj, aug = solver.odeint(x0_tmp, t_span + i) + else: + traj = solver.odeint(x0_tmp, t_span + i) + full_trajs.append(traj) + if not self.is_image: + traj, aug = traj[-1], aug[-1] + else: + traj = traj[-1] + aug = torch.tensor(0.0) + x0_tmp = traj + regs.append(torch.mean(aug, dim=0).detach().cpu().numpy()) + trajs.append(traj) + nfe += solver.nfe + + if not self.is_image: + regs = np.stack(regs).mean(axis=0) + names = [f"{prefix}/{name}" for name in self.val_augmentations.names] + self.log_dict(dict(zip(names, regs)), sync_dist=True) + + # Evaluate the fit + names, dists = compute_distribution_distances(trajs, x_rest) + names = [f"{prefix}/{name}" for name in names] + d = dict(zip(names, dists)) + if self.hparams.leaveout_timepoint >= 0: + to_add = { + f"{prefix}/t_out/{key.split('/')[-1]}": val + for key, val in d.items() + if key.startswith(f"{prefix}/t{self.hparams.leaveout_timepoint}") + } + d.update(to_add) + d[f"{prefix}/nfe"] = nfe + self.log_dict(d, sync_dist=True) + + if hasattr(self.datamodule, "GAUSSIAN_CLOSED_FORM"): + solver.augmentations = None + t_span = torch.linspace(0, 1, 21) # 101 + traj = solver.odeint(x0, t_span) + # t_span = t_span[::5] + # traj = traj[::5] + assert traj.shape[0] == t_span.shape[0] + kls = [ + self.datamodule.KL(xt, self.hparams.sigma_min, t) for t, xt in zip(t_span, traj) + ] + # others = torch.stack([self.datamodule.detailed_evaluation(xt, self.hparams.sigma_min, t) for t, xt in zip(t_span, traj)]) + + self.log_dict({f"{prefix}/kl/mean": torch.stack(kls).mean().item()}, sync_dist=True) + self.log_dict({f"{prefix}/kl/tp_{i}": kls[i] for i in range(21)}, sync_dist=True) + + full_trajs = torch.cat(full_trajs) + return trajs, full_trajs + + +class DSBMSharedLitModule(SF2MLitModule): + """Based on SF2M module except directly regresses against the target SDE drift rather than + separating the ODE and Score components.""" + + def step(self, batch: Any, training: bool = False): + """Computes the loss on a batch of data.""" + X = self.unpack_batch(batch) + x0, x1, t_select = self.preprocess_batch(X, training) + # Either randomly sample a single T or sample a batch of T's + if self.hparams.avg_size > 0: + t = torch.rand(1).repeat(X.shape[0]).type_as(X) + else: + t = torch.rand(X.shape[0]).type_as(X) + # Resample the plan if we are using optimal transport + if self.ot_sampler is not None: + x0, x1 = self.ot_sampler.sample_plan(x0, x1) + + x, ut, t, mu_t, sigma_t, score_target = self.calc_loc_and_target( + x0, x1, t, t_select, training + ) + + if self.hparams.avg_size > 0: + x, ut, t = self.average_ut(x, t, mu_t, sigma_t, ut) + aug_x = self.aug_net(t, x, augmented_input=False) + reg, vt = self.augmentations(aug_x) + forward_flow_loss = self.criterion(vt + sigma_t * self.score_net(t, x), ut + score_target) + backward_flow_loss = self.criterion( + -vt + sigma_t * self.score_net(t, x), -ut + score_target + ) + # flow_loss = self.criterion(vt + sigma_t * self.score_net, ut + score_target) + # score_loss = self.criterion(sigma_t * self.score_net(t, x), score_target) + return torch.mean(reg) + backward_flow_loss, forward_flow_loss + + +class FMLitModule(CFMLitModule): + """Implements a Lipman et al. + + 2023 style flow matching loss. This maps the standard normal distribution to the data + distribution by using conditional flows that are the optimal transport flow from a narrow + Gaussian around a datapoint to a standard N(x | 0, 1). + """ + + def calc_mu_sigma(self, x0, x1, t): + assert not self.is_trajectory + del x0 + sigma_min = self.hparams.sigma_min + mu_t = t * x1 + sigma_t = 1 - (1 - sigma_min) * t + return mu_t, sigma_t + + def calc_u(self, x0, x1, x, t, mu_t, sigma_t): + del x0, mu_t, sigma_t + sigma_min = self.hparams.sigma_min + ut = (x1 - (1 - sigma_min) * x) / (1 - (1 - sigma_min) * t) + return ut + + +class SplineCFMLitModule(CFMLitModule): + """Implements cubic spline version of OT-CFM.""" + + def preprocess_batch(self, X, training=False): + from torchcubicspline import NaturalCubicSpline, natural_cubic_spline_coeffs + + """Converts a batch of data into matched a random pair of (x0, x1)""" + lotp = self.hparams.leaveout_timepoint + valid_times = torch.arange(X.shape[1]).type_as(X) + t_select = torch.zeros(1) + batch_size, times, dim = X.shape + # TODO handle leaveout case + if training and self.hparams.leaveout_timepoint > 0: + # Select random except for the leftout timepoint + t_select = torch.randint(times - 2, size=(batch_size,)) + X = torch.cat([X[:, :lotp], X[:, lotp + 1 :]], dim=1) + valid_times = valid_times[valid_times != lotp] + else: + t_select = torch.randint(times - 1, size=(batch_size,)) + traj = torch.from_numpy(self.ot_sampler.sample_trajectory(X)).type_as(X) + x0 = [] + x1 = [] + for i in range(batch_size): + x0.append(traj[i, t_select[i]]) + x1.append(traj[i, t_select[i] + 1]) + x0, x1 = torch.stack(x0), torch.stack(x1) + if training and self.hparams.leaveout_timepoint > 0: + t_select[t_select >= self.hparams.leaveout_timepoint] += 1 + + coeffs = natural_cubic_spline_coeffs(valid_times, traj) + spline = NaturalCubicSpline(coeffs) + return x0, x1, t_select, spline + + def step(self, batch: Any, training: bool = False): + """Computes the loss on a batch of data.""" + assert self.is_trajectory + X = self.unpack_batch(batch) + x0, x1, t_select, spline = self.preprocess_batch(X, training) + + t = torch.rand(X.shape[0], 1) + # t [batch, 1] + # coeffs [batch, times, dims] + # t that network sees is incremented by first timepoint + t = t + t_select[:, None] + ut = torch.stack([spline.derivative(b[0])[i] for i, b in enumerate(t)], dim=0) + mu_t = torch.stack([spline.evaluate(b[0])[i] for i, b in enumerate(t)], dim=0) + sigma_t = self.hparams.sigma_min + + # if we are starting from right before the leaveout_timepoint then we + # divide the target by 2 + if training and self.hparams.leaveout_timepoint > 0: + ut[t_select + 1 == self.hparams.leaveout_timepoint] /= 2 + t[t_select + 1 == self.hparams.leaveout_timepoint] *= 2 + + x = mu_t + sigma_t * torch.randn_like(x0) + aug_x = self.aug_net(t, x, augmented_input=False) + reg, vt = self.augmentations(aug_x) + return torch.mean(reg), self.criterion(vt, ut) + + +class CNFLitModule(CFMLitModule): + def forward_integrate(self, batch: Any, t_span: torch.Tensor): + """Forward pass with integration over t_span intervals. + + (t, x, t_span) -> [x_t_span]. + """ + return super().forward_integrate(batch, t_span + 1) + + def step(self, batch: Any, training: bool = False): + obs = self.unpack_batch(batch) + if not self.is_trajectory: + obs = obs[:, None, :] + even_ts = torch.arange(obs.shape[1]).to(obs) + 1 + self.prior = MultivariateNormal( + torch.zeros(self.dim).type_as(obs), torch.eye(self.dim).type_as(obs) + ) + # Minimize the log likelihood by integrating all back to the initial timepoint + reversed_ts = torch.cat([torch.flip(even_ts, [0]), torch.tensor([0]).type_as(even_ts)]) + + # If only one timepoint then Gaussian is at t0, data t1 + # If multiple timepoints then Gaussian is at t_{-1} data is at times 0 to T + if self.is_trajectory: + reversed_ts -= 1 + losses = [] + regs = [] + for t in range(len(reversed_ts) - 1): + # When leaving out a timepoint simply skip it in the backwards integration + if self.hparams.leaveout_timepoint == t: + continue + ts, x = reversed_ts[t:], obs[:, len(even_ts) - t - 1, :] + # ts, x = self.aug(reversed_ts[t:], obs[:, len(even_ts) - t - 1, :]) + _, x = self.aug_node(x, ts) + x = x[-1] + # Assume log prob is in zero spot + delta_logprob, reg, x = self.augmentations(x) + logprob = self.prior.log_prob(x).to(x) - delta_logprob + losses.append(-torch.mean(logprob)) + # negative because we are integrating backwards + regs.append(-reg) + # Predicted locations + reg = torch.mean(torch.stack(regs)) + loss = torch.mean(torch.stack(losses)) + return reg, loss diff --git a/conditional-flow-matching/runner/src/models/icnn_module.py b/conditional-flow-matching/runner/src/models/icnn_module.py new file mode 100644 index 0000000000000000000000000000000000000000..06bb89e78500e05201f580a8ccb132ceb2c6ed3e --- /dev/null +++ b/conditional-flow-matching/runner/src/models/icnn_module.py @@ -0,0 +1,245 @@ +from typing import Any, List + +import torch +import torch.nn.functional as F +from pytorch_lightning import LightningDataModule, LightningModule +from torch import autograd + +from .components.distribution_distances import compute_distribution_distances +from .utils import get_wandb_logger + + +def to_numpy(tensor): + return tensor.to("cpu").detach().numpy() + + +def plot(x, y, x_pred, y_pred, savename=None, wandb_logger=None): + x = to_numpy(x)[:, 0] + y = to_numpy(y)[:, 0] + x_pred = to_numpy(x_pred)[:, 0] + y_pred = to_numpy(y_pred)[:, 0] + + import matplotlib.pyplot as plt + + plt.scatter(y[:, 0], y[:, 1], color="C1", alpha=0.5, label=r"$Y$") + plt.scatter(x[:, 0], x[:, 1], color="C2", alpha=0.5, label=r"$X$") + plt.scatter(x_pred[:, 0], x_pred[:, 1], color="C3", alpha=0.5, label=r"$\nabla g(Y)$") + plt.scatter(y_pred[:, 0], y_pred[:, 1], color="C4", alpha=0.5, label=r"$\nabla f(X)$") + plt.legend() + if savename: + plt.savefig(savename) + if wandb_logger: + wandb_logger.log_image(key="match", images=[f"{savename}.png"]) + plt.close() + + +class ICNNLitModule(LightningModule): + """Conditional Flow Matching Module for training generative models and models over time.""" + + def __init__( + self, + f_net: Any, + g_net: Any, + optimizer: Any, + datamodule: LightningDataModule, + reg: int = 0.1, + leaveout_timepoint: int = -1, + ) -> None: + """Initialize a conditional flow matching network either as a generative model or for a + sequence of timepoints. + + Args: + net: torch module representing dx/dt = f(t, x) for t in [1, T] missing dimension. + optimizer: partial torch.optimizer missing parameters. + datamodule: datamodule object needs to have "dim", "IS_TRAJECTORY" properties. + ot_sampler: ot_sampler specified as an object or string. If none then no OT is used in minibatch. + sigma_min: sigma_min determines the width of the Gaussian smoothing of the data and interpolations. + leaveout_timepoint: which (if any) timepoint to leave out during the training phase + """ + super().__init__() + self.save_hyperparameters(ignore=["net", "optimizer", "datamodule"], logger=False) + self.is_trajectory = datamodule.IS_TRAJECTORY + self.dim = datamodule.dim + self.f = f_net(dim=datamodule.dim) + self.g = g_net(dim=datamodule.dim) + self.optimizer = optimizer + self.reg = reg + self.criterion = torch.nn.MSELoss() + + def unpack_batch(self, batch): + """Unpacks a batch of data to a single tensor.""" + if self.is_trajectory: + return torch.stack(batch, dim=1) + return batch + + def preprocess_batch(self, X): + """Converts a batch of data into matched a random pair of (x0, x1)""" + t_select = torch.zeros(1) + if self.is_trajectory: + batch_size, times, dim = X.shape + if times > 2: + raise NotImplementedError("ICNN not implemented for times > 2") + t_select = torch.randint(times - 1, size=(batch_size,)) + x0 = [] + x1 = [] + for i in range(batch_size): + x0.append(X[i, t_select[i]]) + x1.append(X[i, t_select[i] + 1]) + x0, x1 = torch.stack(x0), torch.stack(x1) + else: + batch_size, dim = X.shape + # If no trajectory assume generate from standard normal + x0 = torch.randn(batch_size, X.shape[1]) + x1 = X + x0.requires_grad_() + x1.requires_grad_() + return x0, x1, t_select + + def training_step(self, batch: Any, batch_idx: int, optimizer_idx: int): + X = self.unpack_batch(batch) + x, y, t_select = self.preprocess_batch(X) + + if optimizer_idx == 0: + fx = self.f(x) + gy = self.g(y) + grad_gy = torch.autograd.grad(torch.sum(gy), y, retain_graph=True, create_graph=True)[ + 0 + ] + f_grad_gy = self.f(grad_gy) + y_dot_grad_gy = torch.sum(torch.mul(y, grad_gy), axis=1, keepdim=True) + loss = torch.mean(f_grad_gy - y_dot_grad_gy) + if self.reg > 0: + reg = self.reg * torch.sum( + torch.stack([torch.sum(F.relu(-w.weight) ** 2) / 2 for w in self.g.Wzs]) + ) + loss += reg + if optimizer_idx == 1: + fx = self.f(x) + gy = self.g(y) + grad_gy = autograd.grad(torch.sum(gy), y, retain_graph=True, create_graph=True)[0] + f_grad_gy = self.f(grad_gy) + loss = torch.mean(fx - f_grad_gy) + if self.reg > 0: + reg = self.reg * torch.sum( + torch.stack([torch.sum(F.relu(-w.weight) ** 2) / 2 for w in self.f.Wzs]) + ) + loss += reg + + prefix = "train" + self.log_dict( + {f"{prefix}/loss": loss, f"{prefix}/reg": reg}, + on_step=True, + on_epoch=False, + prog_bar=True, + ) + return loss + + def eval_step(self, batch: Any, batch_idx: int, prefix: str): + X = self.unpack_batch(batch) + x, y, t_select = self.preprocess_batch(X) + w2, f_loss, g_loss = compute_w2(self.f, self.g, x, y, return_loss=True) + self.log_dict( + { + f"{prefix}/model_w2": w2, + f"{prefix}/loss": w2, + f"{prefix}/f_loss": f_loss, + f"{prefix}/g_loss": g_loss, + }, + on_step=False, + on_epoch=True, + ) + return { + "loss": w2, + "f_loss": f_loss, + "g_loss": g_loss, + "x": self.unpack_batch(batch), + } + + def eval_epoch_end(self, outputs: List[Any], prefix: str): + def transport(model, x): + return autograd.grad(torch.sum(model(x)), x)[0] + + def y_to_x(y): + return transport(self.g, y) + + def x_to_y(x): + return transport(self.f, x) + + v = {k: torch.cat([d[k] for d in outputs]) for k in ["x"]} + x = v["x"] + wandb_logger = get_wandb_logger(self.loggers) + + if not self.is_trajectory: + # Sample some random points for the plotting function + rand = torch.randn_like(x) + x = torch.stack([rand, x], dim=1) + x.requires_grad_() + x0 = x[:, :1] + x1 = x[:, 1:] + pred = x_to_y(x0) + _, dists = compute_distribution_distances(x0, pred) + w1, w2 = dists[:2] + self.log_dict({f"{prefix}/L2": w1, f"{prefix}/squared_L2": w2}) + + # Evaluate the fit + names, dists = compute_distribution_distances(pred, x[:, 1:]) + names = [f"{prefix}/{name}" for name in names] + self.log_dict(dict(zip(names, dists))) + + x_pred = y_to_x(x1) + plot( + x0, + x1, + x_pred, + pred, + savename=f"{self.current_epoch}_match", + wandb_logger=wandb_logger, + ) + + def validation_step(self, batch: Any, batch_idx: int): + return self.eval_step(batch, batch_idx, "val") + + def validation_epoch_end(self, outputs: List[Any]): + self.eval_epoch_end(outputs, "val") + + def test_step(self, batch: Any, batch_idx: int): + return self.eval_step(batch, batch_idx, "test") + + def test_epoch_end(self, outputs: List[Any]): + self.eval_epoch_end(outputs, "test") + + def configure_optimizers(self): + """Pass model parameters to optimizer.""" + f_opt = self.optimizer(params=self.f.parameters()) + g_opt = self.optimizer(params=self.g.parameters()) + return [ + {"optimizer": g_opt, "frequency": 10}, + {"optimizer": f_opt, "frequency": 1}, + ] + + def on_validation_model_eval(self, *args, **kwargs): + super().on_validation_model_eval(*args, **kwargs) + torch.set_grad_enabled(True) + + def on_test_model_eval(self, *args, **kwargs): + super().on_test_model_eval(*args, **kwargs) + torch.set_grad_enabled(True) + + +def compute_w2(f, g, x, y, return_loss=False): + fx = f(x) + gy = g(y) + grad_gy = autograd.grad(torch.sum(gy), y, retain_graph=True, create_graph=True)[0] + + f_grad_gy = f(grad_gy) + y_dot_grad_gy = torch.sum(torch.multiply(y, grad_gy), axis=1, keepdim=True) + + x_squared = torch.sum(torch.pow(x, 2), axis=1, keepdim=True) + y_squared = torch.sum(torch.pow(y, 2), axis=1, keepdim=True) + + w2 = torch.mean(f_grad_gy - fx - y_dot_grad_gy + 0.5 * x_squared + 0.5 * y_squared) + if not return_loss: + return w2 + g_loss = torch.mean(f_grad_gy - y_dot_grad_gy) + f_loss = torch.mean(fx - f_grad_gy) + return w2, f_loss, g_loss diff --git a/conditional-flow-matching/runner/src/models/runner.py b/conditional-flow-matching/runner/src/models/runner.py new file mode 100644 index 0000000000000000000000000000000000000000..2a0c3fe95e0cc9e01a17cca380c9c4a62029321a --- /dev/null +++ b/conditional-flow-matching/runner/src/models/runner.py @@ -0,0 +1,173 @@ +from typing import Any, List, Optional + +import numpy as np +import torch +from pytorch_lightning import LightningDataModule, LightningModule + +from torchcfm import ConditionalFlowMatcher + +from .components.augmentation import AugmentationModule +from .components.distribution_distances import compute_distribution_distances +from .components.plotting import plot_trajectory, store_trajectories +from .components.solver import FlowSolver +from .utils import get_wandb_logger + + +class CFMLitModule(LightningModule): + def __init__( + self, + net: Any, + optimizer: Any, + datamodule: LightningDataModule, + flow_matcher: ConditionalFlowMatcher, + solver: FlowSolver, + scheduler: Optional[Any] = None, + plot: bool = False, + ) -> None: + super().__init__() + self.save_hyperparameters( + ignore=[ + "net", + "optimizer", + "scheduler", + "datamodule", + "augmentations", + "flow_matcher", + "solver", + ], + logger=False, + ) + self.datamodule = datamodule + self.is_trajectory = False + if hasattr(datamodule, "IS_TRAJECTORY"): + self.is_trajectory = datamodule.IS_TRAJECTORY + # dims is either an integer or a tuple. This helps us to decide whether to process things as + # a vector or as an image. + if hasattr(datamodule, "dim"): + self.dim = datamodule.dim + self.is_image = False + elif hasattr(datamodule, "dims"): + self.dim = datamodule.dims + self.is_image = True + else: + raise NotImplementedError("Datamodule must have either dim or dims") + self.net = net(dim=self.dim) + self.solver = solver + self.optimizer = optimizer + self.flow_matcher = flow_matcher + self.scheduler = scheduler + self.criterion = torch.nn.MSELoss() + self.val_augmentations = AugmentationModule( + # cnf_estimator=None, + l1_reg=1, + l2_reg=1, + squared_l2_reg=1, + ) + + def unpack_batch(self, batch): + """Unpacks a batch of data to a single tensor.""" + if not isinstance(self.dim, int): + # Assume this is an image classification dataset where we need to strip the targets + return batch[0] + return batch + + def preprocess_batch(self, batch, training=False): + """Converts a batch of data into matched a random pair of (x0, x1)""" + X = self.unpack_batch(batch) + # If no trajectory assume generate from standard normal + x0 = torch.randn_like(X) + x1 = X + return x0, x1 + + def step(self, batch: Any, training: bool = False): + """Computes the loss on a batch of data.""" + x0, x1 = self.preprocess_batch(batch, training) + t, xt, ut = self.flow_matcher.sample_location_and_conditional_flow(x0, x1) + vt = self.net(t, xt) + return torch.nn.functional.mse_loss(vt, ut) + + def training_step(self, batch: Any, batch_idx: int): + loss = self.step(batch, training=True) + self.log("train/loss", loss, on_step=True, prog_bar=True) + return loss + + def eval_step(self, batch: Any, batch_idx: int, prefix: str): + loss = self.step(batch, training=True) + self.log(f"{prefix}/loss", loss) + return {"loss": loss, "x": batch} + + def preprocess_epoch_end(self, outputs: List[Any], prefix: str): + """Preprocess the outputs of the epoch end function.""" + v = {k: torch.cat([d[k] for d in outputs]) for k in ["x"]} + x = v["x"] + + # Sample some random points for the plotting function + rand = torch.randn_like(x) + x = torch.stack([rand, x], dim=1) + ts = x.shape[1] + x0 = x[:, 0] + x_rest = x[:, 1:] + return ts, x, x0, x_rest + + def forward_eval_integrate(self, ts, x0, x_rest, outputs, prefix): + # Build a trajectory + t_span = torch.linspace(0, 1, 101) + solver = self.solver(self.net, self.dim) + solver.augmentations = self.val_augmentations + traj, aug = solver.odeint(x0, t_span) + full_trajs = [traj] + traj, aug = traj[-1], aug[-1] + regs = [torch.mean(aug, dim=0).detach().cpu().numpy()] + trajs = [traj] + nfe = solver.nfe + full_trajs = torch.cat(full_trajs) + + regs = np.stack(regs).mean(axis=0) + names = [f"{prefix}/{name}" for name in self.val_augmentations.names] + self.log_dict(dict(zip(names, regs)), sync_dist=True) + + names, dists = compute_distribution_distances(trajs, x_rest) + names = [f"{prefix}/{name}" for name in names] + d = dict(zip(names, dists)) + d[f"{prefix}/nfe"] = nfe + self.log_dict(d, sync_dist=True) + return trajs, full_trajs + + def eval_epoch_end(self, outputs: List[Any], prefix: str): + wandb_logger = get_wandb_logger(self.loggers) + ts, x, x0, x_rest = self.preprocess_epoch_end(outputs, prefix) + trajs, full_trajs = self.forward_eval_integrate(ts, x0, x_rest, outputs, prefix) + + if self.hparams.plot: + plot_trajectory( + x, + full_trajs, + title=f"{self.current_epoch}_ode", + key="ode_path", + wandb_logger=wandb_logger, + ) + store_trajectories(x, self.net) + + def validation_step(self, batch: Any, batch_idx: int): + return self.eval_step(batch, batch_idx, "val") + + def validation_epoch_end(self, outputs: List[Any]): + self.eval_epoch_end(outputs, "val") + + def test_step(self, batch: Any, batch_idx: int): + return self.eval_step(batch, batch_idx, "test") + + def test_epoch_end(self, outputs: List[Any]): + self.eval_epoch_end(outputs, "test") + + def configure_optimizers(self): + """Pass model parameters to optimizer.""" + optimizer = self.optimizer(params=self.parameters()) + if self.scheduler is None: + return optimizer + + scheduler = self.scheduler(optimizer) + return [optimizer], [{"scheduler": scheduler, "interval": "epoch"}] + + def lr_scheduler_step(self, scheduler, optimizer_idx, metric): + scheduler.step(epoch=self.current_epoch) diff --git a/conditional-flow-matching/runner/src/models/utils.py b/conditional-flow-matching/runner/src/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c74336828d2cf9f0fb5bef3a3db91ade1a88c979 --- /dev/null +++ b/conditional-flow-matching/runner/src/models/utils.py @@ -0,0 +1,10 @@ +from pytorch_lightning.loggers import WandbLogger + + +def get_wandb_logger(loggers): + """Gets the wandb logger if it is the list of loggers otherwise returns None.""" + wandb_logger = None + for logger in loggers: + if isinstance(logger, WandbLogger): + wandb_logger = logger + return wandb_logger diff --git a/conditional-flow-matching/runner/src/train.py b/conditional-flow-matching/runner/src/train.py new file mode 100644 index 0000000000000000000000000000000000000000..9464b33e3fb856e637652102a1b0bd240b22c363 --- /dev/null +++ b/conditional-flow-matching/runner/src/train.py @@ -0,0 +1,141 @@ +import pyrootutils + +root = pyrootutils.setup_root( + search_from=__file__, + indicator=[".git", "pyproject.toml", "README.md"], + pythonpath=True, + dotenv=True, +) + +# ------------------------------------------------------------------------------------ # +# `pyrootutils.setup_root(...)` above is optional line to make environment more convenient +# should be placed at the top of each entry file +# +# main advantages: +# - allows you to keep all entry files in "src/" without installing project as a package +# - launching python file works no matter where is your current work dir +# - automatically loads environment variables from ".env" if exists +# +# how it works: +# - `setup_root()` above recursively searches for either ".git" or "pyproject.toml" in present +# and parent dirs, to determine the project root dir +# - adds root dir to the PYTHONPATH (if `pythonpath=True`), so this file can be run from +# any place without installing project as a package +# - sets PROJECT_ROOT environment variable which is used in "configs/paths/default.yaml" +# to make all paths always relative to project root +# - loads environment variables from ".env" in root dir (if `dotenv=True`) +# +# you can remove `pyrootutils.setup_root(...)` if you: +# 1. either install project as a package or move each entry file to the project root dir +# 2. remove PROJECT_ROOT variable from paths in "configs/paths/default.yaml" +# +# https://github.com/ashleve/pyrootutils +# ------------------------------------------------------------------------------------ # + +from typing import List, Optional, Tuple + +import hydra +import pytorch_lightning as pl +from omegaconf import DictConfig +from pytorch_lightning import Callback, LightningDataModule, LightningModule, Trainer +from pytorch_lightning.loggers import LightningLoggerBase + +from src import utils + +log = utils.get_pylogger(__name__) + + +@utils.task_wrapper +def train(cfg: DictConfig) -> Tuple[dict, dict]: + """Trains the model. + + Can additionally evaluate on a testset, using best weights obtained during training. + + This method is wrapped in optional @task_wrapper decorator which applies extra utilities + before and after the call. + + Args: + cfg (DictConfig): Configuration composed by Hydra. + + Returns: + Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects. + """ + # set seed for random number generators in pytorch, numpy and python.random + if cfg.get("seed"): + pl.seed_everything(cfg.seed, workers=True) + + log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>") + datamodule: LightningDataModule = hydra.utils.instantiate(cfg.datamodule) + + log.info(f"Instantiating model <{cfg.model._target_}>") + if hasattr(datamodule, "pass_to_model"): + log.info("Passing full datamodule to model") + model: LightningModule = hydra.utils.instantiate(cfg.model)(datamodule=datamodule) + else: + if hasattr(datamodule, "dim"): + log.info("Passing datamodule.dim to model") + model: LightningModule = hydra.utils.instantiate(cfg.model)(dim=datamodule.dim) + else: + model: LightningModule = hydra.utils.instantiate(cfg.model) + + log.info("Instantiating callbacks...") + callbacks: List[Callback] = utils.instantiate_callbacks(cfg.get("callbacks")) + + log.info("Instantiating loggers...") + logger: List[LightningLoggerBase] = utils.instantiate_loggers(cfg.get("logger")) + + log.info(f"Instantiating trainer <{cfg.trainer._target_}>") + trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger) + + object_dict = { + "cfg": cfg, + "datamodule": datamodule, + "model": model, + "callbacks": callbacks, + "logger": logger, + "trainer": trainer, + } + + if logger: + log.info("Logging hyperparameters!") + utils.log_hyperparameters(object_dict) + + if cfg.get("train"): + log.info("Starting training!") + trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path")) + + train_metrics = trainer.callback_metrics + + if cfg.get("test"): + log.info("Starting testing!") + ckpt_path = trainer.checkpoint_callback.best_model_path + if ckpt_path == "": + log.warning("Best ckpt not found! Using current weights for testing...") + ckpt_path = None + trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) + log.info(f"Best ckpt path: {ckpt_path}") + + test_metrics = trainer.callback_metrics + + # merge train and test metrics + metric_dict = {**train_metrics, **test_metrics} + + return metric_dict, object_dict + + +@hydra.main(version_base="1.2", config_path=root / "configs", config_name="train.yaml") +def main(cfg: DictConfig) -> Optional[float]: + # train the model + metric_dict, _ = train(cfg) + + # safely retrieve metric value for hydra-based hyperparameter optimization + metric_value = utils.get_metric_value( + metric_dict=metric_dict, metric_name=cfg.get("optimized_metric") + ) + + # return optimized metric + return metric_value + + +if __name__ == "__main__": + main() diff --git a/conditional-flow-matching/runner/src/utils/__init__.py b/conditional-flow-matching/runner/src/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f81ee30da909f895d82490069ef8b64f52643807 --- /dev/null +++ b/conditional-flow-matching/runner/src/utils/__init__.py @@ -0,0 +1,12 @@ +from src.utils.pylogger import get_pylogger +from src.utils.rich_utils import enforce_tags, print_config_tree +from src.utils.utils import ( + close_loggers, + extras, + get_metric_value, + instantiate_callbacks, + instantiate_loggers, + log_hyperparameters, + save_file, + task_wrapper, +) diff --git a/conditional-flow-matching/runner/src/utils/pylogger.py b/conditional-flow-matching/runner/src/utils/pylogger.py new file mode 100644 index 0000000000000000000000000000000000000000..142e959e7cc1d607b822a3d1586e51b4994cab8e --- /dev/null +++ b/conditional-flow-matching/runner/src/utils/pylogger.py @@ -0,0 +1,24 @@ +import logging + +from pytorch_lightning.utilities import rank_zero_only + + +def get_pylogger(name=__name__) -> logging.Logger: + """Initializes multi-GPU-friendly python command line logger.""" + logger = logging.getLogger(name) + + # this ensures all logging levels get marked with the rank zero decorator + # otherwise logs would get multiplied for each GPU process in multi-GPU setup + logging_levels = ( + "debug", + "info", + "warning", + "error", + "exception", + "fatal", + "critical", + ) + for level in logging_levels: + setattr(logger, level, rank_zero_only(getattr(logger, level))) + + return logger diff --git a/conditional-flow-matching/runner/src/utils/rich_utils.py b/conditional-flow-matching/runner/src/utils/rich_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..da17b1fa660747ad7c9c0e4bc250262072c0a3c2 --- /dev/null +++ b/conditional-flow-matching/runner/src/utils/rich_utils.py @@ -0,0 +1,103 @@ +from pathlib import Path +from typing import Sequence + +import rich +import rich.syntax +import rich.tree +from hydra.core.hydra_config import HydraConfig +from omegaconf import DictConfig, OmegaConf, open_dict +from pytorch_lightning.utilities import rank_zero_only +from rich.prompt import Prompt + +from src.utils import pylogger + +log = pylogger.get_pylogger(__name__) + + +@rank_zero_only +def print_config_tree( + cfg: DictConfig, + print_order: Sequence[str] = ( + "datamodule", + "model", + "callbacks", + "logger", + "trainer", + "paths", + "extras", + ), + resolve: bool = False, + save_to_file: bool = False, +) -> None: + """Prints content of DictConfig using Rich library and its tree structure. + + Args: + cfg (DictConfig): Configuration composed by Hydra. + print_order (Sequence[str], optional): Determines in what order config components are printed. + resolve (bool, optional): Whether to resolve reference fields of DictConfig. + save_to_file (bool, optional): Whether to export config to the hydra output folder. + """ + style = "dim" + tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) + + queue = [] + + # add fields from `print_order` to queue + for field in print_order: + queue.append(field) if field in cfg else log.warning( + f"Field '{field}' not found in config. Skipping '{field}' config printing..." + ) + + # add all the other fields to queue (not specified in `print_order`) + for field in cfg: + if field not in queue: + queue.append(field) + + # generate config tree from queue + for field in queue: + branch = tree.add(field, style=style, guide_style=style) + + config_group = cfg[field] + if isinstance(config_group, DictConfig): + branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) + else: + branch_content = str(config_group) + + branch.add(rich.syntax.Syntax(branch_content, "yaml")) + + # print config tree + rich.print(tree) + + # save config tree to file + if save_to_file: + with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: + rich.print(tree, file=file) + + +@rank_zero_only +def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: + """Prompts user to input tags from command line if no tags are provided in config.""" + if not cfg.get("tags"): + if "id" in HydraConfig().cfg.hydra.job: + raise ValueError("Specify tags before launching a multirun!") + + log.warning("No tags provided in config. Prompting user to input tags...") + tags = Prompt.ask("Enter a list of comma separated tags", default="dev") + tags = [t.strip() for t in tags.split(",") if t != ""] + + with open_dict(cfg): + cfg.tags = tags + + log.info(f"Tags: {cfg.tags}") + + if save_to_file: + with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: + rich.print(cfg.tags, file=file) + + +if __name__ == "__main__": + from hydra import compose, initialize + + with initialize(version_base="1.2", config_path="../../configs"): + cfg = compose(config_name="train.yaml", return_hydra_config=False, overrides=[]) + print_config_tree(cfg, resolve=False, save_to_file=False) diff --git a/conditional-flow-matching/runner/src/utils/utils.py b/conditional-flow-matching/runner/src/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..97d29c44fd4609777de0d4ca0457b184f96b2e84 --- /dev/null +++ b/conditional-flow-matching/runner/src/utils/utils.py @@ -0,0 +1,201 @@ +import time +import warnings +from importlib.util import find_spec +from pathlib import Path +from typing import Callable, List + +import hydra +from omegaconf import DictConfig +from pytorch_lightning import Callback +from pytorch_lightning.loggers import LightningLoggerBase +from pytorch_lightning.utilities import rank_zero_only + +from src.utils import pylogger, rich_utils + +log = pylogger.get_pylogger(__name__) + + +def task_wrapper(task_func: Callable) -> Callable: + """Optional decorator that wraps the task function in extra utilities. + + Makes multirun more resistant to failure. + + Utilities: + - Calling the `utils.extras()` before the task is started + - Calling the `utils.close_loggers()` after the task is finished + - Logging the exception if occurs + - Logging the task total execution time + - Logging the output dir + """ + + def wrap(cfg: DictConfig): + # apply extra utilities + extras(cfg) + + # execute the task + try: + start_time = time.time() + metric_dict, object_dict = task_func(cfg=cfg) + except Exception as ex: + log.exception("") # save exception to `.log` file + raise ex + finally: + path = Path(cfg.paths.output_dir, "exec_time.log") + content = f"'{cfg.task_name}' execution time: {time.time() - start_time} (s)" + save_file(path, content) # save task execution time (even if exception occurs) + close_loggers() # close loggers (even if exception occurs so multirun won't fail) + + log.info(f"Output dir: {cfg.paths.output_dir}") + + return metric_dict, object_dict + + return wrap + + +def extras(cfg: DictConfig) -> None: + """Applies optional utilities before the task is started. + + Utilities: + - Ignoring python warnings + - Setting tags from command line + - Rich config printing + """ + # return if no `extras` config + if not cfg.get("extras"): + log.warning("Extras config not found! ") + return + + # disable python warnings + if cfg.extras.get("ignore_warnings"): + log.info("Disabling python warnings! ") + warnings.filterwarnings("ignore") + + # prompt user to input tags from command line if none are provided in the config + if cfg.extras.get("enforce_tags"): + log.info("Enforcing tags! ") + rich_utils.enforce_tags(cfg, save_to_file=True) + + # pretty print config tree using Rich library + if cfg.extras.get("print_config"): + log.info("Printing config tree with Rich! ") + rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True) + + +@rank_zero_only +def save_file(path: str, content: str) -> None: + """Save file in rank zero mode (only on one process in multi-GPU setup).""" + with open(path, "w+") as file: + file.write(content) + + +def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: + """Instantiates callbacks from config.""" + callbacks: List[Callback] = [] + + if not callbacks_cfg: + log.warning("Callbacks config is empty.") + return callbacks + + if not isinstance(callbacks_cfg, DictConfig): + raise TypeError("Callbacks config must be a DictConfig!") + + for _, cb_conf in callbacks_cfg.items(): + if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: + log.info(f"Instantiating callback <{cb_conf._target_}>") + callbacks.append(hydra.utils.instantiate(cb_conf)) + + return callbacks + + +def instantiate_loggers(logger_cfg: DictConfig) -> List[LightningLoggerBase]: + """Instantiates loggers from config.""" + logger: List[LightningLoggerBase] = [] + + if not logger_cfg: + log.warning("Logger config is empty.") + return logger + + if not isinstance(logger_cfg, DictConfig): + raise TypeError("Logger config must be a DictConfig!") + + for _, lg_conf in logger_cfg.items(): + if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: + log.info(f"Instantiating logger <{lg_conf._target_}>") + logger.append(hydra.utils.instantiate(lg_conf)) + + return logger + + +@rank_zero_only +def log_hyperparameters(object_dict: dict) -> None: + """Controls which config parts are saved by lightning loggers. + + Additionally saves: + - Number of model parameters + """ + hparams = {} + + cfg = object_dict["cfg"] + model = object_dict["model"] + trainer = object_dict["trainer"] + + if not trainer.loggers: + log.warning("Logger not found! Skipping hyperparameter logging...") + return + + hparams["model"] = cfg["model"] + + # save number of model parameters + hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) + hparams["model/params/trainable"] = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + hparams["model/params/non_trainable"] = sum( + p.numel() for p in model.parameters() if not p.requires_grad + ) + + hparams["datamodule"] = cfg["datamodule"] + hparams["trainer"] = cfg["trainer"] + + hparams["callbacks"] = cfg.get("callbacks") + hparams["extras"] = cfg.get("extras") + + hparams["task_name"] = cfg.get("task_name") + hparams["tags"] = cfg.get("tags") + hparams["ckpt_path"] = cfg.get("ckpt_path") + hparams["seed"] = cfg.get("seed") + + # send hparams to all loggers + for logger in trainer.loggers: + logger.log_hyperparams(hparams) + + +def get_metric_value(metric_dict: dict, metric_name: str) -> float: + """Safely retrieves value of the metric logged in LightningModule.""" + if not metric_name: + log.info("Metric name is None! Skipping metric value retrieval...") + return None + + if metric_name not in metric_dict: + raise Exception( + f"Metric value not found! \n" + "Make sure metric name logged in LightningModule is correct!\n" + "Make sure `optimized_metric` name in `hparams_search` config is correct!" + ) + + metric_value = metric_dict[metric_name].item() + log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") + + return metric_value + + +def close_loggers() -> None: + """Makes sure all loggers closed properly (prevents logging failure during multirun).""" + log.info("Closing loggers...") + + if find_spec("wandb"): # if wandb is installed + import wandb + + if wandb.run: + log.info("Closing wandb!") + wandb.finish() diff --git a/conditional-flow-matching/runner/tests/__init__.py b/conditional-flow-matching/runner/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/conditional-flow-matching/runner/tests/conftest.py b/conditional-flow-matching/runner/tests/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..527554e45a12f811a6b945f1388d604cfe201cd7 --- /dev/null +++ b/conditional-flow-matching/runner/tests/conftest.py @@ -0,0 +1,82 @@ +import pyrootutils +import pytest +from hydra import compose, initialize +from hydra.core.global_hydra import GlobalHydra +from omegaconf import DictConfig, open_dict + + +@pytest.fixture(scope="package") +def cfg_train_global() -> DictConfig: + with initialize(version_base="1.2", config_path="../configs"): + cfg = compose(config_name="train.yaml", return_hydra_config=True, overrides=[]) + + # set defaults for all tests + with open_dict(cfg): + cfg.paths.root_dir = str(pyrootutils.find_root()) + cfg.trainer.max_epochs = 1 + cfg.trainer.limit_train_batches = 0.02 + cfg.trainer.limit_val_batches = 0.2 + cfg.trainer.limit_test_batches = 0.2 + cfg.trainer.accelerator = "cpu" + cfg.trainer.devices = 1 + cfg.datamodule.num_workers = 0 + cfg.datamodule.pin_memory = False + cfg.extras.print_config = False + cfg.extras.enforce_tags = False + cfg.logger = None + cfg.launcher = None + + return cfg + + +@pytest.fixture(scope="package") +def cfg_eval_global() -> DictConfig: + with initialize(version_base="1.2", config_path="../configs"): + cfg = compose(config_name="eval.yaml", return_hydra_config=True, overrides=["ckpt_path=."]) + + # set defaults for all tests + with open_dict(cfg): + cfg.paths.root_dir = str(pyrootutils.find_root()) + cfg.trainer.max_epochs = 1 + cfg.trainer.limit_test_batches = 0.2 + cfg.trainer.accelerator = "cpu" + cfg.trainer.devices = 1 + cfg.datamodule.num_workers = 0 + cfg.datamodule.pin_memory = False + cfg.extras.print_config = False + cfg.extras.enforce_tags = False + cfg.logger = None + + return cfg + + +# this is called by each test which uses `cfg_train` arg +# each test generates its own temporary logging path +@pytest.fixture(scope="function") +def cfg_train(cfg_train_global, tmp_path) -> DictConfig: + cfg = cfg_train_global.copy() + + with open_dict(cfg): + cfg.paths.data_dir = str(tmp_path) + cfg.paths.output_dir = str(tmp_path) + cfg.paths.log_dir = str(tmp_path) + + yield cfg + + GlobalHydra.instance().clear() + + +# this is called by each test which uses `cfg_eval` arg +# each test generates its own temporary logging path +@pytest.fixture(scope="function") +def cfg_eval(cfg_eval_global, tmp_path) -> DictConfig: + cfg = cfg_eval_global.copy() + + with open_dict(cfg): + cfg.paths.data_dir = str(tmp_path) + cfg.paths.output_dir = str(tmp_path) + cfg.paths.log_dir = str(tmp_path) + + yield cfg + + GlobalHydra.instance().clear() diff --git a/conditional-flow-matching/runner/tests/test_configs.py b/conditional-flow-matching/runner/tests/test_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..f2f4c0e463bc837157c212200362bf756054635e --- /dev/null +++ b/conditional-flow-matching/runner/tests/test_configs.py @@ -0,0 +1,29 @@ +import hydra +from hydra.core.hydra_config import HydraConfig +from omegaconf import DictConfig + + +def test_train_config(cfg_train: DictConfig): + assert cfg_train + assert cfg_train.datamodule + assert cfg_train.model + assert cfg_train.trainer + + HydraConfig().set_config(cfg_train) + + hydra.utils.instantiate(cfg_train.datamodule) + hydra.utils.instantiate(cfg_train.model) + hydra.utils.instantiate(cfg_train.trainer) + + +def test_eval_config(cfg_eval: DictConfig): + assert cfg_eval + assert cfg_eval.datamodule + assert cfg_eval.model + assert cfg_eval.trainer + + HydraConfig().set_config(cfg_eval) + + hydra.utils.instantiate(cfg_eval.datamodule) + hydra.utils.instantiate(cfg_eval.model) + hydra.utils.instantiate(cfg_eval.trainer) diff --git a/conditional-flow-matching/runner/tests/test_datamodule.py b/conditional-flow-matching/runner/tests/test_datamodule.py new file mode 100644 index 0000000000000000000000000000000000000000..bc9fc98fd8214b98ff002ce999303deaba599a41 --- /dev/null +++ b/conditional-flow-matching/runner/tests/test_datamodule.py @@ -0,0 +1,64 @@ +import pytest +import torch + +from src.datamodules.distribution_datamodule import ( + SKLearnDataModule, + TorchDynDataModule, + TwoDimDataModule, +) + + +@pytest.mark.parametrize("batch_size", [32, 128]) +@pytest.mark.parametrize("train_val_test_split", [400, [1000, 100, 100]]) +@pytest.mark.parametrize( + "datamodule,system", + [ + (SKLearnDataModule, "scurve"), + (SKLearnDataModule, "moons"), + (TorchDynDataModule, "gaussians"), + ], +) +def test_single_datamodule(batch_size, train_val_test_split, datamodule, system): + dm = datamodule( + batch_size=batch_size, train_val_test_split=train_val_test_split, system=system + ) + + assert dm.data_train is not None and dm.data_val is not None and dm.data_test is not None + assert dm.train_dataloader() and dm.val_dataloader() and dm.test_dataloader() + + num_datapoints = len(dm.data_train) + len(dm.data_val) + len(dm.data_test) + assert num_datapoints == 1200 + + batch = next(iter(dm.train_dataloader())) + x = batch + assert x.dim() == 2 + assert x.shape[0] == batch_size + assert x.shape[-1] == 2 + assert dm.dim == 2 + assert x.dtype == torch.float32 + + +@pytest.mark.parametrize("batch_size", [32, 128]) +@pytest.mark.parametrize("train_val_test_split", [300, [200, 50, 50]]) +@pytest.mark.parametrize( + "datamodule,system", + [ + (TwoDimDataModule, "moon-8gaussians"), + ], +) +def test_trajectory_datamodule(batch_size, train_val_test_split, datamodule, system): + dm = datamodule( + batch_size=batch_size, train_val_test_split=train_val_test_split, system=system + ) + # assert dm.train_dataloader() and dm.val_dataloader() and dm.test_dataloader() + + batch = next(iter(dm.train_dataloader())) + x = batch + assert len(x) == 2 + for t in range(len(dm.timepoint_data)): + xt = x[t] + assert xt.dim() == 2 + assert xt.shape[0] == batch_size + assert xt.shape[-1] == 2 + assert xt.dtype == torch.float32 + assert dm.dim == 2 diff --git a/conditional-flow-matching/runner/tests/test_eval.py b/conditional-flow-matching/runner/tests/test_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..0decfda6cf36014d7b4be71f51f5aac0f90c3b53 --- /dev/null +++ b/conditional-flow-matching/runner/tests/test_eval.py @@ -0,0 +1,31 @@ +import os + +import pytest +from hydra.core.hydra_config import HydraConfig +from omegaconf import open_dict + +from src.eval import evaluate +from src.train import train + + +@pytest.mark.slow +def test_train_eval(tmp_path, cfg_train, cfg_eval): + """Train for 1 epoch with `train.py` and evaluate with `eval.py`""" + assert str(tmp_path) == cfg_train.paths.output_dir == cfg_eval.paths.output_dir + + with open_dict(cfg_train): + cfg_train.trainer.max_epochs = 1 + cfg_train.test = True + + HydraConfig().set_config(cfg_train) + train_metric_dict, _ = train(cfg_train) + + assert "last.ckpt" in os.listdir(tmp_path / "checkpoints") + + with open_dict(cfg_eval): + cfg_eval.ckpt_path = str(tmp_path / "checkpoints" / "last.ckpt") + + HydraConfig().set_config(cfg_eval) + test_metric_dict, _ = evaluate(cfg_eval) + + assert test_metric_dict["test/2-Wasserstein"] > 0.0 diff --git a/conditional-flow-matching/runner/tests/test_sweeps.py b/conditional-flow-matching/runner/tests/test_sweeps.py new file mode 100644 index 0000000000000000000000000000000000000000..ae38fec02eb4f2c76f0457aa2eead985ebb63f40 --- /dev/null +++ b/conditional-flow-matching/runner/tests/test_sweeps.py @@ -0,0 +1,129 @@ +import pytest + +from tests.helpers.run_if import RunIf +from tests.helpers.run_sh_command import run_sh_command + +startfile = "runner/src/train.py" +overrides = ["logger=[]"] +dir_overrides = ["paths.data_dir", "hydra.sweep.dir"] + + +@RunIf(sh=True) +@pytest.mark.slow +@pytest.mark.xfail( + reason="Currently failing experiments with fast_dev_run which messes with gradients" +) +def test_xfail_fast_dev_experiments(tmp_path): + """Test running all available experiment configs with fast_dev_run=True.""" + command = ( + [ + startfile, + "-m", + "experiment=glob(*)", + "++trainer.fast_dev_run=true", + ] + + overrides + + [f"{d}={tmp_path}" for d in dir_overrides] + ) + run_sh_command(command) + + +@RunIf(sh=True) +@pytest.mark.slow +def test_experiments(tmp_path): + """Test running all available experiment configs with fast_dev_run=True.""" + command = ( + [ + startfile, + "-m", + "experiment=cfm", + "model=cfm,otcfm,sbcfm,fm", + "++trainer.fast_dev_run=true", + "++trainer.limit_val_batches=0.25", + ] + + overrides + + [f"{d}={tmp_path}" for d in dir_overrides] + ) + run_sh_command(command) + + +@RunIf(sh=True) +@pytest.mark.slow +def test_hydra_sweep(tmp_path): + """Test default hydra sweep.""" + command = ( + [ + startfile, + "-m", + "hydra.sweep.dir=" + str(tmp_path), + "model.optimizer.lr=0.005,0.01", + "++trainer.fast_dev_run=true", + ] + + overrides + + [f"{d}={tmp_path}" for d in dir_overrides] + ) + + run_sh_command(command) + + +@RunIf(sh=True) +@pytest.mark.slow +@pytest.mark.xfail(reason="DDP is not working yet") +def test_hydra_sweep_ddp_sim(tmp_path): + """Test default hydra sweep with ddp sim.""" + command = ( + [ + startfile, + "-m", + "trainer=ddp_sim", + "trainer.max_epochs=3", + "+trainer.limit_train_batches=0.01", + "+trainer.limit_val_batches=0.1", + "+trainer.limit_test_batches=0.1", + "model.optimizer.lr=0.005,0.01,0.02", + ] + + overrides + + [f"{d}={tmp_path}" for d in dir_overrides] + ) + run_sh_command(command) + + +@RunIf(sh=True) +@pytest.mark.slow +@pytest.mark.skip(reason="Too slow for easy esting, pathway currently not used") +def test_optuna_sweep(tmp_path): + """Test optuna sweep.""" + command = ( + [ + startfile, + "-m", + "hparams_search=optuna", + "hydra.sweep.dir=" + str(tmp_path), + "hydra.sweeper.n_trials=3", + "hydra.sweeper.sampler.n_startup_trials=2", + # "++trainer.fast_dev_run=true", + ] + + overrides + + [f"{d}={tmp_path}" for d in dir_overrides] + ) + run_sh_command(command) + + +@RunIf(wandb=True, sh=True) +@pytest.mark.slow +@pytest.mark.xfail(reason="wandb import is still bad without API key") +def test_optuna_sweep_ddp_sim_wandb(tmp_path): + """Test optuna sweep with wandb and ddp sim.""" + command = [ + startfile, + "-m", + "hparams_search=optuna", + "hydra.sweeper.n_trials=5", + "trainer=ddp_sim", + "trainer.max_epochs=3", + "+trainer.limit_train_batches=0.01", + "+trainer.limit_val_batches=0.1", + "+trainer.limit_test_batches=0.1", + "logger=wandb", + ] + [f"{d}={tmp_path}" for d in dir_overrides] + run_sh_command(command) diff --git a/conditional-flow-matching/runner/tests/test_train.py b/conditional-flow-matching/runner/tests/test_train.py new file mode 100644 index 0000000000000000000000000000000000000000..290dad33f56ef68b927b89148083b6557be477c8 --- /dev/null +++ b/conditional-flow-matching/runner/tests/test_train.py @@ -0,0 +1,88 @@ +import os + +import pytest +from hydra.core.hydra_config import HydraConfig +from omegaconf import open_dict + +from src.train import train +from tests.helpers.run_if import RunIf + + +def test_train_fast_dev_run(cfg_train): + """Run for 1 train, val and test step.""" + HydraConfig().set_config(cfg_train) + with open_dict(cfg_train): + cfg_train.trainer.fast_dev_run = True + cfg_train.trainer.accelerator = "cpu" + train(cfg_train) + + +@RunIf(min_gpus=1) +def test_train_fast_dev_run_gpu(cfg_train): + """Run for 1 train, val and test step on GPU.""" + HydraConfig().set_config(cfg_train) + with open_dict(cfg_train): + cfg_train.trainer.fast_dev_run = True + cfg_train.trainer.accelerator = "gpu" + train(cfg_train) + + +@RunIf(min_gpus=1) +@pytest.mark.slow +def test_train_epoch_gpu_amp(cfg_train): + """Train 1 epoch on GPU with mixed-precision.""" + HydraConfig().set_config(cfg_train) + with open_dict(cfg_train): + cfg_train.trainer.max_epochs = 1 + cfg_train.trainer.accelerator = "cpu" + cfg_train.trainer.precision = 16 + train(cfg_train) + + +@pytest.mark.slow +def test_train_epoch_double_val_loop(cfg_train): + """Train 1 epoch with validation loop twice per epoch.""" + HydraConfig().set_config(cfg_train) + with open_dict(cfg_train): + cfg_train.trainer.max_epochs = 1 + cfg_train.trainer.val_check_interval = 0.5 + train(cfg_train) + + +@pytest.mark.slow +@pytest.mark.xfail(reason="DDP currently failing") +def test_train_ddp_sim(cfg_train): + """Simulate DDP (Distributed Data Parallel) on 2 CPU processes.""" + HydraConfig().set_config(cfg_train) + with open_dict(cfg_train): + cfg_train.trainer.max_epochs = 2 + cfg_train.trainer.accelerator = "cpu" + cfg_train.trainer.devices = 2 + cfg_train.trainer.strategy = "ddp_spawn" + train(cfg_train) + + +@pytest.mark.slow +def test_train_resume(tmp_path, cfg_train): + """Run 1 epoch, finish, and resume for another epoch.""" + with open_dict(cfg_train): + cfg_train.trainer.max_epochs = 1 + cfg_train.callbacks.model_checkpoint.save_top_k = 2 + print(cfg_train) + + HydraConfig().set_config(cfg_train) + metric_dict_1, _ = train(cfg_train) + + files = os.listdir(tmp_path / "checkpoints") + assert "last.ckpt" in files + assert "epoch_0000.ckpt" in files + + with open_dict(cfg_train): + cfg_train.ckpt_path = str(tmp_path / "checkpoints" / "last.ckpt") + cfg_train.trainer.max_epochs = 2 + + metric_dict_2, _ = train(cfg_train) + + files = os.listdir(tmp_path / "checkpoints") + assert "epoch_0001.ckpt" in files + assert "epoch_0002.ckpt" not in files diff --git a/conditional-flow-matching/torchcfm/models/models.py b/conditional-flow-matching/torchcfm/models/models.py new file mode 100644 index 0000000000000000000000000000000000000000..029443e0dce08ea559007d17201f3e0c3212ec99 --- /dev/null +++ b/conditional-flow-matching/torchcfm/models/models.py @@ -0,0 +1,32 @@ +import torch + + +class MLP(torch.nn.Module): + def __init__(self, dim, out_dim=None, w=64, time_varying=False): + super().__init__() + self.time_varying = time_varying + if out_dim is None: + out_dim = dim + self.net = torch.nn.Sequential( + torch.nn.Linear(dim + (1 if time_varying else 0), w), + torch.nn.SELU(), + torch.nn.Linear(w, w), + torch.nn.SELU(), + torch.nn.Linear(w, w), + torch.nn.SELU(), + torch.nn.Linear(w, out_dim), + ) + + def forward(self, x): + return self.net(x) + + +class GradModel(torch.nn.Module): + def __init__(self, action): + super().__init__() + self.action = action + + def forward(self, x): + x = x.requires_grad_(True) + grad = torch.autograd.grad(torch.sum(self.action(x)), x, create_graph=True)[0] + return grad[:, :-1] diff --git a/conditional-flow-matching/torchcfm/models/unet/__init__.py b/conditional-flow-matching/torchcfm/models/unet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..626d2fa68d2a81e289d15ed56a6054b281c1bf19 --- /dev/null +++ b/conditional-flow-matching/torchcfm/models/unet/__init__.py @@ -0,0 +1 @@ +from .unet import UNetModelWrapper as UNetModel diff --git a/conditional-flow-matching/torchcfm/models/unet/fp16_util.py b/conditional-flow-matching/torchcfm/models/unet/fp16_util.py new file mode 100644 index 0000000000000000000000000000000000000000..8c1298682bd519f7eab903d7b1f6f5ca94f28a40 --- /dev/null +++ b/conditional-flow-matching/torchcfm/models/unet/fp16_util.py @@ -0,0 +1,216 @@ +"""Helpers to train with 16-bit precision.""" + +import numpy as np +import torch as th +import torch.nn as nn +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + +from . import logger + +INITIAL_LOG_LOSS_SCALE = 20.0 + + +def convert_module_to_f16(l): + """Convert primitive modules to float16.""" + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + + +def convert_module_to_f32(l): + """Convert primitive modules to float32, undoing convert_module_to_f16().""" + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + l.weight.data = l.weight.data.float() + if l.bias is not None: + l.bias.data = l.bias.data.float() + + +def make_master_params(param_groups_and_shapes): + """Copy model parameters into a (differently-shaped) list of full-precision parameters.""" + master_params = [] + for param_group, shape in param_groups_and_shapes: + master_param = nn.Parameter( + _flatten_dense_tensors([param.detach().float() for (_, param) in param_group]).view( + shape + ) + ) + master_param.requires_grad = True + master_params.append(master_param) + return master_params + + +def model_grads_to_master_grads(param_groups_and_shapes, master_params): + """Copy the gradients from the model parameters into the master parameters from + make_master_params().""" + for master_param, (param_group, shape) in zip(master_params, param_groups_and_shapes): + master_param.grad = _flatten_dense_tensors( + [param_grad_or_zeros(param) for (_, param) in param_group] + ).view(shape) + + +def master_params_to_model_params(param_groups_and_shapes, master_params): + """Copy the master parameter data back into the model parameters.""" + # Without copying to a list, if a generator is passed, this will + # silently not copy any parameters. + for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes): + for (_, param), unflat_master_param in zip( + param_group, unflatten_master_params(param_group, master_param.view(-1)) + ): + param.detach().copy_(unflat_master_param) + + +def unflatten_master_params(param_group, master_param): + return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group]) + + +def get_param_groups_and_shapes(named_model_params): + named_model_params = list(named_model_params) + scalar_vector_named_params = ( + [(n, p) for (n, p) in named_model_params if p.ndim <= 1], + (-1), + ) + matrix_named_params = ( + [(n, p) for (n, p) in named_model_params if p.ndim > 1], + (1, -1), + ) + return [scalar_vector_named_params, matrix_named_params] + + +def master_params_to_state_dict(model, param_groups_and_shapes, master_params, use_fp16): + if use_fp16: + state_dict = model.state_dict() + for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes): + for (name, _), unflat_master_param in zip( + param_group, unflatten_master_params(param_group, master_param.view(-1)) + ): + assert name in state_dict + state_dict[name] = unflat_master_param + else: + state_dict = model.state_dict() + for i, (name, _value) in enumerate(model.named_parameters()): + assert name in state_dict + state_dict[name] = master_params[i] + return state_dict + + +def state_dict_to_master_params(model, state_dict, use_fp16): + if use_fp16: + named_model_params = [(name, state_dict[name]) for name, _ in model.named_parameters()] + param_groups_and_shapes = get_param_groups_and_shapes(named_model_params) + master_params = make_master_params(param_groups_and_shapes) + else: + master_params = [state_dict[name] for name, _ in model.named_parameters()] + return master_params + + +def zero_master_grads(master_params): + for param in master_params: + param.grad = None + + +def zero_grad(model_params): + for param in model_params: + # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group + if param.grad is not None: + param.grad.detach_() + param.grad.zero_() + + +def param_grad_or_zeros(param): + if param.grad is not None: + return param.grad.data.detach() + else: + return th.zeros_like(param) + + +class MixedPrecisionTrainer: + def __init__( + self, + *, + model, + use_fp16=False, + fp16_scale_growth=1e-3, + initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE, + ): + self.model = model + self.use_fp16 = use_fp16 + self.fp16_scale_growth = fp16_scale_growth + + self.model_params = list(self.model.parameters()) + self.master_params = self.model_params + self.param_groups_and_shapes = None + self.lg_loss_scale = initial_lg_loss_scale + + if self.use_fp16: + self.param_groups_and_shapes = get_param_groups_and_shapes( + self.model.named_parameters() + ) + self.master_params = make_master_params(self.param_groups_and_shapes) + self.model.convert_to_fp16() + + def zero_grad(self): + zero_grad(self.model_params) + + def backward(self, loss: th.Tensor): + if self.use_fp16: + loss_scale = 2**self.lg_loss_scale + (loss * loss_scale).backward() + else: + loss.backward() + + def optimize(self, opt: th.optim.Optimizer): + if self.use_fp16: + return self._optimize_fp16(opt) + else: + return self._optimize_normal(opt) + + def _optimize_fp16(self, opt: th.optim.Optimizer): + logger.logkv_mean("lg_loss_scale", self.lg_loss_scale) + model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params) + grad_norm, param_norm = self._compute_norms(grad_scale=2**self.lg_loss_scale) + if check_overflow(grad_norm): + self.lg_loss_scale -= 1 + logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}") + zero_master_grads(self.master_params) + return False + + logger.logkv_mean("grad_norm", grad_norm) + logger.logkv_mean("param_norm", param_norm) + + for p in self.master_params: + p.grad.mul_(1.0 / (2**self.lg_loss_scale)) + opt.step() + zero_master_grads(self.master_params) + master_params_to_model_params(self.param_groups_and_shapes, self.master_params) + self.lg_loss_scale += self.fp16_scale_growth + return True + + def _optimize_normal(self, opt: th.optim.Optimizer): + grad_norm, param_norm = self._compute_norms() + logger.logkv_mean("grad_norm", grad_norm) + logger.logkv_mean("param_norm", param_norm) + opt.step() + return True + + def _compute_norms(self, grad_scale=1.0): + grad_norm = 0.0 + param_norm = 0.0 + for p in self.master_params: + with th.no_grad(): + param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2 + if p.grad is not None: + grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2 + return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm) + + def master_params_to_state_dict(self, master_params): + return master_params_to_state_dict( + self.model, self.param_groups_and_shapes, master_params, self.use_fp16 + ) + + def state_dict_to_master_params(self, state_dict): + return state_dict_to_master_params(self.model, state_dict, self.use_fp16) + + +def check_overflow(value): + return (value == float("inf")) or (value == -float("inf")) or (value != value) diff --git a/conditional-flow-matching/torchcfm/models/unet/logger.py b/conditional-flow-matching/torchcfm/models/unet/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..911dc54aca9e8017e20c144241b072a48cc49b79 --- /dev/null +++ b/conditional-flow-matching/torchcfm/models/unet/logger.py @@ -0,0 +1,468 @@ +"""Logger copied from OpenAI baselines to avoid extra RL-based dependencies: + +https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py +""" + +import datetime +import json +import os +import os.path as osp +import sys +import tempfile +import time +import warnings +from collections import defaultdict +from contextlib import contextmanager + +DEBUG = 10 +INFO = 20 +WARN = 30 +ERROR = 40 + +DISABLED = 50 + + +class KVWriter: + def writekvs(self, kvs): + raise NotImplementedError + + +class SeqWriter: + def writeseq(self, seq): + raise NotImplementedError + + +class HumanOutputFormat(KVWriter, SeqWriter): + def __init__(self, filename_or_file): + if isinstance(filename_or_file, str): + self.file = open(filename_or_file, "w") + self.own_file = True + else: + assert hasattr(filename_or_file, "read"), ( + "expected file or str, got %s" % filename_or_file + ) + self.file = filename_or_file + self.own_file = False + + def writekvs(self, kvs): + # Create strings for printing + key2str = {} + for key, val in sorted(kvs.items()): + if hasattr(val, "__float__"): + valstr = "%-8.3g" % val + else: + valstr = str(val) + key2str[self._truncate(key)] = self._truncate(valstr) + + # Find max widths + if len(key2str) == 0: + print("WARNING: tried to write empty key-value dict") + return + else: + keywidth = max(map(len, key2str.keys())) + valwidth = max(map(len, key2str.values())) + + # Write out the data + dashes = "-" * (keywidth + valwidth + 7) + lines = [dashes] + for key, val in sorted(key2str.items(), key=lambda kv: kv[0].lower()): + lines.append( + "| %s%s | %s%s |" + % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val))) + ) + lines.append(dashes) + self.file.write("\n".join(lines) + "\n") + + # Flush the output to the file + self.file.flush() + + def _truncate(self, s): + maxlen = 30 + return s[: maxlen - 3] + "..." if len(s) > maxlen else s + + def writeseq(self, seq): + seq = list(seq) + for i, elem in enumerate(seq): + self.file.write(elem) + if i < len(seq) - 1: # add space unless this is the last one + self.file.write(" ") + self.file.write("\n") + self.file.flush() + + def close(self): + if self.own_file: + self.file.close() + + +class JSONOutputFormat(KVWriter): + def __init__(self, filename): + self.file = open(filename, "w") + + def writekvs(self, kvs): + for k, v in sorted(kvs.items()): + if hasattr(v, "dtype"): + kvs[k] = float(v) + self.file.write(json.dumps(kvs) + "\n") + self.file.flush() + + def close(self): + self.file.close() + + +class CSVOutputFormat(KVWriter): + def __init__(self, filename): + self.file = open(filename, "w+t") + self.keys = [] + self.sep = "," + + def writekvs(self, kvs): + # Add our current row to the history + extra_keys = list(kvs.keys() - self.keys) + extra_keys.sort() + if extra_keys: + self.keys.extend(extra_keys) + self.file.seek(0) + lines = self.file.readlines() + self.file.seek(0) + for i, k in enumerate(self.keys): + if i > 0: + self.file.write(",") + self.file.write(k) + self.file.write("\n") + for line in lines[1:]: + self.file.write(line[:-1]) + self.file.write(self.sep * len(extra_keys)) + self.file.write("\n") + for i, k in enumerate(self.keys): + if i > 0: + self.file.write(",") + v = kvs.get(k) + if v is not None: + self.file.write(str(v)) + self.file.write("\n") + self.file.flush() + + def close(self): + self.file.close() + + +class TensorBoardOutputFormat(KVWriter): + """Dumps key/value pairs into TensorBoard's numeric format.""" + + def __init__(self, dir): + os.makedirs(dir, exist_ok=True) + self.dir = dir + self.step = 1 + prefix = "events" + path = osp.join(osp.abspath(dir), prefix) + import tensorflow as tf + from tensorflow.core.util import event_pb2 + from tensorflow.python import pywrap_tensorflow + from tensorflow.python.util import compat + + self.tf = tf + self.event_pb2 = event_pb2 + self.pywrap_tensorflow = pywrap_tensorflow + self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path)) + + def writekvs(self, kvs): + def summary_val(k, v): + kwargs = {"tag": k, "simple_value": float(v)} + return self.tf.Summary.Value(**kwargs) + + summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()]) + event = self.event_pb2.Event(wall_time=time.time(), summary=summary) + event.step = self.step # is there any reason why you'd want to specify the step? + self.writer.WriteEvent(event) + self.writer.Flush() + self.step += 1 + + def close(self): + if self.writer: + self.writer.Close() + self.writer = None + + +def make_output_format(format, ev_dir, log_suffix=""): + os.makedirs(ev_dir, exist_ok=True) + if format == "stdout": + return HumanOutputFormat(sys.stdout) + elif format == "log": + return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix)) + elif format == "json": + return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix)) + elif format == "csv": + return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix)) + elif format == "tensorboard": + return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix)) + else: + raise ValueError(f"Unknown format specified: {format}") + + +# ================================================================ +# API +# ================================================================ + + +def logkv(key, val): + """Log a value of some diagnostic Call this once for each diagnostic quantity, each iteration + If called many times, last value will be used.""" + get_current().logkv(key, val) + + +def logkv_mean(key, val): + """The same as logkv(), but if called many times, values averaged.""" + get_current().logkv_mean(key, val) + + +def logkvs(d): + """Log a dictionary of key-value pairs.""" + for k, v in d.items(): + logkv(k, v) + + +def dumpkvs(): + """Write all of the diagnostics from the current iteration.""" + return get_current().dumpkvs() + + +def getkvs(): + return get_current().name2val + + +def log(*args, level=INFO): + """Write the sequence of args, with no separators, to the console and output files (if you've + configured an output file).""" + get_current().log(*args, level=level) + + +def debug(*args): + log(*args, level=DEBUG) + + +def info(*args): + log(*args, level=INFO) + + +def warn(*args): + log(*args, level=WARN) + + +def error(*args): + log(*args, level=ERROR) + + +def set_level(level): + """Set logging threshold on current logger.""" + get_current().set_level(level) + + +def set_comm(comm): + get_current().set_comm(comm) + + +def get_dir(): + """Get directory that log files are being written to. + + will be None if there is no output directory (i.e., if you didn't call start) + """ + return get_current().get_dir() + + +record_tabular = logkv +dump_tabular = dumpkvs + + +@contextmanager +def profile_kv(scopename): + logkey = "wait_" + scopename + tstart = time.time() + try: + yield + finally: + get_current().name2val[logkey] += time.time() - tstart + + +def profile(n): + """ + Usage: + @profile("my_func") + def my_func(): code + """ + + def decorator_with_name(func): + def func_wrapper(*args, **kwargs): + with profile_kv(n): + return func(*args, **kwargs) + + return func_wrapper + + return decorator_with_name + + +# ================================================================ +# Backend +# ================================================================ + + +def get_current(): + if Logger.CURRENT is None: + _configure_default_logger() + + return Logger.CURRENT + + +class Logger: + DEFAULT = None # A logger with no output files. (See right below class definition) + # So that you can still log to the terminal without setting up any output files + CURRENT = None # Current logger being used by the free functions above + + def __init__(self, dir, output_formats, comm=None): + self.name2val = defaultdict(float) # values this iteration + self.name2cnt = defaultdict(int) + self.level = INFO + self.dir = dir + self.output_formats = output_formats + self.comm = comm + + # Logging API, forwarded + # ---------------------------------------- + def logkv(self, key, val): + self.name2val[key] = val + + def logkv_mean(self, key, val): + oldval, cnt = self.name2val[key], self.name2cnt[key] + self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1) + self.name2cnt[key] = cnt + 1 + + def dumpkvs(self): + if self.comm is None: + d = self.name2val + else: + d = mpi_weighted_mean( + self.comm, + {name: (val, self.name2cnt.get(name, 1)) for (name, val) in self.name2val.items()}, + ) + if self.comm.rank != 0: + d["dummy"] = 1 # so we don't get a warning about empty dict + out = d.copy() # Return the dict for unit testing purposes + for fmt in self.output_formats: + if isinstance(fmt, KVWriter): + fmt.writekvs(d) + self.name2val.clear() + self.name2cnt.clear() + return out + + def log(self, *args, level=INFO): + if self.level <= level: + self._do_log(args) + + # Configuration + # ---------------------------------------- + def set_level(self, level): + self.level = level + + def set_comm(self, comm): + self.comm = comm + + def get_dir(self): + return self.dir + + def close(self): + for fmt in self.output_formats: + fmt.close() + + # Misc + # ---------------------------------------- + def _do_log(self, args): + for fmt in self.output_formats: + if isinstance(fmt, SeqWriter): + fmt.writeseq(map(str, args)) + + +def get_rank_without_mpi_import(): + # check environment variables here instead of importing mpi4py + # to avoid calling MPI_Init() when this module is imported + for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]: + if varname in os.environ: + return int(os.environ[varname]) + return 0 + + +def mpi_weighted_mean(comm, local_name2valcount): + """ + Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110 + Perform a weighted average over dicts that are each on a different node + Input: local_name2valcount: dict mapping key -> (value, count) + Returns: key -> mean + """ + all_name2valcount = comm.gather(local_name2valcount) + if comm.rank == 0: + name2sum = defaultdict(float) + name2count = defaultdict(float) + for n2vc in all_name2valcount: + for name, (val, count) in n2vc.items(): + try: + val = float(val) + except ValueError: + if comm.rank == 0: + warnings.warn(f"WARNING: tried to compute mean on non-float {name}={val}") + else: + name2sum[name] += val * count + name2count[name] += count + return {name: name2sum[name] / name2count[name] for name in name2sum} + else: + return {} + + +def configure(dir=None, format_strs=None, comm=None, log_suffix=""): + """If comm is provided, average all numerical stats across that comm.""" + if dir is None: + dir = os.getenv("OPENAI_LOGDIR") + if dir is None: + dir = osp.join( + tempfile.gettempdir(), + datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"), + ) + assert isinstance(dir, str) + dir = os.path.expanduser(dir) + os.makedirs(os.path.expanduser(dir), exist_ok=True) + + rank = get_rank_without_mpi_import() + if rank > 0: + log_suffix = log_suffix + "-rank%03i" % rank + + if format_strs is None: + if rank == 0: + format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",") + else: + format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",") + format_strs = filter(None, format_strs) + output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs] + + Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm) + if output_formats: + log("Logging to %s" % dir) + + +def _configure_default_logger(): + configure() + Logger.DEFAULT = Logger.CURRENT + + +def reset(): + if Logger.CURRENT is not Logger.DEFAULT: + Logger.CURRENT.close() + Logger.CURRENT = Logger.DEFAULT + log("Reset logger") + + +@contextmanager +def scoped_configure(dir=None, format_strs=None, comm=None): + prevlogger = Logger.CURRENT + configure(dir=dir, format_strs=format_strs, comm=comm) + try: + yield + finally: + Logger.CURRENT.close() + Logger.CURRENT = prevlogger diff --git a/conditional-flow-matching/torchcfm/models/unet/nn.py b/conditional-flow-matching/torchcfm/models/unet/nn.py new file mode 100644 index 0000000000000000000000000000000000000000..62d99802ae714b77d1fcf0b91566a82e8ba430b1 --- /dev/null +++ b/conditional-flow-matching/torchcfm/models/unet/nn.py @@ -0,0 +1,153 @@ +"""Various utilities for neural networks.""" + +import math + +import torch as th +import torch.nn as nn + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * th.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +def conv_nd(dims, *args, **kwargs): + """Create a 1D, 2D, or 3D convolution module.""" + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """Create a linear module.""" + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """Create a 1D, 2D, or 3D average pooling module.""" + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def update_ema(target_params, source_params, rate=0.99): + """Update target parameters to be closer to those of source parameters using an exponential + moving average. + + :param target_params: the target parameter sequence. + :param source_params: the source parameter sequence. + :param rate: the EMA rate (closer to 1 means slower). + """ + for targ, src in zip(target_params, source_params): + targ.detach().mul_(rate).add_(src, alpha=1 - rate) + + +def zero_module(module): + """Zero out the parameters of a module and return it.""" + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """Scale the parameters of a module and return it.""" + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """Take the mean over all non-batch dimensions.""" + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """Make a standard normalization layer. + + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +def timestep_embedding(timesteps, dim, max_period=10000): + """Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = th.exp( + -math.log(max_period) + * th.arange(start=0, end=half, dtype=th.float32, device=timesteps.device) + / half + ) + args = timesteps[:, None].float() * freqs[None] + embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) + if dim % 2: + embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +def checkpoint(func, inputs, params, flag): + """Evaluate a function without caching intermediate activations, allowing for reduced memory at + the expense of extra compute in the backward pass. + + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not explicitly take as + arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(th.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + with th.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with th.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = th.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads diff --git a/conditional-flow-matching/torchcfm/models/unet/unet.py b/conditional-flow-matching/torchcfm/models/unet/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..ac0153c8c7dac5b5a42024c59b12aabe5e27fbcc --- /dev/null +++ b/conditional-flow-matching/torchcfm/models/unet/unet.py @@ -0,0 +1,924 @@ +"""From https://raw.githubusercontent.com/openai/guided-diffusion/main/guided_diffusion/unet.py.""" + +import math +from abc import abstractmethod + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +from .fp16_util import convert_module_to_f16, convert_module_to_f32 +from .nn import ( + avg_pool_nd, + checkpoint, + conv_nd, + linear, + normalization, + timestep_embedding, + zero_module, +) + + +class AttentionPool2d(nn.Module): + """Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py.""" + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter( + th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5 + ) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1) # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """Any module where forward() takes timestep embeddings as a second argument.""" + + @abstractmethod + def forward(self, x, emb): + """Apply the module to `x` given `emb` timestep embeddings.""" + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """A sequential module that passes timestep embeddings to the children that support it as an + extra input.""" + + def forward(self, x, emb): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """An upsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then upsampling occurs in the + inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest") + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + """A downsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then downsampling occurs in the + inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=1) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """A residual block that can optionally change the number of channels. + + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial convolution instead of a + smaller 1x1 convolution to change the channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """Apply the block to a Tensor, conditioned on a timestep embedding. + + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint(self._forward, (x, emb), self.parameters(), self.use_checkpoint) + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """An attention block that allows spatial positions to attend to each other. + + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint) + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +def count_flops_attn(model, _x, y): + """A counter for the `thop` package to count the operations in an attention operation. + + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial**2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """A module which performs QKV attention. + + Matches legacy QKVAttention + input/output heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """Apply QKV attention. + + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """A module which performs QKV attention and splits in a different order.""" + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """Apply QKV attention. + + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class UNetModel(nn.Module): + """The full UNet model with attention and timestep embedding. + + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which attention will take + place. May be a set, list, or tuple. For example, if this contains 4, then at 4x + downsampling, attention will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be class-conditional with + `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use a fixed channel width + per attention head. + :param num_heads_upsample: works with num_heads to set a different number of heads for + upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially increased + efficiency. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + + ch = input_ch = int(channel_mult[0] * model_channels) + self.input_blocks = nn.ModuleList( + [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))] + ) + self._feature_size = ch + input_block_chans = [ch] + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=int(mult * model_channels), + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = int(mult * model_channels) + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=int(model_channels * mult), + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = int(model_channels * mult) + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + if level and i == num_res_blocks: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)), + ) + + def convert_to_fp16(self): + """Convert the torso of the model to float16.""" + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """Convert the torso of the model to float32.""" + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward(self, t, x, y=None): + """Apply the model to an input batch. + + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + timesteps = t + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + while timesteps.dim() > 1: + print(timesteps.shape) + timesteps = timesteps[:, 0] + if timesteps.dim() == 0: + timesteps = timesteps.repeat(x.shape[0]) + + hs = [] + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + + if self.num_classes is not None: + assert y.shape == (x.shape[0],) + emb = emb + self.label_emb(y) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb) + hs.append(h) + h = self.middle_block(h, emb) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb) + h = h.type(x.dtype) + return self.out(h) + + +class SuperResModel(UNetModel): + """A UNetModel that performs super-resolution. + + Expects an extra kwarg `low_res` to condition on a low-resolution image. + """ + + def __init__(self, image_size, in_channels, *args, **kwargs): + super().__init__(image_size, in_channels * 2, *args, **kwargs) + + def forward(self, x, timesteps, low_res=None, **kwargs): + _, _, new_height, new_width = x.shape + upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear") + x = th.cat([x, upsampled], dim=1) + return super().forward(x, timesteps, **kwargs) + + +class EncoderUNetModel(nn.Module): + """The half UNet model with attention and timestep embedding. + + For usage, see UNet. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + pool="adaptive", + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + ch = int(channel_mult[0] * model_channels) + self.input_blocks = nn.ModuleList( + [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))] + ) + self._feature_size = ch + input_block_chans = [ch] + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=int(mult * model_channels), + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = int(mult * model_channels) + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + self.pool = pool + if pool == "adaptive": + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + nn.AdaptiveAvgPool2d((1, 1)), + zero_module(conv_nd(dims, ch, out_channels, 1)), + nn.Flatten(), + ) + elif pool == "attention": + assert num_head_channels != -1 + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + AttentionPool2d((image_size // ds), ch, num_head_channels, out_channels), + ) + elif pool == "spatial": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + nn.ReLU(), + nn.Linear(2048, self.out_channels), + ) + elif pool == "spatial_v2": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + normalization(2048), + nn.SiLU(), + nn.Linear(2048, self.out_channels), + ) + else: + raise NotImplementedError(f"Unexpected {pool} pooling") + + def convert_to_fp16(self): + """Convert the torso of the model to float16.""" + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """Convert the torso of the model to float32.""" + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x, timesteps): + """Apply the model to an input batch. + + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :return: an [N x K] Tensor of outputs. + """ + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + + results = [] + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = self.middle_block(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = th.cat(results, axis=-1) + return self.out(h) + else: + h = h.type(x.dtype) + return self.out(h) + + +NUM_CLASSES = 1000 + + +class UNetModelWrapper(UNetModel): + def __init__( + self, + dim, + num_channels, + num_res_blocks, + channel_mult=None, + learn_sigma=False, + class_cond=False, + num_classes=NUM_CLASSES, + use_checkpoint=False, + attention_resolutions="16", + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + dropout=0, + resblock_updown=False, + use_fp16=False, + use_new_attention_order=False, + ): + """Dim (tuple): (C, H, W)""" + image_size = dim[-1] + if channel_mult is None: + if image_size == 512: + channel_mult = (0.5, 1, 1, 2, 2, 4, 4) + elif image_size == 256: + channel_mult = (1, 1, 2, 2, 4, 4) + elif image_size == 128: + channel_mult = (1, 1, 2, 3, 4) + elif image_size == 64: + channel_mult = (1, 2, 3, 4) + elif image_size == 32: + channel_mult = (1, 2, 2, 2) + elif image_size == 28: + channel_mult = (1, 2, 2) + else: + raise ValueError(f"unsupported image size: {image_size}") + else: + channel_mult = list(channel_mult) + + attention_ds = [] + for res in attention_resolutions.split(","): + attention_ds.append(image_size // int(res)) + + return super().__init__( + image_size=image_size, + in_channels=dim[0], + model_channels=num_channels, + out_channels=(dim[0] if not learn_sigma else dim[0] * 2), + num_res_blocks=num_res_blocks, + attention_resolutions=tuple(attention_ds), + dropout=dropout, + channel_mult=channel_mult, + num_classes=(num_classes if class_cond else None), + use_checkpoint=use_checkpoint, + use_fp16=use_fp16, + num_heads=num_heads, + num_head_channels=num_head_channels, + num_heads_upsample=num_heads_upsample, + use_scale_shift_norm=use_scale_shift_norm, + resblock_updown=resblock_updown, + use_new_attention_order=use_new_attention_order, + ) + + def forward(self, t, x, y=None, *args, **kwargs): + return super().forward(t, x, y=y)