carmelog commited on
Commit
830a558
·
0 Parent(s):

init: magnetohydrodynamics with physicsnemo

Browse files
.gitignore ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ checkpoints/
2
+ logs/
3
+ outputs/
4
+ launch.log
5
+
6
+ # Byte-compiled / optimized / DLL files
7
+ **__pycache__/**
8
+ *.py[codz]
9
+ *$py.class
10
+
11
+ # C extensions
12
+ *.so
13
+
14
+ # Distribution / packaging
15
+ .Python
16
+ build/
17
+ develop-eggs/
18
+ dist/
19
+ downloads/
20
+ eggs/
21
+ .eggs/
22
+ lib/
23
+ lib64/
24
+ parts/
25
+ sdist/
26
+ var/
27
+ wheels/
28
+ share/python-wheels/
29
+ *.egg-info/
30
+ .installed.cfg
31
+ *.egg
32
+ MANIFEST
33
+
34
+ # PyInstaller
35
+ # Usually these files are written by a python script from a template
36
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
37
+ *.manifest
38
+ *.spec
39
+
40
+ # Installer logs
41
+ pip-log.txt
42
+ pip-delete-this-directory.txt
43
+
44
+ # Unit test / coverage reports
45
+ htmlcov/
46
+ .tox/
47
+ .nox/
48
+ .coverage
49
+ .coverage.*
50
+ .cache
51
+ nosetests.xml
52
+ coverage.xml
53
+ *.cover
54
+ *.py.cover
55
+ .hypothesis/
56
+ .pytest_cache/
57
+ cover/
58
+
59
+ # Translations
60
+ *.mo
61
+ *.pot
62
+
63
+ # Django stuff:
64
+ *.log
65
+ local_settings.py
66
+ db.sqlite3
67
+ db.sqlite3-journal
68
+
69
+ # Flask stuff:
70
+ instance/
71
+ .webassets-cache
72
+
73
+ # Scrapy stuff:
74
+ .scrapy
75
+
76
+ # Sphinx documentation
77
+ docs/_build/
78
+
79
+ # PyBuilder
80
+ .pybuilder/
81
+ target/
82
+
83
+ # Jupyter Notebook
84
+ .ipynb_checkpoints
85
+
86
+ # IPython
87
+ profile_default/
88
+ ipython_config.py
89
+
90
+ # pyenv
91
+ # For a library or package, you might want to ignore these files since the code is
92
+ # intended to run in multiple environments; otherwise, check them in:
93
+ # .python-version
94
+
95
+ # pipenv
96
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
97
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
98
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
99
+ # install all needed dependencies.
100
+ # Pipfile.lock
101
+
102
+ # UV
103
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
104
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
105
+ # commonly ignored for libraries.
106
+ # uv.lock
107
+
108
+ # poetry
109
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
110
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
111
+ # commonly ignored for libraries.
112
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
113
+ # poetry.lock
114
+ # poetry.toml
115
+
116
+ # pdm
117
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
118
+ # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
119
+ # https://pdm-project.org/en/latest/usage/project/#working-with-version-control
120
+ # pdm.lock
121
+ # pdm.toml
122
+ .pdm-python
123
+ .pdm-build/
124
+
125
+ # pixi
126
+ # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
127
+ # pixi.lock
128
+ # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
129
+ # in the .venv directory. It is recommended not to include this directory in version control.
130
+ .pixi
131
+
132
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
133
+ __pypackages__/
134
+
135
+ # Celery stuff
136
+ celerybeat-schedule
137
+ celerybeat.pid
138
+
139
+ # Redis
140
+ *.rdb
141
+ *.aof
142
+ *.pid
143
+
144
+ # RabbitMQ
145
+ mnesia/
146
+ rabbitmq/
147
+ rabbitmq-data/
148
+
149
+ # ActiveMQ
150
+ activemq-data/
151
+
152
+ # SageMath parsed files
153
+ *.sage.py
154
+
155
+ # Environments
156
+ .env
157
+ .envrc
158
+ .venv
159
+ env/
160
+ venv/
161
+ ENV/
162
+ env.bak/
163
+ venv.bak/
164
+
165
+ # Spyder project settings
166
+ .spyderproject
167
+ .spyproject
168
+
169
+ # Rope project settings
170
+ .ropeproject
171
+
172
+ # mkdocs documentation
173
+ /site
174
+
175
+ # mypy
176
+ .mypy_cache/
177
+ .dmypy.json
178
+ dmypy.json
179
+
180
+ # Pyre type checker
181
+ .pyre/
182
+
183
+ # pytype static type analyzer
184
+ .pytype/
185
+
186
+ # Cython debug symbols
187
+ cython_debug/
188
+
189
+ # PyCharm
190
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
191
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
192
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
193
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
194
+ # .idea/
195
+
196
+ # Abstra
197
+ # Abstra is an AI-powered process automation framework.
198
+ # Ignore directories containing user credentials, local state, and settings.
199
+ # Learn more at https://abstra.io/docs
200
+ .abstra/
201
+
202
+ # Visual Studio Code
203
+ # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
204
+ # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
205
+ # and can be added to the global gitignore or merged into this file. However, if you prefer,
206
+ # you could uncomment the following to ignore the entire vscode folder
207
+ # .vscode/
208
+
209
+ # Ruff stuff:
210
+ .ruff_cache/
211
+
212
+ # PyPI configuration file
213
+ .pypirc
214
+
215
+ # Marimo
216
+ marimo/_static/
217
+ marimo/_lsp/
218
+ __marimo__/
219
+
220
+ # Streamlit
221
+ .streamlit/secrets.toml
.pre-commit-config.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v4.4.0
4
+ hooks:
5
+ - id: trailing-whitespace
6
+ - id: end-of-file-fixer
7
+ - id: check-yaml
8
+ - id: debug-statements
9
+ - repo: https://github.com/astral-sh/ruff-pre-commit
10
+ rev: v0.4.0
11
+ hooks:
12
+ - id: ruff
13
+ args: [ --fix ]
14
+ types_or: [ python, pyi, jupyter ]
15
+ - id: ruff-format
16
+ types_or: [ python, pyi, jupyter ]
Dockerfile ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvcr.io/nvidia/physicsnemo/physicsnemo:25.08
2
+
3
+ ENV DEBIAN_FRONTEND=noninteractive
4
+
5
+ USER root
6
+ # Create non-root user and set up directories
7
+ RUN useradd -m -u 1001 user && \
8
+ mkdir -p /home/user/.cache /home/user/.config /home/user/.local /home/user/.local/share/jupyter && \
9
+ chmod -R 777 /home/user && \
10
+ mkdir /mhd-demo && chown user:user /mhd-demo && chmod 777 /mhd-demo
11
+
12
+ USER user
13
+ ENV HOME=/home/user
14
+ ENV PATH=/home/user/.local/bin:$PATH
15
+ WORKDIR $HOME/app
16
+
17
+
18
+ # Upgrade pip
19
+ RUN python -m pip install --upgrade pip
20
+
21
+ # # Copy all files at once
22
+ COPY --chown=user on_startup.sh README.md start_server.sh requirements.txt ./
23
+ COPY --chown=user login.html /usr/local/lib/python3.12/dist-packages/jupyter_server/templates/login.html
24
+ COPY --chown=user magnetohydrodynamics.ipynb /mhd-demo/
25
+ COPY --chown=user mhd /mhd-demo/mhd/
26
+
27
+
28
+ RUN chmod +x start_server.sh && \
29
+ chmod -R 777 /mhd-demo/ && \
30
+ pip install -r requirements.txt
31
+
32
+ EXPOSE 7860
33
+ CMD ["./start_server.sh"]
Dockerfile.dedalus ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:12.8.1-cudnn-runtime-ubuntu24.04
2
+
3
+ ENV DEBIAN_FRONTEND=noninteractive
4
+
5
+ RUN apt-get update -qq && \
6
+ apt-get autoremove -y -qq && \
7
+ apt-get install -y -qq apt-file \
8
+ vim \
9
+ wget \
10
+ git \
11
+ software-properties-common \
12
+ make \
13
+ g++ \
14
+ gcc \
15
+ gpg-agent && \
16
+ apt-get clean && rm -rf /var/cache/apt/archives /var/lib/apt/lists/*
17
+
18
+ RUN useradd -m -u 1001 user && \
19
+ mkdir -p /home/user/.cache /home/user/.config /home/user/.local && \
20
+ chmod -R 777 /home/user && \
21
+ mkdir /mhd-demo && chown user:user /mhd-demo && chmod 777 /mhd-demo
22
+
23
+ USER user
24
+ ENV HOME=/home/user
25
+ ENV PATH=/home/user/.local/bin:$PATH
26
+ WORKDIR $HOME/app
27
+
28
+ ENV CONDA_DIR=$HOME/conda
29
+ ENV PATH=$CONDA_DIR/bin:$PATH
30
+ RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh && \
31
+ /bin/bash ~/miniconda.sh -b -p $HOME/conda && \
32
+ rm ~/miniconda.sh && \
33
+ conda config --add channels conda-forge && \
34
+ conda config --set channel_priority strict && \
35
+ conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main && \
36
+ conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r && \
37
+ conda create -n env python=3.12 -y --quiet && \
38
+ conda run -n env conda env config vars set OMP_NUM_THREADS=1 && \
39
+ conda run -n env conda env config vars set NUMEXPR_MAX_THREADS=1 && \
40
+ conda run -n env conda install -c conda-forge dedalus jupyter jupyterlab torch hydra-core imageio -y --quiet
41
+
42
+ ENV PATH=$HOME/conda/envs/env/bin:$PATH
43
+
44
+ # # Copy all files at once
45
+ COPY --chown=user on_startup.sh README.md start_server.sh requirements.txt ./
46
+ COPY --chown=user magnetohydrodynamics.ipynb mhd /mhd-demo/
47
+
48
+ RUN chmod +x start_server.sh && \
49
+ chmod -R 777 /mhd-demo/
50
+
51
+ EXPOSE 7860
52
+ CMD ["./start_server.sh"]
README.md ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Modeling Magnetohydrodynamics with PhysicsNeMo
3
+ emoji: 🟢
4
+ colorFrom: green
5
+ colorTo: gray
6
+ sdk: docker
7
+ app_port: 7860
8
+ pinned: false
9
+ tags:
10
+ - physics
11
+ - cfd
12
+ - machine-learning
13
+ - neural-operators
14
+ - magnetohydrodynamics
15
+ - scientific-computing
16
+ ---
17
+
login.html ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {% extends "page.html" %}
2
+
3
+
4
+ {% block stylesheet %}
5
+ {% endblock %}
6
+
7
+ {% block site %}
8
+
9
+ <div id="jupyter-main-app" class="container">
10
+
11
+ <img src="https://huggingface.co/front/assets/huggingface_logo-noborder.svg" alt="Hugging Face Logo">
12
+ <h4>Welcome to JupyterLab</h4>
13
+
14
+ <h5>The default token is <span style="color:orange;">huggingface</span></h5>
15
+
16
+ {% if login_available %}
17
+ {# login_available means password-login is allowed. Show the form. #}
18
+ <div class="row">
19
+ <div class="navbar col-sm-8">
20
+ <div class="navbar-inner">
21
+ <div class="container">
22
+ <div class="center-nav">
23
+ <form action="{{base_url}}login?next={{next}}" method="post" class="navbar-form pull-left">
24
+ {{ xsrf_form_html() | safe }}
25
+ {% if token_available %}
26
+ <label for="password_input"><strong>{% trans %}Jupyter token <span title="This is the secret you set up when deploying your JupyterLab space">ⓘ</span> {% endtrans
27
+ %}</strong></label>
28
+ {% else %}
29
+ <label for="password_input"><strong>{% trans %}Jupyter password:{% endtrans %}</strong></label>
30
+ {% endif %}
31
+ <input type="password" name="password" id="password_input" class="form-control">
32
+ <button type="submit" class="btn btn-default" id="login_submit">{% trans %}Log in{% endtrans
33
+ %}</button>
34
+ </form>
35
+ </div>
36
+ </div>
37
+ </div>
38
+ </div>
39
+ </div>
40
+ {% else %}
41
+ <p>{% trans %}No login available, you shouldn't be seeing this page.{% endtrans %}</p>
42
+ {% endif %}
43
+
44
+ <h5>If you don't have the credentials for this Jupyter space, <a target="_blank" href="https://huggingface.co/spaces/SpacesExamples/jupyterlab?duplicate=true">create your own.</a></h5>
45
+ <br>
46
+
47
+ <p>This template was created by <a href="https://twitter.com/camenduru" target="_blank" >camenduru</a> and <a href="https://huggingface.co/nateraw" target="_blank" >nateraw</a>, with contributions of <a href="https://huggingface.co/osanseviero" target="_blank" >osanseviero</a> and <a href="https://huggingface.co/azzr" target="_blank" >azzr</a> </p>
48
+ {% if message %}
49
+ <div class="row">
50
+ {% for key in message %}
51
+ <div class="message {{key}}">
52
+ {{message[key]}}
53
+ </div>
54
+ {% endfor %}
55
+ </div>
56
+ {% endif %}
57
+ {% if token_available %}
58
+ {% block token_message %}
59
+
60
+ {% endblock token_message %}
61
+ {% endif %}
62
+ </div>
63
+
64
+ {% endblock %}
65
+
66
+
67
+ {% block script %}
68
+ {% endblock %}
magnetohydrodynamics.ipynb ADDED
@@ -0,0 +1,1592 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "3902fdc0",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Modeling Magnetohydrodynamics with Physics Informed Neural Operators"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "markdown",
13
+ "id": "982a76df",
14
+ "metadata": {},
15
+ "source": [
16
+ "In this notebook, we will study the application of physics informed data-driven modeling to the incompressible magnetohydrodynamics (MHD) equations representing an incompressible fluid in the presence of a magnetic field $\\mathbf{B}$. Our model will be built using a Tensor Factorized Fourier Neural Operator (tFNO), and trained in conjunction with the PDEs representing our system. The model is physics-informed during training by encoding known information about the physical system into the loss functions, enabling generalization of the resulting model to a variety of settings in the solution space. Specifically, the PDEs and initial conditions are used as soft constraints learned by the neural network as its trains. Models covering different data regimes governed by the Reynolds number are trained using transfer learning to showcase how our model may be applied to both laminar and turbulent flows. The AI-accelerated surrogate model is compared to classical simulations to compare its throughput and accuracy.\n",
17
+ "\n",
18
+ "Note that while the majority of the code needed to run this example is provided in the notebook, the lower barrier to entry for training and evaluating models will be to run the scripts in the source directory, and the material referenced here should be used as a base for learning the underlying components leading to model training and evaluation. "
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "markdown",
23
+ "id": "19f35947",
24
+ "metadata": {},
25
+ "source": [
26
+ "#### Learning Outcomes\n",
27
+ "* How to apply physics constraints to neural networks\n",
28
+ "* Learn how the Tensor Factorized Fourier Neural Operator can be applied to physics based problems\n",
29
+ "* Learn how to define PDEs with PhysicsNeMo\n",
30
+ "* Train PINOs with PhysicsNeMo Core\n",
31
+ "* Learn how data driven modeling can help build computationally efficient surrogates for physics problems"
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "markdown",
36
+ "id": "03be824a",
37
+ "metadata": {},
38
+ "source": [
39
+ "## Pre-Requisites\n",
40
+ "This workshop is derived primarily from the informative paper [Magnetohydrodynamics with physics informed neural operators\n",
41
+ "](https://iopscience.iop.org/article/10.1088/2632-2153/ace30a)[1]. Reading the paper will provide both context and an overview of what will be presented in this workshop. Additionally, the paper serves as a great reference if more details are needed on any specific section. It is encouraged to read through the paper before continuing.\n",
42
+ "\n",
43
+ "[1] Rosofsky, S. G., & Huerta, E. A. (2023). Magnetohydrodynamics with physics informed neural operators. Machine Learning: Science and Technology, 4(3), 035002."
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "markdown",
48
+ "id": "e5a666a2",
49
+ "metadata": {},
50
+ "source": [
51
+ "## Problem Overview\n",
52
+ "\n",
53
+ "To examine the properties of PINOs with multiple complex equations, we examined the ability of the networks to reproduce the incompressible magnetohydrodynamics (MHD) equations representing an incompressible fluid in the presence of a magnetic field $\\mathbf{B}$. These equations are present in several astrophysical phenomena, including black hole accretion and binary neutron star mergers. Additionally, MDH has applications to nuclear power engineering, and plasma modeling. \n",
54
+ "\n",
55
+ "These equations for incompressible MHD are given by:\n",
56
+ "\n",
57
+ "$$\\begin{align*}\n",
58
+ "\\partial_t \\mathbf{u}+\\mathbf{u} \\cdot \\nabla \\mathbf{u} &=\n",
59
+ "-\\nabla \\left( p+\\frac{B^2}{2} \\right)/\\rho_0 +\\mathbf{B}\n",
60
+ "\\cdot \\nabla \\mathbf{B}+\\nu \\nabla^2 \\mathbf{u}, \\\\\n",
61
+ "\\partial_t \\mathbf{B}+\\mathbf{u} \\cdot \\nabla \\mathbf{B} &=\n",
62
+ "\\mathbf{B} \\cdot \\nabla \\mathbf{u}+\\eta \\nabla^2 \\mathbf{B}, \\\\\n",
63
+ "\\nabla \\cdot \\mathbf{u} &= 0, \\\\\n",
64
+ "\\nabla \\cdot \\mathbf{B} &= 0,\n",
65
+ "\\end{align*}$$\n",
66
+ " \n",
67
+ "where $\\mathbf{u}$ is the velocity field, $p$ is the pressure, $B^2$ is the magnitude of the magnetic field, $\\rho_0=1$ is the density of the fluid, $\\nu$ is the kinetic viscosity, and $\\eta$ is the magnetic resistivity. We have two equations for evolution and two constraint equations.\n",
68
+ "\n",
69
+ "\n",
70
+ "For the magnetic field divergence equation, we can either include it in the loss function or instead evolve the magnetic vector potential $\\mathbf{A}$. This quantity is defined such that\n",
71
+ "\n",
72
+ "$$\\begin{align*}\n",
73
+ "\\mathbf{B} = \\nabla \\times \\mathbf{A},\n",
74
+ "\\end{align*}$$\n",
75
+ "\n",
76
+ "which ensures that the divergence of $\\mathbf{B}$ is zero. By evolving magnetic vector potential $\\mathbf{A}$ instead of the magnetic field $\\mathbf{B}$, we have a new evolution equation for the vector potential $\\mathbf{A}$. This equation is given by \n",
77
+ "\n",
78
+ "$$\\begin{align*}\n",
79
+ "\\partial_t \\mathbf{A} + \\mathbf{u} \\cdot \\nabla \\mathbf{A}=\\eta \\nabla^2 \\mathbf{A}.\n",
80
+ "\\end{align*}$$\n",
81
+ "\n",
82
+ "In practice, using the magnetic vector potential representation leads to better model performance. "
83
+ ]
84
+ },
85
+ {
86
+ "cell_type": "markdown",
87
+ "id": "5331b9b9",
88
+ "metadata": {},
89
+ "source": [
90
+ "## Data Creation\n",
91
+ "Note that in this HuggingFace Space, the data are available at `/data/mhd_data`. There are:\n",
92
+ "1000 samples for Re=100, 100 samples for Re=250 and 100 samples for Re=1,000.\n",
93
+ "\n",
94
+ "To train our model, a representative dataset is first created that gives enough coverage of the solution space to train a surrogate model to make predictions on new data points. To obtain interesting results without additional computational difficulty, we will solve the equations in 2D with periodic boundary conditions. This results in solving a total of 3 evolution PDEs at each time step. Two for the velocity evolution, and one for the magnetic vector potential. \n",
95
+ "\n",
96
+ "The solution space to this problem can be obtained numerically by solving the PDEs from above with a numerical solver such as `dedalus`. To generate this data, `dedalus` is used to simulate a 2D periodic incompressible MHD flow with a passive tracer field for visualization. The initial flow is in the $x$-direction and depends only on $z$. The problem is non-dimensionalized using the shear-layer spacing and velocity jump, so the resulting viscosity and tracer diffusivity are related to the Reynolds and\n",
97
+ "Schmidt numbers as:\n",
98
+ "\n",
99
+ "$$\\begin{align}\n",
100
+ "\\nu &= \\frac{1}{\\text{Re}} \\\\\n",
101
+ "\\eta &= \\frac{1}{\\text{Re}_M} \\\\\n",
102
+ "D &= \\frac{\\nu}{\\text{Sc}}\n",
103
+ "\\end{align}$$\n",
104
+ "\n",
105
+ "The initial data field for running the simulation is produced using the Gaussian Random Field method in which the radial basis function kernel (RBF) is transformed into Fourier space to obey the desired periodic boundary conditions. Finally, two initial data fields the vorticity potential and magnetic potential are used to guarantee initial velocity and magnetic fields are divergence free. \n",
106
+ "\n",
107
+ "The dataset is produced by running 1,000 simulations with different initial conditions, and evolving the system for 1,000 time steps. The time step used is $\\Delta t=0.001s$, however output data is saved at an interval of $t=0.01$ for a total time of $1$ second, resulting in 101 samples per simulation. \n",
108
+ "\n",
109
+ "Scripts to generate the dataset are in the `generate_mhd_data` folder. Make sure to source this environment with `source activate env` to make use of the environment. To generate the dataset, run the command: \n",
110
+ "```bash\n",
111
+ "python dedalus_mhd_parallel.py\n",
112
+ "```\n",
113
+ "Note that depending on system resources, this process may take up to a few hours to complete. Once data generation is finished, we can exit the env with `conda deactivate`. "
114
+ ]
115
+ },
116
+ {
117
+ "cell_type": "markdown",
118
+ "id": "69a099f7",
119
+ "metadata": {},
120
+ "source": [
121
+ "## Defining our Constraints - Setting up the PDE\n",
122
+ "\n",
123
+ "Constraints are used to define the objectives for training our model. They house a set of nodes from which a computational graph is build for execution as well as the loss function. [PhysicsNeMo Sim](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-sym/index.html) provides utilities tailored for physics-informed machine learning, and uses abstracted APIs that allow users to think and model the problem from the lens of equations, constraints, etc. In this example, we will only leverage the physics-informed utilities to see how we can add physics to an existing data-driven model with ease while still maintaining the flexibility to define our own training loop and other details. The types of constraints used will be problem dependent. For this example, we can define the following constraints: \n",
124
+ "\n",
125
+ "**Data Loss**: Obtain simulation data and compare it to the PINO output.\n",
126
+ "\n",
127
+ "**PDE Loss**: Use the known PDEs of the system to describe violations of the time evolution of our system\n",
128
+ "\n",
129
+ "**Constraint Loss**: This loss describes constraints from the PDE. Specifically, the velocity divergence free condition and magnetic divergence free condition.\n",
130
+ "\n",
131
+ "**Initial Condition Loss**: Input field compared to the output at $t=0$\n",
132
+ "\n",
133
+ "**Boundary Condition Loss**: Difference in boundary terms. In our case, we have a periodic boundary constraint.\n",
134
+ "\n",
135
+ "\n",
136
+ "\n",
137
+ "To begin setting up our constraints, we can start by defining the MHD equations using the `PDE` class from `physicsnemo.sym.eq.pde`. The process of converting our PDEs into a form that is compatible with `PhysicsNeMo` involves defining a class to hold our equations, called `MHD_PDE`, and including each term of the equations. Each variable of the equations is set up as a `Sympy` `Function`, which is then used to create an attribute of our `MHD_PDE` class that holds the final `equations`.\n",
138
+ "\n",
139
+ "Because we have elected to solve the equations in two dimensions, we only have the input variables $x$, $y$, $t$ and and the Laplacian operator. \n",
140
+ "\n",
141
+ "In PhysicsNeMo, it is preferable to represent our equations by isolating our target terms on the left, and moving the rest of the equation to the right-hand-side. To do this, various components of each equation are compartmentalized, and the final set of equations is composed from these parts.\n",
142
+ "\n",
143
+ "```python\n",
144
+ "from physicsnemo.sym.eq.pde import PDE\n",
145
+ "from sympy import Symbol, Function, Number\n",
146
+ "\n",
147
+ "\n",
148
+ "class MHD_PDE(PDE):\n",
149
+ " \"\"\"MHD PDEs using PhysicsNeMo Sym\"\"\"\n",
150
+ "\n",
151
+ " name = \"MHD_PDE\"\n",
152
+ "\n",
153
+ " def __init__(self, nu=1e-4, eta=1e-4, rho0=1.0):\n",
154
+ "\n",
155
+ " # x, y, time\n",
156
+ " x, y, t, lap = Symbol(\"x\"), Symbol(\"y\"), Symbol(\"t\"), Symbol(\"lap\")\n",
157
+ "\n",
158
+ " # make input variables\n",
159
+ " input_variables = {\"x\": x, \"y\": y, \"t\": t, \"lap\": lap}\n",
160
+ "\n",
161
+ " # make functions\n",
162
+ " u = Function(\"u\")(*input_variables)\n",
163
+ " v = Function(\"v\")(*input_variables)\n",
164
+ " Bx = Function(\"Bx\")(*input_variables)\n",
165
+ " By = Function(\"By\")(*input_variables)\n",
166
+ " A = Function(\"A\")(*input_variables)\n",
167
+ " # pressure\n",
168
+ " ptot = Function(\"ptot\")(*input_variables)\n",
169
+ "\n",
170
+ " u_rhs = Function(\"u_rhs\")(*input_variables)\n",
171
+ " v_rhs = Function(\"v_rhs\")(*input_variables)\n",
172
+ " Bx_rhs = Function(\"Bx_rhs\")(*input_variables)\n",
173
+ " By_rhs = Function(\"By_rhs\")(*input_variables)\n",
174
+ " A_rhs = Function(\"A_rhs\")(*input_variables)\n",
175
+ "\n",
176
+ " # initialize constants\n",
177
+ " nu = Number(nu)\n",
178
+ " eta = Number(eta)\n",
179
+ " rho0 = Number(rho0)\n",
180
+ "\n",
181
+ " # set equations\n",
182
+ " self.equations = {}\n",
183
+ "\n",
184
+ " # u · ∇u\n",
185
+ " self.equations[\"vel_grad_u\"] = u * u.diff(x) + v * u.diff(y)\n",
186
+ " self.equations[\"vel_grad_v\"] = u * v.diff(x) + v * v.diff(y)\n",
187
+ " # B · ∇u\n",
188
+ " self.equations[\"B_grad_u\"] = Bx * u.diff(x) + v * Bx.diff(y)\n",
189
+ " self.equations[\"B_grad_v\"] = Bx * v.diff(x) + By * v.diff(y)\n",
190
+ " # u · ∇B\n",
191
+ " self.equations[\"vel_grad_Bx\"] = u * Bx.diff(x) + v * Bx.diff(y)\n",
192
+ " self.equations[\"vel_grad_By\"] = u * By.diff(x) + v * By.diff(y)\n",
193
+ " # B · ∇B\n",
194
+ " self.equations[\"B_grad_Bx\"] = Bx * Bx.diff(x) + By * Bx.diff(y)\n",
195
+ " self.equations[\"B_grad_By\"] = Bx * By.diff(x) + By * By.diff(y)\n",
196
+ " # ∇ × (u × B) = u(∇ · B) - B(∇ · u) + B · ∇u − u · ∇B\n",
197
+ " self.equations[\"uBy_x\"] = u * By.diff(x) + By * u.diff(x)\n",
198
+ " self.equations[\"uBy_y\"] = u * By.diff(y) + By * u.diff(y)\n",
199
+ " self.equations[\"vBx_x\"] = v * Bx.diff(x) + Bx * v.diff(x)\n",
200
+ " self.equations[\"vBx_y\"] = v * Bx.diff(y) + Bx * v.diff(y)\n",
201
+ " # ∇ · B \n",
202
+ " self.equations[\"div_B\"] = Bx.diff(x) + By.diff(y)\n",
203
+ " # ∇ · u \n",
204
+ " self.equations[\"div_vel\"] = u.diff(x) + v.diff(y)\n",
205
+ "\n",
206
+ " # RHS of MHD equations\n",
207
+ " # = u · ∇u - p/rho + B · ∇B + ν * ∇^2(u)\n",
208
+ " self.equations[\"u_rhs\"] = (\n",
209
+ " -self.equations[\"vel_grad_u\"]\n",
210
+ " - ptot.diff(x) / rho0\n",
211
+ " + self.equations[\"B_grad_Bx\"] / rho0\n",
212
+ " + nu * u.diff(lap)\n",
213
+ " )\n",
214
+ " self.equations[\"v_rhs\"] = (\n",
215
+ " -self.equations[\"vel_grad_v\"]\n",
216
+ " - ptot.diff(y) / rho0\n",
217
+ " + self.equations[\"B_grad_By\"] / rho0\n",
218
+ " + nu * v.diff(lap)\n",
219
+ " )\n",
220
+ " # Uses identity above\n",
221
+ " # = ∇ × (u × B) + η * ∇^2(B)\n",
222
+ " self.equations[\"Bx_rhs\"] = (\n",
223
+ " self.equations[\"uBy_y\"] - self.equations[\"vBx_y\"] + eta * Bx.diff(lap)\n",
224
+ " )\n",
225
+ " self.equations[\"By_rhs\"] = -(\n",
226
+ " self.equations[\"uBy_x\"] - self.equations[\"vBx_x\"]\n",
227
+ " ) + eta * By.diff(lap)\n",
228
+ "\n",
229
+ " # Final equations move all terms to RHS\n",
230
+ " # Node 18, 19, 20, 21\n",
231
+ " self.equations[\"Du\"] = u.diff(t) - u_rhs\n",
232
+ " self.equations[\"Dv\"] = v.diff(t) - v_rhs\n",
233
+ " self.equations[\"DBx\"] = Bx.diff(t) - Bx_rhs\n",
234
+ " self.equations[\"DBy\"] = By.diff(t) - By_rhs\n",
235
+ "\n",
236
+ " # Vec potential equations\n",
237
+ " # Node 22, 23, 24\n",
238
+ " self.equations[\"vel_grad_A\"] = u * A.diff(x) + v * A.diff(y)\n",
239
+ " self.equations[\"A_rhs\"] = -self.equations[\"vel_grad_A\"] + eta * A.diff(lap)\n",
240
+ " self.equations[\"DA\"] = A.diff(t) - A_rhs\n",
241
+ "```\n",
242
+ "\n",
243
+ "Our model's output can then be used to compute the loss between prediction and true values, and for computing loss based on initial conditions, PDEs, and simulation data.\n"
244
+ ]
245
+ },
246
+ {
247
+ "cell_type": "markdown",
248
+ "id": "677d49b4",
249
+ "metadata": {
250
+ "vscode": {
251
+ "languageId": "plaintext"
252
+ }
253
+ },
254
+ "source": [
255
+ "## Defining our Constraints - Loss Functions \n",
256
+ "\n",
257
+ "Now that we have defined our PDE, we can define all of the constraints that make up the loss function for our problem. The loss functions are defined inside of a class called `LossMHD_PhysicsNeMo`, which can use a weighted sum of individual losses for training. Additionally, all of the fixed and constant parameters needed are added to the class definition.\n",
258
+ "\n",
259
+ "```python\n",
260
+ "import torch\n",
261
+ "import torch.nn.functional as F\n",
262
+ "from physicsnemo.models.layers.spectral_layers import fourier_derivatives\n",
263
+ "\n",
264
+ "from .losses import (LpLoss, fourier_derivatives_lap, fourier_derivatives_ptot,\n",
265
+ " fourier_derivatives_vec_pot)\n",
266
+ "from .mhd_pde import MHD_PDE\n",
267
+ "\n",
268
+ "\n",
269
+ "class LossMHDVecPot_PhysicsNeMo(object):\n",
270
+ " \"Calculate loss for MHD equations with vector potential, using physicsnemo derivatives\"\n",
271
+ "\n",
272
+ " def __init__(\n",
273
+ " self,\n",
274
+ " nu=1e-4,\n",
275
+ " eta=1e-4,\n",
276
+ " rho0=1.0,\n",
277
+ " data_weight=1.0,\n",
278
+ " ic_weight=1.0,\n",
279
+ " pde_weight=1.0,\n",
280
+ " constraint_weight=1.0,\n",
281
+ " use_data_loss=True,\n",
282
+ " use_ic_loss=True,\n",
283
+ " use_pde_loss=True,\n",
284
+ " use_constraint_loss=True,\n",
285
+ " u_weight=1.0,\n",
286
+ " v_weight=1.0,\n",
287
+ " A_weight=1.0,\n",
288
+ " Du_weight=1.0,\n",
289
+ " Dv_weight=1.0,\n",
290
+ " DA_weight=1.0,\n",
291
+ " div_B_weight=1.0,\n",
292
+ " div_vel_weight=1.0,\n",
293
+ " Lx=1.0,\n",
294
+ " Ly=1.0,\n",
295
+ " tend=1.0,\n",
296
+ " use_weighted_mean=False,\n",
297
+ " **kwargs,\n",
298
+ " ): # add **kwargs so that we ignore unexpected kwargs when passing a config dict):\n",
299
+ "\n",
300
+ " self.nu = nu\n",
301
+ " self.eta = eta\n",
302
+ " self.rho0 = rho0\n",
303
+ " self.data_weight = data_weight\n",
304
+ " self.ic_weight = ic_weight\n",
305
+ " self.pde_weight = pde_weight\n",
306
+ " self.constraint_weight = constraint_weight\n",
307
+ " self.use_data_loss = use_data_loss\n",
308
+ " self.use_ic_loss = use_ic_loss\n",
309
+ " self.use_pde_loss = use_pde_loss\n",
310
+ " self.use_constraint_loss = use_constraint_loss\n",
311
+ " self.u_weight = u_weight\n",
312
+ " self.v_weight = v_weight\n",
313
+ " self.Du_weight = Du_weight\n",
314
+ " self.Dv_weight = Dv_weight\n",
315
+ " self.div_B_weight = div_B_weight\n",
316
+ " self.div_vel_weight = div_vel_weight\n",
317
+ " self.Lx = Lx\n",
318
+ " self.Ly = Ly\n",
319
+ " self.tend = tend\n",
320
+ " self.use_weighted_mean = use_weighted_mean\n",
321
+ " self.A_weight = A_weight\n",
322
+ " self.DA_weight = DA_weight\n",
323
+ " # Define 2D MHD PDEs\n",
324
+ " self.mhd_pde_eq = MHD_PDE(self.nu, self.eta, self.rho0)\n",
325
+ " self.mhd_pde_node = self.mhd_pde_eq.make_nodes()\n",
326
+ "\n",
327
+ " if not self.use_data_loss:\n",
328
+ " self.data_weight = 0\n",
329
+ " if not self.use_ic_loss:\n",
330
+ " self.ic_weight = 0\n",
331
+ " if not self.use_pde_loss:\n",
332
+ " self.pde_weight = 0\n",
333
+ " if not self.use_constraint_loss:\n",
334
+ " self.constraint_weight = 0\n",
335
+ "\n",
336
+ " def __call__(self, pred, true, inputs, return_loss_dict=False):\n",
337
+ " loss, loss_dict = self.compute_losses(pred, true, inputs)\n",
338
+ " return loss, loss_dict\n",
339
+ "\n",
340
+ " def compute_losses(self, pred, true, inputs):\n",
341
+ " \"Compute weighted loss and dictionary\"\n",
342
+ " pred = pred.reshape(true.shape)\n",
343
+ " u = pred[..., 0]\n",
344
+ " v = pred[..., 1]\n",
345
+ " A = pred[..., 2]\n",
346
+ "\n",
347
+ " loss_dict = {}\n",
348
+ "\n",
349
+ " # Data\n",
350
+ " if self.use_data_loss:\n",
351
+ " loss_data, loss_u, loss_v, loss_A = self.data_loss(\n",
352
+ " pred, true, return_all_losses=True\n",
353
+ " )\n",
354
+ " loss_dict[\"loss_data\"] = loss_data\n",
355
+ " loss_dict[\"loss_u\"] = loss_u\n",
356
+ " loss_dict[\"loss_v\"] = loss_v\n",
357
+ " loss_dict[\"loss_A\"] = loss_A\n",
358
+ " else:\n",
359
+ " loss_data = 0\n",
360
+ " # IC\n",
361
+ " if self.use_ic_loss:\n",
362
+ " loss_ic, loss_u_ic, loss_v_ic, loss_A_ic = self.ic_loss(\n",
363
+ " pred, inputs, return_all_losses=True\n",
364
+ " )\n",
365
+ " loss_dict[\"loss_ic\"] = loss_ic\n",
366
+ " loss_dict[\"loss_u_ic\"] = loss_u_ic\n",
367
+ " loss_dict[\"loss_v_ic\"] = loss_v_ic\n",
368
+ " loss_dict[\"loss_A_ic\"] = loss_A_ic\n",
369
+ " else:\n",
370
+ " loss_ic = 0\n",
371
+ "\n",
372
+ " # PDE\n",
373
+ " if self.use_pde_loss:\n",
374
+ " Du, Dv, DA = self.mhd_pde(u, v, A)\n",
375
+ " loss_pde, loss_Du, loss_Dv, loss_DA = self.mhd_pde_loss(\n",
376
+ " Du, Dv, DA, return_all_losses=True\n",
377
+ " )\n",
378
+ " loss_dict[\"loss_pde\"] = loss_pde\n",
379
+ " loss_dict[\"loss_Du\"] = loss_Du\n",
380
+ " loss_dict[\"loss_Dv\"] = loss_Dv\n",
381
+ " loss_dict[\"loss_DA\"] = loss_DA\n",
382
+ " else:\n",
383
+ " loss_pde = 0\n",
384
+ "\n",
385
+ " # Constraints\n",
386
+ " if self.use_constraint_loss:\n",
387
+ " div_vel, div_B = self.mhd_constraint(u, v, A)\n",
388
+ " loss_constraint, loss_div_vel, loss_div_B = self.mhd_constraint_loss(\n",
389
+ " div_vel, div_B, return_all_losses=True\n",
390
+ " )\n",
391
+ " loss_dict[\"loss_constraint\"] = loss_constraint\n",
392
+ " loss_dict[\"loss_div_vel\"] = loss_div_vel\n",
393
+ " loss_dict[\"loss_div_B\"] = loss_div_B\n",
394
+ " else:\n",
395
+ " loss_constraint = 0\n",
396
+ "\n",
397
+ " if self.use_weighted_mean:\n",
398
+ " weight_sum = (\n",
399
+ " self.data_weight\n",
400
+ " + self.ic_weight\n",
401
+ " + self.pde_weight\n",
402
+ " + self.constraint_weight\n",
403
+ " )\n",
404
+ " else:\n",
405
+ " weight_sum = 1.0\n",
406
+ "\n",
407
+ " loss = (\n",
408
+ " self.data_weight * loss_data\n",
409
+ " + self.ic_weight * loss_ic\n",
410
+ " + self.pde_weight * loss_pde\n",
411
+ " + self.constraint_weight * loss_constraint\n",
412
+ " ) / weight_sum\n",
413
+ " loss_dict[\"loss\"] = loss\n",
414
+ " return loss, loss_dict\n",
415
+ "\n",
416
+ "```\n",
417
+ "\n",
418
+ "The MDH equations that we defined before are initialized for use within the following loss functions. "
419
+ ]
420
+ },
421
+ {
422
+ "cell_type": "markdown",
423
+ "id": "c6c1a66a",
424
+ "metadata": {},
425
+ "source": [
426
+ "### Data Loss\n",
427
+ "The data loss is used to compare simulation data to the output of our model. The velocity in $x$ and $y$, as well as magnetic vector potential $\\mathbf{A}$ is directly compared to the ground truth data through the `Lp-Loss`, and the relative mean squared error is returned. \n",
428
+ "\n",
429
+ "\n",
430
+ "```python\n",
431
+ "def data_loss(self, pred, true, return_all_losses=False):\n",
432
+ " \"Compute data loss\"\n",
433
+ " lploss = LpLoss(size_average=True)\n",
434
+ " u_pred = pred[..., 0]\n",
435
+ " v_pred = pred[..., 1]\n",
436
+ " A_pred = pred[..., 2]\n",
437
+ "\n",
438
+ " u_true = true[..., 0]\n",
439
+ " v_true = true[..., 1]\n",
440
+ " A_true = true[..., 2]\n",
441
+ "\n",
442
+ " loss_u = lploss(u_pred, u_true)\n",
443
+ " loss_v = lploss(v_pred, v_true)\n",
444
+ " loss_A = lploss(A_pred, A_true)\n",
445
+ "\n",
446
+ " if self.use_weighted_mean:\n",
447
+ " weight_sum = self.u_weight + self.v_weight + self.A_weight\n",
448
+ " else:\n",
449
+ " weight_sum = 1.0\n",
450
+ "\n",
451
+ " loss_data = (\n",
452
+ " self.u_weight * loss_u + self.v_weight * loss_v + self.A_weight * loss_A\n",
453
+ " ) / weight_sum\n",
454
+ "\n",
455
+ " if return_all_losses:\n",
456
+ " return loss_data, loss_u, loss_v, loss_A\n",
457
+ " else:\n",
458
+ " return loss_data\n",
459
+ "```"
460
+ ]
461
+ },
462
+ {
463
+ "cell_type": "markdown",
464
+ "id": "fe2f1f44",
465
+ "metadata": {},
466
+ "source": [
467
+ "## PDE Loss\n",
468
+ "The PDE loss describes violations of the time evolution of the PDEs and the PINO outputs. In order to make this comparison, the spatial and temporal derivatives of the output fields need to be computed. To do so, Fourier differentiation is used to calculate the spacial derivatives, and second order finite differencing is used for temporal derivatives. The output fields are the velocity in the $x$ direction ($u$), the velocity in the $y$ direction ($v$), and the magnetic vector potential ($\\mathbf{A}$). The PDE loss is then defined as the MSE loss between zero and the PDE, after putting all the terms on the same side of the equation.\n",
469
+ "\n",
470
+ "Specifically, this loss covers the following equations: \n",
471
+ "$$\\begin{align*}\n",
472
+ "\\partial_t \\mathbf{u}+\\mathbf{u} \\cdot \\nabla \\mathbf{u} &=\n",
473
+ "-\\nabla \\left( p+\\frac{B^2}{2} \\right)/\\rho_0 +\\mathbf{B}\n",
474
+ "\\cdot \\nabla \\mathbf{B}+\\nu \\nabla^2 \\mathbf{u}, \\\\\n",
475
+ "\n",
476
+ "\\partial_t \\mathbf{A} + \\mathbf{u} \\cdot \\nabla \\mathbf{A} &=\\eta \\nabla^2 \\mathbf{A}\n",
477
+ "\\end{align*}$$\n",
478
+ "\n",
479
+ "```python\n",
480
+ "def mhd_pde(self, u, v, A, p=None):\n",
481
+ " \"Compute PDEs for MHD using vector potential\"\n",
482
+ " nt = u.size(1)\n",
483
+ " nx = u.size(2)\n",
484
+ " ny = u.size(3)\n",
485
+ " dt = self.tend / (nt - 1)\n",
486
+ "\n",
487
+ " # compute fourier derivatives\n",
488
+ " f_du, _ = fourier_derivatives(u, [self.Lx, self.Ly])\n",
489
+ " f_dv, _ = fourier_derivatives(v, [self.Lx, self.Ly])\n",
490
+ " f_dBx, f_dBy, f_dA, f_dB, B2_h = fourier_derivatives_vec_pot(\n",
491
+ " A, [self.Lx, self.Ly]\n",
492
+ " )\n",
493
+ "\n",
494
+ " u_x = f_du[:, 0:nt, :nx, :ny]\n",
495
+ " u_y = f_du[:, nt : 2 * nt, :nx, :ny]\n",
496
+ " v_x = f_dv[:, 0:nt, :nx, :ny]\n",
497
+ " v_y = f_dv[:, nt : 2 * nt, :nx, :ny]\n",
498
+ " A_x = f_dA[:, 0:nt, :nx, :ny]\n",
499
+ " A_y = f_dA[:, nt : 2 * nt, :nx, :ny]\n",
500
+ "\n",
501
+ " Bx = f_dB[:, 0:nt, :nx, :ny]\n",
502
+ " By = f_dB[:, nt : 2 * nt, :nx, :ny]\n",
503
+ " Bx_x = f_dBx[:, 0:nt, :nx, :ny]\n",
504
+ " Bx_y = f_dBx[:, nt : 2 * nt, :nx, :ny]\n",
505
+ " By_x = f_dBy[:, 0:nt, :nx, :ny]\n",
506
+ " By_y = f_dBy[:, nt : 2 * nt, :nx, :ny]\n",
507
+ "\n",
508
+ " u_lap = fourier_derivatives_lap(u, [self.Lx, self.Ly])\n",
509
+ " v_lap = fourier_derivatives_lap(v, [self.Lx, self.Ly])\n",
510
+ " A_lap = fourier_derivatives_lap(A, [self.Lx, self.Ly])\n",
511
+ "\n",
512
+ " # note that for pressure, the zero mode (the mean) cannot be zero for invertability so it is set to 1\n",
513
+ " div_vel_grad_vel = u_x**2 + 2 * u_y * v_x + v_y**2\n",
514
+ " div_B_grad_B = Bx_x**2 + 2 * Bx_y * By_x + By_y**2\n",
515
+ " f_dptot = fourier_derivatives_ptot(\n",
516
+ " p, div_vel_grad_vel, div_B_grad_B, B2_h, self.rho0, [self.Lx, self.Ly]\n",
517
+ " )\n",
518
+ " ptot_x = f_dptot[:, 0:nt, :nx, :ny]\n",
519
+ " ptot_y = f_dptot[:, nt : 2 * nt, :nx, :ny]\n",
520
+ "\n",
521
+ " # Plug inputs into dictionary\n",
522
+ " all_inputs = {\n",
523
+ " \"u\": u,\n",
524
+ " \"u__x\": u_x,\n",
525
+ " \"u__y\": u_y,\n",
526
+ " \"v\": v,\n",
527
+ " \"v__x\": v_x,\n",
528
+ " \"v__y\": v_y,\n",
529
+ " \"Bx\": Bx,\n",
530
+ " \"Bx__x\": Bx_x,\n",
531
+ " \"Bx__y\": Bx_y,\n",
532
+ " \"By\": By,\n",
533
+ " \"By__x\": By_x,\n",
534
+ " \"By__y\": By_y,\n",
535
+ " \"A__x\": A_x,\n",
536
+ " \"A__y\": A_y,\n",
537
+ " \"ptot__x\": ptot_x,\n",
538
+ " \"ptot__y\": ptot_y,\n",
539
+ " \"u__lap\": u_lap,\n",
540
+ " \"v__lap\": v_lap,\n",
541
+ " \"A__lap\": A_lap,\n",
542
+ " }\n",
543
+ "\n",
544
+ " # Substitute values into PDE equations\n",
545
+ " u_rhs = self.mhd_pde_node[14].evaluate(all_inputs)[\"u_rhs\"]\n",
546
+ " v_rhs = self.mhd_pde_node[15].evaluate(all_inputs)[\"v_rhs\"]\n",
547
+ " A_rhs = self.mhd_pde_node[23].evaluate(all_inputs)[\"A_rhs\"]\n",
548
+ "\n",
549
+ " u_t = self.Du_t(u, dt)\n",
550
+ " v_t = self.Du_t(v, dt)\n",
551
+ " A_t = self.Du_t(A, dt)\n",
552
+ "\n",
553
+ " # Find difference\n",
554
+ " Du = self.mhd_pde_node[18].evaluate({\"u__t\": u_t, \"u_rhs\": u_rhs[:, 1:-1]})[\n",
555
+ " \"Du\"\n",
556
+ " ]\n",
557
+ " Dv = self.mhd_pde_node[19].evaluate({\"v__t\": v_t, \"v_rhs\": v_rhs[:, 1:-1]})[\n",
558
+ " \"Dv\"\n",
559
+ " ]\n",
560
+ " DA = self.mhd_pde_node[24].evaluate({\"A__t\": A_t, \"A_rhs\": A_rhs[:, 1:-1]})[\n",
561
+ " \"DA\"\n",
562
+ " ]\n",
563
+ "\n",
564
+ " return Du, Dv, DA\n",
565
+ "\n",
566
+ "\n",
567
+ "def mhd_pde_loss(self, Du, Dv, DA, return_all_losses=None):\n",
568
+ " \"Compute PDE loss\"\n",
569
+ " Du_val = torch.zeros_like(Du)\n",
570
+ " Dv_val = torch.zeros_like(Dv)\n",
571
+ " DA_val = torch.zeros_like(DA)\n",
572
+ "\n",
573
+ " loss_Du = F.mse_loss(Du, Du_val)\n",
574
+ " loss_Dv = F.mse_loss(Dv, Dv_val)\n",
575
+ " loss_DA = F.mse_loss(DA, DA_val)\n",
576
+ "\n",
577
+ " if self.use_weighted_mean:\n",
578
+ " weight_sum = self.Du_weight + self.Dv_weight + self.DA_weight\n",
579
+ " else:\n",
580
+ " weight_sum = 1.0\n",
581
+ "\n",
582
+ " loss_pde = (\n",
583
+ " self.Du_weight * loss_Du\n",
584
+ " + self.Dv_weight * loss_Dv\n",
585
+ " + self.DA_weight * loss_DA\n",
586
+ " ) / weight_sum\n",
587
+ "\n",
588
+ " if return_all_losses:\n",
589
+ " return loss_pde, loss_Du, loss_Dv, loss_DA\n",
590
+ " else:\n",
591
+ " return loss_pde\n",
592
+ "```"
593
+ ]
594
+ },
595
+ {
596
+ "cell_type": "markdown",
597
+ "id": "cd40c1ed",
598
+ "metadata": {},
599
+ "source": [
600
+ "## Constraint Loss\n",
601
+ "The constraint illustrates the deviations of the velocity divergence free condition and the magnetic divergence free condition. These conditions are implemented similarly to the PDE loss, but without time derivative terms. The constraint loss is then the MSE between each of the constraint equations and zero. \n",
602
+ "\n",
603
+ "Specifically, the equations used for constraint loss are:\n",
604
+ "$$\\begin{align*}\n",
605
+ "\\nabla \\cdot \\mathbf{u} &= 0, \\\\\n",
606
+ "\\nabla \\cdot \\mathbf{B} &= 0\n",
607
+ "\\end{align*}$$\n",
608
+ "\n",
609
+ "\n",
610
+ "```python\n",
611
+ "def mhd_constraint(self, u, v, A):\n",
612
+ " \"Compute constraints\"\n",
613
+ " nt = u.size(1)\n",
614
+ " nx = u.size(2)\n",
615
+ " ny = u.size(3)\n",
616
+ "\n",
617
+ " f_du, _ = fourier_derivatives(u, [self.Lx, self.Ly])\n",
618
+ " f_dv, _ = fourier_derivatives(v, [self.Lx, self.Ly])\n",
619
+ " f_dBx, f_dBy, _, _, _ = fourier_derivatives_vec_pot(A, [self.Lx, self.Ly])\n",
620
+ "\n",
621
+ " u_x = f_du[:, 0:nt, :nx, :ny]\n",
622
+ " v_y = f_dv[:, nt : 2 * nt, :nx, :ny]\n",
623
+ " Bx_x = f_dBx[:, 0:nt, :nx, :ny]\n",
624
+ " By_y = f_dBy[:, nt : 2 * nt, :nx, :ny]\n",
625
+ "\n",
626
+ " div_B = self.mhd_pde_node[12].evaluate({\"Bx__x\": Bx_x, \"By__y\": By_y})[\"div_B\"]\n",
627
+ " div_vel = self.mhd_pde_node[13].evaluate({\"u__x\": u_x, \"v__y\": v_y})[\"div_vel\"]\n",
628
+ "\n",
629
+ " return div_vel, div_B\n",
630
+ "\n",
631
+ "def mhd_constraint_loss(self, div_vel, div_B, return_all_losses=False):\n",
632
+ " \"Compute constraint loss\"\n",
633
+ " div_vel_val = torch.zeros_like(div_vel)\n",
634
+ " div_B_val = torch.zeros_like(div_B)\n",
635
+ "\n",
636
+ " loss_div_vel = F.mse_loss(div_vel, div_vel_val)\n",
637
+ " loss_div_B = F.mse_loss(div_B, div_B_val)\n",
638
+ "\n",
639
+ " if self.use_weighted_mean:\n",
640
+ " weight_sum = self.div_vel_weight + self.div_B_weight\n",
641
+ " else:\n",
642
+ " weight_sum = 1.0\n",
643
+ "\n",
644
+ " loss_constraint = (\n",
645
+ " self.div_vel_weight * loss_div_vel + self.div_B_weight * loss_div_B\n",
646
+ " ) / weight_sum\n",
647
+ "\n",
648
+ " if return_all_losses:\n",
649
+ " return loss_constraint, loss_div_vel, loss_div_B\n",
650
+ " else:\n",
651
+ " return loss_constraint\n",
652
+ "```\n"
653
+ ]
654
+ },
655
+ {
656
+ "cell_type": "markdown",
657
+ "id": "627ae018",
658
+ "metadata": {},
659
+ "source": [
660
+ "## Initial Condition Loss\n",
661
+ "The initial condition loss encourages the model to associate the input field with the output field specifically at $t=0$. This constraint can usually be achieved with data loss, however this approach emphasized the importance of correct initial condition prediction, and enables training in the absence of data. Training without data and the significance of the initial condition term stem from the PDE loss term. \n",
662
+ "\n",
663
+ "```python\n",
664
+ "def ic_loss(self, pred, input, return_all_losses=False):\n",
665
+ " \"Compute initial condition loss\"\n",
666
+ " lploss = LpLoss(size_average=True)\n",
667
+ " ic_pred = pred[:, 0]\n",
668
+ " ic_true = input[:, 0, ..., 3:]\n",
669
+ " u_ic_pred = ic_pred[..., 0]\n",
670
+ " v_ic_pred = ic_pred[..., 1]\n",
671
+ " A_ic_pred = ic_pred[..., 2]\n",
672
+ "\n",
673
+ " u_ic_true = ic_true[..., 0]\n",
674
+ " v_ic_true = ic_true[..., 1]\n",
675
+ " A_ic_true = ic_true[..., 2]\n",
676
+ "\n",
677
+ " loss_u_ic = lploss(u_ic_pred, u_ic_true)\n",
678
+ " loss_v_ic = lploss(v_ic_pred, v_ic_true)\n",
679
+ " loss_A_ic = lploss(A_ic_pred, A_ic_true)\n",
680
+ "\n",
681
+ " if self.use_weighted_mean:\n",
682
+ " weight_sum = self.u_weight + self.v_weight + self.A_weight\n",
683
+ " else:\n",
684
+ " weight_sum = 1.0\n",
685
+ "\n",
686
+ " loss_ic = (\n",
687
+ " self.u_weight * loss_u_ic\n",
688
+ " + self.v_weight * loss_v_ic\n",
689
+ " + self.A_weight * loss_A_ic\n",
690
+ " ) / weight_sum\n",
691
+ "\n",
692
+ " if return_all_losses:\n",
693
+ " return loss_ic, loss_u_ic, loss_v_ic, loss_A_ic\n",
694
+ " else:\n",
695
+ " return loss_ic\n",
696
+ "```\n",
697
+ "\n",
698
+ "Similar to the initial condition loss, boundary condition loss can be used to describe violations of the boundary terms. In this specific case, the tFNO architecture ensures that the periodic boundary conditions are satisfied, thus the term is not used in this example. "
699
+ ]
700
+ },
701
+ {
702
+ "cell_type": "markdown",
703
+ "id": "58809a2f",
704
+ "metadata": {},
705
+ "source": [
706
+ "In theory, training can be done by correctly predicting the initial conditions, boundary conditions and correctly evolving the PDE forward in time. In practice, having data helps the model converge more quickly. However, an incorrect initial condition results in the PDE evolving the wrong state forward in time, which is why it is emphasized as its own term. The initial condition loss is calculated by taking the input fields and computing the relative MSE with output fields at $t=0$. \n"
707
+ ]
708
+ },
709
+ {
710
+ "cell_type": "markdown",
711
+ "id": "f14c9c48",
712
+ "metadata": {},
713
+ "source": [
714
+ "## Dataset and Dataloaders\n",
715
+ "To use the data we have generated, we need to define a dataset and a dataloader that can ingest the files and parse them based on the relevant content. \n",
716
+ "\n",
717
+ "```python\n",
718
+ "import glob\n",
719
+ "import os\n",
720
+ "\n",
721
+ "import h5py\n",
722
+ "from torch.utils import data\n",
723
+ "\n",
724
+ "\n",
725
+ "class Dedalus2DDataset(data.Dataset):\n",
726
+ " \"Dataset for MHD 2D Dataset\"\n",
727
+ "\n",
728
+ " def __init__(\n",
729
+ " self,\n",
730
+ " data_path,\n",
731
+ " output_names=\"output-\",\n",
732
+ " field_names=[\"magnetic field\", \"velocity\"],\n",
733
+ " num_train=None,\n",
734
+ " num_test=None,\n",
735
+ " num=None,\n",
736
+ " use_train=True,\n",
737
+ " ):\n",
738
+ " self.data_path = data_path\n",
739
+ " output_names = \"output-\" + \"?\"*len(str(len(os.listdir(data_path))))\n",
740
+ " self.output_names = output_names\n",
741
+ " raw_path = os.path.join(data_path, output_names, \"*.h5\")\n",
742
+ " files_raw = sorted(glob.glob(raw_path))\n",
743
+ " self.files_raw = files_raw\n",
744
+ " self.num_files_raw = num_files_raw = len(files_raw)\n",
745
+ " self.field_names = field_names\n",
746
+ " self.use_train = use_train\n",
747
+ "\n",
748
+ " # Handle num parameter: -1 means use full dataset, otherwise limit to specified number\n",
749
+ " if num is not None and num > 0:\n",
750
+ " num_files_raw = min(num, num_files_raw)\n",
751
+ " files_raw = files_raw[:num_files_raw]\n",
752
+ " self.files_raw = files_raw\n",
753
+ " self.num_files_raw = num_files_raw\n",
754
+ "\n",
755
+ " # Handle percentage-based splits\n",
756
+ " if num_train is not None and num_train <= 1.0:\n",
757
+ " # num_train is a percentage\n",
758
+ " num_train = int(num_train * num_files_raw)\n",
759
+ " elif num_train is None or num_train > num_files_raw:\n",
760
+ " num_train = num_files_raw\n",
761
+ "\n",
762
+ " if num_test is not None and num_test <= 1.0:\n",
763
+ " # num_test is a percentage\n",
764
+ " num_test = int(num_test * num_files_raw)\n",
765
+ " elif num_test is None or num_test > (num_files_raw - num_train):\n",
766
+ " num_test = num_files_raw - num_train\n",
767
+ "\n",
768
+ " self.num_train = num_train\n",
769
+ " self.train_files = self.files_raw[:num_train]\n",
770
+ " self.num_test = num_test\n",
771
+ " self.test_end = test_end = num_train + num_test\n",
772
+ " self.test_files = self.files_raw[num_train:test_end]\n",
773
+ " \n",
774
+ " if (self.use_train) or (self.test_files is None):\n",
775
+ " files = self.train_files\n",
776
+ " else:\n",
777
+ " files = self.test_files\n",
778
+ " self.files = files\n",
779
+ " self.num_files = num_files = len(files)\n",
780
+ "\n",
781
+ " def __len__(self):\n",
782
+ " length = len(self.files)\n",
783
+ " return length\n",
784
+ "\n",
785
+ " def __getitem__(self, index):\n",
786
+ " \"Gets item for dataloader\"\n",
787
+ " file = self.files[index]\n",
788
+ "\n",
789
+ " field_names = self.field_names\n",
790
+ " fields = {}\n",
791
+ " coords = []\n",
792
+ " with h5py.File(file, mode=\"r\") as h5file:\n",
793
+ " data_file = h5file[\"tasks\"]\n",
794
+ " keys = list(data_file.keys())\n",
795
+ " if field_names is None:\n",
796
+ " field_names = keys\n",
797
+ " for field_name in field_names:\n",
798
+ " if field_name in data_file:\n",
799
+ " field = data_file[field_name][:]\n",
800
+ " fields[field_name] = field\n",
801
+ " else:\n",
802
+ " print(f\"field name {field_name} not found\")\n",
803
+ " dataset = fields\n",
804
+ " return dataset\n",
805
+ "\n",
806
+ " def get_coords(self, index):\n",
807
+ " \"Gets coordinates of t, x, y for dataloader\"\n",
808
+ " file = self.files[index]\n",
809
+ " with h5py.File(file, mode=\"r\") as h5file:\n",
810
+ " data_file = h5file[\"tasks\"]\n",
811
+ " keys = list(data_file.keys())\n",
812
+ " dims = data_file[keys[0]].dims\n",
813
+ "\n",
814
+ " ndims = len(dims)\n",
815
+ " t = dims[0][\"sim_time\"][:]\n",
816
+ " x = dims[ndims - 2][0][:]\n",
817
+ " y = dims[ndims - 1][0][:]\n",
818
+ " return t, x, y\n",
819
+ "```\n",
820
+ "\n",
821
+ "And the dataloader which is sampled from during training.\n",
822
+ "\n",
823
+ "```python\n",
824
+ "class MHDDataloaderVecPot(Dataset):\n",
825
+ " \"Dataloader for MHD Dataset with vector potential\"\n",
826
+ "\n",
827
+ " def __init__(\n",
828
+ " self, dataset: Dedalus2DDataset, sub_x=1, sub_t=1, ind_x=None, ind_t=None\n",
829
+ " ):\n",
830
+ " self.dataset = dataset\n",
831
+ " self.sub_x = sub_x\n",
832
+ " self.sub_t = sub_t\n",
833
+ " self.ind_x = ind_x\n",
834
+ " self.ind_t = ind_t\n",
835
+ " t, x, y = dataset.get_coords(0)\n",
836
+ " self.x = x[:ind_x:sub_x]\n",
837
+ " self.y = y[:ind_x:sub_x]\n",
838
+ " self.t = t[:ind_t:sub_t]\n",
839
+ " self.nx = len(self.x)\n",
840
+ " self.ny = len(self.y)\n",
841
+ " self.nt = len(self.t)\n",
842
+ " self.num = num = len(self.dataset)\n",
843
+ " self.x_slice = slice(0, self.ind_x, self.sub_x)\n",
844
+ " self.t_slice = slice(0, self.ind_t, self.sub_t)\n",
845
+ "\n",
846
+ " def __len__(self):\n",
847
+ " length = len(self.dataset)\n",
848
+ " return length\n",
849
+ "\n",
850
+ " def __getitem__(self, index):\n",
851
+ " \"Gets input of dataloader, including data, t, x, and y\"\n",
852
+ " fields = self.dataset[index]\n",
853
+ "\n",
854
+ " # Data includes velocity and vector potential\n",
855
+ " velocity = fields[\"velocity\"]\n",
856
+ " vector_potential = fields[\"vector potential\"]\n",
857
+ "\n",
858
+ " u = torch.from_numpy(\n",
859
+ " velocity[\n",
860
+ " : self.ind_t : self.sub_t,\n",
861
+ " 0,\n",
862
+ " : self.ind_x : self.sub_x,\n",
863
+ " : self.ind_x : self.sub_x,\n",
864
+ " ]\n",
865
+ " )\n",
866
+ " v = torch.from_numpy(\n",
867
+ " velocity[\n",
868
+ " : self.ind_t : self.sub_t,\n",
869
+ " 1,\n",
870
+ " : self.ind_x : self.sub_x,\n",
871
+ " : self.ind_x : self.sub_x,\n",
872
+ " ]\n",
873
+ " )\n",
874
+ " A = torch.from_numpy(\n",
875
+ " vector_potential[\n",
876
+ " : self.ind_t : self.sub_t,\n",
877
+ " : self.ind_x : self.sub_x,\n",
878
+ " : self.ind_x : self.sub_x,\n",
879
+ " ]\n",
880
+ " )\n",
881
+ "\n",
882
+ " # shape is now (self.nt, self.nx, self.ny, nfields)\n",
883
+ " data = torch.stack([u, v, A], dim=-1)\n",
884
+ " data0 = data[0].reshape(1, self.nx, self.ny, -1).repeat(self.nt, 1, 1, 1)\n",
885
+ "\n",
886
+ " grid_t = (\n",
887
+ " torch.from_numpy(self.t)\n",
888
+ " .reshape(self.nt, 1, 1, 1)\n",
889
+ " .repeat(1, self.nx, self.ny, 1)\n",
890
+ " )\n",
891
+ " grid_x = (\n",
892
+ " torch.from_numpy(self.x)\n",
893
+ " .reshape(1, self.nx, 1, 1)\n",
894
+ " .repeat(self.nt, 1, self.ny, 1)\n",
895
+ " )\n",
896
+ " grid_y = (\n",
897
+ " torch.from_numpy(self.y)\n",
898
+ " .reshape(1, 1, self.ny, 1)\n",
899
+ " .repeat(self.nt, self.nx, 1, 1)\n",
900
+ " )\n",
901
+ "\n",
902
+ " inputs = torch.cat([grid_t, grid_x, grid_y, data0], dim=-1)\n",
903
+ " outputs = data\n",
904
+ "\n",
905
+ " return inputs, outputs\n",
906
+ " \n",
907
+ " def create_dataloader(\n",
908
+ " self,\n",
909
+ " batch_size=1,\n",
910
+ " shuffle=False,\n",
911
+ " num_workers=0,\n",
912
+ " pin_memory=False,\n",
913
+ " distributed=False,\n",
914
+ " ):\n",
915
+ " \"Creates dataloader and sampler based on whether distributed training is on\"\n",
916
+ " if distributed:\n",
917
+ " sampler = torch.utils.data.DistributedSampler(self)\n",
918
+ " dataloader = DataLoader(\n",
919
+ " self,\n",
920
+ " batch_size=batch_size,\n",
921
+ " shuffle=False,\n",
922
+ " sampler=sampler,\n",
923
+ " num_workers=num_workers,\n",
924
+ " pin_memory=pin_memory,\n",
925
+ " )\n",
926
+ " else:\n",
927
+ " sampler = None\n",
928
+ " dataloader = DataLoader(\n",
929
+ " self,\n",
930
+ " batch_size=batch_size,\n",
931
+ " shuffle=shuffle,\n",
932
+ " num_workers=num_workers,\n",
933
+ " pin_memory=pin_memory,\n",
934
+ " )\n",
935
+ "\n",
936
+ " return dataloader, sampler\n",
937
+ "```"
938
+ ]
939
+ },
940
+ {
941
+ "cell_type": "markdown",
942
+ "id": "bf33d763",
943
+ "metadata": {},
944
+ "source": [
945
+ "## Model Architecture\n",
946
+ "<div style=\"display: flex; justify-content: center; gap: 10px;\">\n",
947
+ " <figure style=\"text-align: center;\">\n",
948
+ " <img src=\"https://raw.githubusercontent.com/openhackathons-org/End-to-End-AI-for-Science/main/workspace/python/jupyter_notebook/MagnetoHydrodynamics/images/model_arch.png\" style=\"width: 100%; height: auto;\">\n",
949
+ " <figcaption>Model architecture overview.</figcaption>\n",
950
+ " </figure>\n",
951
+ "</div>\n",
952
+ "\n",
953
+ "<!-- ![model_arch](images/model_arch.png) -->\n",
954
+ "\n",
955
+ "Our PINO model is composed of Tensor Factorized Neural Operators as the core component. Input fields are fed in as the input, which are composed of $u$, $v$, and $A$ initial conditions. The data is first lifted into a higher dimension representation by the neural network, P1. The data then enters the Fourier layers ($F_1$,...,$F_n$). Each Fourier layer consists of a sequence of non-logical integral operators, and nonlinear activation functions. $T_1$ represents a linear transform that employs CP decomposed tensors as weights, and $T_2$ represents a local linear transform. $\\sigma$ is the activation function, and $\\mathcal{F}$, $\\mathcal{F}^{-1}$ represent the Fourier transfrom and inverse Fourier transform respectively. At the end, $P_2$ projects back down into the input space, producing the output shown on the right which describe the\n",
956
+ "time evolution of the system. \n"
957
+ ]
958
+ },
959
+ {
960
+ "cell_type": "markdown",
961
+ "id": "5dacfbfb",
962
+ "metadata": {},
963
+ "source": [
964
+ "## Training our Model\n",
965
+ "\n",
966
+ "PhysicsNeMo has two distinct styles, namely Core and Sym. PhysicsNeMo Sym is a framework providing pythonic APIs, algorithms and utilities to be used with PhysicsNeMo Core, while PhysicsNeMo Core interoperates with PyTorch directly. Working with PhysicsNeMo Core looks and feels more like a PyTorch workflow with some key utils like models, utils, and datapipes imported directly from `physicsnemo` itself. While some components of this workflow so far have borrowed from PhysicsNeMo Sym (`MHD_PDE`), the training workflow for this problem will be build primarily using the Core style. This will provide more flexibility over our training loop, and allow for further customizations to our workflow. The training script follows the standard flow of training models using pytorch. \n"
967
+ ]
968
+ },
969
+ {
970
+ "cell_type": "markdown",
971
+ "id": "3339921a",
972
+ "metadata": {},
973
+ "source": [
974
+ "## Hydra Config\n",
975
+ "\n",
976
+ "Training in PhysicsNeMo is facilitated by Hydra configs, which allow us to set and manager parameters from a single file, updating parameters for components such as our model, datasets, optimizer, logger, loss function, and dataloaders. The first step in getting set up for training is defining this yaml file and loading the config.\n",
977
+ "\n",
978
+ "\n",
979
+ "```yaml \n",
980
+ "## Training options\n",
981
+ "# Reynolds number parameter\n",
982
+ "reynolds_number: 100\n",
983
+ "\n",
984
+ "load_ckpt: False\n",
985
+ "output_dir: './checkpoints/MHDVecPot_TFNO/MHDVecPot_TFNO_PINO_Re${reynolds_number}/figures/'\n",
986
+ "\n",
987
+ "###################\n",
988
+ "## Model options\n",
989
+ "model_params:\n",
990
+ " layers: 8\n",
991
+ " modes: 8\n",
992
+ " num_fno_layers: 4\n",
993
+ " fc_dim: 128\n",
994
+ " decoder_layers: 1\n",
995
+ " in_dim: 6 # 3 + in_fields\n",
996
+ " out_dim: 3\n",
997
+ " dimension: 3\n",
998
+ " activation: 'gelu'\n",
999
+ " pad_x: 5\n",
1000
+ " pad_y: 0\n",
1001
+ " pad_z: 0\n",
1002
+ " input_norm: [1.0, 1.0, 1.0, 1.0, 1.0, 0.00025]\n",
1003
+ " output_norm: [1.0, 1.0, 0.00025]\n",
1004
+ "\n",
1005
+ " #TensorLy arguments\n",
1006
+ " rank: 0.5\n",
1007
+ " factorization: 'cp'\n",
1008
+ " fixed_rank_modes: null\n",
1009
+ " decomposition_kwargs: {}\n",
1010
+ "\n",
1011
+ "###################\n",
1012
+ "## Dataset options\n",
1013
+ "dataset_params:\n",
1014
+ " data_dir: '/Datasets/mhd_data/simulation_outputs_Re${reynolds_number}'\n",
1015
+ " field_names: ['velocity', 'vector potential']\n",
1016
+ " output_names: 'output-????'\n",
1017
+ " dataset_type: 'mhd'\n",
1018
+ " name: 'MHDVecPot_TFNO_Re${reynolds_number}'\n",
1019
+ " num: -1 # -1 means use full dataset, otherwise specify total number\n",
1020
+ " num_train: 0.8 # percentage of dataset for training\n",
1021
+ " num_test: 0.2 # percentage of dataset for testing\n",
1022
+ " sub_x: 1\n",
1023
+ " sub_t: 1\n",
1024
+ " ind_x: null\n",
1025
+ " ind_t: null\n",
1026
+ " nin: 3\n",
1027
+ " nout: 3\n",
1028
+ " fields: ['u', 'v', 'A']\n",
1029
+ "\n",
1030
+ "###################\n",
1031
+ "## Dataloader options\n",
1032
+ "train_loader_params:\n",
1033
+ " batch_size: 1\n",
1034
+ " shuffle: True\n",
1035
+ " num_workers: 4\n",
1036
+ " pin_memory: True\n",
1037
+ "\n",
1038
+ "val_loader_params:\n",
1039
+ " batch_size: 1\n",
1040
+ " shuffle: False\n",
1041
+ " num_workers: 4\n",
1042
+ " pin_memory: True\n",
1043
+ "\n",
1044
+ "test_loader_params:\n",
1045
+ " batch_size: 1\n",
1046
+ " shuffle: False\n",
1047
+ " num_workers: 4\n",
1048
+ " pin_memory: True\n",
1049
+ "\n",
1050
+ "###################\n",
1051
+ "## Loss options\n",
1052
+ "loss_params:\n",
1053
+ " nu: 0.004\n",
1054
+ " eta: 0.004\n",
1055
+ " rho0: 1.0\n",
1056
+ "\n",
1057
+ " data_weight: 5.0\n",
1058
+ " ic_weight: 1.0\n",
1059
+ " pde_weight: 1.0\n",
1060
+ " constraint_weight: 10.0\n",
1061
+ "\n",
1062
+ " use_data_loss: True\n",
1063
+ " use_ic_loss: True\n",
1064
+ " use_pde_loss: True\n",
1065
+ " use_constraint_loss: True\n",
1066
+ "\n",
1067
+ " u_weight: 1.0\n",
1068
+ " v_weight: 1.0\n",
1069
+ " A_weight: 1.0\n",
1070
+ "\n",
1071
+ " Du_weight: 1.0\n",
1072
+ " Dv_weight: 1.0\n",
1073
+ " DA_weight: 1_000_000\n",
1074
+ "\n",
1075
+ " div_B_weight: 1.0\n",
1076
+ " div_vel_weight: 1.0\n",
1077
+ "\n",
1078
+ " Lx: 1.0\n",
1079
+ " Ly: 1.0\n",
1080
+ " tend: 1.0\n",
1081
+ "\n",
1082
+ " use_weighted_mean: False\n",
1083
+ "\n",
1084
+ "###################\n",
1085
+ "## Optimizer options\n",
1086
+ "optimizer_params:\n",
1087
+ " betas: [0.9, 0.999]\n",
1088
+ " lr: 5.0e-4\n",
1089
+ " milestones: [20, 40, 60, 80, 100]\n",
1090
+ " gamma: 0.5\n",
1091
+ "\n",
1092
+ "\n",
1093
+ "###################\n",
1094
+ "## Train params\n",
1095
+ "train_params:\n",
1096
+ " epochs: 100\n",
1097
+ " ckpt_freq: 10\n",
1098
+ " ckpt_path: 'checkpoints/MHDVecPot_TFNO/MHDVecPot_TFNO_PINO_Re${reynolds_number}/'\n",
1099
+ "\n",
1100
+ "###################\n",
1101
+ "## log params\n",
1102
+ "log_params:\n",
1103
+ " log_dir: 'logs'\n",
1104
+ " log_project: 'MHD_PINO'\n",
1105
+ " log_group: 'MHDVecPot_TFNO_Re${reynolds_number}'\n",
1106
+ " log_num_plots: 1\n",
1107
+ " log_plot_freq: 5\n",
1108
+ " log_plot_types: ['ic', 'pred', 'true', 'error']\n",
1109
+ "\n",
1110
+ "test:\n",
1111
+ " batchsize: 1\n",
1112
+ " ckpt_path: 'checkpoints/MHDVecPot_TFNO/MHDVecPot_TFNO_PINO_Re${reynolds_number}/'\n",
1113
+ "\n",
1114
+ "```"
1115
+ ]
1116
+ },
1117
+ {
1118
+ "cell_type": "markdown",
1119
+ "id": "c9c370fa",
1120
+ "metadata": {},
1121
+ "source": [
1122
+ "## Training Setup\n",
1123
+ "\n",
1124
+ "We begin with importing the required modules, capturing our hydra config, and initializing some utilities to facilitate the model training. Most of this initial setup is \n",
1125
+ "\n",
1126
+ "```python\n",
1127
+ "import os\n",
1128
+ "\n",
1129
+ "import hydra\n",
1130
+ "from omegaconf import ListConfig, OmegaConf\n",
1131
+ "import torch\n",
1132
+ "from omegaconf import DictConfig\n",
1133
+ "from physicsnemo.distributed import DistributedManager\n",
1134
+ "from physicsnemo.launch.logging import LaunchLogger, PythonLogger\n",
1135
+ "from physicsnemo.launch.utils import load_checkpoint, save_checkpoint\n",
1136
+ "from physicsnemo.sym.hydra import to_absolute_path\n",
1137
+ "from torch.nn.parallel import DistributedDataParallel\n",
1138
+ "from torch.optim import AdamW\n",
1139
+ "\n",
1140
+ "from dataloaders import Dedalus2DDataset, MHDDataloaderVecPot\n",
1141
+ "from losses import LossMHDVecPot_PhysicsNeMo\n",
1142
+ "from tfno import TFNO\n",
1143
+ "from utils.plot_utils import plot_predictions_mhd, plot_predictions_mhd_plotly\n",
1144
+ "\n",
1145
+ "dtype = torch.float\n",
1146
+ "torch.set_default_dtype(dtype)\n",
1147
+ "\n",
1148
+ "\n",
1149
+ "@hydra.main(\n",
1150
+ " version_base=\"1.3\", config_path=\"config\", config_name=\"train_mhd_vec_pot_tfno.yaml\"\n",
1151
+ ")\n",
1152
+ "def main(cfg: DictConfig) -> None:\n",
1153
+ " DistributedManager.initialize() # Only call this once in the entire script!\n",
1154
+ " dist = DistributedManager() # call if required elsewhere\n",
1155
+ " cfg = OmegaConf.to_container(cfg, resolve=True)\n",
1156
+ "\n",
1157
+ " # initialize monitoring\n",
1158
+ " log = PythonLogger(name=\"mhd_pino\")\n",
1159
+ " log.file_logging()\n",
1160
+ "\n",
1161
+ " log_params = cfg[\"log_params\"]\n",
1162
+ "\n",
1163
+ " # Load config file parameters\n",
1164
+ " model_params = cfg[\"model_params\"]\n",
1165
+ " dataset_params = cfg[\"dataset_params\"]\n",
1166
+ " train_loader_params = cfg[\"train_loader_params\"]\n",
1167
+ " val_loader_params = cfg[\"val_loader_params\"]\n",
1168
+ " loss_params = cfg[\"loss_params\"]\n",
1169
+ " optimizer_params = cfg[\"optimizer_params\"]\n",
1170
+ " train_params = cfg[\"train_params\"]\n",
1171
+ "\n",
1172
+ " load_ckpt = cfg[\"load_ckpt\"]\n",
1173
+ " output_dir = cfg[\"output_dir\"]\n",
1174
+ "\n",
1175
+ " output_dir = to_absolute_path(output_dir)\n",
1176
+ " os.makedirs(output_dir, exist_ok=True)\n",
1177
+ "\n",
1178
+ " data_dir = dataset_params[\"data_dir\"]\n",
1179
+ " ckpt_path = train_params[\"ckpt_path\"]\n",
1180
+ "```\n"
1181
+ ]
1182
+ },
1183
+ {
1184
+ "cell_type": "markdown",
1185
+ "id": "bacc38c7",
1186
+ "metadata": {},
1187
+ "source": [
1188
+ "## Datasets and Dataloaders\n",
1189
+ "\n",
1190
+ "Datasets and dataloaders are initialized using parameters from the hydra config.\n",
1191
+ "\n",
1192
+ "```python\n",
1193
+ "# Construct dataloaders\n",
1194
+ "dataset_train = Dedalus2DDataset(\n",
1195
+ " dataset_params[\"data_dir\"],\n",
1196
+ " output_names=dataset_params[\"output_names\"],\n",
1197
+ " field_names=dataset_params[\"field_names\"],\n",
1198
+ " num_train=dataset_params[\"num_train\"],\n",
1199
+ " num_test=dataset_params[\"num_test\"],\n",
1200
+ " num=dataset_params[\"num\"],\n",
1201
+ " use_train=True,\n",
1202
+ ")\n",
1203
+ "dataset_val = Dedalus2DDataset(\n",
1204
+ " data_dir,\n",
1205
+ " output_names=dataset_params[\"output_names\"],\n",
1206
+ " field_names=dataset_params[\"field_names\"],\n",
1207
+ " num_train=dataset_params[\"num_train\"],\n",
1208
+ " num_test=dataset_params[\"num_test\"],\n",
1209
+ " num=dataset_params[\"num\"],\n",
1210
+ " use_train=False,\n",
1211
+ ")\n",
1212
+ "\n",
1213
+ "mhd_dataloader_train = MHDDataloaderVecPot(\n",
1214
+ " dataset_train,\n",
1215
+ " sub_x=dataset_params[\"sub_x\"],\n",
1216
+ " sub_t=dataset_params[\"sub_t\"],\n",
1217
+ " ind_x=dataset_params[\"ind_x\"],\n",
1218
+ " ind_t=dataset_params[\"ind_t\"],\n",
1219
+ ")\n",
1220
+ "mhd_dataloader_val = MHDDataloaderVecPot(\n",
1221
+ " dataset_val,\n",
1222
+ " sub_x=dataset_params[\"sub_x\"],\n",
1223
+ " sub_t=dataset_params[\"sub_t\"],\n",
1224
+ " ind_x=dataset_params[\"ind_x\"],\n",
1225
+ " ind_t=dataset_params[\"ind_t\"],\n",
1226
+ ")\n",
1227
+ "\n",
1228
+ "dataloader_train, sampler_train = mhd_dataloader_train.create_dataloader(\n",
1229
+ " batch_size=train_loader_params[\"batch_size\"],\n",
1230
+ " shuffle=train_loader_params[\"shuffle\"],\n",
1231
+ " num_workers=train_loader_params[\"num_workers\"],\n",
1232
+ " pin_memory=train_loader_params[\"pin_memory\"],\n",
1233
+ " distributed=dist.distributed,\n",
1234
+ ")\n",
1235
+ "dataloader_val, sampler_val = mhd_dataloader_val.create_dataloader(\n",
1236
+ " batch_size=val_loader_params[\"batch_size\"],\n",
1237
+ " shuffle=val_loader_params[\"shuffle\"],\n",
1238
+ " num_workers=val_loader_params[\"num_workers\"],\n",
1239
+ " pin_memory=val_loader_params[\"pin_memory\"],\n",
1240
+ " distributed=dist.distributed,\n",
1241
+ ")\n",
1242
+ "```"
1243
+ ]
1244
+ },
1245
+ {
1246
+ "cell_type": "markdown",
1247
+ "id": "4826da07",
1248
+ "metadata": {},
1249
+ "source": [
1250
+ "## Model Construction\n",
1251
+ "For a relatively simple model such as `FNO`, we can directly use an architecture pre-defined by PhysicsNeMo. Hyper-parameters are set directly from the hydra config, and makes it straight forward to configure hyper parameter optimization if necessary. For a more complex model such as `tFNO`, we can leverage a combination of PhysicsNeMo primitives and third party packages to build a model in pytorch. \n",
1252
+ "\n",
1253
+ "```python\n",
1254
+ "# Define the model\n",
1255
+ "model = TFNO(\n",
1256
+ " in_channels=model_params[\"in_dim\"],\n",
1257
+ " out_channels=model_params[\"out_dim\"],\n",
1258
+ " decoder_layers=model_params[\"decoder_layers\"],\n",
1259
+ " decoder_layer_size=model_params[\"fc_dim\"],\n",
1260
+ " dimension=model_params[\"dimension\"],\n",
1261
+ " latent_channels=model_params[\"layers\"],\n",
1262
+ " num_fno_layers=model_params[\"num_fno_layers\"],\n",
1263
+ " num_fno_modes=model_params[\"modes\"],\n",
1264
+ " padding=[model_params[\"pad_z\"], model_params[\"pad_y\"], model_params[\"pad_x\"]],\n",
1265
+ " rank=model_params[\"rank\"],\n",
1266
+ " factorization=model_params[\"factorization\"],\n",
1267
+ " fixed_rank_modes=model_params[\"fixed_rank_modes\"],\n",
1268
+ " decomposition_kwargs=model_params[\"decomposition_kwargs\"],\n",
1269
+ ").to(dist.device)\n",
1270
+ "# Set up DistributedDataParallel if using more than a single process.\n",
1271
+ "# The `distributed` property of DistributedManager can be used to\n",
1272
+ "# check this.\n",
1273
+ "if dist.distributed:\n",
1274
+ " ddps = torch.cuda.Stream()\n",
1275
+ " with torch.cuda.stream(ddps):\n",
1276
+ " model = DistributedDataParallel(\n",
1277
+ " model,\n",
1278
+ " device_ids=[dist.local_rank], # Set the device_id to be\n",
1279
+ " # the local rank of this process on\n",
1280
+ " # this node\n",
1281
+ " output_device=dist.device,\n",
1282
+ " broadcast_buffers=dist.broadcast_buffers,\n",
1283
+ " find_unused_parameters=dist.find_unused_parameters,\n",
1284
+ " )\n",
1285
+ " torch.cuda.current_stream().wait_stream(ddps)\n",
1286
+ "\n",
1287
+ "```"
1288
+ ]
1289
+ },
1290
+ {
1291
+ "cell_type": "markdown",
1292
+ "id": "99604fbc",
1293
+ "metadata": {},
1294
+ "source": [
1295
+ "## Optimizer, Scheduler, Loss Functions and Check-pointing\n",
1296
+ "\n",
1297
+ "\n",
1298
+ "```python\n",
1299
+ "# Construct optimizer and scheduler\n",
1300
+ "optimizer = AdamW(\n",
1301
+ " model.parameters(),\n",
1302
+ " betas=optimizer_params[\"betas\"],\n",
1303
+ " lr=optimizer_params[\"lr\"],\n",
1304
+ " weight_decay=0.1,\n",
1305
+ ")\n",
1306
+ "\n",
1307
+ "scheduler = torch.optim.lr_scheduler.MultiStepLR(\n",
1308
+ " optimizer,\n",
1309
+ " milestones=optimizer_params[\"milestones\"],\n",
1310
+ " gamma=optimizer_params[\"gamma\"],\n",
1311
+ ")\n",
1312
+ "\n",
1313
+ "# Construct Loss class\n",
1314
+ "mhd_loss = LossMHDVecPot_PhysicsNeMo(**loss_params)\n",
1315
+ "\n",
1316
+ "# Load model from checkpoint (if exists)\n",
1317
+ "loaded_epoch = 0\n",
1318
+ "if load_ckpt:\n",
1319
+ " loaded_epoch = load_checkpoint(\n",
1320
+ " ckpt_path, model, optimizer, scheduler, device=dist.device\n",
1321
+ " )\n",
1322
+ "```\n"
1323
+ ]
1324
+ },
1325
+ {
1326
+ "cell_type": "markdown",
1327
+ "id": "a775b128",
1328
+ "metadata": {},
1329
+ "source": [
1330
+ "## Training Loop\n",
1331
+ "Finally, the main training loop iterates through the dataset for our defined number of epochs, saving checkpoints and visualizations of our training along the way.\n",
1332
+ "\n",
1333
+ "```python\n",
1334
+ "# Training Loop\n",
1335
+ "epochs = train_params[\"epochs\"]\n",
1336
+ "ckpt_freq = train_params[\"ckpt_freq\"]\n",
1337
+ "names = dataset_params[\"fields\"]\n",
1338
+ "input_norm = torch.tensor(model_params[\"input_norm\"]).to(dist.device)\n",
1339
+ "output_norm = torch.tensor(model_params[\"output_norm\"]).to(dist.device)\n",
1340
+ "for epoch in range(max(1, loaded_epoch + 1), epochs + 1):\n",
1341
+ " with LaunchLogger(\n",
1342
+ " \"train\",\n",
1343
+ " epoch=epoch,\n",
1344
+ " num_mini_batch=len(dataloader_train),\n",
1345
+ " epoch_alert_freq=1,\n",
1346
+ " ) as log:\n",
1347
+ " if dist.distributed:\n",
1348
+ " sampler_train.set_epoch(epoch)\n",
1349
+ "\n",
1350
+ " # Train Loop\n",
1351
+ " model.train()\n",
1352
+ "\n",
1353
+ " for i, (inputs, outputs) in enumerate(dataloader_train):\n",
1354
+ " inputs = inputs.type(torch.FloatTensor).to(dist.device)\n",
1355
+ " outputs = outputs.type(torch.FloatTensor).to(dist.device)\n",
1356
+ " # Zero Gradients\n",
1357
+ " optimizer.zero_grad()\n",
1358
+ " # Compute Predictions\n",
1359
+ " pred = (\n",
1360
+ " model((inputs / input_norm).permute(0, 4, 1, 2, 3)).permute(\n",
1361
+ " 0, 2, 3, 4, 1\n",
1362
+ " )\n",
1363
+ " * output_norm\n",
1364
+ " )\n",
1365
+ " # Compute Loss\n",
1366
+ " loss, loss_dict = mhd_loss(pred, outputs, inputs, return_loss_dict=True)\n",
1367
+ " # Compute Gradients for Back Propagation\n",
1368
+ " loss.backward()\n",
1369
+ " # Update Weights\n",
1370
+ " optimizer.step()\n",
1371
+ "\n",
1372
+ " log.log_minibatch(loss_dict)\n",
1373
+ "\n",
1374
+ " log.log_epoch({\"Learning Rate\": optimizer.param_groups[0][\"lr\"]})\n",
1375
+ " scheduler.step()\n",
1376
+ "\n",
1377
+ " with LaunchLogger(\"valid\", epoch=epoch) as log:\n",
1378
+ " # Val loop\n",
1379
+ " model.eval()\n",
1380
+ " plot_count = 0\n",
1381
+ " with torch.no_grad():\n",
1382
+ " for i, (inputs, outputs) in enumerate(dataloader_val):\n",
1383
+ " inputs = inputs.type(dtype).to(dist.device)\n",
1384
+ " outputs = outputs.type(dtype).to(dist.device)\n",
1385
+ "\n",
1386
+ " # Compute Predictions\n",
1387
+ " pred = (\n",
1388
+ " model((inputs / input_norm).permute(0, 4, 1, 2, 3)).permute(\n",
1389
+ " 0, 2, 3, 4, 1\n",
1390
+ " )\n",
1391
+ " * output_norm\n",
1392
+ " )\n",
1393
+ " # Compute Loss\n",
1394
+ " loss, loss_dict = mhd_loss(\n",
1395
+ " pred, outputs, inputs, return_loss_dict=True\n",
1396
+ " )\n",
1397
+ "\n",
1398
+ " log.log_minibatch(loss_dict)\n",
1399
+ "\n",
1400
+ " # Get prediction plots to log\n",
1401
+ " # Do for number of batches specified in the config file\n",
1402
+ " if (i < log_params[\"log_num_plots\"]) and (\n",
1403
+ " epoch % log_params[\"log_plot_freq\"] == 0\n",
1404
+ " ):\n",
1405
+ " # Add all predictions in batch\n",
1406
+ " for j, _ in enumerate(pred):\n",
1407
+ " # Make plots for each field\n",
1408
+ " for index, name in enumerate(names):\n",
1409
+ " # Generate figure\n",
1410
+ " _ = plot_predictions_mhd_plotly(\n",
1411
+ " pred[j].cpu(),\n",
1412
+ " outputs[j].cpu(),\n",
1413
+ " inputs[j].cpu(),\n",
1414
+ " index=index,\n",
1415
+ " name=name,\n",
1416
+ " )\n",
1417
+ " plot_count += 1\n",
1418
+ "\n",
1419
+ " # Get prediction plots and save images locally\n",
1420
+ " if (i < 2) and (epoch % log_params[\"log_plot_freq\"] == 0):\n",
1421
+ " # Add all predictions in batch\n",
1422
+ " for j, _ in enumerate(pred):\n",
1423
+ " # Generate figure\n",
1424
+ " plot_predictions_mhd(\n",
1425
+ " pred[j].cpu(),\n",
1426
+ " outputs[j].cpu(),\n",
1427
+ " inputs[j].cpu(),\n",
1428
+ " names=names,\n",
1429
+ " save_path=os.path.join(\n",
1430
+ " output_dir,\n",
1431
+ " \"MHD_physicsnemo\" + \"_\" + str(dist.rank),\n",
1432
+ " ),\n",
1433
+ " save_suffix=i,\n",
1434
+ " )\n",
1435
+ "\n",
1436
+ " if epoch % ckpt_freq == 0 and dist.rank == 0:\n",
1437
+ " save_checkpoint(ckpt_path, model, optimizer, scheduler, epoch=epoch)\n",
1438
+ "\n",
1439
+ "```"
1440
+ ]
1441
+ },
1442
+ {
1443
+ "cell_type": "markdown",
1444
+ "id": "ca49a425",
1445
+ "metadata": {},
1446
+ "source": [
1447
+ "## Running the Training Script\n",
1448
+ "\n",
1449
+ "The full set of python code to start training is available in the folder `./mhd`. Configs, data generation, dataloaders, loss functions, model architectures, and training scripts are all available here. If utilizing the scripts outside of this HuggingFace Space, you can launch training with:\n",
1450
+ "\n",
1451
+ "```bash\n",
1452
+ "torchrun --standalone --nnodes=1 --nproc_per_node=1 train_mhd_vec_pot_tfno.py\n",
1453
+ "```\n",
1454
+ "\n",
1455
+ "With the default set of parameters, the model will take up around 5.2GB of GPU memory, and a full training run up to 100 epochs will take around 1.5 hours."
1456
+ ]
1457
+ },
1458
+ {
1459
+ "cell_type": "markdown",
1460
+ "id": "39c969ce",
1461
+ "metadata": {},
1462
+ "source": [
1463
+ "## End-to-End Training\n",
1464
+ "\n",
1465
+ "All of the code that was detailed above is available to explore in the \"./mhd\" folder. There are also two scripts that execute the end-to-end workflow for training and evaluation. "
1466
+ ]
1467
+ },
1468
+ {
1469
+ "cell_type": "code",
1470
+ "execution_count": null,
1471
+ "id": "6013cf5d-8232-45bf-8e79-1757bb29d3fe",
1472
+ "metadata": {},
1473
+ "outputs": [],
1474
+ "source": [
1475
+ "!python mhd/train_mhd_vec_pot_tfno.py"
1476
+ ]
1477
+ },
1478
+ {
1479
+ "cell_type": "markdown",
1480
+ "id": "daf003ec",
1481
+ "metadata": {},
1482
+ "source": [
1483
+ "## Transfer Learning to New Reynolds Number\n",
1484
+ "In practice, our system may not follow smooth, laminar flows described with low Reynolds numbers. In MHD systems, much of the magnetic field energy is stored at high wave numbers, which occur at smaller scales. Models must then be able to characterize high frequency features in order to successfully reproduce the trajectories of the system. These turbulent flows at higher Reynolds number are simulated, which will in turn produce higher frequency features that a model trained on smooth flows may not be able to resolve with good accuracy. To this end, transfer learning can be used to take a base model and adapt it to the new data domain by using a pre-trained checkpoint as the starting point of a new iteration of model training. \n",
1485
+ "\n",
1486
+ "To run transfer learning, we need a dataset of points from our new domain. For example, our default model is trained on data using $Re=100$, so we can use the model checkpoint from this domain to start off transferlerning to a new dataset with $Re=250$. In the Hydra config, we can update the following parameters:\n",
1487
+ "\n",
1488
+ "```yaml\n",
1489
+ "load_ckpt: True\n",
1490
+ "output_dir: \"/path/to/new/output_dir\"\n",
1491
+ "\n",
1492
+ "dataset_params:\n",
1493
+ " data_dir: \"/path/to/new/dataset\"\n",
1494
+ " name: 'Dataset Name'\n",
1495
+ "\n",
1496
+ "train_params:\n",
1497
+ " ckpt_path: \"/path/to/starting_checkpoint\"\n",
1498
+ "```\n",
1499
+ "\n",
1500
+ "<div style=\"display: flex; justify-content: center; gap: 10px;\">\n",
1501
+ " <figure style=\"text-align: center;\">\n",
1502
+ " <img src=\"https://raw.githubusercontent.com/openhackathons-org/End-to-End-AI-for-Science/main/workspace/python/jupyter_notebook/MagnetoHydrodynamics/images/high_frequency.png\" style=\"width: 100%; height: auto;\">\n",
1503
+ " <figcaption>Predictions with a large Reynolds Number.</figcaption>\n",
1504
+ " </figure>\n",
1505
+ "</div>"
1506
+ ]
1507
+ },
1508
+ {
1509
+ "cell_type": "markdown",
1510
+ "id": "93c7d926",
1511
+ "metadata": {},
1512
+ "source": [
1513
+ "## Evaluation\n",
1514
+ "When solving the MHD equations with `dedalus`, the average time per simulation is about 37 seconds. On the other hand, our physics informed model has an average inference time of 0.15 seconds, a 246x speedup. This comes at the cost of decreased accuracy in our solution, as it is an approximation to the system equations. Furthermore, our models performance will vary, depending on the Reynolds number. \n",
1515
+ "\n",
1516
+ "Evaluation can be run a few different ways. If there are many systems to evaluate, we can load them into a dataloader and do batch processing. In this example, we will use a standalone script, which is a stripped down version of the training script that will run our model with a single sample. \n",
1517
+ "\n",
1518
+ "To run evaluation we can use the following command, which pulls in a config that points to a specific pre-trained checkpoint and dataset. The config is found in `eval_mhd_vec_pot_tfno.yaml`\n",
1519
+ "\n",
1520
+ "\n",
1521
+ "```bash\n",
1522
+ "torchrun --standalone --nnodes=1 --nproc_per_node=1 evaluate_mhd_vec_pot_tfno.py\n",
1523
+ "```\n",
1524
+ "\n",
1525
+ "In evaluations, our model is able to accurately simulate flows at $Re<250$. Specifically, for $Re=100$, our surrogate model has less than 4% error at $t=1$ for all fields. At $Re=250$, the velocity field and vector potential potential are accurately described, with MSEs <7% and <10%, respectively. At higher Reynolds numbers, our model starts to break down. An example for $Re=100$ is shown below, as well as some plots showing $MSE$ vs $Re$.\n",
1526
+ "\n",
1527
+ "<div style=\"display: flex; justify-content: center; gap: 10px;\">\n",
1528
+ " <figure style=\"text-align: center;\">\n",
1529
+ " <img src=\"https://raw.githubusercontent.com/openhackathons-org/End-to-End-AI-for-Science/main/workspace/python/jupyter_notebook/MagnetoHydrodynamics/images/re100.png\" style=\"width: 100%; height: auto;\">\n",
1530
+ " <figcaption>Predictions with a low Reynolds Number.</figcaption>\n",
1531
+ " </figure>\n",
1532
+ "</div>\n",
1533
+ "<div style=\"display: flex; justify-content: center; gap: 10px;\">\n",
1534
+ " <figure style=\"text-align: center;\">\n",
1535
+ " <img src=\"https://raw.githubusercontent.com/openhackathons-org/End-to-End-AI-for-Science/main/workspace/python/jupyter_notebook/MagnetoHydrodynamics/images/mse_vs_re.png\" style=\"width: 100%; height: auto;\">\n",
1536
+ " <figcaption>Error vs. Reynolds Number.</figcaption>\n",
1537
+ " </figure>\n",
1538
+ "</div>"
1539
+ ]
1540
+ },
1541
+ {
1542
+ "cell_type": "markdown",
1543
+ "id": "9fec5e93",
1544
+ "metadata": {},
1545
+ "source": [
1546
+ "## End-to-End Evaluation\n",
1547
+ "To run evaluation, use the following script:"
1548
+ ]
1549
+ },
1550
+ {
1551
+ "cell_type": "code",
1552
+ "execution_count": null,
1553
+ "id": "b44a1bfd",
1554
+ "metadata": {},
1555
+ "outputs": [],
1556
+ "source": [
1557
+ "!python mhd/evaluate_mhd_vec_pot_tfno.py"
1558
+ ]
1559
+ },
1560
+ {
1561
+ "cell_type": "markdown",
1562
+ "id": "26acf9de",
1563
+ "metadata": {},
1564
+ "source": [
1565
+ "## Shortcomings and areas for improvement\n",
1566
+ "\n",
1567
+ "Physics informed machine learning shows promising results when applied to certain regions of parameter space as governed by the Reynolds number. While models such as tFNOs are able to accurately capture and simulate systems, they do not always perform well when the underlying physics begin to shift into regions of high frequency features. A tradeoff is present in accuracy and throughput, where these AI surrogate models accelerate simulations over 200x, however they remain accuracy for only the low Reynolds number parameter space. To this end, applying physics informed ML to the MHD equations shows both promise and room for improvement. For example, increased model size, additional physical loss functions from energy spectra, and higher resolution datasets may be a few areas in which the development and application of these models may be improved. In conclusion, the efficacy of physics informed machine learning has been shown to the modeling of magnetohydrodynamics, and researchers, scientists, and engineers are encouraged to build on this foundation to enhance these techniques further. "
1568
+ ]
1569
+ }
1570
+ ],
1571
+ "metadata": {
1572
+ "kernelspec": {
1573
+ "display_name": "Python 3 (ipykernel)",
1574
+ "language": "python",
1575
+ "name": "python3"
1576
+ },
1577
+ "language_info": {
1578
+ "codemirror_mode": {
1579
+ "name": "ipython",
1580
+ "version": 3
1581
+ },
1582
+ "file_extension": ".py",
1583
+ "mimetype": "text/x-python",
1584
+ "name": "python",
1585
+ "nbconvert_exporter": "python",
1586
+ "pygments_lexer": "ipython3",
1587
+ "version": "3.12.3"
1588
+ }
1589
+ },
1590
+ "nbformat": 4,
1591
+ "nbformat_minor": 5
1592
+ }
mhd/.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ *_results
2
+ outputs
3
+ logs
4
+ mhd_data
5
+ checkpoints
6
+ README.md
7
+ launch.log
8
+ requirements.txt
mhd/config/eval_mhd_vec_pot_tfno.yaml ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-FileCopyrightText: All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+
18
+ ## Training options
19
+ # Reynolds number parameter
20
+ reynolds_number: 100
21
+
22
+ load_ckpt: False
23
+ use_log: True
24
+ output_dir: './checkpoints/MHDVecPot_TFNO/MHDVecPot_TFNO_PINO_Re${reynolds_number}/figures/'
25
+ derivative: 'physicsnemo'
26
+
27
+ ###################
28
+ ## Model options
29
+ model_params:
30
+ layers: 8
31
+ modes: 8
32
+ num_fno_layers: 4
33
+ fc_dim: 128
34
+ decoder_layers: 1
35
+ in_dim: 6 # 3 + in_fields
36
+ out_dim: 3
37
+ dimension: 3
38
+ activation: 'gelu'
39
+ pad_x: 5
40
+ pad_y: 0
41
+ pad_z: 0
42
+ input_norm: [1.0, 1.0, 1.0, 1.0, 1.0, 0.00025]
43
+ output_norm: [1.0, 1.0, 0.00025]
44
+
45
+ #TensorLy arguments
46
+ rank: 0.5
47
+ factorization: 'cp'
48
+ fixed_rank_modes: null
49
+
50
+ ###################
51
+ ## Dataset options
52
+ dataset_params:
53
+ data_dir: '/data/mhd_data/simulation_outputs_Re${reynolds_number}'
54
+ field_names: ['velocity', 'vector potential']
55
+ output_names: 'output-????'
56
+ dataset_type: 'mhd'
57
+ name: 'MHDVecPot_TFNO_Re${reynolds_number}'
58
+ num: -1 # -1 means use full dataset for evaluation
59
+ num_train: 0.8 # percentage of dataset for training (not used in eval)
60
+ num_test: 0.2 # percentage of dataset for testing (not used in eval)
61
+ sub_x: 1
62
+ sub_t: 1
63
+ ind_x: null
64
+ ind_t: null
65
+ nin: 3
66
+ nout: 3
67
+ fields: ['u', 'v', 'A']
68
+
69
+ ###################
70
+ ## Dataloader options
71
+ test_loader_params:
72
+ batch_size: 1
73
+ shuffle: False
74
+ num_workers: 4
75
+ pin_memory: True
76
+
77
+ ###################
78
+ ## Loss options
79
+ loss_params:
80
+ nu: 0.004
81
+ eta: 0.004
82
+ rho0: 1.0
83
+
84
+ data_weight: 5.0
85
+ ic_weight: 1.0
86
+ pde_weight: 1.0
87
+ constraint_weight: 10.0
88
+
89
+ use_data_loss: True
90
+ use_ic_loss: True
91
+ use_pde_loss: True
92
+ use_constraint_loss: True
93
+
94
+ u_weight: 1.0
95
+ v_weight: 1.0
96
+ A_weight: 1.0
97
+
98
+ Du_weight: 1.0
99
+ Dv_weight: 1.0
100
+ DA_weight: 1_000_000
101
+
102
+ div_B_weight: 1.0
103
+ div_vel_weight: 1.0
104
+
105
+ Lx: 1.0
106
+ Ly: 1.0
107
+ tend: 1.0
108
+
109
+ use_weighted_mean: False
110
+
111
+ ###################
112
+ ## Optimizer options
113
+ optimizer_params:
114
+ betas: [0.9, 0.999]
115
+ lr: 5.0e-4
116
+ milestones: [20, 40, 60, 80, 100]
117
+ gamma: 0.5
118
+
119
+
120
+ ###################
121
+ ## Train params
122
+ train_params:
123
+ epochs: 100
124
+ ckpt_freq: 50
125
+ ckpt_path: 'checkpoints/MHDVecPot_TFNO/MHDVecPot_TFNO_PINO_Re${reynolds_number}/'
126
+
127
+ ###################
128
+ ## log params
129
+ log_params:
130
+ log_dir: 'logs'
131
+ log_project: 'MHD_PINO'
132
+ log_group: 'MHDVecPot_TFNO_Re${reynolds_number}'
133
+ log_num_plots: 1
134
+ log_plot_freq: 5
135
+ log_plot_types: ['ic', 'pred', 'true', 'error']
136
+
137
+ test:
138
+ batchsize: 1
139
+ ckpt_path: 'checkpoints/MHDVecPot_TFNO/MHDVecPot_TFNO_PINO_Re${reynolds_number}/'
mhd/config/train_mhd_vec_pot_tfno.yaml ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-FileCopyrightText: All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+
18
+ ## Training options
19
+ # Reynolds number parameter
20
+ reynolds_number: 100
21
+
22
+ load_ckpt: False
23
+ output_dir: './outputs/Re${reynolds_number}/figures/'
24
+
25
+ ###################
26
+ ## Model options
27
+ model_params:
28
+ layers: 8
29
+ modes: 8
30
+ num_fno_layers: 4
31
+ fc_dim: 128
32
+ decoder_layers: 1
33
+ in_dim: 6 # 3 + in_fields
34
+ out_dim: 3
35
+ dimension: 3
36
+ activation: 'gelu'
37
+ pad_x: 5
38
+ pad_y: 0
39
+ pad_z: 0
40
+ input_norm: [1.0, 1.0, 1.0, 1.0, 1.0, 0.00025]
41
+ output_norm: [1.0, 1.0, 0.00025]
42
+
43
+ #TensorLy arguments
44
+ rank: 0.5
45
+ factorization: 'cp'
46
+ fixed_rank_modes: null
47
+ decomposition_kwargs: {}
48
+
49
+ ###################
50
+ ## Dataset options
51
+ dataset_params:
52
+ data_dir: '/data/mhd_data/simulation_outputs_Re${reynolds_number}'
53
+ field_names: ['velocity', 'vector potential']
54
+ output_names: 'output-????'
55
+ dataset_type: 'mhd'
56
+ name: 'MHDVecPot_TFNO_Re${reynolds_number}'
57
+ num: -1 # -1 means use full dataset, otherwise specify total number
58
+ num_train: 0.8 # percentage of dataset for training
59
+ num_test: 0.2 # percentage of dataset for testing
60
+ sub_x: 1
61
+ sub_t: 1
62
+ ind_x: null
63
+ ind_t: null
64
+ nin: 3
65
+ nout: 3
66
+ fields: ['u', 'v', 'A']
67
+
68
+ ###################
69
+ ## Dataloader options
70
+ train_loader_params:
71
+ batch_size: 4
72
+ shuffle: True
73
+ num_workers: 8
74
+ pin_memory: True
75
+
76
+ val_loader_params:
77
+ batch_size: 4
78
+ shuffle: False
79
+ num_workers: 8
80
+ pin_memory: True
81
+
82
+ test_loader_params:
83
+ batch_size: 4
84
+ shuffle: False
85
+ num_workers: 8
86
+ pin_memory: True
87
+
88
+ ###################
89
+ ## Loss options
90
+ loss_params:
91
+ nu: 0.004
92
+ eta: 0.004
93
+ rho0: 1.0
94
+
95
+ data_weight: 5.0
96
+ ic_weight: 1.0
97
+ pde_weight: 1.0
98
+ constraint_weight: 10.0
99
+
100
+ use_data_loss: True
101
+ use_ic_loss: True
102
+ use_pde_loss: True
103
+ use_constraint_loss: True
104
+
105
+ u_weight: 1.0
106
+ v_weight: 1.0
107
+ A_weight: 1.0
108
+
109
+ Du_weight: 1.0
110
+ Dv_weight: 1.0
111
+ DA_weight: 1_000_000
112
+
113
+ div_B_weight: 1.0
114
+ div_vel_weight: 1.0
115
+
116
+ Lx: 1.0
117
+ Ly: 1.0
118
+ tend: 1.0
119
+
120
+ use_weighted_mean: False
121
+
122
+ ###################
123
+ ## Optimizer options
124
+ optimizer_params:
125
+ betas: [0.9, 0.999]
126
+ lr: 5.0e-4
127
+ milestones: [20, 40, 60, 80, 100]
128
+ gamma: 0.5
129
+
130
+
131
+ ###################
132
+ ## Train params
133
+ train_params:
134
+ epochs: 100
135
+ ckpt_freq: 10
136
+ ckpt_path: './outputs/checkpoints/Re${reynolds_number}/'
137
+
138
+ ###################
139
+ ## log params
140
+ log_params:
141
+ log_dir: 'logs'
142
+ log_project: 'MHD_PINO'
143
+ log_group: 'MHDVecPot_TFNO_Re${reynolds_number}'
144
+ log_num_plots: 1
145
+ log_plot_freq: 5
146
+ log_plot_types: ['ic', 'pred', 'true', 'error']
147
+
148
+ test:
149
+ batchsize: 1
150
+ ckpt_path: './outputs/checkpoints/Re${reynolds_number}/'
mhd/dataloaders/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-FileCopyrightText: All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from .datasets import Dedalus2DDataset
18
+ from .dataloaders import MHDDataloader, MHDDataloaderVecPot
mhd/dataloaders/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (271 Bytes). View file
 
mhd/dataloaders/__pycache__/dataloaders.cpython-312.pyc ADDED
Binary file (9.62 kB). View file
 
mhd/dataloaders/__pycache__/datasets.cpython-312.pyc ADDED
Binary file (4.14 kB). View file
 
mhd/dataloaders/dataloaders.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-FileCopyrightText: All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import torch
18
+ from IPython.display import display
19
+ from torch.utils.data import DataLoader, Dataset
20
+
21
+ try:
22
+ from .datasets import Dedalus2DDataset
23
+ except:
24
+ from datasets import Dedalus2DDataset
25
+
26
+
27
+ class MHDDataloader(Dataset):
28
+ "Dataloader for MHD Dataset with magnetic field"
29
+
30
+ def __init__(
31
+ self, dataset: Dedalus2DDataset, sub_x=1, sub_t=1, ind_x=None, ind_t=None
32
+ ):
33
+ self.dataset = dataset
34
+ self.sub_x = sub_x
35
+ self.sub_t = sub_t
36
+ self.ind_x = ind_x
37
+ self.ind_t = ind_t
38
+ t, x, y = dataset.get_coords(0)
39
+ self.x = x[:ind_x:sub_x]
40
+ self.y = y[:ind_x:sub_x]
41
+ self.t = t[:ind_t:sub_t]
42
+ self.nx = len(self.x)
43
+ self.ny = len(self.y)
44
+ self.nt = len(self.t)
45
+ self.num = num = len(self.dataset)
46
+ self.x_slice = slice(0, self.ind_x, self.sub_x)
47
+ self.t_slice = slice(0, self.ind_t, self.sub_t)
48
+
49
+ def __len__(self):
50
+ length = len(self.dataset)
51
+ return length
52
+
53
+ def __getitem__(self, index):
54
+ "Gets input of dataloader, including data, t, x, and y"
55
+ fields = self.dataset[index]
56
+
57
+ # Data includes velocity and magnetic field
58
+ velocity = fields["velocity"]
59
+ magnetic_field = fields["magnetic field"]
60
+
61
+ u = torch.from_numpy(
62
+ velocity[
63
+ : self.ind_t : self.sub_t,
64
+ 0,
65
+ : self.ind_x : self.sub_x,
66
+ : self.ind_x : self.sub_x,
67
+ ]
68
+ )
69
+ v = torch.from_numpy(
70
+ velocity[
71
+ : self.ind_t : self.sub_t,
72
+ 1,
73
+ : self.ind_x : self.sub_x,
74
+ : self.ind_x : self.sub_x,
75
+ ]
76
+ )
77
+ Bx = torch.from_numpy(
78
+ magnetic_field[
79
+ : self.ind_t : self.sub_t,
80
+ 0,
81
+ : self.ind_x : self.sub_x,
82
+ : self.ind_x : self.sub_x,
83
+ ]
84
+ )
85
+ By = torch.from_numpy(
86
+ magnetic_field[
87
+ : self.ind_t : self.sub_t,
88
+ 1,
89
+ : self.ind_x : self.sub_x,
90
+ : self.ind_x : self.sub_x,
91
+ ]
92
+ )
93
+
94
+ # shape is now (nt, nx, ny, nfields)
95
+ data = torch.stack([u, v, Bx, By], dim=-1)
96
+ data0 = data[0].reshape(1, self.nx, self.ny, -1).repeat(self.nt, 1, 1, 1)
97
+
98
+ grid_t = (
99
+ torch.from_numpy(self.t)
100
+ .reshape(self.nt, 1, 1, 1)
101
+ .repeat(1, self.nx, self.ny, 1)
102
+ )
103
+ grid_x = (
104
+ torch.from_numpy(self.x)
105
+ .reshape(1, self.nx, 1, 1)
106
+ .repeat(self.nt, 1, self.ny, 1)
107
+ )
108
+ grid_y = (
109
+ torch.from_numpy(self.y)
110
+ .reshape(1, 1, self.ny, 1)
111
+ .repeat(self.nt, self.nx, 1, 1)
112
+ )
113
+
114
+ inputs = torch.cat([grid_t, grid_x, grid_y, data0], dim=-1)
115
+ outputs = data
116
+
117
+ return inputs, outputs
118
+
119
+ def create_dataloader(
120
+ self,
121
+ batch_size=1,
122
+ shuffle=False,
123
+ num_workers=0,
124
+ pin_memory=False,
125
+ distributed=False,
126
+ ):
127
+ "Creates dataloader and sampler based on whether distributed training is on"
128
+ if distributed:
129
+ sampler = torch.utils.data.DistributedSampler(self)
130
+ dataloader = DataLoader(
131
+ self,
132
+ batch_size=batch_size,
133
+ shuffle=False,
134
+ sampler=sampler,
135
+ num_workers=num_workers,
136
+ pin_memory=pin_memory,
137
+ )
138
+ else:
139
+ sampler = None
140
+ dataloader = DataLoader(
141
+ self,
142
+ batch_size=batch_size,
143
+ shuffle=shuffle,
144
+ num_workers=num_workers,
145
+ pin_memory=pin_memory,
146
+ )
147
+
148
+ return dataloader, sampler
149
+
150
+
151
+ class MHDDataloaderVecPot(MHDDataloader):
152
+ "Dataloader for MHD Dataset with vector potential"
153
+
154
+ def __init__(
155
+ self, dataset: Dedalus2DDataset, sub_x=1, sub_t=1, ind_x=None, ind_t=None
156
+ ):
157
+ self.dataset = dataset
158
+ self.sub_x = sub_x
159
+ self.sub_t = sub_t
160
+ self.ind_x = ind_x
161
+ self.ind_t = ind_t
162
+ t, x, y = dataset.get_coords(0)
163
+ self.x = x[:ind_x:sub_x]
164
+ self.y = y[:ind_x:sub_x]
165
+ self.t = t[:ind_t:sub_t]
166
+ self.nx = len(self.x)
167
+ self.ny = len(self.y)
168
+ self.nt = len(self.t)
169
+ self.num = num = len(self.dataset)
170
+ self.x_slice = slice(0, self.ind_x, self.sub_x)
171
+ self.t_slice = slice(0, self.ind_t, self.sub_t)
172
+
173
+ def __len__(self):
174
+ length = len(self.dataset)
175
+ return length
176
+
177
+ def __getitem__(self, index):
178
+ "Gets input of dataloader, including data, t, x, and y"
179
+ fields = self.dataset[index]
180
+
181
+ # Data includes velocity and vector potential
182
+ velocity = fields["velocity"]
183
+ vector_potential = fields["vector potential"]
184
+
185
+ u = torch.from_numpy(
186
+ velocity[
187
+ : self.ind_t : self.sub_t,
188
+ 0,
189
+ : self.ind_x : self.sub_x,
190
+ : self.ind_x : self.sub_x,
191
+ ]
192
+ )
193
+ v = torch.from_numpy(
194
+ velocity[
195
+ : self.ind_t : self.sub_t,
196
+ 1,
197
+ : self.ind_x : self.sub_x,
198
+ : self.ind_x : self.sub_x,
199
+ ]
200
+ )
201
+ A = torch.from_numpy(
202
+ vector_potential[
203
+ : self.ind_t : self.sub_t,
204
+ : self.ind_x : self.sub_x,
205
+ : self.ind_x : self.sub_x,
206
+ ]
207
+ )
208
+
209
+ # shape is now (self.nt, self.nx, self.ny, nfields)
210
+ data = torch.stack([u, v, A], dim=-1)
211
+ data0 = data[0].reshape(1, self.nx, self.ny, -1).repeat(self.nt, 1, 1, 1)
212
+
213
+ grid_t = (
214
+ torch.from_numpy(self.t)
215
+ .reshape(self.nt, 1, 1, 1)
216
+ .repeat(1, self.nx, self.ny, 1)
217
+ )
218
+ grid_x = (
219
+ torch.from_numpy(self.x)
220
+ .reshape(1, self.nx, 1, 1)
221
+ .repeat(self.nt, 1, self.ny, 1)
222
+ )
223
+ grid_y = (
224
+ torch.from_numpy(self.y)
225
+ .reshape(1, 1, self.ny, 1)
226
+ .repeat(self.nt, self.nx, 1, 1)
227
+ )
228
+
229
+ inputs = torch.cat([grid_t, grid_x, grid_y, data0], dim=-1)
230
+ outputs = data
231
+
232
+ return inputs, outputs
233
+
234
+ def create_dataloader(
235
+ self,
236
+ batch_size=1,
237
+ shuffle=False,
238
+ num_workers=0,
239
+ pin_memory=False,
240
+ distributed=False,
241
+ ):
242
+ "Creates dataloader and sampler based on whether distributed training is on"
243
+ if distributed:
244
+ sampler = torch.utils.data.DistributedSampler(self)
245
+ dataloader = DataLoader(
246
+ self,
247
+ batch_size=batch_size,
248
+ shuffle=False,
249
+ sampler=sampler,
250
+ num_workers=num_workers,
251
+ pin_memory=pin_memory,
252
+ )
253
+ else:
254
+ sampler = None
255
+ dataloader = DataLoader(
256
+ self,
257
+ batch_size=batch_size,
258
+ shuffle=shuffle,
259
+ num_workers=num_workers,
260
+ pin_memory=pin_memory,
261
+ )
262
+
263
+ return dataloader, sampler
264
+
265
+
266
+
267
+ if __name__ == "__main__":
268
+ dataset = Dedalus2DDataset(
269
+ data_path="../mhd_data/simulation_outputs_Re250",
270
+ output_names="output-????",
271
+ field_names=["magnetic field", "velocity", "vector potential"],
272
+ )
273
+ mhd_dataloader = MHDDataloader(dataset)
274
+ mhd_vec_pot_dataloader = MHDDataloaderVecPot(dataset)
275
+
276
+ data = mhd_dataloader[0]
277
+ display(data)
mhd/dataloaders/datasets.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-FileCopyrightText: All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import glob
18
+ import os
19
+
20
+ import h5py
21
+ from torch.utils import data
22
+
23
+
24
+ class Dedalus2DDataset(data.Dataset):
25
+ "Dataset for MHD 2D Dataset"
26
+
27
+ def __init__(
28
+ self,
29
+ data_path,
30
+ output_names="output-",
31
+ field_names=["magnetic field", "velocity"],
32
+ num_train=None,
33
+ num_test=None,
34
+ num=None,
35
+ use_train=True,
36
+ ):
37
+ self.data_path = data_path
38
+ output_names = "output-" + "?"*len(str(len(os.listdir(data_path))))
39
+ self.output_names = output_names
40
+ raw_path = os.path.join(data_path, output_names, "*.h5")
41
+ files_raw = sorted(glob.glob(raw_path))
42
+ self.files_raw = files_raw
43
+ self.num_files_raw = num_files_raw = len(files_raw)
44
+ self.field_names = field_names
45
+ self.use_train = use_train
46
+
47
+ # Handle num parameter: -1 means use full dataset, otherwise limit to specified number
48
+ if num is not None and num > 0:
49
+ num_files_raw = min(num, num_files_raw)
50
+ files_raw = files_raw[:num_files_raw]
51
+ self.files_raw = files_raw
52
+ self.num_files_raw = num_files_raw
53
+
54
+ # Handle percentage-based splits
55
+ if num_train is not None and num_train <= 1.0:
56
+ # num_train is a percentage
57
+ num_train = int(num_train * num_files_raw)
58
+ elif num_train is None or num_train > num_files_raw:
59
+ num_train = num_files_raw
60
+
61
+ if num_test is not None and num_test <= 1.0:
62
+ # num_test is a percentage
63
+ num_test = int(num_test * num_files_raw)
64
+ elif num_test is None or num_test > (num_files_raw - num_train):
65
+ num_test = num_files_raw - num_train
66
+
67
+ self.num_train = num_train
68
+ self.train_files = self.files_raw[:num_train]
69
+ self.num_test = num_test
70
+ self.test_end = test_end = num_train + num_test
71
+ self.test_files = self.files_raw[num_train:test_end]
72
+
73
+ if (self.use_train) or (self.test_files is None):
74
+ files = self.train_files
75
+ else:
76
+ files = self.test_files
77
+ self.files = files
78
+ self.num_files = num_files = len(files)
79
+
80
+ def __len__(self):
81
+ length = len(self.files)
82
+ return length
83
+
84
+ def __getitem__(self, index):
85
+ "Gets item for dataloader"
86
+ file = self.files[index]
87
+
88
+ field_names = self.field_names
89
+ fields = {}
90
+ coords = []
91
+ with h5py.File(file, mode="r") as h5file:
92
+ data_file = h5file["tasks"]
93
+ keys = list(data_file.keys())
94
+ if field_names is None:
95
+ field_names = keys
96
+ for field_name in field_names:
97
+ if field_name in data_file:
98
+ field = data_file[field_name][:]
99
+ fields[field_name] = field
100
+ else:
101
+ print(f"field name {field_name} not found")
102
+ dataset = fields
103
+ return dataset
104
+
105
+ def get_coords(self, index):
106
+ "Gets coordinates of t, x, y for dataloader"
107
+ file = self.files[index]
108
+ with h5py.File(file, mode="r") as h5file:
109
+ data_file = h5file["tasks"]
110
+ keys = list(data_file.keys())
111
+ dims = data_file[keys[0]].dims
112
+
113
+ ndims = len(dims)
114
+ t = dims[0]["sim_time"][:]
115
+ x = dims[ndims - 2][0][:]
116
+ y = dims[ndims - 1][0][:]
117
+ return t, x, y
mhd/evaluate_mhd_vec_pot_tfno.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-FileCopyrightText: All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import os
18
+
19
+ import hydra
20
+ import torch
21
+ from omegaconf import DictConfig, OmegaConf
22
+ from physicsnemo.distributed import DistributedManager
23
+ from physicsnemo.launch.logging import LaunchLogger, PythonLogger
24
+ from physicsnemo.sym.hydra import to_absolute_path
25
+ from torch.nn.parallel import DistributedDataParallel
26
+ from torch.optim import AdamW
27
+ import time
28
+
29
+ from dataloaders import Dedalus2DDataset, MHDDataloaderVecPot
30
+ from losses import LossMHDVecPot_PhysicsNeMo
31
+ from tfno import TFNO
32
+ from utils.plot_utils import plot_predictions_mhd, plot_predictions_mhd_plotly
33
+
34
+ dtype = torch.float
35
+ torch.set_default_dtype(dtype)
36
+
37
+
38
+ @hydra.main(
39
+ version_base="1.3", config_path="config", config_name="eval_mhd_vec_pot_tfno.yaml"
40
+ )
41
+ def main(cfg: DictConfig) -> None:
42
+ DistributedManager.initialize() # Only call this once in the entire script!
43
+ dist = DistributedManager() # call if required elsewhere
44
+ cfg = OmegaConf.to_container(cfg, resolve=True)
45
+ # initialize monitoring
46
+ log = PythonLogger(name="mhd_pino")
47
+ log.file_logging()
48
+ # Load config file parameters
49
+ model_params = cfg["model_params"]
50
+ dataset_params = cfg["dataset_params"]
51
+ test_loader_params = cfg["test_loader_params"]
52
+ loss_params = cfg["loss_params"]
53
+ optimizer_params = cfg["optimizer_params"]
54
+
55
+ output_dir = cfg["output_dir"]
56
+ test_params = cfg["test"]
57
+ load_checkpoint = cfg.get("load_ckpt", False)
58
+
59
+ output_dir = to_absolute_path(output_dir)
60
+ os.makedirs(output_dir, exist_ok=True)
61
+
62
+ data_dir = dataset_params["data_dir"]
63
+
64
+ # Construct dataloaders
65
+ dataset_test = Dedalus2DDataset(
66
+ data_dir,
67
+ output_names=dataset_params["output_names"],
68
+ field_names=dataset_params["field_names"],
69
+ num_train=dataset_params["num_train"],
70
+ num_test=dataset_params["num_test"],
71
+ num=dataset_params["num"],
72
+ use_train=False,
73
+ )
74
+ mhd_dataloader_test = MHDDataloaderVecPot(
75
+ dataset_test,
76
+ sub_x=dataset_params["sub_x"],
77
+ sub_t=dataset_params["sub_t"],
78
+ ind_x=dataset_params["ind_x"],
79
+ ind_t=dataset_params["ind_t"],
80
+ )
81
+ dataloader_test, sampler_test = mhd_dataloader_test.create_dataloader(
82
+ batch_size=test_loader_params["batch_size"],
83
+ shuffle=test_loader_params["shuffle"],
84
+ num_workers=test_loader_params["num_workers"],
85
+ pin_memory=test_loader_params["pin_memory"],
86
+ distributed=dist.distributed,
87
+ )
88
+
89
+ # define FNO model
90
+ model = TFNO(
91
+ in_channels=model_params["in_dim"],
92
+ out_channels=model_params["out_dim"],
93
+ decoder_layers=model_params["decoder_layers"],
94
+ decoder_layer_size=model_params["fc_dim"],
95
+ dimension=model_params["dimension"],
96
+ latent_channels=model_params["layers"],
97
+ num_fno_layers=model_params["num_fno_layers"],
98
+ num_fno_modes=model_params["modes"],
99
+ padding=[model_params["pad_z"], model_params["pad_y"], model_params["pad_x"]],
100
+ rank=model_params["rank"],
101
+ factorization=model_params["factorization"],
102
+ fixed_rank_modes=model_params["fixed_rank_modes"],
103
+ ).to(dist.device)
104
+
105
+ # Set up DistributedDataParallel if using more than a single process.
106
+ # The `distributed` property of DistributedManager can be used to
107
+ # check this.
108
+ if dist.distributed:
109
+ ddps = torch.cuda.Stream()
110
+ with torch.cuda.stream(ddps):
111
+ model = DistributedDataParallel(
112
+ model,
113
+ device_ids=[dist.local_rank], # Set the device_id to be
114
+ # the local rank of this process on
115
+ # this node
116
+ output_device=dist.device,
117
+ broadcast_buffers=dist.broadcast_buffers,
118
+ find_unused_parameters=dist.find_unused_parameters,
119
+ )
120
+ torch.cuda.current_stream().wait_stream(ddps)
121
+
122
+ # Construct optimizer and scheduler
123
+ optimizer = AdamW(
124
+ model.parameters(),
125
+ betas=optimizer_params["betas"],
126
+ lr=optimizer_params["lr"],
127
+ weight_decay=0.1,
128
+ )
129
+
130
+ scheduler = torch.optim.lr_scheduler.MultiStepLR(
131
+ optimizer,
132
+ milestones=optimizer_params["milestones"],
133
+ gamma=optimizer_params["gamma"],
134
+ )
135
+
136
+ # Construct Loss class
137
+ mhd_loss = LossMHDVecPot_PhysicsNeMo(**loss_params)
138
+
139
+ # Load model from checkpoint (if exists)
140
+ if load_checkpoint:
141
+ _ = load_checkpoint(
142
+ test_params["ckpt_path"], model, optimizer, scheduler, device=dist.device
143
+ )
144
+
145
+ # Eval Loop
146
+ names = dataset_params["fields"]
147
+ input_norm = torch.tensor(model_params["input_norm"]).to(dist.device)
148
+ output_norm = torch.tensor(model_params["output_norm"]).to(dist.device)
149
+
150
+ with LaunchLogger("test") as log:
151
+ # Val loop
152
+ model.eval()
153
+ plot_count = 0
154
+ with torch.no_grad():
155
+ for i, (inputs, outputs) in enumerate(dataloader_test):
156
+ inputs = inputs.type(dtype).to(dist.device)
157
+ outputs = outputs.type(dtype).to(dist.device)
158
+ start_time = time.time()
159
+ # Compute Predictions
160
+ pred = (
161
+ model((inputs / input_norm).permute(0, 4, 1, 2, 3)).permute(
162
+ 0, 2, 3, 4, 1
163
+ )
164
+ * output_norm
165
+ )
166
+ end_time = time.time()
167
+ print(f"Inference Time: {end_time-start_time}")
168
+ # Compute Loss
169
+ loss, loss_dict = mhd_loss(pred, outputs, inputs, return_loss_dict=True)
170
+
171
+ log.log_minibatch(loss_dict)
172
+
173
+ # Get prediction plots
174
+ for j, _ in enumerate(pred):
175
+ # Make plots for each field
176
+ for index, name in enumerate(names):
177
+ # Generate figure
178
+ _ = plot_predictions_mhd_plotly(
179
+ pred[j].cpu(),
180
+ outputs[j].cpu(),
181
+ inputs[j].cpu(),
182
+ index=index,
183
+ name=name,
184
+ )
185
+
186
+ plot_count += 1
187
+
188
+ # Get prediction plots and save images locally
189
+ for j, _ in enumerate(pred):
190
+ # Generate figure
191
+ plot_predictions_mhd(
192
+ pred[j].cpu(),
193
+ outputs[j].cpu(),
194
+ inputs[j].cpu(),
195
+ names=names,
196
+ save_path=os.path.join(
197
+ output_dir,
198
+ "MHD_eval_" + str(dist.rank),
199
+ ),
200
+ save_suffix=i,
201
+ )
202
+
203
+
204
+ if __name__ == "__main__":
205
+ main()
mhd/generate_mhd_data/dedalus_mhd_parallel.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-FileCopyrightText: All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ """
18
+ Dedalus script simulating a 2D periodic incompressible MHD flow with a passive
19
+ tracer field for visualization. This script demonstrates solving a 2D periodic
20
+ initial value problem. This script is meant to be ran in parallel, and uses the
21
+ built-in analysis framework to save data snapshots to HDF5 files.
22
+ The simulation should take at least 100 gpu-minutes to run.
23
+
24
+ The initial flow is in the x-direction and depends only on z. The problem is
25
+ non-dimensionalized usign the shear-layer spacing and velocity jump, so the
26
+ resulting viscosity and tracer diffusivity are related to the Reynolds and
27
+ Schmidt numbers as:
28
+
29
+ nu = 1 / Re
30
+ eta = 1 / ReM
31
+ D = nu / Schmidt
32
+
33
+ To run this script:
34
+ $ python dedalus_mhd_parallel.py
35
+ """
36
+
37
+
38
+ import os
39
+ import glob
40
+ import h5py
41
+ import numpy as np
42
+ import functools
43
+ from functools import partial
44
+ import matplotlib
45
+ import matplotlib.pyplot as plt
46
+ import argparse
47
+ import multiprocessing as mp
48
+ import dedalus
49
+ import dedalus.public as d3
50
+ from dedalus.extras import plot_tools
51
+ import pathlib
52
+ from docopt import docopt
53
+ from dedalus.tools import logging
54
+ from dedalus.tools import post
55
+ from dedalus.tools.parallel import Sync
56
+ import logging
57
+ import math
58
+ from IPython.display import display
59
+ import imageio
60
+ from importlib import reload
61
+ from my_random_fields import GRF_Mattern
62
+ import torch
63
+ from functorch import vmap
64
+ from hydra import compose, initialize
65
+ from hydra.utils import get_class
66
+
67
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
68
+ # display(device)
69
+
70
+
71
+ def check_if_complete(sim_outputs, Nt=101):
72
+ try:
73
+ files = sorted(glob.glob(sim_outputs))
74
+ file = files[0]
75
+ with h5py.File(file, mode="r") as h5file:
76
+ data_file = h5file["tasks"]
77
+ keys = list(data_file.keys())
78
+ dims = data_file[keys[0]].dims
79
+ t = dims[0]["sim_time"][:]
80
+ if len(t) == Nt:
81
+ return True
82
+ else:
83
+ return False
84
+ except Exception:
85
+ return False
86
+
87
+
88
+ if __name__ == "__main__":
89
+ import sys
90
+
91
+ # Parse command line args before Hydra initialization
92
+ parser = argparse.ArgumentParser(add_help=False)
93
+ parser.add_argument('--Re', type=float, help='Reynolds number')
94
+ parser.add_argument('--N', type=int, help='Number of samples')
95
+ args, remaining_argv = parser.parse_known_args()
96
+
97
+ # Initialize Hydra with remaining args
98
+ sys.argv = [sys.argv[0]] + remaining_argv
99
+ initialize(version_base=None, config_path=".", job_name="generate_mhd_field")
100
+ cfg = compose(config_name="mhd_field")
101
+
102
+ # Parameters - override with command line args if provided
103
+ Lx, Ly = cfg.Lx, cfg.Ly
104
+ Nx, Ny = cfg.Nx, cfg.Ny
105
+ Re = args.Re if args.Re is not None else cfg.Re # Use CLI arg or default to config
106
+ Re = int(Re)
107
+ ReM = Re
108
+ Schmidt = cfg.Schmidt # 1
109
+ rho0 = cfg.rho0 # 1.0
110
+ dealias = cfg.dealias # 3/2
111
+ stop_sim_time = cfg.tend
112
+ timestepper = get_class(cfg.timestepper) # d3.RK443 #d3.RK222
113
+ Dt = cfg.Dt # 1e-3
114
+ max_timestep = cfg.max_timestep # 1e-2
115
+ output_dt = cfg.output_dt # 1e-2 # 1e-1
116
+ log_iter = cfg.log_iter # 10
117
+ dtype = get_class(cfg.dtype) # np.float64
118
+ max_writes = cfg.max_writes # None
119
+ logger = logging.getLogger(__name__)
120
+ output_dir = f"/Datasets/mhd_data/simulation_outputs_Re{Re}"
121
+ movie_dir = f"{output_dir}/movie"
122
+ use_cfl = cfg.use_cfl # False
123
+ skip_exists = cfg.skip_exists # False
124
+
125
+ ## ID Parameters
126
+ L = cfg.L # 1
127
+ dim = 2
128
+ Nsamples = args.N if args.N is not None else cfg.N # Use CLI arg or default to config
129
+ l_u = cfg.l_u # 0.1
130
+ l_A = cfg.l_A # 0.1
131
+ Nu = cfg.Nu # None
132
+ sigma_u = cfg.sigma_u # 0.1
133
+ sigma_A = cfg.sigma_A # 5e-3
134
+
135
+ # Generate Random Initial Data
136
+ grf_u = GRF_Mattern(
137
+ dim,
138
+ Nx,
139
+ length=Lx,
140
+ nu=Nu,
141
+ l=l_u,
142
+ sigma=sigma_u,
143
+ boundary="periodic",
144
+ device=device,
145
+ )
146
+ grf_A = GRF_Mattern(
147
+ dim,
148
+ Nx,
149
+ length=Lx,
150
+ nu=Nu,
151
+ l=l_A,
152
+ sigma=sigma_A,
153
+ boundary="periodic",
154
+ device=device,
155
+ )
156
+
157
+ u0_pot = grf_u.sample(Nsamples).cpu().numpy().reshape(Nsamples, Nx, Ny)
158
+ A0 = grf_A.sample(Nsamples).cpu().numpy().reshape(Nsamples, Nx, Ny)
159
+ digits = int(math.log10(Nsamples)) + 1
160
+
161
+ # expected number of time steps
162
+ Nt = len(np.arange(0, stop_sim_time + Dt, output_dt))
163
+ indices = list(range(Nsamples))
164
+
165
+ if skip_exists:
166
+ completed_list = []
167
+ for j in range(Nsamples):
168
+ # print('hi')
169
+ sim_output_dir = os.path.join(output_dir, f"output-{j:0{digits}}")
170
+ sim_outputs = os.path.join(sim_output_dir, "*.h5")
171
+ # skip if the next output directory exists and if the output is complete
172
+ if os.path.exists(sim_output_dir):
173
+ completed = check_if_complete(sim_outputs, Nt=Nt)
174
+ else:
175
+ completed = False
176
+ completed_list.append(completed)
177
+ indices = [j for j, completed in enumerate(completed_list) if not completed]
178
+ print(indices)
179
+
180
+ def run_simulation(
181
+ i,
182
+ Lx=Lx,
183
+ Ly=Ly,
184
+ Nx=Nx,
185
+ Ny=Ny,
186
+ Re=Re,
187
+ ReM=ReM,
188
+ Schmidt=Schmidt,
189
+ rho0=rho0,
190
+ dealias=dealias,
191
+ stop_sim_time=stop_sim_time,
192
+ timestepper=timestepper,
193
+ Dt=Dt,
194
+ max_timestep=max_timestep,
195
+ output_dt=output_dt,
196
+ log_iter=log_iter,
197
+ dtype=dtype,
198
+ max_writes=max_writes,
199
+ logger=logger,
200
+ output_dir=output_dir,
201
+ use_cfl=use_cfl,
202
+ L=L,
203
+ dim=dim,
204
+ Nsamples=Nsamples,
205
+ l_u=l_u,
206
+ l_A=l_A,
207
+ Nu=Nu,
208
+ sigma_u=sigma_u,
209
+ sigma_A=sigma_A,
210
+ grf_u=grf_u,
211
+ grf_A=grf_A,
212
+ u0_pot=u0_pot,
213
+ A0=A0,
214
+ digits=digits,
215
+ Nt=Nt,
216
+ ):
217
+ sim_output_dir = os.path.join(output_dir, f"output-{i:0{digits}}")
218
+ sim_outputs = os.path.join(sim_output_dir, "*.h5")
219
+ print(
220
+ f"Running simulation {i:0{digits}} with outputs in {sim_output_dir}",
221
+ flush=True,
222
+ )
223
+ # Bases
224
+ coords = d3.CartesianCoordinates("x", "y")
225
+ dist = d3.Distributor(coords, dtype=dtype)
226
+ xbasis = d3.RealFourier(coords["x"], size=Nx, bounds=(0, Lx), dealias=dealias)
227
+ ybasis = d3.RealFourier(coords["y"], size=Ny, bounds=(0, Ly), dealias=dealias)
228
+
229
+ # Fields
230
+ p = dist.Field(name="p", bases=(xbasis, ybasis))
231
+ s = dist.Field(name="s", bases=(xbasis, ybasis))
232
+ u = dist.VectorField(coords, name="u", bases=(xbasis, ybasis))
233
+ B = dist.VectorField(coords, name="B", bases=(xbasis, ybasis))
234
+ A = dist.Field(name="A", bases=(xbasis, ybasis))
235
+ B2 = dist.Field(name="B2", bases=(xbasis, ybasis))
236
+ u_pot = dist.Field(name="u_pot", bases=(xbasis, ybasis))
237
+ Ax = dist.Field(name="Ax", bases=(xbasis, ybasis))
238
+ Ay = dist.Field(name="Ay", bases=(xbasis, ybasis))
239
+ Bx = dist.Field(name="Bx", bases=(xbasis, ybasis))
240
+ By = dist.Field(name="By", bases=(xbasis, ybasis))
241
+ u0 = dist.VectorField(coords, name="u0", bases=(xbasis, ybasis))
242
+ ux = dist.Field(name="ux", bases=(xbasis, ybasis))
243
+ uy = dist.Field(name="uy", bases=(xbasis, ybasis))
244
+ tau_p = dist.Field(name="tau_p")
245
+
246
+ # Substitutions
247
+ nu = 1 / Re
248
+ D = nu / Schmidt
249
+ eta = 1 / ReM
250
+ x, y = dist.local_grids(xbasis, ybasis)
251
+ X, Y = np.meshgrid(x, y, indexing="ij")
252
+ ex, ey = coords.unit_vector_fields(dist)
253
+ # ez = d3.CrossProduct(ex, ey)
254
+ curl2d_scalar = lambda x: -d3.skew(d3.grad(x))
255
+ curl2d_vector = lambda x: -d3.div(d3.skew(x))
256
+ B = curl2d_scalar(A)
257
+ B2 = d3.dot(B, B)
258
+ Bx = B @ ex
259
+ By = B @ ey
260
+ ux = u @ ex
261
+ uy = u @ ey
262
+
263
+ # Problem
264
+ problem = d3.IVP([u, p, A, tau_p, s], namespace=locals())
265
+ problem.add_equation(
266
+ "dt(u) + grad(p)/rho0 - nu*lap(u) = - 0.5*grad(B2)/rho0 - u@grad(u) + B@grad(B)/rho0"
267
+ )
268
+ problem.add_equation("dt(s) - D*lap(s) = - u@grad(s)")
269
+ problem.add_equation("dt(A) - eta*lap(A) = - u@grad(A)")
270
+ problem.add_equation("div(u) + tau_p = 0")
271
+ problem.add_equation("integ(p) = 0") # Pressure gauge
272
+
273
+ # Solver
274
+ solver = problem.build_solver(timestepper)
275
+ # solver.stop_sim_time = stop_sim_time
276
+ solver.stop_sim_time = (
277
+ stop_sim_time + Dt
278
+ ) # Make sure we record the last timestep
279
+
280
+ # Initial conditions
281
+ u_pot["g"] = u0_pot[i]
282
+ u0 = curl2d_scalar(u_pot).evaluate()
283
+ u0.change_scales(1)
284
+ u["g"] = u0["g"]
285
+ ux = u @ ex
286
+ uy = u @ ey
287
+ B2 = d3.dot(B, B)
288
+ # s.set_global_data(u0_pot[i])
289
+ s["g"] = u0_pot[i]
290
+ # A.set_global_data(A0[i])
291
+ A["g"] = A0[i]
292
+
293
+ # Analysis (This overwrites existing files)
294
+ os.makedirs(sim_output_dir, exist_ok=True)
295
+ snapshots = solver.evaluator.add_file_handler(
296
+ sim_output_dir, sim_dt=output_dt, max_writes=max_writes
297
+ )
298
+
299
+ snapshots.add_task(s, name="tracer")
300
+ snapshots.add_task(A, name="vector potential")
301
+ snapshots.add_task(B, name="magnetic field")
302
+
303
+ snapshots.add_task(u, name="velocity")
304
+ snapshots.add_task(p, name="pressure")
305
+
306
+ # CFL (Don't actually use this. Use constant timestep instead)
307
+ CFL = d3.CFL(
308
+ solver,
309
+ initial_dt=max_timestep,
310
+ cadence=10,
311
+ safety=0.2,
312
+ threshold=0.1,
313
+ max_change=1.5,
314
+ min_change=0.5,
315
+ max_dt=max_timestep,
316
+ )
317
+ CFL.add_velocity(u)
318
+
319
+ # Flow properties
320
+ flow = d3.GlobalFlowProperty(solver, cadence=10)
321
+ flow.add_property(d3.dot(u, u), name="w2")
322
+ flow.add_property(d3.dot(B, B), name="B2")
323
+ flow.add_property(d3.div(B), name="divB")
324
+
325
+ # Main loop
326
+ try:
327
+ logger.info("Starting main loop")
328
+ while solver.proceed:
329
+ if use_cfl:
330
+ timestep = CFL.compute_timestep()
331
+ else:
332
+ timestep = Dt
333
+ solver.step(timestep)
334
+ if (solver.iteration) % 10 == 0:
335
+ max_w = np.sqrt(flow.max("w2"))
336
+ max_B = np.sqrt(flow.max("B2"))
337
+ max_divB = flow.max("divB")
338
+ logger.info(
339
+ f"Iteration={solver.iteration}, Time={solver.sim_time:#.3g}, dt={timestep:#.3g}, max(w)={max_w:#.3g}, max(B)={max_B:#.3g}, max(div_B)={max_divB:#.3g}"
340
+ )
341
+ print(
342
+ f"Finished simulation {i:0{digits}} with outputs in {sim_output_dir}",
343
+ flush=True,
344
+ )
345
+ except:
346
+ logger.error("Exception raised, triggering end of main loop.")
347
+ raise
348
+ solver.log_stats()
349
+
350
+ # Run in parallel
351
+ with mp.Pool(mp.cpu_count() - 1) as pool:
352
+ pool.map(run_simulation, indices, chunksize=10)
mhd/generate_mhd_data/mhd_field.yaml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-FileCopyrightText: All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ Lx: 1.0 # Length of domain in x direction
18
+ Ly: 1.0 # Length of domain in y direction
19
+ Nx: 128 # Number of points in x direction
20
+ Ny: 128 # Number of points in y direction
21
+ Schmidt: 1.0 # Schmit number
22
+ rho0: 1.0 # Density of fluid
23
+ dealias: 1.5 # Dealiasing factor
24
+ tend: 1.0 # End time of simulation
25
+ Dt: 1.0e-3 # Timestep size
26
+ timestepper: dedalus.public.RK443 # Timestepper type
27
+ max_timestep: 1.0e-2 # Maximum timestep for CFL control
28
+ output_dt: 1.0e-2 # Time between outputs
29
+ log_iter: 10 # Iterations between logging
30
+ dtype: numpy.float64 # Datatype for simulation
31
+ max_writes: null # Maximum file writes
32
+ L: 1.0 # Length of domain for generating data
33
+ l_u: 0.1 # Length of typical spatial deviations for velocity potential
34
+ l_A: 0.1 # Length of typical spatial deviations for magnetic vector potential
35
+ sigma_u: 0.1 # Typical amplitude of velocity potential
36
+ sigma_A: 0.5e-3 # Typical amplitude of magnetic vector potential
37
+ Nu: null # Smoothness parameter for GRF
38
+ use_cfl: false # Whether to use timestep computed based on CFL
39
+ skip_exists: true # Skip existing output files
mhd/generate_mhd_data/my_random_fields.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-FileCopyrightText: All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ import math
20
+ from math import pi, gamma, sqrt
21
+ import numpy as np
22
+
23
+ torch.manual_seed(0)
24
+
25
+
26
+ class GRF_Mattern(object):
27
+ """Generate Random Fields"""
28
+
29
+ def __init__(
30
+ self,
31
+ dim,
32
+ size,
33
+ length=1.0,
34
+ nu=None,
35
+ l=0.1,
36
+ sigma=1.0,
37
+ boundary="periodic",
38
+ constant_eig=None,
39
+ device=None,
40
+ ):
41
+
42
+ self.dim = dim
43
+ self.device = device
44
+ self.bc = boundary
45
+
46
+ a = sqrt(2 / length)
47
+ if self.bc == "dirichlet":
48
+ constant_eig = None
49
+
50
+ if nu is not None:
51
+ kappa = sqrt(2 * nu) / l
52
+ alpha = nu + 0.5 * dim
53
+ self.eta2 = (
54
+ size**dim
55
+ * sigma
56
+ * (4.0 * pi) ** (0.5 * dim)
57
+ * gamma(alpha)
58
+ / (kappa**dim * gamma(nu))
59
+ )
60
+ else:
61
+ self.eta2 = size**dim * sigma * (sqrt(2.0 * pi) * l) ** dim
62
+
63
+ k_max = size // 2
64
+ if self.bc == "periodic":
65
+ const = (4.0 * (pi**2)) / (length**2)
66
+ else:
67
+ const = (pi**2) / (length**2)
68
+
69
+ if dim == 1:
70
+ k = torch.cat(
71
+ (
72
+ torch.arange(start=0, end=k_max, step=1, device=device),
73
+ torch.arange(start=-k_max, end=0, step=1, device=device),
74
+ ),
75
+ 0,
76
+ )
77
+
78
+ k2 = k**2
79
+ if nu is not None:
80
+ eigs = 1.0 + (const / (kappa * length) ** 2 * k2)
81
+ self.sqrt_eig = self.eta2 / (length**dim) * eigs ** (-alpha / 2.0)
82
+ else:
83
+ self.sqrt_eig = (
84
+ self.eta2
85
+ / (length**dim)
86
+ * torch.exp(-((l) ** 2) * const * k2 / 4.0)
87
+ )
88
+
89
+ if constant_eig is not None:
90
+ self.sqrt_eig[0] = constant_eig # (size**dim)*sigma*(tau**(-alpha))
91
+ else:
92
+ self.sqrt_eig[0] = 0.0
93
+
94
+ elif dim == 2:
95
+ wavenumers = torch.cat(
96
+ (
97
+ torch.arange(start=0, end=k_max, step=1, device=device),
98
+ torch.arange(start=-k_max, end=0, step=1, device=device),
99
+ ),
100
+ 0,
101
+ ).repeat(size, 1)
102
+
103
+ k_x = wavenumers.transpose(0, 1)
104
+ k_y = wavenumers
105
+
106
+ k2 = k_x**2 + k_y**2
107
+ if nu is not None:
108
+ eigs = 1.0 + (const / (kappa * length) ** 2 * k2)
109
+ self.sqrt_eig = self.eta2 / (length**dim) * eigs ** (-alpha / 2.0)
110
+ else:
111
+ self.sqrt_eig = (
112
+ self.eta2
113
+ / (length**dim)
114
+ * torch.exp(-((l) ** 2) * const * k2 / 4.0)
115
+ )
116
+
117
+ if constant_eig is not None:
118
+ self.sqrt_eig[0, 0] = constant_eig # (size**dim)*sigma*(tau**(-alpha))
119
+ else:
120
+ self.sqrt_eig[0, 0] = 0.0
121
+
122
+ elif dim == 3:
123
+ wavenumers = torch.cat(
124
+ (
125
+ torch.arange(start=0, end=k_max, step=1, device=device),
126
+ torch.arange(start=-k_max, end=0, step=1, device=device),
127
+ ),
128
+ 0,
129
+ ).repeat(size, size, 1)
130
+
131
+ k_x = wavenumers.transpose(1, 2)
132
+ k_y = wavenumers
133
+ k_z = wavenumers.transpose(0, 2)
134
+
135
+ k2 = k_x**2 + k_y**2 + k_z**2
136
+ if nu is not None:
137
+ eigs = 1.0 + (const / (kappa * length) ** 2 * k2)
138
+ self.sqrt_eig = self.eta2 / (length**dim) * eigs ** (-alpha / 2.0)
139
+ else:
140
+ self.sqrt_eig = (
141
+ self.eta2
142
+ / (length**dim)
143
+ * torch.exp(-((l) ** 2) * const * k2 / 4.0)
144
+ )
145
+
146
+ if constant_eig is not None:
147
+ self.sqrt_eig[
148
+ 0, 0, 0
149
+ ] = constant_eig # (size**dim)*sigma*(tau**(-alpha))
150
+ else:
151
+ self.sqrt_eig[0, 0, 0] = 0.0
152
+
153
+ self.size = []
154
+ for j in range(self.dim):
155
+ self.size.append(size)
156
+
157
+ self.size = tuple(self.size)
158
+
159
+ def sample(self, N):
160
+
161
+ coeff = torch.randn(N, *self.size, dtype=torch.cfloat, device=self.device)
162
+ if self.bc == "dirichlet":
163
+ coeff.real[:] = 0
164
+ if self.bc == "neumann":
165
+ coeff.imag[:] = 0
166
+ coeff = self.sqrt_eig * coeff
167
+
168
+ u = torch.fft.irfftn(coeff, self.size, norm="backward")
169
+ return u
170
+
171
+
172
+ if __name__ == "__main__":
173
+ from hydra import compose, initialize
174
+ import h5py
175
+ import os
176
+ import matplotlib.pyplot as plt
177
+
178
+ initialize(version_base=None, config_path=".", job_name="generate_random_field")
179
+ cfg = compose(config_name="example_field")
180
+
181
+ N = cfg.num_samples
182
+ n = cfg.num_points
183
+ dim = cfg.dim
184
+ L = cfg.length
185
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
186
+ grf = GRF_Mattern(
187
+ dim=cfg.dim,
188
+ size=cfg.num_points,
189
+ length=cfg.length,
190
+ nu=cfg.nu,
191
+ l=cfg.length_scale,
192
+ sigma=cfg.sigma,
193
+ boundary=cfg.boundary_condition,
194
+ constant_eig=cfg.mean,
195
+ device=device,
196
+ )
197
+ U = grf.sample(N)
198
+ # convert to pad periodically
199
+ pad_width = [(0, 0)] + [(0, 1) for _ in range(dim)]
200
+
201
+ u = np.pad(U.cpu().numpy(), pad_width, mode="wrap")
202
+ x = np.linspace(0, L, n + 1)
203
+ digits = int(math.log10(N)) + 1
204
+ basefile = cfg.file
205
+ if basefile:
206
+ filedir, file = os.path.split(basefile)
207
+ if filedir:
208
+ os.makedirs(filedir, exist_ok=True)
209
+
210
+ for i, u0 in enumerate(u):
211
+ filename = f"{basefile}-{i:0{digits}d}.h5"
212
+ with h5py.File(filename, "w") as hf:
213
+ hf.create_dataset("u", data=u0)
214
+ for j in range(dim):
215
+ coord_name = f"x{j+1}"
216
+ hf.create_dataset(coord_name, data=x)
217
+
218
+ if cfg.plot:
219
+ # coords = [x for _ in dim]
220
+ # X = np.meshgrid(*coords, indexing='ij')
221
+ if dim == 2:
222
+ X, Y = np.meshgrid(x, x, indexing="ij")
223
+ plt.close("all")
224
+ fig = plt.figure()
225
+ pmesh = plt.pcolormesh(X, Y, u[0], cmap="jet", shading="gouraud")
226
+ plt.colorbar(pmesh)
227
+ plt.axis("square")
228
+ plt.title("Random Initial Data")
229
+ plt.show()
mhd/losses/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-FileCopyrightText: All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from .losses import LpLoss
18
+ from .loss_mhd_vec_pot_physicsnemo import LossMHDVecPot_PhysicsNeMo
mhd/losses/loss_mhd_vec_pot_physicsnemo.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-FileCopyrightText: All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from physicsnemo.models.layers.spectral_layers import fourier_derivatives
20
+
21
+ from .losses import (LpLoss, fourier_derivatives_lap, fourier_derivatives_ptot,
22
+ fourier_derivatives_vec_pot)
23
+ from .mhd_pde import MHD_PDE
24
+
25
+
26
+ class LossMHDVecPot_PhysicsNeMo(object):
27
+ "Calculate loss for MHD equations with vector potential, using physicsnemo derivatives"
28
+
29
+ def __init__(
30
+ self,
31
+ nu=1e-4,
32
+ eta=1e-4,
33
+ rho0=1.0,
34
+ data_weight=1.0,
35
+ ic_weight=1.0,
36
+ pde_weight=1.0,
37
+ constraint_weight=1.0,
38
+ use_data_loss=True,
39
+ use_ic_loss=True,
40
+ use_pde_loss=True,
41
+ use_constraint_loss=True,
42
+ u_weight=1.0,
43
+ v_weight=1.0,
44
+ A_weight=1.0,
45
+ Du_weight=1.0,
46
+ Dv_weight=1.0,
47
+ DA_weight=1.0,
48
+ div_B_weight=1.0,
49
+ div_vel_weight=1.0,
50
+ Lx=1.0,
51
+ Ly=1.0,
52
+ tend=1.0,
53
+ use_weighted_mean=False,
54
+ **kwargs,
55
+ ): # add **kwargs so that we ignore unexpected kwargs when passing a config dict):
56
+
57
+ self.nu = nu
58
+ self.eta = eta
59
+ self.rho0 = rho0
60
+ self.data_weight = data_weight
61
+ self.ic_weight = ic_weight
62
+ self.pde_weight = pde_weight
63
+ self.constraint_weight = constraint_weight
64
+ self.use_data_loss = use_data_loss
65
+ self.use_ic_loss = use_ic_loss
66
+ self.use_pde_loss = use_pde_loss
67
+ self.use_constraint_loss = use_constraint_loss
68
+ self.u_weight = u_weight
69
+ self.v_weight = v_weight
70
+ self.Du_weight = Du_weight
71
+ self.Dv_weight = Dv_weight
72
+ self.div_B_weight = div_B_weight
73
+ self.div_vel_weight = div_vel_weight
74
+ self.Lx = Lx
75
+ self.Ly = Ly
76
+ self.tend = tend
77
+ self.use_weighted_mean = use_weighted_mean
78
+ self.A_weight = A_weight
79
+ self.DA_weight = DA_weight
80
+ # Define 2D MHD PDEs
81
+ self.mhd_pde_eq = MHD_PDE(self.nu, self.eta, self.rho0)
82
+ self.mhd_pde_node = self.mhd_pde_eq.make_nodes()
83
+
84
+ if not self.use_data_loss:
85
+ self.data_weight = 0
86
+ if not self.use_ic_loss:
87
+ self.ic_weight = 0
88
+ if not self.use_pde_loss:
89
+ self.pde_weight = 0
90
+ if not self.use_constraint_loss:
91
+ self.constraint_weight = 0
92
+
93
+ def __call__(self, pred, true, inputs, return_loss_dict=False):
94
+ loss, loss_dict = self.compute_losses(pred, true, inputs)
95
+ return loss, loss_dict
96
+
97
+ def compute_loss(self, pred, true, inputs):
98
+ "Compute weighted loss"
99
+ pred = pred.reshape(true.shape)
100
+ u = pred[..., 0]
101
+ v = pred[..., 1]
102
+ A = pred[..., 2]
103
+
104
+ # Data
105
+ if self.use_data_loss:
106
+ loss_data = self.data_loss(pred, true)
107
+ else:
108
+ loss_data = 0
109
+ # IC
110
+ if self.use_ic_loss:
111
+ loss_ic = self.ic_loss(pred, inputs)
112
+ else:
113
+ loss_ic = 0
114
+
115
+ # PDE
116
+ if self.use_pde_loss:
117
+ Du, Dv, DA = self.mhd_pde(u, v, A)
118
+ loss_pde = self.mhd_pde_loss(Du, Dv, DA)
119
+ else:
120
+ loss_pde = 0
121
+
122
+ # Constraints
123
+ if self.use_constraint_loss:
124
+ div_vel, div_B = self.mhd_constraint(u, v, A)
125
+ loss_constraint = self.mhd_constraint_loss(div_vel, div_B)
126
+ else:
127
+ loss_constraint = 0
128
+
129
+ if self.use_weighted_mean:
130
+ weight_sum = (
131
+ self.data_weight
132
+ + self.ic_weight
133
+ + self.pde_weight
134
+ + self.constraint_weight
135
+ )
136
+ else:
137
+ weight_sum = 1.0
138
+
139
+ loss = (
140
+ self.data_weight * loss_data
141
+ + self.ic_weight * loss_ic
142
+ + self.pde_weight * loss_pde
143
+ + self.constraint_weight * loss_constraint
144
+ ) / weight_sum
145
+ return loss
146
+
147
+ def compute_losses(self, pred, true, inputs):
148
+ "Compute weighted loss and dictionary"
149
+ pred = pred.reshape(true.shape)
150
+ u = pred[..., 0]
151
+ v = pred[..., 1]
152
+ A = pred[..., 2]
153
+
154
+ loss_dict = {}
155
+
156
+ # Data
157
+ if self.use_data_loss:
158
+ loss_data, loss_u, loss_v, loss_A = self.data_loss(
159
+ pred, true, return_all_losses=True
160
+ )
161
+ loss_dict["loss_data"] = loss_data
162
+ loss_dict["loss_u"] = loss_u
163
+ loss_dict["loss_v"] = loss_v
164
+ loss_dict["loss_A"] = loss_A
165
+ else:
166
+ loss_data = 0
167
+ # IC
168
+ if self.use_ic_loss:
169
+ loss_ic, loss_u_ic, loss_v_ic, loss_A_ic = self.ic_loss(
170
+ pred, inputs, return_all_losses=True
171
+ )
172
+ loss_dict["loss_ic"] = loss_ic
173
+ loss_dict["loss_u_ic"] = loss_u_ic
174
+ loss_dict["loss_v_ic"] = loss_v_ic
175
+ loss_dict["loss_A_ic"] = loss_A_ic
176
+ else:
177
+ loss_ic = 0
178
+
179
+ # PDE
180
+ if self.use_pde_loss:
181
+ Du, Dv, DA = self.mhd_pde(u, v, A)
182
+ loss_pde, loss_Du, loss_Dv, loss_DA = self.mhd_pde_loss(
183
+ Du, Dv, DA, return_all_losses=True
184
+ )
185
+ loss_dict["loss_pde"] = loss_pde
186
+ loss_dict["loss_Du"] = loss_Du
187
+ loss_dict["loss_Dv"] = loss_Dv
188
+ loss_dict["loss_DA"] = loss_DA
189
+ else:
190
+ loss_pde = 0
191
+
192
+ # Constraints
193
+ if self.use_constraint_loss:
194
+ div_vel, div_B = self.mhd_constraint(u, v, A)
195
+ loss_constraint, loss_div_vel, loss_div_B = self.mhd_constraint_loss(
196
+ div_vel, div_B, return_all_losses=True
197
+ )
198
+ loss_dict["loss_constraint"] = loss_constraint
199
+ loss_dict["loss_div_vel"] = loss_div_vel
200
+ loss_dict["loss_div_B"] = loss_div_B
201
+ else:
202
+ loss_constraint = 0
203
+
204
+ if self.use_weighted_mean:
205
+ weight_sum = (
206
+ self.data_weight
207
+ + self.ic_weight
208
+ + self.pde_weight
209
+ + self.constraint_weight
210
+ )
211
+ else:
212
+ weight_sum = 1.0
213
+
214
+ loss = (
215
+ self.data_weight * loss_data
216
+ + self.ic_weight * loss_ic
217
+ + self.pde_weight * loss_pde
218
+ + self.constraint_weight * loss_constraint
219
+ ) / weight_sum
220
+ loss_dict["loss"] = loss
221
+ return loss, loss_dict
222
+
223
+ def data_loss(self, pred, true, return_all_losses=False):
224
+ "Compute data loss"
225
+ lploss = LpLoss(size_average=True)
226
+ u_pred = pred[..., 0]
227
+ v_pred = pred[..., 1]
228
+ A_pred = pred[..., 2]
229
+
230
+ u_true = true[..., 0]
231
+ v_true = true[..., 1]
232
+ A_true = true[..., 2]
233
+
234
+ loss_u = lploss(u_pred, u_true)
235
+ loss_v = lploss(v_pred, v_true)
236
+ loss_A = lploss(A_pred, A_true)
237
+
238
+ if self.use_weighted_mean:
239
+ weight_sum = self.u_weight + self.v_weight + self.A_weight
240
+ else:
241
+ weight_sum = 1.0
242
+
243
+ loss_data = (
244
+ self.u_weight * loss_u + self.v_weight * loss_v + self.A_weight * loss_A
245
+ ) / weight_sum
246
+
247
+ if return_all_losses:
248
+ return loss_data, loss_u, loss_v, loss_A
249
+ else:
250
+ return loss_data
251
+
252
+ def ic_loss(self, pred, input, return_all_losses=False):
253
+ "Compute initial condition loss"
254
+ lploss = LpLoss(size_average=True)
255
+ ic_pred = pred[:, 0]
256
+ ic_true = input[:, 0, ..., 3:]
257
+ u_ic_pred = ic_pred[..., 0]
258
+ v_ic_pred = ic_pred[..., 1]
259
+ A_ic_pred = ic_pred[..., 2]
260
+
261
+ u_ic_true = ic_true[..., 0]
262
+ v_ic_true = ic_true[..., 1]
263
+ A_ic_true = ic_true[..., 2]
264
+
265
+ loss_u_ic = lploss(u_ic_pred, u_ic_true)
266
+ loss_v_ic = lploss(v_ic_pred, v_ic_true)
267
+ loss_A_ic = lploss(A_ic_pred, A_ic_true)
268
+
269
+ if self.use_weighted_mean:
270
+ weight_sum = self.u_weight + self.v_weight + self.A_weight
271
+ else:
272
+ weight_sum = 1.0
273
+
274
+ loss_ic = (
275
+ self.u_weight * loss_u_ic
276
+ + self.v_weight * loss_v_ic
277
+ + self.A_weight * loss_A_ic
278
+ ) / weight_sum
279
+
280
+ if return_all_losses:
281
+ return loss_ic, loss_u_ic, loss_v_ic, loss_A_ic
282
+ else:
283
+ return loss_ic
284
+
285
+ def mhd_pde_loss(self, Du, Dv, DA, return_all_losses=None):
286
+ "Compute PDE loss"
287
+ Du_val = torch.zeros_like(Du)
288
+ Dv_val = torch.zeros_like(Dv)
289
+ DA_val = torch.zeros_like(DA)
290
+
291
+ loss_Du = F.mse_loss(Du, Du_val)
292
+ loss_Dv = F.mse_loss(Dv, Dv_val)
293
+ loss_DA = F.mse_loss(DA, DA_val)
294
+
295
+ if self.use_weighted_mean:
296
+ weight_sum = self.Du_weight + self.Dv_weight + self.DA_weight
297
+ else:
298
+ weight_sum = 1.0
299
+
300
+ loss_pde = (
301
+ self.Du_weight * loss_Du
302
+ + self.Dv_weight * loss_Dv
303
+ + self.DA_weight * loss_DA
304
+ ) / weight_sum
305
+
306
+ if return_all_losses:
307
+ return loss_pde, loss_Du, loss_Dv, loss_DA
308
+ else:
309
+ return loss_pde
310
+
311
+ def mhd_constraint(self, u, v, A):
312
+ "Compute constraints"
313
+ nt = u.size(1)
314
+ nx = u.size(2)
315
+ ny = u.size(3)
316
+
317
+ f_du, _ = fourier_derivatives(u, [self.Lx, self.Ly])
318
+ f_dv, _ = fourier_derivatives(v, [self.Lx, self.Ly])
319
+ f_dBx, f_dBy, _, _, _ = fourier_derivatives_vec_pot(A, [self.Lx, self.Ly])
320
+
321
+ u_x = f_du[:, 0:nt, :nx, :ny]
322
+ v_y = f_dv[:, nt : 2 * nt, :nx, :ny]
323
+ Bx_x = f_dBx[:, 0:nt, :nx, :ny]
324
+ By_y = f_dBy[:, nt : 2 * nt, :nx, :ny]
325
+
326
+ div_B = self.mhd_pde_node[12].evaluate({"Bx__x": Bx_x, "By__y": By_y})["div_B"]
327
+ div_vel = self.mhd_pde_node[13].evaluate({"u__x": u_x, "v__y": v_y})["div_vel"]
328
+
329
+ return div_vel, div_B
330
+
331
+ def mhd_constraint_loss(self, div_vel, div_B, return_all_losses=False):
332
+ "Compute constraint loss"
333
+ div_vel_val = torch.zeros_like(div_vel)
334
+ div_B_val = torch.zeros_like(div_B)
335
+
336
+ loss_div_vel = F.mse_loss(div_vel, div_vel_val)
337
+ loss_div_B = F.mse_loss(div_B, div_B_val)
338
+
339
+ if self.use_weighted_mean:
340
+ weight_sum = self.div_vel_weight + self.div_B_weight
341
+ else:
342
+ weight_sum = 1.0
343
+
344
+ loss_constraint = (
345
+ self.div_vel_weight * loss_div_vel + self.div_B_weight * loss_div_B
346
+ ) / weight_sum
347
+
348
+ if return_all_losses:
349
+ return loss_constraint, loss_div_vel, loss_div_B
350
+ else:
351
+ return loss_constraint
352
+
353
+ def mhd_pde(self, u, v, A, p=None):
354
+ "Compute PDEs for MHD using vector potential"
355
+ nt = u.size(1)
356
+ nx = u.size(2)
357
+ ny = u.size(3)
358
+ dt = self.tend / (nt - 1)
359
+
360
+ # compute fourier derivatives
361
+ f_du, _ = fourier_derivatives(u, [self.Lx, self.Ly])
362
+ f_dv, _ = fourier_derivatives(v, [self.Lx, self.Ly])
363
+ f_dBx, f_dBy, f_dA, f_dB, B2_h = fourier_derivatives_vec_pot(
364
+ A, [self.Lx, self.Ly]
365
+ )
366
+
367
+ u_x = f_du[:, 0:nt, :nx, :ny]
368
+ u_y = f_du[:, nt : 2 * nt, :nx, :ny]
369
+ v_x = f_dv[:, 0:nt, :nx, :ny]
370
+ v_y = f_dv[:, nt : 2 * nt, :nx, :ny]
371
+ A_x = f_dA[:, 0:nt, :nx, :ny]
372
+ A_y = f_dA[:, nt : 2 * nt, :nx, :ny]
373
+
374
+ Bx = f_dB[:, 0:nt, :nx, :ny]
375
+ By = f_dB[:, nt : 2 * nt, :nx, :ny]
376
+ Bx_x = f_dBx[:, 0:nt, :nx, :ny]
377
+ Bx_y = f_dBx[:, nt : 2 * nt, :nx, :ny]
378
+ By_x = f_dBy[:, 0:nt, :nx, :ny]
379
+ By_y = f_dBy[:, nt : 2 * nt, :nx, :ny]
380
+
381
+ u_lap = fourier_derivatives_lap(u, [self.Lx, self.Ly])
382
+ v_lap = fourier_derivatives_lap(v, [self.Lx, self.Ly])
383
+ A_lap = fourier_derivatives_lap(A, [self.Lx, self.Ly])
384
+
385
+ # note that for pressure, the zero mode (the mean) cannot be zero for invertability so it is set to 1
386
+ div_vel_grad_vel = u_x**2 + 2 * u_y * v_x + v_y**2
387
+ div_B_grad_B = Bx_x**2 + 2 * Bx_y * By_x + By_y**2
388
+ f_dptot = fourier_derivatives_ptot(
389
+ p, div_vel_grad_vel, div_B_grad_B, B2_h, self.rho0, [self.Lx, self.Ly]
390
+ )
391
+ ptot_x = f_dptot[:, 0:nt, :nx, :ny]
392
+ ptot_y = f_dptot[:, nt : 2 * nt, :nx, :ny]
393
+
394
+ # Plug inputs into dictionary
395
+ all_inputs = {
396
+ "u": u,
397
+ "u__x": u_x,
398
+ "u__y": u_y,
399
+ "v": v,
400
+ "v__x": v_x,
401
+ "v__y": v_y,
402
+ "Bx": Bx,
403
+ "Bx__x": Bx_x,
404
+ "Bx__y": Bx_y,
405
+ "By": By,
406
+ "By__x": By_x,
407
+ "By__y": By_y,
408
+ "A__x": A_x,
409
+ "A__y": A_y,
410
+ "ptot__x": ptot_x,
411
+ "ptot__y": ptot_y,
412
+ "u__lap": u_lap,
413
+ "v__lap": v_lap,
414
+ "A__lap": A_lap,
415
+ }
416
+
417
+ # Substitute values into PDE equations
418
+ u_rhs = self.mhd_pde_node[14].evaluate(all_inputs)["u_rhs"]
419
+ v_rhs = self.mhd_pde_node[15].evaluate(all_inputs)["v_rhs"]
420
+ A_rhs = self.mhd_pde_node[23].evaluate(all_inputs)["A_rhs"]
421
+
422
+ u_t = self.Du_t(u, dt)
423
+ v_t = self.Du_t(v, dt)
424
+ A_t = self.Du_t(A, dt)
425
+
426
+ # Find difference
427
+ Du = self.mhd_pde_node[18].evaluate({"u__t": u_t, "u_rhs": u_rhs[:, 1:-1]})[
428
+ "Du"
429
+ ]
430
+ Dv = self.mhd_pde_node[19].evaluate({"v__t": v_t, "v_rhs": v_rhs[:, 1:-1]})[
431
+ "Dv"
432
+ ]
433
+ DA = self.mhd_pde_node[24].evaluate({"A__t": A_t, "A_rhs": A_rhs[:, 1:-1]})[
434
+ "DA"
435
+ ]
436
+ return Du, Dv, DA
437
+
438
+ def Du_t(self, u, dt):
439
+ "Compute time derivative"
440
+ u_t = (u[:, 2:] - u[:, :-2]) / (2 * dt)
441
+ return u_t
mhd/losses/losses.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-FileCopyrightText: All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn.functional as F
20
+ import math
21
+ from torch import Tensor
22
+ from typing import List
23
+
24
+
25
+ class LpLoss(object):
26
+ """
27
+ loss function with rel/abs Lp loss
28
+ """
29
+
30
+ def __init__(self, d=2, p=2, size_average=True, reduction=True):
31
+ super(LpLoss, self).__init__()
32
+
33
+ # Dimension and Lp-norm type are postive
34
+ assert d > 0 and p > 0
35
+
36
+ self.d = d
37
+ self.p = p
38
+ self.reduction = reduction
39
+ self.size_average = size_average
40
+
41
+ def abs(self, x, y):
42
+ num_examples = x.size()[0]
43
+
44
+ # Assume uniform mesh
45
+ h = 1.0 / (x.size()[1] - 1.0)
46
+
47
+ all_norms = (h ** (self.d / self.p)) * torch.norm(
48
+ x.view(num_examples, -1) - y.view(num_examples, -1), self.p, 1
49
+ )
50
+
51
+ if self.reduction:
52
+ if self.size_average:
53
+ return torch.mean(all_norms)
54
+ else:
55
+ return torch.sum(all_norms)
56
+
57
+ return all_norms
58
+
59
+ def rel(self, x, y):
60
+ num_examples = x.size()[0]
61
+
62
+ diff_norms = torch.norm(
63
+ x.reshape(num_examples, -1) - y.reshape(num_examples, -1), self.p, 1
64
+ )
65
+ y_norms = torch.norm(y.reshape(num_examples, -1), self.p, 1)
66
+
67
+ if self.reduction:
68
+ if self.size_average:
69
+ return torch.mean(diff_norms / y_norms)
70
+ else:
71
+ return torch.sum(diff_norms / y_norms)
72
+
73
+ return diff_norms / y_norms
74
+
75
+ def __call__(self, x, y):
76
+ return self.rel(x, y)
77
+
78
+
79
+ def fourier_derivatives_lap(x: Tensor, ell: List[float]) -> Tensor:
80
+ """
81
+ Fourier derivative laplacian function
82
+ """
83
+
84
+ # check that input shape maches domain length
85
+ if len(x.shape) - 2 != len(ell):
86
+ raise ValueError("input shape doesn't match domain dims")
87
+
88
+ # set pi from numpy
89
+ pi = float(np.pi)
90
+
91
+ # get needed dims
92
+ n = x.shape[2:]
93
+ dim = len(ell)
94
+
95
+ # get device
96
+ device = x.device
97
+
98
+ # compute fourier transform
99
+ x_h = torch.fft.fftn(x, dim=list(range(2, dim + 2)))
100
+
101
+ # make wavenumbers
102
+ k_x = []
103
+ for i, nx in enumerate(n):
104
+ k_x.append(
105
+ (2 * pi / ell[i])
106
+ * torch.cat(
107
+ (
108
+ torch.arange(start=0, end=nx // 2, step=1, device=device),
109
+ torch.arange(start=-nx // 2, end=0, step=1, device=device),
110
+ ),
111
+ 0,
112
+ ).reshape((i + 2) * [1] + [nx] + (dim - i - 1) * [1])
113
+ )
114
+ lap = torch.zeros_like(k_x[0])
115
+ for i in k_x:
116
+ lap = lap - i**2
117
+
118
+ # compute laplacian in fourier space
119
+ wx_h = lap * x_h
120
+
121
+ # inverse fourier transform out
122
+ wx = torch.fft.ifftn(wx_h, dim=list(range(2, dim + 2))).real
123
+ return wx
124
+
125
+
126
+ def fourier_derivatives_ptot(
127
+ p: Tensor,
128
+ div_vel_grad_vel: Tensor,
129
+ div_B_grad_B: Tensor,
130
+ B2_h: Tensor,
131
+ rho0: float,
132
+ ell: List[float],
133
+ ) -> List[Tensor]:
134
+ """
135
+ Fourier derivative function to calculate ptot in MHD equations
136
+ """
137
+
138
+ # check that input shape maches domain length
139
+ if len(div_vel_grad_vel.shape) - 2 != len(ell):
140
+ raise ValueError("input shape doesn't match domain dims")
141
+
142
+ # set pi from numpy
143
+ pi = float(np.pi)
144
+
145
+ # get needed dims
146
+ n = div_vel_grad_vel.shape[2:]
147
+ dim = len(ell)
148
+
149
+ # get device
150
+ device = div_vel_grad_vel.device
151
+
152
+ # make wavenumbers
153
+ k_x = []
154
+ for i, nx in enumerate(n):
155
+ k_x.append(
156
+ torch.cat(
157
+ (
158
+ torch.arange(start=0, end=nx // 2, step=1, device=device),
159
+ torch.arange(start=-nx // 2, end=0, step=1, device=device),
160
+ ),
161
+ 0,
162
+ ).reshape((i + 2) * [1] + [nx] + (dim - i - 1) * [1])
163
+ )
164
+ # note that for pressure, the zero mode (the mean) cannot be zero for invertability so it is set to 1
165
+ lap = torch.zeros_like(k_x[0])
166
+ for i, k_x_i in enumerate(k_x):
167
+ lap = lap - ((2 * pi / ell[i]) * k_x_i) ** 2
168
+ lap[..., 0, 0] = -1.0
169
+
170
+ if p is None:
171
+ # compute fourier transform
172
+ div_vel_grad_vel_h = torch.fft.fftn(
173
+ div_vel_grad_vel, dim=list(range(2, dim + 2))
174
+ )
175
+ div_B_grad_B_h = torch.fft.fftn(div_B_grad_B, dim=list(range(2, dim + 2)))
176
+ ptot_h = (div_B_grad_B_h - rho0 * div_vel_grad_vel_h) / lap
177
+ ptot_h[..., 0, 0] = B2_h[..., 0, 0] / 2.0
178
+ else:
179
+ p_h = torch.fft.fftn(p, dim=list(range(2, dim + 2)))
180
+ ptot_h = p_h + B2_h / 2.0
181
+
182
+ # compute laplacian in fourier space
183
+ j = torch.complex(
184
+ torch.tensor([0.0], device=device), torch.tensor([1.0], device=device)
185
+ ) # Cuda graphs does not work here
186
+ wx_h = [j * k_x_i * ptot_h * (2 * pi / ell[i]) for i, k_x_i in enumerate(k_x)]
187
+
188
+ # inverse fourier transform out
189
+ wx = torch.cat(
190
+ [torch.fft.ifftn(wx_h_i, dim=list(range(2, dim + 2))).real for wx_h_i in wx_h],
191
+ dim=1,
192
+ )
193
+ return wx
194
+
195
+
196
+ def fourier_derivatives_vec_pot(x: Tensor, ell: List[float]) -> List[Tensor]:
197
+ """
198
+ Fourier derivative function for vector potential
199
+ """
200
+
201
+ # check that input shape maches domain length
202
+ if len(x.shape) - 2 != len(ell):
203
+ raise ValueError("input shape doesn't match domain dims")
204
+
205
+ # set pi from numpy
206
+ pi = float(np.pi)
207
+
208
+ # get needed dims
209
+ n = x.shape[2:]
210
+ dim = len(ell)
211
+
212
+ # get device
213
+ device = x.device
214
+
215
+ # compute fourier transform
216
+ x_h = torch.fft.fftn(x, dim=list(range(2, dim + 2)))
217
+
218
+ # make wavenumbers
219
+ k_x = []
220
+ for i, nx in enumerate(n):
221
+ k_x.append(
222
+ torch.cat(
223
+ (
224
+ torch.arange(start=0, end=nx // 2, step=1, device=device),
225
+ torch.arange(start=-nx // 2, end=0, step=1, device=device),
226
+ ),
227
+ 0,
228
+ ).reshape((i + 2) * [1] + [nx] + (dim - i - 1) * [1])
229
+ )
230
+
231
+ # compute laplacian in fourier space
232
+ j = torch.complex(
233
+ torch.tensor([0.0], device=device), torch.tensor([1.0], device=device)
234
+ ) # Cuda graphs does not work here
235
+ Ax_h = j * k_x[0] * x_h * (2 * pi / ell[0])
236
+ Ay_h = j * k_x[1] * x_h * (2 * pi / ell[1])
237
+
238
+ B2_h = (Ay_h) ** 2 + (-Ax_h) ** 2
239
+
240
+ Bx_h = [j * k_x_i * Ay_h * (2 * pi / ell[i]) for i, k_x_i in enumerate(k_x)]
241
+ By_h = [j * k_x_i * -Ax_h * (2 * pi / ell[i]) for i, k_x_i in enumerate(k_x)]
242
+
243
+ # inverse fourier transform out
244
+ wA = torch.cat(
245
+ [
246
+ torch.fft.ifftn(w_h_i, dim=list(range(2, dim + 2))).real
247
+ for w_h_i in [Ax_h, Ay_h]
248
+ ],
249
+ dim=1,
250
+ )
251
+ wB = torch.cat(
252
+ [
253
+ torch.fft.ifftn(w_h_i, dim=list(range(2, dim + 2))).real
254
+ for w_h_i in [Ay_h, -Ax_h]
255
+ ],
256
+ dim=1,
257
+ )
258
+ wx = torch.cat(
259
+ [torch.fft.ifftn(wx_h_i, dim=list(range(2, dim + 2))).real for wx_h_i in Bx_h],
260
+ dim=1,
261
+ )
262
+ wy = torch.cat(
263
+ [torch.fft.ifftn(wx_h_i, dim=list(range(2, dim + 2))).real for wx_h_i in By_h],
264
+ dim=1,
265
+ )
266
+
267
+ return wx, wy, wA, wB, B2_h
mhd/losses/mhd_pde.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-FileCopyrightText: All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from physicsnemo.sym.eq.pde import PDE
18
+ from sympy import Function, Number, Symbol
19
+
20
+
21
+ class MHD_PDE(PDE):
22
+ """MHD PDEs using PhysicsNeMo Sym"""
23
+
24
+ name = "MHD_PDE"
25
+
26
+ def __init__(self, nu=1e-4, eta=1e-4, rho0=1.0):
27
+
28
+ # x, y, time
29
+ x, y, t, lap = Symbol("x"), Symbol("y"), Symbol("t"), Symbol("lap")
30
+
31
+ # make input variables
32
+ input_variables = {"x": x, "y": y, "t": t, "lap": lap}
33
+
34
+ # make functions
35
+ u = Function("u")(*input_variables)
36
+ v = Function("v")(*input_variables)
37
+ Bx = Function("Bx")(*input_variables)
38
+ By = Function("By")(*input_variables)
39
+ A = Function("A")(*input_variables)
40
+ ptot = Function("ptot")(*input_variables)
41
+
42
+ u_rhs = Function("u_rhs")(*input_variables)
43
+ v_rhs = Function("v_rhs")(*input_variables)
44
+ Bx_rhs = Function("Bx_rhs")(*input_variables)
45
+ By_rhs = Function("By_rhs")(*input_variables)
46
+ A_rhs = Function("A_rhs")(*input_variables)
47
+
48
+ # initialize constants
49
+ nu = Number(nu)
50
+ eta = Number(eta)
51
+ rho0 = Number(rho0)
52
+
53
+ # set equations
54
+ self.equations = {}
55
+
56
+ self.equations["vel_grad_u"] = u * u.diff(x) + v * u.diff(y)
57
+ self.equations["vel_grad_v"] = u * v.diff(x) + v * v.diff(y)
58
+
59
+ self.equations["B_grad_u"] = Bx * u.diff(x) + v * Bx.diff(y)
60
+ self.equations["B_grad_v"] = Bx * v.diff(x) + By * v.diff(y)
61
+
62
+ self.equations["vel_grad_Bx"] = u * Bx.diff(x) + v * Bx.diff(y)
63
+ self.equations["vel_grad_By"] = u * By.diff(x) + v * By.diff(y)
64
+
65
+ self.equations["B_grad_Bx"] = Bx * Bx.diff(x) + By * Bx.diff(y)
66
+ self.equations["B_grad_By"] = Bx * By.diff(x) + By * By.diff(y)
67
+
68
+ self.equations["uBy_x"] = u * By.diff(x) + By * u.diff(x)
69
+ self.equations["uBy_y"] = u * By.diff(y) + By * u.diff(y)
70
+ self.equations["vBx_x"] = v * Bx.diff(x) + Bx * v.diff(x)
71
+ self.equations["vBx_y"] = v * Bx.diff(y) + Bx * v.diff(y)
72
+
73
+ self.equations["div_B"] = Bx.diff(x) + By.diff(y)
74
+ self.equations["div_vel"] = u.diff(x) + v.diff(y)
75
+
76
+ # RHS of MHD equations
77
+ self.equations["u_rhs"] = (
78
+ -self.equations["vel_grad_u"]
79
+ - ptot.diff(x) / rho0
80
+ + self.equations["B_grad_Bx"] / rho0
81
+ + nu * u.diff(lap)
82
+ )
83
+ self.equations["v_rhs"] = (
84
+ -self.equations["vel_grad_v"]
85
+ - ptot.diff(y) / rho0
86
+ + self.equations["B_grad_By"] / rho0
87
+ + nu * v.diff(lap)
88
+ )
89
+ self.equations["Bx_rhs"] = (
90
+ self.equations["uBy_y"] - self.equations["vBx_y"] + eta * Bx.diff(lap)
91
+ )
92
+ self.equations["By_rhs"] = -(
93
+ self.equations["uBy_x"] - self.equations["vBx_x"]
94
+ ) + eta * By.diff(lap)
95
+ # Node 18, 19, 20, 21
96
+ self.equations["Du"] = u.diff(t) - u_rhs
97
+ self.equations["Dv"] = v.diff(t) - v_rhs
98
+ self.equations["DBx"] = Bx.diff(t) - Bx_rhs
99
+ self.equations["DBy"] = By.diff(t) - By_rhs
100
+ # Node 22, 23, 24
101
+ # Vec potential equations
102
+ self.equations["vel_grad_A"] = u * A.diff(x) + v * A.diff(y)
103
+ self.equations["A_rhs"] = -self.equations["vel_grad_A"] + +eta * A.diff(lap)
104
+ self.equations["DA"] = A.diff(t) - A_rhs
mhd/tfno/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-FileCopyrightText: All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from .tfno import TFNO, TFNO1DEncoder, TFNO2DEncoder, TFNO3DEncoder, TFNO4DEncoder
18
+ from .spectral_layers import (
19
+ FactorizedSpectralConv1d,
20
+ FactorizedSpectralConv2d,
21
+ FactorizedSpectralConv3d,
22
+ FactorizedSpectralConv4d,
23
+ )
mhd/tfno/spectral_layers.py ADDED
@@ -0,0 +1,657 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-FileCopyrightText: All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from typing import List, Tuple
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ from torch import Tensor
24
+ import tltorch
25
+
26
+
27
+ class FactorizedSpectralConv1d(nn.Module):
28
+ """1D Factorized Fourier layer. It does FFT, linear transform, and Inverse FFT.
29
+
30
+ Parameters
31
+ ----------
32
+ in_channels : int
33
+ Number of input channels
34
+ out_channels : int
35
+ Number of output channels
36
+ modes1 : int
37
+ Number of Fourier modes to multiply, at most floor(N/2) + 1
38
+ rank : float
39
+ Rank of the decomposition
40
+ factorization : {'CP', 'TT', 'Tucker'}
41
+ Tensor factorization to use to decompose the tensor
42
+ fixed_rank_modes : List[int]
43
+ A list of modes for which the initial value is not modified
44
+ The last mode cannot be fixed due to error computation.
45
+ decomposition_kwargs : dict
46
+ Additional arguments to initialization of factorized tensors
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ in_channels: int,
52
+ out_channels: int,
53
+ modes1: int,
54
+ rank: float,
55
+ factorization: str,
56
+ fixed_rank_modes: bool,
57
+ decomposition_kwargs: dict,
58
+ ):
59
+ super().__init__()
60
+
61
+ self.in_channels = in_channels
62
+ self.out_channels = out_channels
63
+ self.modes1 = (
64
+ modes1 # Number of Fourier modes to multiply, at most floor(N/2) + 1
65
+ )
66
+
67
+ self.scale = 1 / (in_channels * out_channels)
68
+ self.weights1 = tltorch.FactorizedTensor.new(
69
+ (in_channels, out_channels, self.modes1, 2),
70
+ rank=rank,
71
+ factorization=factorization,
72
+ fixed_rank_modes=fixed_rank_modes,
73
+ **decomposition_kwargs
74
+ )
75
+ self.reset_parameters()
76
+
77
+ def compl_mul1d(
78
+ self,
79
+ input: Tensor,
80
+ weights: Tensor,
81
+ ) -> Tensor:
82
+ """Complex multiplication
83
+
84
+ Parameters
85
+ ----------
86
+ input : Tensor
87
+ Input tensor
88
+ weights : Tensor
89
+ Weights tensor
90
+
91
+ Returns
92
+ -------
93
+ Tensor
94
+ Product of complex multiplication
95
+ """
96
+ # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x)
97
+ cweights = torch.view_as_complex(weights.to_tensor().contiguous())
98
+ return torch.einsum("bix,iox->box", input, cweights)
99
+
100
+ def forward(self, x: Tensor) -> Tensor:
101
+ bsize = x.shape[0]
102
+ # Compute Fourier coeffcients up to factor of e^(- something constant)
103
+ x_ft = torch.fft.rfft(x)
104
+
105
+ # Multiply relevant Fourier modes
106
+ out_ft = torch.zeros(
107
+ bsize,
108
+ self.out_channels,
109
+ x.size(-1) // 2 + 1,
110
+ device=x.device,
111
+ dtype=torch.cfloat,
112
+ )
113
+ out_ft[:, :, : self.modes1] = self.compl_mul1d(
114
+ x_ft[:, :, : self.modes1],
115
+ self.weights1,
116
+ )
117
+
118
+ # Return to physical space
119
+ x = torch.fft.irfft(out_ft, n=x.size(-1))
120
+ return x
121
+
122
+ def reset_parameters(self):
123
+ """Reset spectral weights with distribution scale*N(0,1)"""
124
+ self.weights1.normal_(0, self.scale)
125
+
126
+
127
+ class FactorizedSpectralConv2d(nn.Module):
128
+ """2D Factorized Fourier layer. It does FFT, linear transform, and Inverse FFT.
129
+
130
+ Parameters
131
+ ----------
132
+ in_channels : int
133
+ Number of input channels
134
+ out_channels : int
135
+ Number of output channels
136
+ modes1 : int
137
+ Number of Fourier modes to multiply in first dimension, at most floor(N/2) + 1
138
+ modes2 : int
139
+ Number of Fourier modes to multiply in second dimension, at most floor(N/2) + 1
140
+ rank : float
141
+ Rank of the decomposition
142
+ factorization : {'CP', 'TT', 'Tucker'}
143
+ Tensor factorization to use to decompose the tensor
144
+ fixed_rank_modes : List[int]
145
+ A list of modes for which the initial value is not modified
146
+ The last mode cannot be fixed due to error computation.
147
+ decomposition_kwargs : dict
148
+ Additional arguments to initialization of factorized tensors
149
+ """
150
+
151
+ def __init__(
152
+ self,
153
+ in_channels: int,
154
+ out_channels: int,
155
+ modes1: int,
156
+ modes2: int,
157
+ rank: float,
158
+ factorization: str,
159
+ fixed_rank_modes: bool,
160
+ decomposition_kwargs: dict,
161
+ ):
162
+ super().__init__()
163
+
164
+ self.in_channels = in_channels
165
+ self.out_channels = out_channels
166
+ self.modes1 = (
167
+ modes1 # Number of Fourier modes to multiply, at most floor(N/2) + 1
168
+ )
169
+ self.modes2 = modes2
170
+
171
+ self.scale = 1 / (in_channels * out_channels)
172
+ self.weights1 = tltorch.FactorizedTensor.new(
173
+ (in_channels, out_channels, self.modes1, self.modes2, 2),
174
+ rank=rank,
175
+ factorization=factorization,
176
+ fixed_rank_modes=fixed_rank_modes,
177
+ **decomposition_kwargs
178
+ )
179
+ self.weights2 = tltorch.FactorizedTensor.new(
180
+ (in_channels, out_channels, self.modes1, self.modes2, 2),
181
+ rank=rank,
182
+ factorization=factorization,
183
+ fixed_rank_modes=fixed_rank_modes,
184
+ **decomposition_kwargs
185
+ )
186
+ self.reset_parameters()
187
+
188
+ def compl_mul2d(self, input: Tensor, weights: Tensor) -> Tensor:
189
+ """Complex multiplication
190
+
191
+ Parameters
192
+ ----------
193
+ input : Tensor
194
+ Input tensor
195
+ weights : Tensor
196
+ Weights tensor
197
+
198
+ Returns
199
+ -------
200
+ Tensor
201
+ Product of complex multiplication
202
+ """
203
+ # (batch, in_channel, x, y), (in_channel, out_channel, x, y) -> (batch, out_channel, x, y)
204
+ cweights = torch.view_as_complex(weights.to_tensor().contiguous())
205
+ return torch.einsum("bixy,ioxy->boxy", input, cweights)
206
+
207
+ def forward(self, x: Tensor) -> Tensor:
208
+ batchsize = x.shape[0]
209
+ # Compute Fourier coeffcients up to factor of e^(- something constant)
210
+ x_ft = torch.fft.rfft2(x)
211
+
212
+ # Multiply relevant Fourier modes
213
+ out_ft = torch.zeros(
214
+ batchsize,
215
+ self.out_channels,
216
+ x.size(-2),
217
+ x.size(-1) // 2 + 1,
218
+ dtype=torch.cfloat,
219
+ device=x.device,
220
+ )
221
+ out_ft[:, :, : self.modes1, : self.modes2] = self.compl_mul2d(
222
+ x_ft[:, :, : self.modes1, : self.modes2],
223
+ self.weights1,
224
+ )
225
+ out_ft[:, :, -self.modes1 :, : self.modes2] = self.compl_mul2d(
226
+ x_ft[:, :, -self.modes1 :, : self.modes2],
227
+ self.weights2,
228
+ )
229
+
230
+ # Return to physical space
231
+ x = torch.fft.irfft2(out_ft, s=(x.size(-2), x.size(-1)))
232
+ return x
233
+
234
+ def reset_parameters(self):
235
+ """Reset spectral weights with distribution scale*N(0,1)"""
236
+ self.weights1.normal_(0, self.scale)
237
+ self.weights2.normal_(0, self.scale)
238
+
239
+
240
+ class FactorizedSpectralConv3d(nn.Module):
241
+ """3D Factorized Fourier layer. It does FFT, linear transform, and Inverse FFT.
242
+
243
+ Parameters
244
+ ----------
245
+ in_channels : int
246
+ Number of input channels
247
+ out_channels : int
248
+ Number of output channels
249
+ modes1 : int
250
+ Number of Fourier modes to multiply in first dimension, at most floor(N/2) + 1
251
+ modes2 : int
252
+ Number of Fourier modes to multiply in second dimension, at most floor(N/2) + 1
253
+ modes3 : int
254
+ Number of Fourier modes to multiply in third dimension, at most floor(N/2) + 1
255
+ rank : float
256
+ Rank of the decomposition
257
+ factorization : {'CP', 'TT', 'Tucker'}
258
+ Tensor factorization to use to decompose the tensor
259
+ fixed_rank_modes : List[int]
260
+ A list of modes for which the initial value is not modified
261
+ The last mode cannot be fixed due to error computation.
262
+ decomposition_kwargs : dict
263
+ Additional arguments to initialization of factorized tensors
264
+ """
265
+
266
+ def __init__(
267
+ self,
268
+ in_channels: int,
269
+ out_channels: int,
270
+ modes1: int,
271
+ modes2: int,
272
+ modes3: int,
273
+ rank: float,
274
+ factorization: str,
275
+ fixed_rank_modes: bool,
276
+ decomposition_kwargs: dict,
277
+ ):
278
+ super().__init__()
279
+
280
+ self.in_channels = in_channels
281
+ self.out_channels = out_channels
282
+ self.modes1 = (
283
+ modes1 # Number of Fourier modes to multiply, at most floor(N/2) + 1
284
+ )
285
+ self.modes2 = modes2
286
+ self.modes3 = modes3
287
+
288
+ self.scale = 1 / (in_channels * out_channels)
289
+ self.weights1 = tltorch.FactorizedTensor.new(
290
+ (in_channels, out_channels, self.modes1, self.modes2, self.modes3, 2),
291
+ rank=rank,
292
+ factorization=factorization,
293
+ fixed_rank_modes=fixed_rank_modes,
294
+ **decomposition_kwargs
295
+ )
296
+ self.weights2 = tltorch.FactorizedTensor.new(
297
+ (in_channels, out_channels, self.modes1, self.modes2, self.modes3, 2),
298
+ rank=rank,
299
+ factorization=factorization,
300
+ fixed_rank_modes=fixed_rank_modes,
301
+ **decomposition_kwargs
302
+ )
303
+ self.weights3 = tltorch.FactorizedTensor.new(
304
+ (in_channels, out_channels, self.modes1, self.modes2, self.modes3, 2),
305
+ rank=rank,
306
+ factorization=factorization,
307
+ fixed_rank_modes=fixed_rank_modes,
308
+ **decomposition_kwargs
309
+ )
310
+ self.weights4 = tltorch.FactorizedTensor.new(
311
+ (in_channels, out_channels, self.modes1, self.modes2, self.modes3, 2),
312
+ rank=rank,
313
+ factorization=factorization,
314
+ fixed_rank_modes=fixed_rank_modes,
315
+ **decomposition_kwargs
316
+ )
317
+ self.reset_parameters()
318
+
319
+ def compl_mul3d(self, input: Tensor, weights: Tensor) -> Tensor:
320
+ """Complex multiplication
321
+
322
+ Parameters
323
+ ----------
324
+ input : Tensor
325
+ Input tensor
326
+ weights : Tensor
327
+ Weights tensor
328
+
329
+ Returns
330
+ -------
331
+ Tensor
332
+ Product of complex multiplication
333
+ """
334
+ # (batch, in_channel, x, y, z), (in_channel, out_channel, x, y, z) -> (batch, out_channel, x, y, z)
335
+ cweights = torch.view_as_complex(weights.to_tensor().contiguous())
336
+ return torch.einsum("bixyz,ioxyz->boxyz", input, cweights)
337
+
338
+ def forward(self, x: Tensor) -> Tensor:
339
+ batchsize = x.shape[0]
340
+ # Compute Fourier coeffcients up to factor of e^(- something constant)
341
+ x_ft = torch.fft.rfftn(x, dim=[-3, -2, -1])
342
+
343
+ # Multiply relevant Fourier modes
344
+ out_ft = torch.zeros(
345
+ batchsize,
346
+ self.out_channels,
347
+ x.size(-3),
348
+ x.size(-2),
349
+ x.size(-1) // 2 + 1,
350
+ dtype=torch.cfloat,
351
+ device=x.device,
352
+ )
353
+
354
+ out_ft[:, :, : self.modes1, : self.modes2, : self.modes3] = self.compl_mul3d(
355
+ x_ft[:, :, : self.modes1, : self.modes2, : self.modes3], self.weights1
356
+ )
357
+ out_ft[:, :, -self.modes1 :, : self.modes2, : self.modes3] = self.compl_mul3d(
358
+ x_ft[:, :, -self.modes1 :, : self.modes2, : self.modes3], self.weights2
359
+ )
360
+ out_ft[:, :, : self.modes1, -self.modes2 :, : self.modes3] = self.compl_mul3d(
361
+ x_ft[:, :, : self.modes1, -self.modes2 :, : self.modes3], self.weights3
362
+ )
363
+ out_ft[:, :, -self.modes1 :, -self.modes2 :, : self.modes3] = self.compl_mul3d(
364
+ x_ft[:, :, -self.modes1 :, -self.modes2 :, : self.modes3], self.weights4
365
+ )
366
+
367
+ # Return to physical space
368
+ x = torch.fft.irfftn(out_ft, s=(x.size(-3), x.size(-2), x.size(-1)))
369
+ return x
370
+
371
+ def reset_parameters(self):
372
+ """Reset spectral weights with distribution scale*U(0,1)"""
373
+ self.weights1.normal_(0, self.scale)
374
+ self.weights2.normal_(0, self.scale)
375
+ self.weights3.normal_(0, self.scale)
376
+ self.weights4.normal_(0, self.scale)
377
+
378
+
379
+ class FactorizedSpectralConv4d(nn.Module):
380
+ """4D Factorized Fourier layer. It does FFT, linear transform, and Inverse FFT.
381
+
382
+ Parameters
383
+ ----------
384
+ in_channels : int
385
+ Number of input channels
386
+ out_channels : int
387
+ Number of output channels
388
+ modes1 : int
389
+ Number of Fourier modes to multiply in first dimension, at most floor(N/2) + 1
390
+ modes2 : int
391
+ Number of Fourier modes to multiply in second dimension, at most floor(N/2) + 1
392
+ modes3 : int
393
+ Number of Fourier modes to multiply in third dimension, at most floor(N/2) + 1
394
+ rank : float
395
+ Rank of the decomposition
396
+ factorization : {'CP', 'TT', 'Tucker'}
397
+ Tensor factorization to use to decompose the tensor
398
+ fixed_rank_modes : List[int]
399
+ A list of modes for which the initial value is not modified
400
+ The last mode cannot be fixed due to error computation.
401
+ decomposition_kwargs : dict
402
+ Additional arguments to initialization of factorized tensors
403
+ """
404
+
405
+ def __init__(
406
+ self,
407
+ in_channels: int,
408
+ out_channels: int,
409
+ modes1: int,
410
+ modes2: int,
411
+ modes3: int,
412
+ modes4: int,
413
+ rank: float,
414
+ factorization: str,
415
+ fixed_rank_modes: bool,
416
+ decomposition_kwargs: dict,
417
+ ):
418
+ super().__init__()
419
+
420
+ self.in_channels = in_channels
421
+ self.out_channels = out_channels
422
+
423
+ # Number of Fourier modes to multiply, at most floor(N/2) + 1
424
+ self.modes1 = modes1
425
+ self.modes2 = modes2
426
+ self.modes3 = modes3
427
+ self.modes4 = modes4
428
+
429
+ self.scale = 1 / (in_channels * out_channels)
430
+ self.weights1 = tltorch.FactorizedTensor.new(
431
+ (
432
+ in_channels,
433
+ out_channels,
434
+ self.modes1,
435
+ self.modes2,
436
+ self.modes3,
437
+ self.modes4,
438
+ 2,
439
+ ),
440
+ rank=rank,
441
+ factorization=factorization,
442
+ fixed_rank_modes=fixed_rank_modes,
443
+ **decomposition_kwargs
444
+ )
445
+ self.weights2 = tltorch.FactorizedTensor.new(
446
+ (
447
+ in_channels,
448
+ out_channels,
449
+ self.modes1,
450
+ self.modes2,
451
+ self.modes3,
452
+ self.modes4,
453
+ 2,
454
+ ),
455
+ rank=rank,
456
+ factorization=factorization,
457
+ fixed_rank_modes=fixed_rank_modes,
458
+ **decomposition_kwargs
459
+ )
460
+ self.weights3 = tltorch.FactorizedTensor.new(
461
+ (
462
+ in_channels,
463
+ out_channels,
464
+ self.modes1,
465
+ self.modes2,
466
+ self.modes3,
467
+ self.modes4,
468
+ 2,
469
+ ),
470
+ rank=rank,
471
+ factorization=factorization,
472
+ fixed_rank_modes=fixed_rank_modes,
473
+ **decomposition_kwargs
474
+ )
475
+ self.weights4 = tltorch.FactorizedTensor.new(
476
+ (
477
+ in_channels,
478
+ out_channels,
479
+ self.modes1,
480
+ self.modes2,
481
+ self.modes3,
482
+ self.modes4,
483
+ 2,
484
+ ),
485
+ rank=rank,
486
+ factorization=factorization,
487
+ fixed_rank_modes=fixed_rank_modes,
488
+ **decomposition_kwargs
489
+ )
490
+ self.weights5 = tltorch.FactorizedTensor.new(
491
+ (
492
+ in_channels,
493
+ out_channels,
494
+ self.modes1,
495
+ self.modes2,
496
+ self.modes3,
497
+ self.modes4,
498
+ 2,
499
+ ),
500
+ rank=rank,
501
+ factorization=factorization,
502
+ fixed_rank_modes=fixed_rank_modes,
503
+ **decomposition_kwargs
504
+ )
505
+ self.weights6 = tltorch.FactorizedTensor.new(
506
+ (
507
+ in_channels,
508
+ out_channels,
509
+ self.modes1,
510
+ self.modes2,
511
+ self.modes3,
512
+ self.modes4,
513
+ 2,
514
+ ),
515
+ rank=rank,
516
+ factorization=factorization,
517
+ fixed_rank_modes=fixed_rank_modes,
518
+ **decomposition_kwargs
519
+ )
520
+ self.weights7 = tltorch.FactorizedTensor.new(
521
+ (
522
+ in_channels,
523
+ out_channels,
524
+ self.modes1,
525
+ self.modes2,
526
+ self.modes3,
527
+ self.modes4,
528
+ 2,
529
+ ),
530
+ rank=rank,
531
+ factorization=factorization,
532
+ fixed_rank_modes=fixed_rank_modes,
533
+ **decomposition_kwargs
534
+ )
535
+ self.weights8 = tltorch.FactorizedTensor.new(
536
+ (
537
+ in_channels,
538
+ out_channels,
539
+ self.modes1,
540
+ self.modes2,
541
+ self.modes3,
542
+ self.modes4,
543
+ 2,
544
+ ),
545
+ rank=rank,
546
+ factorization=factorization,
547
+ fixed_rank_modes=fixed_rank_modes,
548
+ **decomposition_kwargs
549
+ )
550
+ self.reset_parameters()
551
+
552
+ def compl_mul4d(
553
+ self,
554
+ input: Tensor,
555
+ weights: Tensor,
556
+ ) -> Tensor:
557
+ """Complex multiplication
558
+
559
+ Parameters
560
+ ----------
561
+ input : Tensor
562
+ Input tensor
563
+ weights : Tensor
564
+ Weights tensor
565
+
566
+ Returns
567
+ -------
568
+ Tensor
569
+ Product of complex multiplication
570
+ """
571
+ # (batch, in_channel, x, y, z), (in_channel, out_channel, x, y, z) -> (batch, out_channel, x, y, z)
572
+ cweights = torch.view_as_complex(weights.to_tensor().contiguous())
573
+ return torch.einsum("bixyzt,ioxyzt->boxyzt", input, cweights)
574
+
575
+ def forward(self, x: Tensor) -> Tensor:
576
+ batchsize = x.shape[0]
577
+ # Compute Fourier coeffcients up to factor of e^(- something constant)
578
+ x_ft = torch.fft.rfftn(x, dim=[-4, -3, -2, -1])
579
+
580
+ # Multiply relevant Fourier modes
581
+ out_ft = torch.zeros(
582
+ batchsize,
583
+ self.out_channels,
584
+ x.size(-4),
585
+ x.size(-3),
586
+ x.size(-2),
587
+ x.size(-1) // 2 + 1,
588
+ dtype=torch.cfloat,
589
+ device=x.device,
590
+ )
591
+
592
+ # print(f'mod: size x: {x_ft.size()}, out: {out_ft.size()}')
593
+ # print(f'mod: x_ft[weight4]: {x_ft[:, :, self.modes1 :, self.modes2 :, : -self.modes3, :self.modes4].size()} weight4: {self.weights4.size()}')
594
+
595
+ out_ft[
596
+ :, :, : self.modes1, : self.modes2, : self.modes3, : self.modes4
597
+ ] = self.compl_mul4d(
598
+ x_ft[:, :, : self.modes1, : self.modes2, : self.modes3, : self.modes4],
599
+ self.weights1,
600
+ )
601
+ out_ft[
602
+ :, :, -self.modes1 :, : self.modes2, : self.modes3, : self.modes4
603
+ ] = self.compl_mul4d(
604
+ x_ft[:, :, -self.modes1 :, : self.modes2, : self.modes3, : self.modes4],
605
+ self.weights2,
606
+ )
607
+ out_ft[
608
+ :, :, : self.modes1, -self.modes2 :, : self.modes3, : self.modes4
609
+ ] = self.compl_mul4d(
610
+ x_ft[:, :, : self.modes1, -self.modes2 :, : self.modes3, : self.modes4],
611
+ self.weights3,
612
+ )
613
+ out_ft[
614
+ :, :, : self.modes1, : self.modes2, -self.modes3 :, : self.modes4
615
+ ] = self.compl_mul4d(
616
+ x_ft[:, :, : self.modes1, : self.modes2, -self.modes3 :, : self.modes4],
617
+ self.weights4,
618
+ )
619
+ out_ft[
620
+ :, :, -self.modes1 :, -self.modes2 :, : self.modes3, : self.modes4
621
+ ] = self.compl_mul4d(
622
+ x_ft[:, :, -self.modes1 :, -self.modes2 :, : self.modes3, : self.modes4],
623
+ self.weights5,
624
+ )
625
+ out_ft[
626
+ :, :, -self.modes1 :, : self.modes2, -self.modes3 :, : self.modes4
627
+ ] = self.compl_mul4d(
628
+ x_ft[:, :, -self.modes1 :, : self.modes2, -self.modes3 :, : self.modes4],
629
+ self.weights6,
630
+ )
631
+ out_ft[
632
+ :, :, : self.modes1, -self.modes2 :, -self.modes3 :, : self.modes4
633
+ ] = self.compl_mul4d(
634
+ x_ft[:, :, : self.modes1, -self.modes2 :, -self.modes3 :, : self.modes4],
635
+ self.weights7,
636
+ )
637
+ out_ft[
638
+ :, :, -self.modes1 :, -self.modes2 :, -self.modes3 :, : self.modes4
639
+ ] = self.compl_mul4d(
640
+ x_ft[:, :, -self.modes1 :, -self.modes2 :, -self.modes3 :, : self.modes4],
641
+ self.weights8,
642
+ )
643
+
644
+ # Return to physical space
645
+ x = torch.fft.irfftn(out_ft, s=(x.size(-4), x.size(-3), x.size(-2), x.size(-1)))
646
+ return x
647
+
648
+ def reset_parameters(self):
649
+ """Reset spectral weights with distribution scale*N(0,1)"""
650
+ self.weights1.normal_(0, self.scale)
651
+ self.weights2.normal_(0, self.scale)
652
+ self.weights3.normal_(0, self.scale)
653
+ self.weights4.normal_(0, self.scale)
654
+ self.weights5.normal_(0, self.scale)
655
+ self.weights6.normal_(0, self.scale)
656
+ self.weights7.normal_(0, self.scale)
657
+ self.weights8.normal_(0, self.scale)
mhd/tfno/tfno.py ADDED
@@ -0,0 +1,1043 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-FileCopyrightText: All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from dataclasses import dataclass
18
+ from typing import List, Tuple, Union
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ from torch import Tensor
24
+
25
+ import physicsnemo # noqa: F401 for docs
26
+ import physicsnemo.models.layers as layers
27
+ from .spectral_layers import (
28
+ FactorizedSpectralConv1d,
29
+ FactorizedSpectralConv2d,
30
+ FactorizedSpectralConv3d,
31
+ FactorizedSpectralConv4d,
32
+ )
33
+
34
+ from physicsnemo.models.meta import ModelMetaData
35
+ from physicsnemo.models.mlp import FullyConnected
36
+ from physicsnemo.models.module import Module
37
+
38
+ # ===================================================================
39
+ # ===================================================================
40
+ # 1D TFNO
41
+ # ===================================================================
42
+ # ===================================================================
43
+
44
+
45
+ class TFNO1DEncoder(nn.Module):
46
+ """1D Spectral encoder for TFNO
47
+
48
+ Parameters
49
+ ----------
50
+ in_channels : int, optional
51
+ Number of input channels, by default 1
52
+ num_fno_layers : int, optional
53
+ Number of spectral convolutional layers, by default 4
54
+ fno_layer_size : int, optional
55
+ Latent features size in spectral convolutions, by default 32
56
+ num_fno_modes : Union[int, List[int]], optional
57
+ Number of Fourier modes kept in spectral convolutions, by default 16
58
+ padding : Union[int, List[int]], optional
59
+ Domain padding for spectral convolutions, by default 8
60
+ padding_type : str, optional
61
+ Type of padding for spectral convolutions, by default "constant"
62
+ activation_fn : nn.Module, optional
63
+ Activation function, by default nn.GELU
64
+ coord_features : bool, optional
65
+ Use coordinate grid as additional feature map, by default True
66
+ rank : float, optional
67
+ Rank of the decomposition, by default 1.0
68
+ factorization : {'CP', 'TT', 'Tucker'}, optional
69
+ Tensor factorization to use to decompose the tensor, by default 'CP'
70
+ fixed_rank_modes : List[int], optional
71
+ A list of modes for which the initial value is not modified, by default None
72
+ The last mode cannot be fixed due to error computation.
73
+ decomposition_kwargs : dict, optional
74
+ Additional arguments to initialization of factorized tensors, by default dict()
75
+ """
76
+
77
+ def __init__(
78
+ self,
79
+ in_channels: int = 1,
80
+ num_fno_layers: int = 4,
81
+ fno_layer_size: int = 32,
82
+ num_fno_modes: Union[int, List[int]] = 16,
83
+ padding: Union[int, List[int]] = 8,
84
+ padding_type: str = "constant",
85
+ activation_fn: nn.Module = nn.GELU(),
86
+ coord_features: bool = True,
87
+ rank: float = 1.0,
88
+ factorization: str = "cp",
89
+ fixed_rank_modes: List[int] = None,
90
+ decomposition_kwargs: dict = dict(),
91
+ ) -> None:
92
+ super().__init__()
93
+
94
+ self.in_channels = in_channels
95
+ self.num_fno_layers = num_fno_layers
96
+ self.fno_width = fno_layer_size
97
+ self.activation_fn = activation_fn
98
+
99
+ # TensorLy arguments
100
+ self.rank = rank
101
+ self.factorization = factorization
102
+ self.fixed_rank_modes = fixed_rank_modes
103
+ self.decomposition_kwargs = decomposition_kwargs
104
+
105
+ # Add relative coordinate feature
106
+ self.coord_features = coord_features
107
+ if self.coord_features:
108
+ self.in_channels = self.in_channels + 1
109
+
110
+ # Padding values for spectral conv
111
+ if isinstance(padding, int):
112
+ padding = [padding]
113
+ self.pad = padding[:1]
114
+ self.ipad = [-pad if pad > 0 else None for pad in self.pad]
115
+ self.padding_type = padding_type
116
+
117
+ if isinstance(num_fno_modes, int):
118
+ num_fno_modes = [num_fno_modes]
119
+
120
+ # build lift
121
+ self.build_lift_network()
122
+ self.build_fno(num_fno_modes)
123
+
124
+ def build_lift_network(self) -> None:
125
+ """construct network for lifting variables to latent space."""
126
+ self.lift_network = torch.nn.Sequential()
127
+ self.lift_network.append(
128
+ layers.Conv1dFCLayer(self.in_channels, int(self.fno_width / 2))
129
+ )
130
+ self.lift_network.append(self.activation_fn)
131
+ self.lift_network.append(
132
+ layers.Conv1dFCLayer(int(self.fno_width / 2), self.fno_width)
133
+ )
134
+
135
+ def build_fno(self, num_fno_modes: List[int]) -> None:
136
+ """construct FNO block.
137
+ Parameters
138
+ ----------
139
+ num_fno_modes : List[int]
140
+ Number of Fourier modes kept in spectral convolutions
141
+
142
+ """
143
+ # Build Neural Fourier Operators
144
+ self.spconv_layers = nn.ModuleList()
145
+ self.conv_layers = nn.ModuleList()
146
+ for _ in range(self.num_fno_layers):
147
+ self.spconv_layers.append(
148
+ FactorizedSpectralConv1d(
149
+ self.fno_width,
150
+ self.fno_width,
151
+ num_fno_modes[0],
152
+ self.rank,
153
+ self.factorization,
154
+ self.fixed_rank_modes,
155
+ self.decomposition_kwargs,
156
+ )
157
+ )
158
+ self.conv_layers.append(nn.Conv1d(self.fno_width, self.fno_width, 1))
159
+
160
+ def forward(self, x: Tensor) -> Tensor:
161
+ if self.coord_features:
162
+ coord_feat = self.meshgrid(list(x.shape), x.device)
163
+ x = torch.cat((x, coord_feat), dim=1)
164
+
165
+ x = self.lift_network(x)
166
+ # (left, right)
167
+ x = F.pad(x, (0, self.pad[0]), mode=self.padding_type)
168
+ # Spectral layers
169
+ for k, conv_w in enumerate(zip(self.conv_layers, self.spconv_layers)):
170
+ conv, w = conv_w
171
+ if k < len(self.conv_layers) - 1:
172
+ x = self.activation_fn(conv(x) + w(x))
173
+ else:
174
+ x = conv(x) + w(x)
175
+
176
+ x = x[..., : self.ipad[0]]
177
+ return x
178
+
179
+ def meshgrid(self, shape: List[int], device: torch.device) -> Tensor:
180
+ """Creates 1D meshgrid feature
181
+
182
+ Parameters
183
+ ----------
184
+ shape : List[int]
185
+ Tensor shape
186
+ device : torch.device
187
+ Device model is on
188
+
189
+ Returns
190
+ -------
191
+ Tensor
192
+ Meshgrid tensor
193
+ """
194
+ bsize, size_x = shape[0], shape[2]
195
+ grid_x = torch.linspace(0, 1, size_x, dtype=torch.float32, device=device)
196
+ grid_x = grid_x.unsqueeze(0).unsqueeze(0).repeat(bsize, 1, 1)
197
+ return grid_x
198
+
199
+ def grid_to_points(self, value: Tensor) -> Tuple[Tensor, List[int]]:
200
+ """converting from grid based (image) to point based representation
201
+
202
+ Parameters
203
+ ----------
204
+ value : Meshgrid tensor
205
+
206
+ Returns
207
+ -------
208
+ Tuple
209
+ Tensor, meshgrid shape
210
+ """
211
+ y_shape = list(value.size())
212
+ output = torch.permute(value, (0, 2, 1))
213
+ return output.reshape(-1, output.size(-1)), y_shape
214
+
215
+ def points_to_grid(self, value: Tensor, shape: List[int]) -> Tensor:
216
+ """converting from point based to grid based (image) representation
217
+
218
+ Parameters
219
+ ----------
220
+ value : Tensor
221
+ Tensor
222
+ shape : List[int]
223
+ meshgrid shape
224
+
225
+ Returns
226
+ -------
227
+ Tensor
228
+ Meshgrid tensor
229
+ """
230
+ output = value.reshape(shape[0], shape[2], value.size(-1))
231
+ return torch.permute(output, (0, 2, 1))
232
+
233
+
234
+ # ===================================================================
235
+ # ===================================================================
236
+ # 2D TFNO
237
+ # ===================================================================
238
+ # ===================================================================
239
+
240
+
241
+ class TFNO2DEncoder(nn.Module):
242
+ """2D Spectral encoder for TFNO
243
+
244
+ Parameters
245
+ ----------
246
+ in_channels : int, optional
247
+ Number of input channels, by default 1
248
+ num_fno_layers : int, optional
249
+ Number of spectral convolutional layers, by default 4
250
+ fno_layer_size : int, optional
251
+ Latent features size in spectral convolutions, by default 32
252
+ num_fno_modes : Union[int, List[int]], optional
253
+ Number of Fourier modes kept in spectral convolutions, by default 16
254
+ padding : Union[int, List[int]], optional
255
+ Domain padding for spectral convolutions, by default 8
256
+ padding_type : str, optional
257
+ Type of padding for spectral convolutions, by default "constant"
258
+ activation_fn : nn.Module, optional
259
+ Activation function, by default nn.GELU
260
+ coord_features : bool, optional
261
+ Use coordinate grid as additional feature map, by default True
262
+ rank : float, optional
263
+ Rank of the decomposition, by default 1.0
264
+ factorization : {'CP', 'TT', 'Tucker'}, optional
265
+ Tensor factorization to use to decompose the tensor, by default 'CP'
266
+ fixed_rank_modes : List[int], optional
267
+ A list of modes for which the initial value is not modified, by default None
268
+ The last mode cannot be fixed due to error computation.
269
+ decomposition_kwargs : dict, optional
270
+ Additional arguments to initialization of factorized tensors, by default dict()
271
+ """
272
+
273
+ def __init__(
274
+ self,
275
+ in_channels: int = 1,
276
+ num_fno_layers: int = 4,
277
+ fno_layer_size: int = 32,
278
+ num_fno_modes: Union[int, List[int]] = 16,
279
+ padding: Union[int, List[int]] = 8,
280
+ padding_type: str = "constant",
281
+ activation_fn: nn.Module = nn.GELU(),
282
+ coord_features: bool = True,
283
+ rank: float = 1.0,
284
+ factorization: str = "cp",
285
+ fixed_rank_modes: List[int] = None,
286
+ decomposition_kwargs: dict = dict(),
287
+ ) -> None:
288
+ super().__init__()
289
+ self.in_channels = in_channels
290
+ self.num_fno_layers = num_fno_layers
291
+ self.fno_width = fno_layer_size
292
+ self.coord_features = coord_features
293
+ self.activation_fn = activation_fn
294
+
295
+ # TensorLy arguments
296
+ self.rank = rank
297
+ self.factorization = factorization
298
+ self.fixed_rank_modes = fixed_rank_modes
299
+ self.decomposition_kwargs = decomposition_kwargs
300
+
301
+ # Add relative coordinate feature
302
+ if self.coord_features:
303
+ self.in_channels = self.in_channels + 2
304
+
305
+ # Padding values for spectral conv
306
+ if isinstance(padding, int):
307
+ padding = [padding, padding]
308
+ padding = padding + [0, 0] # Pad with zeros for smaller lists
309
+ self.pad = padding[:2]
310
+ self.ipad = [-pad if pad > 0 else None for pad in self.pad]
311
+ self.padding_type = padding_type
312
+
313
+ if isinstance(num_fno_modes, int):
314
+ num_fno_modes = [num_fno_modes, num_fno_modes]
315
+
316
+ # build lift
317
+ self.build_lift_network()
318
+ self.build_fno(num_fno_modes)
319
+
320
+ def build_lift_network(self) -> None:
321
+ """construct network for lifting variables to latent space."""
322
+ # Initial lift network
323
+ self.lift_network = torch.nn.Sequential()
324
+ self.lift_network.append(
325
+ layers.Conv2dFCLayer(self.in_channels, int(self.fno_width / 2))
326
+ )
327
+ self.lift_network.append(self.activation_fn)
328
+ self.lift_network.append(
329
+ layers.Conv2dFCLayer(int(self.fno_width / 2), self.fno_width)
330
+ )
331
+
332
+ def build_fno(self, num_fno_modes: List[int]) -> None:
333
+ """construct TFNO block.
334
+ Parameters
335
+ ----------
336
+ num_fno_modes : List[int]
337
+ Number of Fourier modes kept in spectral convolutions
338
+
339
+ """
340
+ # Build Neural Fourier Operators
341
+ self.spconv_layers = nn.ModuleList()
342
+ self.conv_layers = nn.ModuleList()
343
+ for _ in range(self.num_fno_layers):
344
+ self.spconv_layers.append(
345
+ FactorizedSpectralConv2d(
346
+ self.fno_width,
347
+ self.fno_width,
348
+ num_fno_modes[0],
349
+ num_fno_modes[1],
350
+ self.rank,
351
+ self.factorization,
352
+ self.fixed_rank_modes,
353
+ self.decomposition_kwargs,
354
+ )
355
+ )
356
+ self.conv_layers.append(nn.Conv2d(self.fno_width, self.fno_width, 1))
357
+
358
+ def forward(self, x: Tensor) -> Tensor:
359
+ if x.dim() != 4:
360
+ raise ValueError(
361
+ "Only 4D tensors [batch, in_channels, grid_x, grid_y] accepted for 2D FNO"
362
+ )
363
+
364
+ if self.coord_features:
365
+ coord_feat = self.meshgrid(list(x.shape), x.device)
366
+ x = torch.cat((x, coord_feat), dim=1)
367
+
368
+ x = self.lift_network(x)
369
+ # (left, right, top, bottom)
370
+ x = F.pad(x, (0, self.pad[1], 0, self.pad[0]), mode=self.padding_type)
371
+ # Spectral layers
372
+ for k, conv_w in enumerate(zip(self.conv_layers, self.spconv_layers)):
373
+ conv, w = conv_w
374
+ if k < len(self.conv_layers) - 1:
375
+ x = self.activation_fn(conv(x) + w(x))
376
+ else:
377
+ x = conv(x) + w(x)
378
+
379
+ # remove padding
380
+ x = x[..., : self.ipad[0], : self.ipad[1]]
381
+
382
+ return x
383
+
384
+ def meshgrid(self, shape: List[int], device: torch.device) -> Tensor:
385
+ """Creates 2D meshgrid feature
386
+
387
+ Parameters
388
+ ----------
389
+ shape : List[int]
390
+ Tensor shape
391
+ device : torch.device
392
+ Device model is on
393
+
394
+ Returns
395
+ -------
396
+ Tensor
397
+ Meshgrid tensor
398
+ """
399
+ bsize, size_x, size_y = shape[0], shape[2], shape[3]
400
+ grid_x = torch.linspace(0, 1, size_x, dtype=torch.float32, device=device)
401
+ grid_y = torch.linspace(0, 1, size_y, dtype=torch.float32, device=device)
402
+ grid_x, grid_y = torch.meshgrid(grid_x, grid_y, indexing="ij")
403
+ grid_x = grid_x.unsqueeze(0).unsqueeze(0).repeat(bsize, 1, 1, 1)
404
+ grid_y = grid_y.unsqueeze(0).unsqueeze(0).repeat(bsize, 1, 1, 1)
405
+ return torch.cat((grid_x, grid_y), dim=1)
406
+
407
+ def grid_to_points(self, value: Tensor) -> Tuple[Tensor, List[int]]:
408
+ """converting from grid based (image) to point based representation
409
+
410
+ Parameters
411
+ ----------
412
+ value : Meshgrid tensor
413
+
414
+ Returns
415
+ -------
416
+ Tuple
417
+ Tensor, meshgrid shape
418
+ """
419
+ y_shape = list(value.size())
420
+ output = torch.permute(value, (0, 2, 3, 1))
421
+ return output.reshape(-1, output.size(-1)), y_shape
422
+
423
+ def points_to_grid(self, value: Tensor, shape: List[int]) -> Tensor:
424
+ """converting from point based to grid based (image) representation
425
+
426
+ Parameters
427
+ ----------
428
+ value : Tensor
429
+ Tensor
430
+ shape : List[int]
431
+ meshgrid shape
432
+
433
+ Returns
434
+ -------
435
+ Tensor
436
+ Meshgrid tensor
437
+ """
438
+ output = value.reshape(shape[0], shape[2], shape[3], value.size(-1))
439
+ return torch.permute(output, (0, 3, 1, 2))
440
+
441
+
442
+ # ===================================================================
443
+ # ===================================================================
444
+ # 3D TFNO
445
+ # ===================================================================
446
+ # ===================================================================
447
+
448
+
449
+ class TFNO3DEncoder(nn.Module):
450
+ """3D Spectral encoder for TFNO
451
+
452
+ Parameters
453
+ ----------
454
+ in_channels : int, optional
455
+ Number of input channels, by default 1
456
+ num_fno_layers : int, optional
457
+ Number of spectral convolutional layers, by default 4
458
+ fno_layer_size : int, optional
459
+ Latent features size in spectral convolutions, by default 32
460
+ num_fno_modes : Union[int, List[int]], optional
461
+ Number of Fourier modes kept in spectral convolutions, by default 16
462
+ padding : Union[int, List[int]], optional
463
+ Domain padding for spectral convolutions, by default 8
464
+ padding_type : str, optional
465
+ Type of padding for spectral convolutions, by default "constant"
466
+ activation_fn : nn.Module, optional
467
+ Activation function, by default nn.GELU
468
+ coord_features : bool, optional
469
+ Use coordinate grid as additional feature map, by default True
470
+ rank : float, optional
471
+ Rank of the decomposition, by default 1.0
472
+ factorization : {'CP', 'TT', 'Tucker'}, optional
473
+ Tensor factorization to use to decompose the tensor, by default 'CP'
474
+ fixed_rank_modes : List[int], optional
475
+ A list of modes for which the initial value is not modified, by default None
476
+ The last mode cannot be fixed due to error computation.
477
+ decomposition_kwargs : dict, optional
478
+ Additional arguments to initialization of factorized tensors, by default dict()
479
+ """
480
+
481
+ def __init__(
482
+ self,
483
+ in_channels: int = 1,
484
+ num_fno_layers: int = 4,
485
+ fno_layer_size: int = 32,
486
+ num_fno_modes: Union[int, List[int]] = 16,
487
+ padding: Union[int, List[int]] = 8,
488
+ padding_type: str = "constant",
489
+ activation_fn: nn.Module = nn.GELU(),
490
+ coord_features: bool = True,
491
+ rank: float = 1.0,
492
+ factorization: str = "cp",
493
+ fixed_rank_modes: List[int] = None,
494
+ decomposition_kwargs: dict = dict(),
495
+ ) -> None:
496
+ super().__init__()
497
+
498
+ self.in_channels = in_channels
499
+ self.num_fno_layers = num_fno_layers
500
+ self.fno_width = fno_layer_size
501
+ self.coord_features = coord_features
502
+ self.activation_fn = activation_fn
503
+
504
+ # TensorLy arguments
505
+ self.rank = rank
506
+ self.factorization = factorization
507
+ self.fixed_rank_modes = fixed_rank_modes
508
+ self.decomposition_kwargs = decomposition_kwargs
509
+
510
+ # Add relative coordinate feature
511
+ if self.coord_features:
512
+ self.in_channels = self.in_channels + 3
513
+
514
+ # Padding values for spectral conv
515
+ if isinstance(padding, int):
516
+ padding = [padding, padding, padding]
517
+ padding = padding + [0, 0, 0] # Pad with zeros for smaller lists
518
+ self.pad = padding[:3]
519
+ self.ipad = [-pad if pad > 0 else None for pad in self.pad]
520
+ self.padding_type = padding_type
521
+
522
+ if isinstance(num_fno_modes, int):
523
+ num_fno_modes = [num_fno_modes, num_fno_modes, num_fno_modes]
524
+
525
+ # build lift
526
+ self.build_lift_network()
527
+ self.build_fno(num_fno_modes)
528
+
529
+ def build_lift_network(self) -> None:
530
+ """construct network for lifting variables to latent space."""
531
+ # Initial lift network
532
+ self.lift_network = torch.nn.Sequential()
533
+ self.lift_network.append(
534
+ layers.Conv3dFCLayer(self.in_channels, int(self.fno_width / 2))
535
+ )
536
+ self.lift_network.append(self.activation_fn)
537
+ self.lift_network.append(
538
+ layers.Conv3dFCLayer(int(self.fno_width / 2), self.fno_width)
539
+ )
540
+
541
+ def build_fno(self, num_fno_modes: List[int]) -> None:
542
+ """construct FNO block.
543
+ Parameters
544
+ ----------
545
+ num_fno_modes : List[int]
546
+ Number of Fourier modes kept in spectral convolutions
547
+
548
+ """
549
+ # Build Neural Fourier Operators
550
+ self.spconv_layers = nn.ModuleList()
551
+ self.conv_layers = nn.ModuleList()
552
+ for _ in range(self.num_fno_layers):
553
+ self.spconv_layers.append(
554
+ FactorizedSpectralConv3d(
555
+ self.fno_width,
556
+ self.fno_width,
557
+ num_fno_modes[0],
558
+ num_fno_modes[1],
559
+ num_fno_modes[2],
560
+ self.rank,
561
+ self.factorization,
562
+ self.fixed_rank_modes,
563
+ self.decomposition_kwargs,
564
+ )
565
+ )
566
+ self.conv_layers.append(nn.Conv3d(self.fno_width, self.fno_width, 1))
567
+
568
+ def forward(self, x: Tensor) -> Tensor:
569
+ if self.coord_features:
570
+ coord_feat = self.meshgrid(list(x.shape), x.device)
571
+ x = torch.cat((x, coord_feat), dim=1)
572
+
573
+ x = self.lift_network(x)
574
+ # (left, right, top, bottom, front, back)
575
+ x = F.pad(
576
+ x,
577
+ (0, self.pad[2], 0, self.pad[1], 0, self.pad[0]),
578
+ mode=self.padding_type,
579
+ )
580
+ # Spectral layers
581
+ for k, conv_w in enumerate(zip(self.conv_layers, self.spconv_layers)):
582
+ conv, w = conv_w
583
+ if k < len(self.conv_layers) - 1:
584
+ x = self.activation_fn(conv(x) + w(x))
585
+ else:
586
+ x = conv(x) + w(x)
587
+
588
+ x = x[..., : self.ipad[0], : self.ipad[1], : self.ipad[2]]
589
+ return x
590
+
591
+ def meshgrid(self, shape: List[int], device: torch.device) -> Tensor:
592
+ """Creates 3D meshgrid feature
593
+
594
+ Parameters
595
+ ----------
596
+ shape : List[int]
597
+ Tensor shape
598
+ device : torch.device
599
+ Device model is on
600
+
601
+ Returns
602
+ -------
603
+ Tensor
604
+ Meshgrid tensor
605
+ """
606
+ bsize, size_x, size_y, size_z = shape[0], shape[2], shape[3], shape[4]
607
+ grid_x = torch.linspace(0, 1, size_x, dtype=torch.float32, device=device)
608
+ grid_y = torch.linspace(0, 1, size_y, dtype=torch.float32, device=device)
609
+ grid_z = torch.linspace(0, 1, size_z, dtype=torch.float32, device=device)
610
+ grid_x, grid_y, grid_z = torch.meshgrid(grid_x, grid_y, grid_z, indexing="ij")
611
+ grid_x = grid_x.unsqueeze(0).unsqueeze(0).repeat(bsize, 1, 1, 1, 1)
612
+ grid_y = grid_y.unsqueeze(0).unsqueeze(0).repeat(bsize, 1, 1, 1, 1)
613
+ grid_z = grid_z.unsqueeze(0).unsqueeze(0).repeat(bsize, 1, 1, 1, 1)
614
+ return torch.cat((grid_x, grid_y, grid_z), dim=1)
615
+
616
+ def grid_to_points(self, value: Tensor) -> Tuple[Tensor, List[int]]:
617
+ """converting from grid based (image) to point based representation
618
+
619
+ Parameters
620
+ ----------
621
+ value : Meshgrid tensor
622
+
623
+ Returns
624
+ -------
625
+ Tuple
626
+ Tensor, meshgrid shape
627
+ """
628
+ y_shape = list(value.size())
629
+ output = torch.permute(value, (0, 2, 3, 4, 1))
630
+ return output.reshape(-1, output.size(-1)), y_shape
631
+
632
+ def points_to_grid(self, value: Tensor, shape: List[int]) -> Tensor:
633
+ """converting from point based to grid based (image) representation
634
+
635
+ Parameters
636
+ ----------
637
+ value : Tensor
638
+ Tensor
639
+ shape : List[int]
640
+ meshgrid shape
641
+
642
+ Returns
643
+ -------
644
+ Tensor
645
+ Meshgrid tensor
646
+ """
647
+ output = value.reshape(shape[0], shape[2], shape[3], shape[4], value.size(-1))
648
+ return torch.permute(output, (0, 4, 1, 2, 3))
649
+
650
+
651
+ # ===================================================================
652
+ # ===================================================================
653
+ # 4D TFNO
654
+ # ===================================================================
655
+ # ===================================================================
656
+
657
+
658
+ class TFNO4DEncoder(nn.Module):
659
+ """4D Spectral encoder for TFNO
660
+
661
+ Parameters
662
+ ----------
663
+ in_channels : int, optional
664
+ Number of input channels, by default 1
665
+ num_fno_layers : int, optional
666
+ Number of spectral convolutional layers, by default 4
667
+ fno_layer_size : int, optional
668
+ Latent features size in spectral convolutions, by default 32
669
+ num_fno_modes : Union[int, List[int]], optional
670
+ Number of Fourier modes kept in spectral convolutions, by default 16
671
+ padding : Union[int, List[int]], optional
672
+ Domain padding for spectral convolutions, by default 8
673
+ padding_type : str, optional
674
+ Type of padding for spectral convolutions, by default "constant"
675
+ activation_fn : nn.Module, optional
676
+ Activation function, by default nn.GELU
677
+ coord_features : bool, optional
678
+ Use coordinate grid as additional feature map, by default True
679
+ rank : float, optional
680
+ Rank of the decomposition, by default 1.0
681
+ factorization : {'CP', 'TT', 'Tucker'}, optional
682
+ Tensor factorization to use to decompose the tensor, by default 'CP'
683
+ fixed_rank_modes : List[int], optional
684
+ A list of modes for which the initial value is not modified, by default None
685
+ The last mode cannot be fixed due to error computation.
686
+ decomposition_kwargs : dict, optional
687
+ Additional arguments to initialization of factorized tensors, by default dict()
688
+ """
689
+
690
+ def __init__(
691
+ self,
692
+ in_channels: int = 1,
693
+ num_fno_layers: int = 4,
694
+ fno_layer_size: int = 32,
695
+ num_fno_modes: Union[int, List[int]] = 16,
696
+ padding: Union[int, List[int]] = 8,
697
+ padding_type: str = "constant",
698
+ activation_fn: nn.Module = nn.GELU(),
699
+ coord_features: bool = True,
700
+ rank: float = 1.0,
701
+ factorization: str = "cp",
702
+ fixed_rank_modes: List[int] = None,
703
+ decomposition_kwargs: dict = dict(),
704
+ ) -> None:
705
+ super().__init__()
706
+
707
+ self.in_channels = in_channels
708
+ self.num_fno_layers = num_fno_layers
709
+ self.fno_width = fno_layer_size
710
+ self.coord_features = coord_features
711
+ self.activation_fn = activation_fn
712
+
713
+ # TensorLy arguments
714
+ self.rank = rank
715
+ self.factorization = factorization
716
+ self.fixed_rank_modes = fixed_rank_modes
717
+ self.decomposition_kwargs = decomposition_kwargs
718
+
719
+ # Add relative coordinate feature
720
+ if self.coord_features:
721
+ self.in_channels = self.in_channels + 4
722
+
723
+ # Padding values for spectral conv
724
+ if isinstance(padding, int):
725
+ padding = [padding, padding, padding, padding]
726
+ padding = padding + [0, 0, 0, 0] # Pad with zeros for smaller lists
727
+ self.pad = padding[:4]
728
+ self.ipad = [-pad if pad > 0 else None for pad in self.pad]
729
+ self.padding_type = padding_type
730
+
731
+ if isinstance(num_fno_modes, int):
732
+ num_fno_modes = [num_fno_modes, num_fno_modes, num_fno_modes, num_fno_modes]
733
+
734
+ # build lift
735
+ self.build_lift_network()
736
+ self.build_fno(num_fno_modes)
737
+
738
+ def build_lift_network(self) -> None:
739
+ """construct network for lifting variables to latent space."""
740
+ # Initial lift network
741
+ self.lift_network = torch.nn.Sequential()
742
+ self.lift_network.append(
743
+ layers.ConvNdFCLayer(self.in_channels, int(self.fno_width / 2))
744
+ )
745
+ self.lift_network.append(self.activation_fn)
746
+ self.lift_network.append(
747
+ layers.ConvNdFCLayer(int(self.fno_width / 2), self.fno_width)
748
+ )
749
+
750
+ def build_fno(self, num_fno_modes: List[int]) -> None:
751
+ """construct TFNO block.
752
+ Parameters
753
+ ----------
754
+ num_fno_modes : List[int]
755
+ Number of Fourier modes kept in spectral convolutions
756
+
757
+ """
758
+ # Build Neural Fourier Operators
759
+ self.spconv_layers = nn.ModuleList()
760
+ self.conv_layers = nn.ModuleList()
761
+ for _ in range(self.num_fno_layers):
762
+ self.spconv_layers.append(
763
+ FactorizedSpectralConv4d(
764
+ self.fno_width,
765
+ self.fno_width,
766
+ num_fno_modes[0],
767
+ num_fno_modes[1],
768
+ num_fno_modes[2],
769
+ num_fno_modes[3],
770
+ self.rank,
771
+ self.factorization,
772
+ self.fixed_rank_modes,
773
+ self.decomposition_kwargs,
774
+ )
775
+ )
776
+ self.conv_layers.append(
777
+ layers.ConvNdKernel1Layer(self.fno_width, self.fno_width)
778
+ )
779
+
780
+ def forward(self, x: Tensor) -> Tensor:
781
+ if self.coord_features:
782
+ coord_feat = self.meshgrid(list(x.shape), x.device)
783
+ x = torch.cat((x, coord_feat), dim=1)
784
+
785
+ x = self.lift_network(x)
786
+ # (left, right, top, bottom, front, back, past, future)
787
+ x = F.pad(
788
+ x,
789
+ (0, self.pad[3], 0, self.pad[2], 0, self.pad[1], 0, self.pad[0]),
790
+ mode=self.padding_type,
791
+ )
792
+ # Spectral layers
793
+ for k, conv_w in enumerate(zip(self.conv_layers, self.spconv_layers)):
794
+ conv, w = conv_w
795
+ if k < len(self.conv_layers) - 1:
796
+ x = self.activation_fn(conv(x) + w(x))
797
+ else:
798
+ x = conv(x) + w(x)
799
+
800
+ x = x[..., : self.ipad[0], : self.ipad[1], : self.ipad[2], : self.ipad[3]]
801
+ return x
802
+
803
+ def meshgrid(self, shape: List[int], device: torch.device) -> Tensor:
804
+ """Creates 4D meshgrid feature
805
+
806
+ Parameters
807
+ ----------
808
+ shape : List[int]
809
+ Tensor shape
810
+ device : torch.device
811
+ Device model is on
812
+
813
+ Returns
814
+ -------
815
+ Tensor
816
+ Meshgrid tensor
817
+ """
818
+ bsize, size_x, size_y, size_z, size_t = (
819
+ shape[0],
820
+ shape[2],
821
+ shape[3],
822
+ shape[4],
823
+ shape[5],
824
+ )
825
+ grid_x = torch.linspace(0, 1, size_x, dtype=torch.float32, device=device)
826
+ grid_y = torch.linspace(0, 1, size_y, dtype=torch.float32, device=device)
827
+ grid_z = torch.linspace(0, 1, size_z, dtype=torch.float32, device=device)
828
+ grid_t = torch.linspace(0, 1, size_t, dtype=torch.float32, device=device)
829
+ grid_x, grid_y, grid_z, grid_t = torch.meshgrid(
830
+ grid_x, grid_y, grid_z, grid_t, indexing="ij"
831
+ )
832
+ grid_x = grid_x.unsqueeze(0).unsqueeze(0).repeat(bsize, 1, 1, 1, 1, 1)
833
+ grid_y = grid_y.unsqueeze(0).unsqueeze(0).repeat(bsize, 1, 1, 1, 1, 1)
834
+ grid_z = grid_z.unsqueeze(0).unsqueeze(0).repeat(bsize, 1, 1, 1, 1, 1)
835
+ grid_t = grid_t.unsqueeze(0).unsqueeze(0).repeat(bsize, 1, 1, 1, 1, 1)
836
+ return torch.cat((grid_x, grid_y, grid_z, grid_t), dim=1)
837
+
838
+ def grid_to_points(self, value: Tensor) -> Tuple[Tensor, List[int]]:
839
+ """converting from grid based (image) to point based representation
840
+
841
+ Parameters
842
+ ----------
843
+ value : Meshgrid tensor
844
+
845
+ Returns
846
+ -------
847
+ Tuple
848
+ Tensor, meshgrid shape
849
+ """
850
+ y_shape = list(value.size())
851
+ output = torch.permute(value, (0, 2, 3, 4, 5, 1))
852
+ return output.reshape(-1, output.size(-1)), y_shape
853
+
854
+ def points_to_grid(self, value: Tensor, shape: List[int]) -> Tensor:
855
+ """converting from point based to grid based (image) representation
856
+
857
+ Parameters
858
+ ----------
859
+ value : Tensor
860
+ Tensor
861
+ shape : List[int]
862
+ meshgrid shape
863
+
864
+ Returns
865
+ -------
866
+ Tensor
867
+ Meshgrid tensor
868
+ """
869
+ output = value.reshape(
870
+ shape[0], shape[2], shape[3], shape[4], shape[5], value.size(-1)
871
+ )
872
+ return torch.permute(output, (0, 5, 1, 2, 3, 4))
873
+
874
+
875
+ # ===================================================================
876
+ # ===================================================================
877
+ # General TFNO Model
878
+ # ===================================================================
879
+ # ===================================================================
880
+
881
+
882
+ class TFNO(nn.Module):
883
+ """Tensor Factorized Fourier neural operator (FNO) model.
884
+
885
+ Note
886
+ ----
887
+ The TFNO architecture supports options for 1D, 2D, 3D and 4D fields which can
888
+ be controlled using the `dimension` parameter.
889
+
890
+ Parameters
891
+ ----------
892
+ in_channels : int
893
+ Number of input channels
894
+ out_channels : int
895
+ Number of output channels
896
+ decoder_layers : int, optional
897
+ Number of decoder layers, by default 1
898
+ decoder_layer_size : int, optional
899
+ Number of neurons in decoder layers, by default 32
900
+ decoder_activation_fn : str, optional
901
+ Activation function for decoder, by default "silu"
902
+ dimension : int
903
+ Model dimensionality (supports 1, 2, 3).
904
+ latent_channels : int, optional
905
+ Latent features size in spectral convolutions, by default 32
906
+ num_fno_layers : int, optional
907
+ Number of spectral convolutional layers, by default 4
908
+ num_fno_modes : Union[int, List[int]], optional
909
+ Number of Fourier modes kept in spectral convolutions, by default 16
910
+ padding : int, optional
911
+ Domain padding for spectral convolutions, by default 8
912
+ padding_type : str, optional
913
+ Type of padding for spectral convolutions, by default "constant"
914
+ activation_fn : str, optional
915
+ Activation function, by default "gelu"
916
+ coord_features : bool, optional
917
+ Use coordinate grid as additional feature map, by default True
918
+ rank : float, optional
919
+ Rank of the decomposition, by default 1.0
920
+ factorization : {'CP', 'TT', 'Tucker'}, optional
921
+ Tensor factorization to use to decompose the tensor, by default 'CP'
922
+ fixed_rank_modes : List[int], optional
923
+ A list of modes for which the initial value is not modified, by default None
924
+ The last mode cannot be fixed due to error computation.
925
+ decomposition_kwargs : dict, optional
926
+ Additional arguments to initialization of factorized tensors, by default dict()
927
+
928
+ Example
929
+ -------
930
+ >>> # define the 2d TFNO model
931
+ >>> model = physicsnemo.models.fno.TFNO(
932
+ ... in_channels=4,
933
+ ... out_channels=3,
934
+ ... decoder_layers=2,
935
+ ... decoder_layer_size=32,
936
+ ... dimension=2,
937
+ ... latent_channels=32,
938
+ ... num_fno_layers=2,
939
+ ... padding=0,
940
+ ... )
941
+ >>> input = torch.randn(32, 4, 32, 32) #(N, C, H, W)
942
+ >>> output = model(input)
943
+ >>> output.size()
944
+ torch.Size([32, 3, 32, 32])
945
+
946
+ Note
947
+ ----
948
+ Reference: Rosofsky, Shawn G. and Huerta, E. A. "Magnetohydrodynamics with
949
+ Physics Informed Neural Operators." arXiv preprint arXiv:2302.08332 (2023).
950
+ """
951
+
952
+ def __init__(
953
+ self,
954
+ in_channels: int,
955
+ out_channels: int,
956
+ decoder_layers: int = 1,
957
+ decoder_layer_size: int = 32,
958
+ decoder_activation_fn: str = "silu",
959
+ dimension: int = 2,
960
+ latent_channels: int = 32,
961
+ num_fno_layers: int = 4,
962
+ num_fno_modes: Union[int, List[int]] = 16,
963
+ padding: int = 8,
964
+ padding_type: str = "constant",
965
+ activation_fn: str = "gelu",
966
+ coord_features: bool = True,
967
+ rank: float = 1.0,
968
+ factorization: str = "cp",
969
+ fixed_rank_modes: List[int] = None,
970
+ decomposition_kwargs: dict = dict(),
971
+ ) -> None:
972
+ super().__init__()
973
+
974
+ self.num_fno_layers = num_fno_layers
975
+ self.num_fno_modes = num_fno_modes
976
+ self.padding = padding
977
+ self.padding_type = padding_type
978
+ self.activation_fn = layers.get_activation(activation_fn)
979
+ self.coord_features = coord_features
980
+ self.dimension = dimension
981
+
982
+ # TensorLy arguments
983
+ self.rank = rank
984
+ self.factorization = factorization
985
+ self.fixed_rank_modes = fixed_rank_modes
986
+ self.decomposition_kwargs = decomposition_kwargs
987
+
988
+ # decoder net
989
+ self.decoder_net = FullyConnected(
990
+ in_features=latent_channels,
991
+ layer_size=decoder_layer_size,
992
+ out_features=out_channels,
993
+ num_layers=decoder_layers,
994
+ activation_fn=decoder_activation_fn,
995
+ )
996
+
997
+ TFNOModel = self.getTFNOEncoder()
998
+
999
+ self.spec_encoder = TFNOModel(
1000
+ in_channels,
1001
+ num_fno_layers=self.num_fno_layers,
1002
+ fno_layer_size=latent_channels,
1003
+ num_fno_modes=self.num_fno_modes,
1004
+ padding=self.padding,
1005
+ padding_type=self.padding_type,
1006
+ activation_fn=self.activation_fn,
1007
+ coord_features=self.coord_features,
1008
+ rank=self.rank,
1009
+ factorization=self.factorization,
1010
+ fixed_rank_modes=self.fixed_rank_modes,
1011
+ decomposition_kwargs=self.decomposition_kwargs,
1012
+ )
1013
+
1014
+ def getTFNOEncoder(self):
1015
+ "Return correct TFNO ND Encoder"
1016
+ if self.dimension == 1:
1017
+ return TFNO1DEncoder
1018
+ elif self.dimension == 2:
1019
+ return TFNO2DEncoder
1020
+ elif self.dimension == 3:
1021
+ return TFNO3DEncoder
1022
+ elif self.dimension == 4:
1023
+ return TFNO4DEncoder
1024
+ else:
1025
+ raise NotImplementedError(
1026
+ "Invalid dimensionality. Only 1D, 2D, 3D and 4D FNO implemented"
1027
+ )
1028
+
1029
+ def forward(self, x: Tensor) -> Tensor:
1030
+ # Fourier encoder
1031
+ y_latent = self.spec_encoder(x)
1032
+
1033
+ # Reshape to pointwise inputs if not a conv FC model
1034
+ y_shape = y_latent.shape
1035
+ y_latent, y_shape = self.spec_encoder.grid_to_points(y_latent)
1036
+
1037
+ # Decoder
1038
+ y = self.decoder_net(y_latent)
1039
+
1040
+ # Convert back into grid
1041
+ y = self.spec_encoder.points_to_grid(y, y_shape)
1042
+
1043
+ return y
mhd/train_mhd_vec_pot_tfno.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import hydra
4
+ from omegaconf import OmegaConf
5
+ import torch
6
+ from omegaconf import DictConfig
7
+ from physicsnemo.distributed import DistributedManager
8
+ from physicsnemo.launch.logging import LaunchLogger, PythonLogger
9
+ from physicsnemo.launch.utils import load_checkpoint, save_checkpoint
10
+ from physicsnemo.sym.hydra import to_absolute_path
11
+ from torch.nn.parallel import DistributedDataParallel
12
+ from torch.optim import AdamW
13
+
14
+ from dataloaders import Dedalus2DDataset, MHDDataloaderVecPot
15
+ from losses import LossMHDVecPot_PhysicsNeMo
16
+ from tfno import TFNO
17
+ from utils.plot_utils import plot_predictions_mhd, plot_predictions_mhd_plotly
18
+
19
+ dtype = torch.float
20
+ torch.set_default_dtype(dtype)
21
+
22
+
23
+ @hydra.main(
24
+ version_base="1.3", config_path="config", config_name="train_mhd_vec_pot_tfno.yaml"
25
+ )
26
+ def main(cfg: DictConfig) -> None:
27
+ DistributedManager.initialize() # Only call this once in the entire script!
28
+ dist = DistributedManager() # call if required elsewhere
29
+ cfg = OmegaConf.to_container(cfg, resolve=True)
30
+
31
+ # initialize monitoring
32
+ log = PythonLogger(name="mhd_pino")
33
+ log.file_logging()
34
+
35
+ log_params = cfg["log_params"]
36
+
37
+ # Load config file parameters
38
+ model_params = cfg["model_params"]
39
+ dataset_params = cfg["dataset_params"]
40
+ train_loader_params = cfg["train_loader_params"]
41
+ val_loader_params = cfg["val_loader_params"]
42
+ loss_params = cfg["loss_params"]
43
+ optimizer_params = cfg["optimizer_params"]
44
+ train_params = cfg["train_params"]
45
+
46
+ load_ckpt = cfg["load_ckpt"]
47
+ output_dir = cfg["output_dir"]
48
+
49
+ output_dir = to_absolute_path(output_dir)
50
+ os.makedirs(output_dir, exist_ok=True)
51
+
52
+ data_dir = dataset_params["data_dir"]
53
+ ckpt_path = train_params["ckpt_path"]
54
+
55
+ # Construct dataloaders
56
+ dataset_train = Dedalus2DDataset(
57
+ dataset_params["data_dir"],
58
+ output_names=dataset_params["output_names"],
59
+ field_names=dataset_params["field_names"],
60
+ num_train=dataset_params["num_train"],
61
+ num_test=dataset_params["num_test"],
62
+ num=dataset_params["num"],
63
+ use_train=True,
64
+ )
65
+ dataset_val = Dedalus2DDataset(
66
+ data_dir,
67
+ output_names=dataset_params["output_names"],
68
+ field_names=dataset_params["field_names"],
69
+ num_train=dataset_params["num_train"],
70
+ num_test=dataset_params["num_test"],
71
+ num=dataset_params["num"],
72
+ use_train=False,
73
+ )
74
+
75
+ mhd_dataloader_train = MHDDataloaderVecPot(
76
+ dataset_train,
77
+ sub_x=dataset_params["sub_x"],
78
+ sub_t=dataset_params["sub_t"],
79
+ ind_x=dataset_params["ind_x"],
80
+ ind_t=dataset_params["ind_t"],
81
+ )
82
+ mhd_dataloader_val = MHDDataloaderVecPot(
83
+ dataset_val,
84
+ sub_x=dataset_params["sub_x"],
85
+ sub_t=dataset_params["sub_t"],
86
+ ind_x=dataset_params["ind_x"],
87
+ ind_t=dataset_params["ind_t"],
88
+ )
89
+
90
+ dataloader_train, sampler_train = mhd_dataloader_train.create_dataloader(
91
+ batch_size=train_loader_params["batch_size"],
92
+ shuffle=train_loader_params["shuffle"],
93
+ num_workers=train_loader_params["num_workers"],
94
+ pin_memory=train_loader_params["pin_memory"],
95
+ distributed=dist.distributed,
96
+ )
97
+ dataloader_val, sampler_val = mhd_dataloader_val.create_dataloader(
98
+ batch_size=val_loader_params["batch_size"],
99
+ shuffle=val_loader_params["shuffle"],
100
+ num_workers=val_loader_params["num_workers"],
101
+ pin_memory=val_loader_params["pin_memory"],
102
+ distributed=dist.distributed,
103
+ )
104
+
105
+ # define FNO model
106
+ model = TFNO(
107
+ in_channels=model_params["in_dim"],
108
+ out_channels=model_params["out_dim"],
109
+ decoder_layers=model_params["decoder_layers"],
110
+ decoder_layer_size=model_params["fc_dim"],
111
+ dimension=model_params["dimension"],
112
+ latent_channels=model_params["layers"],
113
+ num_fno_layers=model_params["num_fno_layers"],
114
+ num_fno_modes=model_params["modes"],
115
+ padding=[model_params["pad_z"], model_params["pad_y"], model_params["pad_x"]],
116
+ rank=model_params["rank"],
117
+ factorization=model_params["factorization"],
118
+ fixed_rank_modes=model_params["fixed_rank_modes"],
119
+ decomposition_kwargs=model_params["decomposition_kwargs"],
120
+ ).to(dist.device)
121
+ # Set up DistributedDataParallel if using more than a single process.
122
+ # The `distributed` property of DistributedManager can be used to
123
+ # check this.
124
+ if dist.distributed:
125
+ ddps = torch.cuda.Stream()
126
+ with torch.cuda.stream(ddps):
127
+ model = DistributedDataParallel(
128
+ model,
129
+ device_ids=[dist.local_rank], # Set the device_id to be
130
+ # the local rank of this process on
131
+ # this node
132
+ output_device=dist.device,
133
+ broadcast_buffers=dist.broadcast_buffers,
134
+ find_unused_parameters=dist.find_unused_parameters,
135
+ )
136
+ torch.cuda.current_stream().wait_stream(ddps)
137
+
138
+ # Construct optimizer and scheduler
139
+ optimizer = AdamW(
140
+ model.parameters(),
141
+ betas=optimizer_params["betas"],
142
+ lr=optimizer_params["lr"],
143
+ weight_decay=0.1,
144
+ )
145
+
146
+ scheduler = torch.optim.lr_scheduler.MultiStepLR(
147
+ optimizer,
148
+ milestones=optimizer_params["milestones"],
149
+ gamma=optimizer_params["gamma"],
150
+ )
151
+
152
+ # Construct Loss class
153
+ mhd_loss = LossMHDVecPot_PhysicsNeMo(**loss_params)
154
+
155
+ # Load model from checkpoint (if exists)
156
+ loaded_epoch = 0
157
+ if load_ckpt:
158
+ loaded_epoch = load_checkpoint(
159
+ ckpt_path, model, optimizer, scheduler, device=dist.device
160
+ )
161
+
162
+ # Training Loop
163
+ epochs = train_params["epochs"]
164
+ ckpt_freq = train_params["ckpt_freq"]
165
+ names = dataset_params["fields"]
166
+ input_norm = torch.tensor(model_params["input_norm"]).to(dist.device)
167
+ output_norm = torch.tensor(model_params["output_norm"]).to(dist.device)
168
+ for epoch in range(max(1, loaded_epoch + 1), epochs + 1):
169
+ with LaunchLogger(
170
+ "train",
171
+ epoch=epoch,
172
+ num_mini_batch=len(dataloader_train),
173
+ epoch_alert_freq=1,
174
+ ) as log:
175
+ if dist.distributed:
176
+ sampler_train.set_epoch(epoch)
177
+
178
+ # Train Loop
179
+ model.train()
180
+
181
+ for i, (inputs, outputs) in enumerate(dataloader_train):
182
+ inputs = inputs.type(torch.FloatTensor).to(dist.device)
183
+ outputs = outputs.type(torch.FloatTensor).to(dist.device)
184
+ # Zero Gradients
185
+ optimizer.zero_grad()
186
+ # Compute Predictions
187
+ pred = (
188
+ model((inputs / input_norm).permute(0, 4, 1, 2, 3)).permute(
189
+ 0, 2, 3, 4, 1
190
+ )
191
+ * output_norm
192
+ )
193
+ # Compute Loss
194
+ loss, loss_dict = mhd_loss(pred, outputs, inputs, return_loss_dict=True)
195
+ # Compute Gradients for Back Propagation
196
+ loss.backward()
197
+ # Update Weights
198
+ optimizer.step()
199
+
200
+ log.log_minibatch(loss_dict)
201
+
202
+ log.log_epoch({"Learning Rate": optimizer.param_groups[0]["lr"]})
203
+ scheduler.step()
204
+
205
+ with LaunchLogger("valid", epoch=epoch) as log:
206
+ # Val loop
207
+ model.eval()
208
+ plot_count = 0
209
+ with torch.no_grad():
210
+ for i, (inputs, outputs) in enumerate(dataloader_val):
211
+ inputs = inputs.type(dtype).to(dist.device)
212
+ outputs = outputs.type(dtype).to(dist.device)
213
+
214
+ # Compute Predictions
215
+ pred = (
216
+ model((inputs / input_norm).permute(0, 4, 1, 2, 3)).permute(
217
+ 0, 2, 3, 4, 1
218
+ )
219
+ * output_norm
220
+ )
221
+ # Compute Loss
222
+ loss, loss_dict = mhd_loss(
223
+ pred, outputs, inputs, return_loss_dict=True
224
+ )
225
+
226
+ log.log_minibatch(loss_dict)
227
+
228
+ # Get prediction plots to log
229
+ # Do for number of batches specified in the config file
230
+ if (i < log_params["log_num_plots"]) and (
231
+ epoch % log_params["log_plot_freq"] == 0
232
+ ):
233
+ # Add all predictions in batch
234
+ for j, _ in enumerate(pred):
235
+ # Make plots for each field
236
+ for index, name in enumerate(names):
237
+ # Generate figure
238
+ _ = plot_predictions_mhd_plotly(
239
+ pred[j].cpu(),
240
+ outputs[j].cpu(),
241
+ inputs[j].cpu(),
242
+ index=index,
243
+ name=name,
244
+ )
245
+ plot_count += 1
246
+
247
+ # Get prediction plots and save images locally
248
+ if (i < 2) and (epoch % log_params["log_plot_freq"] == 0):
249
+ # Add all predictions in batch
250
+ for j, _ in enumerate(pred):
251
+ # Generate figure
252
+ plot_predictions_mhd(
253
+ pred[j].cpu(),
254
+ outputs[j].cpu(),
255
+ inputs[j].cpu(),
256
+ names=names,
257
+ save_path=os.path.join(
258
+ output_dir,
259
+ "MHD_physicsnemo" + "_" + str(dist.rank),
260
+ ),
261
+ save_suffix=i,
262
+ )
263
+
264
+ if epoch % ckpt_freq == 0 and dist.rank == 0:
265
+ save_checkpoint(ckpt_path, model, optimizer, scheduler, epoch=epoch)
266
+
267
+
268
+ if __name__ == "__main__":
269
+ main()
mhd/utils/plot_utils.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-FileCopyrightText: All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import os
18
+ import math
19
+ import numpy as np
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ from tqdm import tqdm
24
+ import traceback
25
+
26
+ import matplotlib.pyplot as plt
27
+ import matplotlib
28
+ from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
29
+ import imageio
30
+ import plotly
31
+ import plotly.express as px
32
+ from plotly.subplots import make_subplots
33
+ import plotly.graph_objects as go
34
+ from IPython.display import HTML, display
35
+
36
+
37
+ def plot_spectra_mhd(
38
+ k,
39
+ pred_spectra_kin,
40
+ true_spectra_kin,
41
+ pred_spectra_mag,
42
+ true_spectra_mag,
43
+ index_t=-1,
44
+ name="Re100",
45
+ save_path=None,
46
+ save_suffix=None,
47
+ font_size=None,
48
+ sci_limits=None,
49
+ style_kin_pred="b-",
50
+ style_kin_true="k-",
51
+ style_mag_pred="b--",
52
+ style_mag_true="k--",
53
+ xmin=0,
54
+ xmax=200,
55
+ ymin=1e-10,
56
+ ymax=None,
57
+ ):
58
+ "Plots spectra of predicted and true outputs"
59
+ if font_size is not None:
60
+ plt.rcParams.update({"font.size": font_size})
61
+
62
+ if sci_limits is not None:
63
+ plt.rcParams.update({"axes.formatter.limits": sci_limits})
64
+
65
+ E_kin_pred = pred_spectra_kin[index_t]
66
+ E_mag_pred = pred_spectra_mag[index_t]
67
+
68
+ E_kin_true = true_spectra_kin[index_t]
69
+ E_mag_true = true_spectra_mag[index_t]
70
+
71
+ fig = plt.figure(figsize=(6, 5))
72
+
73
+ plt.loglog(k, E_kin_pred, style_kin_pred, label="$E_{kin}$ Pred")
74
+ plt.loglog(k, E_kin_true, style_kin_true, label="$E_{kin}$ True")
75
+ plt.loglog(k, E_mag_pred, style_mag_pred, label="$E_{mag}$ Pred")
76
+ plt.loglog(k, E_mag_true, style_mag_true, label="$E_{mag}$ True")
77
+
78
+ plt.xlabel("k")
79
+ plt.ylabel("E(k)")
80
+ plt.axis([xmin, xmax, ymin, ymax])
81
+
82
+ plt.title(f"Spectra ${name}$")
83
+ plt.legend(loc="upper right")
84
+
85
+ if save_path is not None:
86
+ if save_suffix is not None:
87
+ figure_path = f"{save_path}_spectra_{save_suffix}.png"
88
+ else:
89
+ figure_path = f"{save_path}_spectra.png"
90
+ plt.savefig(figure_path, bbox_inches="tight")
91
+
92
+ return fig
93
+
94
+
95
+ def plot_predictions_mhd(
96
+ pred,
97
+ true,
98
+ inputs,
99
+ index_t=-1,
100
+ names=[],
101
+ save_path=None,
102
+ save_suffix=None,
103
+ font_size=None,
104
+ sci_limits=None,
105
+ shading="auto",
106
+ cmap="jet",
107
+ ):
108
+ "Plots images of predictions and absolute error"
109
+ if font_size is not None:
110
+ plt.rcParams.update({"font.size": font_size})
111
+
112
+ if sci_limits is not None:
113
+ plt.rcParams.update({"axes.formatter.limits": sci_limits})
114
+ # Plot
115
+ fig = plt.figure(figsize=(24, 5 * len(names)))
116
+
117
+ # Make plots for each field
118
+ for index, name in enumerate(names):
119
+ Nt, Nx, Ny, Nfields = pred.shape
120
+ u_pred = pred[index_t, ..., index]
121
+ u_true = true[index_t, ..., index]
122
+ u_err = u_pred - u_true
123
+
124
+ initial_data = inputs[0, ..., 3:]
125
+ u0 = initial_data[..., index]
126
+
127
+ x = inputs[0, :, 0, 1]
128
+ y = inputs[0, 0, :, 2]
129
+ X, Y = torch.meshgrid(x, y, indexing="ij")
130
+ t = inputs[index_t, 0, 0, 0]
131
+
132
+ plt.subplot(len(names), 4, index * 4 + 1)
133
+ plt.pcolormesh(X, Y, u0, cmap=cmap, shading=shading)
134
+ plt.colorbar()
135
+ plt.title(f"Intial Condition ${name}_0(x,y)$")
136
+ plt.tight_layout()
137
+ plt.axis("square")
138
+ plt.axis("off")
139
+
140
+ plt.subplot(len(names), 4, index * 4 + 2)
141
+ plt.pcolormesh(X, Y, u_true, cmap=cmap, shading=shading)
142
+ plt.colorbar()
143
+ plt.title(f"Exact ${name}(x,y,t={t:.2f})$")
144
+ plt.tight_layout()
145
+ plt.axis("square")
146
+ plt.axis("off")
147
+
148
+ plt.subplot(len(names), 4, index * 4 + 3)
149
+ plt.pcolormesh(X, Y, u_pred, cmap=cmap, shading=shading)
150
+ plt.colorbar()
151
+ plt.title(f"Predict ${name}(x,y,t={t:.2f})$")
152
+ plt.axis("square")
153
+ plt.tight_layout()
154
+ plt.axis("off")
155
+
156
+ plt.subplot(len(names), 4, index * 4 + 4)
157
+ plt.pcolormesh(X, Y, u_pred - u_true, cmap=cmap, shading=shading)
158
+ plt.colorbar()
159
+ plt.title(f"Absolute Error ${name}(x,y,t={t:.2f})$")
160
+ plt.tight_layout()
161
+ plt.axis("square")
162
+ plt.axis("off")
163
+
164
+ if save_path is not None:
165
+ if save_suffix is not None:
166
+ figure_path = f"{save_path}_{save_suffix}.png"
167
+ else:
168
+ figure_path = f"{save_path}.png"
169
+ plt.savefig(figure_path, bbox_inches="tight")
170
+ # plt.show()
171
+ # return fig
172
+ plt.close()
173
+
174
+
175
+ def generate_movie_2D(
176
+ preds_y,
177
+ test_y,
178
+ test_x,
179
+ key=0,
180
+ plot_title="",
181
+ field=0,
182
+ val_cbar_index=-1,
183
+ err_cbar_index=-1,
184
+ val_clim=None,
185
+ err_clim=None,
186
+ font_size=None,
187
+ movie_dir="",
188
+ movie_name="movie.gif",
189
+ frame_basename="movie",
190
+ frame_ext="jpg",
191
+ cmap="jet",
192
+ shading="gouraud",
193
+ remove_frames=True,
194
+ ):
195
+ "Generates a movie of the exact, predicted, and absolute error fields"
196
+ frame_files = []
197
+
198
+ if movie_dir:
199
+ os.makedirs(movie_dir, exist_ok=True)
200
+
201
+ if font_size is not None:
202
+ plt.rcParams.update({"font.size": font_size})
203
+
204
+ pred = preds_y[key][..., field]
205
+ true = test_y[key][..., field]
206
+ inputs = test_x[key]
207
+ error = pred - true
208
+
209
+ Nt, Nx, Ny = pred.shape
210
+
211
+ t = inputs[:, 0, 0, 0]
212
+ x = inputs[0, :, 0, 1]
213
+ y = inputs[0, 0, :, 2]
214
+ X, Y = torch.meshgrid(x, y, indexing="ij")
215
+
216
+ fig, axs = plt.subplots(1, 3, figsize=(18, 5))
217
+ ax1 = axs[0]
218
+ ax2 = axs[1]
219
+ ax3 = axs[2]
220
+
221
+ pcm1 = ax1.pcolormesh(
222
+ X, Y, true[val_cbar_index], cmap=cmap, label="true", shading=shading
223
+ )
224
+ pcm2 = ax2.pcolormesh(
225
+ X, Y, pred[val_cbar_index], cmap=cmap, label="pred", shading=shading
226
+ )
227
+ pcm3 = ax3.pcolormesh(
228
+ X, Y, error[err_cbar_index], cmap=cmap, label="error", shading=shading
229
+ )
230
+
231
+ if val_clim is None:
232
+ val_clim = pcm1.get_clim()
233
+ if err_clim is None:
234
+ err_clim = pcm3.get_clim()
235
+
236
+ pcm1.set_clim(val_clim)
237
+ plt.colorbar(pcm1, ax=ax1)
238
+ ax1.axis("square")
239
+ ax1.set_axis_off()
240
+
241
+ pcm2.set_clim(val_clim)
242
+ plt.colorbar(pcm2, ax=ax2)
243
+ ax2.axis("square")
244
+ ax2.set_axis_off()
245
+
246
+ pcm3.set_clim(err_clim)
247
+ plt.colorbar(pcm3, ax=ax3)
248
+ ax3.axis("square")
249
+ ax3.set_axis_off()
250
+
251
+ plt.tight_layout()
252
+
253
+ for i in range(Nt):
254
+ # Exact
255
+ ax1.clear()
256
+ pcm1 = ax1.pcolormesh(X, Y, true[i], cmap=cmap, label="true", shading=shading)
257
+ pcm1.set_clim(val_clim)
258
+ ax1.set_title(f"Exact {plot_title}: $t={t[i]:.2f}$")
259
+ ax1.axis("square")
260
+ ax1.set_axis_off()
261
+
262
+ # Predictions
263
+ ax2.clear()
264
+ pcm2 = ax2.pcolormesh(X, Y, pred[i], cmap=cmap, label="pred", shading=shading)
265
+ pcm2.set_clim(val_clim)
266
+ ax2.set_title(f"Predict {plot_title}: $t={t[i]:.2f}$")
267
+ ax2.axis("square")
268
+ ax2.set_axis_off()
269
+
270
+ # Error
271
+ ax3.clear()
272
+ pcm3 = ax3.pcolormesh(X, Y, error[i], cmap=cmap, label="error", shading=shading)
273
+ pcm3.set_clim(err_clim)
274
+ ax3.set_title(f"Error {plot_title}: $t={t[i]:.2f}$")
275
+ ax3.axis("square")
276
+ ax3.set_axis_off()
277
+
278
+ # plt.tight_layout()
279
+ fig.canvas.draw()
280
+
281
+ if movie_dir:
282
+ frame_path = os.path.join(movie_dir, f"{frame_basename}-{i:03}.{frame_ext}")
283
+ frame_files.append(frame_path)
284
+ plt.savefig(frame_path, bbox_inches="tight")
285
+
286
+ if movie_dir:
287
+ movie_path = os.path.join(movie_dir, movie_name)
288
+ with imageio.get_writer(movie_path, mode="I") as writer:
289
+ for frame in frame_files:
290
+ image = imageio.imread(frame)
291
+ writer.append_data(image)
292
+
293
+ if movie_dir and remove_frames:
294
+ for frame in frame_files:
295
+ try:
296
+ os.remove(frame)
297
+ except:
298
+ pass
299
+
300
+
301
+ def plot_predictions_mhd_plotly(
302
+ pred,
303
+ true,
304
+ inputs,
305
+ index=0,
306
+ index_t=-1,
307
+ name="u",
308
+ save_path=None,
309
+ font_size=None,
310
+ shading="auto",
311
+ cmap="jet",
312
+ ):
313
+ "Plots images of predictions and absolute error to be saved to wandb"
314
+ Nt, Nx, Ny, Nfields = pred.shape
315
+ u_pred = pred[index_t, ..., index]
316
+ u_true = true[index_t, ..., index]
317
+
318
+ ic = inputs[0, ..., 3:]
319
+ u_ic = ic[..., index]
320
+ u_err = u_pred - u_true
321
+
322
+ x = inputs[0, :, 0, 1]
323
+ y = inputs[0, 0, :, 2]
324
+ X, Y = torch.meshgrid(x, y, indexing="ij")
325
+ t = inputs[index_t, 0, 0, 0]
326
+
327
+ zmin = u_true.min().item()
328
+ zmax = u_true.max().item()
329
+ labels = {"color": name}
330
+
331
+ # Initial Conditions
332
+ title_ic = f"{name}0"
333
+ fig_ic = px.imshow(
334
+ u_ic,
335
+ binary_string=False,
336
+ color_continuous_scale=cmap,
337
+ labels=labels,
338
+ title=title_ic,
339
+ )
340
+ fig_ic.update_xaxes(showticklabels=False)
341
+ fig_ic.update_yaxes(showticklabels=False)
342
+
343
+ # Predictions
344
+ title_pred = f"Predict {name}: t={t:.2f}"
345
+ fig_pred = px.imshow(
346
+ u_pred,
347
+ binary_string=False,
348
+ color_continuous_scale=cmap,
349
+ labels=labels,
350
+ title=title_pred,
351
+ )
352
+ fig_pred.update_xaxes(showticklabels=False)
353
+ fig_pred.update_yaxes(showticklabels=False)
354
+
355
+ # Ground Truth
356
+ title_true = f"Exact {name}: t={t:.2f}"
357
+ fig_true = px.imshow(
358
+ u_true,
359
+ binary_string=False,
360
+ color_continuous_scale=cmap,
361
+ labels=labels,
362
+ title=title_true,
363
+ )
364
+ fig_true.update_xaxes(showticklabels=False)
365
+ fig_true.update_yaxes(showticklabels=False)
366
+
367
+ # Ground Truth
368
+ title_err = f"Error {name}: t={t:.2f}"
369
+ fig_err = px.imshow(
370
+ u_err,
371
+ binary_string=False,
372
+ color_continuous_scale=cmap,
373
+ labels=labels,
374
+ title=title_err,
375
+ )
376
+ fig_err.update_xaxes(showticklabels=False)
377
+ fig_err.update_yaxes(showticklabels=False)
378
+
379
+ return fig_ic, fig_pred, fig_true, fig_err
on_startup.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Write some commands here that will run on root user before startup.
3
+ # For example, to clone transformers and install it in dev mode:
4
+ # git clone https://github.com/huggingface/transformers.git
5
+ # cd transformers && pip install -e ".[dev]"
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ Cython==3.1.0
2
+ hydra-core==1.3.2
3
+ imageio==2.37.0
4
+ ipython==9.2.0
5
+ mlflow>=2.1.1
6
+ plotly==6.0.1
7
+ tensorly==0.9.0
8
+ tensorly-torch==0.5.0
9
+ termcolor
10
+ wandb
start_server.sh ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ JUPYTER_TOKEN="${JUPYTER_TOKEN:=huggingface}"
3
+
4
+ NOTEBOOK_DIR="/mhd-demo"
5
+
6
+ jupyter labextension disable "@jupyterlab/apputils-extension:announcements"
7
+
8
+ jupyter-lab \
9
+ --ip 0.0.0.0 \
10
+ --port 7860 \
11
+ --no-browser \
12
+ --allow-root \
13
+ --ServerApp.token="$JUPYTER_TOKEN" \
14
+ --ServerApp.tornado_settings="{'headers': {'Content-Security-Policy': 'frame-ancestors *'}}" \
15
+ --ServerApp.cookie_options="{'SameSite': 'None', 'Secure': True}" \
16
+ --ServerApp.disable_check_xsrf=True \
17
+ --LabApp.news_url=None \
18
+ --LabApp.check_for_updates_class="jupyterlab.NeverCheckForUpdate" \
19
+ --notebook-dir=$NOTEBOOK_DIR