xiazeyu commited on
Commit
083cb0a
0 Parent(s):

1.0.0b1@gradio

Browse files
Files changed (5) hide show
  1. .gitattributes +35 -0
  2. .gitignore +240 -0
  3. README.md +75 -0
  4. app.py +441 -0
  5. requirements.txt +1 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
157
+ # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
158
+
159
+ # User-specific stuff
160
+ .idea/**/workspace.xml
161
+ .idea/**/tasks.xml
162
+ .idea/**/usage.statistics.xml
163
+ .idea/**/dictionaries
164
+ .idea/**/shelf
165
+
166
+ # AWS User-specific
167
+ .idea/**/aws.xml
168
+
169
+ # Generated files
170
+ .idea/**/contentModel.xml
171
+
172
+ # Sensitive or high-churn files
173
+ .idea/**/dataSources/
174
+ .idea/**/dataSources.ids
175
+ .idea/**/dataSources.local.xml
176
+ .idea/**/sqlDataSources.xml
177
+ .idea/**/dynamic.xml
178
+ .idea/**/uiDesigner.xml
179
+ .idea/**/dbnavigator.xml
180
+
181
+ # Gradle
182
+ .idea/**/gradle.xml
183
+ .idea/**/libraries
184
+
185
+ # Gradle and Maven with auto-import
186
+ # When using Gradle or Maven with auto-import, you should exclude module files,
187
+ # since they will be recreated, and may cause churn. Uncomment if using
188
+ # auto-import.
189
+ # .idea/artifacts
190
+ # .idea/compiler.xml
191
+ # .idea/jarRepositories.xml
192
+ # .idea/modules.xml
193
+ # .idea/*.iml
194
+ # .idea/modules
195
+ # *.iml
196
+ # *.ipr
197
+
198
+ # CMake
199
+ cmake-build-*/
200
+
201
+ # Mongo Explorer plugin
202
+ .idea/**/mongoSettings.xml
203
+
204
+ # File-based project format
205
+ *.iws
206
+
207
+ # IntelliJ
208
+ out/
209
+
210
+ # mpeltonen/sbt-idea plugin
211
+ .idea_modules/
212
+
213
+ # JIRA plugin
214
+ atlassian-ide-plugin.xml
215
+
216
+ # Cursive Clojure plugin
217
+ .idea/replstate.xml
218
+
219
+ # SonarLint plugin
220
+ .idea/sonarlint/
221
+
222
+ # Crashlytics plugin (for Android Studio and IntelliJ)
223
+ com_crashlytics_export_strings.xml
224
+ crashlytics.properties
225
+ crashlytics-build.properties
226
+ fabric.properties
227
+
228
+ # Editor-based Rest Client
229
+ .idea/httpRequests
230
+
231
+ # Android studio 3.1+ serialized cache file
232
+ .idea/caches/build_file_checksums.ser
233
+
234
+ .idea/
235
+
236
+ # Experimental Files
237
+ runs/
238
+
239
+ # MacOS Files
240
+ .DS_Store
README.md ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: WildTorch
3
+ emoji: 🔥
4
+ colorFrom: red
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ sdk_version: 4.26.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+ # WildTorch
13
+
14
+ [![Hatch project](https://img.shields.io/badge/%F0%9F%A5%9A-Hatch-4051b5.svg)](https://github.com/pypa/hatch)
15
+ [![Read the Docs](https://readthedocs.org/projects/wildtorch/badge/)](https://wildtorch.readthedocs.io/)
16
+ [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.10968647.svg)](https://doi.org/10.5281/zenodo.10968647)
17
+
18
+ WildTorch: Leveraging GPU Acceleration for High-Fidelity, Stochastic Wildfire Simulations with PyTorch
19
+
20
+ GitHub: [https://github.com/xiazeyu/WildTorch](https://github.com/xiazeyu/WildTorch)
21
+
22
+ ### Installation
23
+
24
+ Install with minimal dependencies:
25
+
26
+ ```shell
27
+ pip install wildtorch
28
+ ```
29
+
30
+ Install with full dependencies (includes visualization and logging):
31
+
32
+ ```shell
33
+ pip install 'wildtorch[full]'
34
+ ```
35
+
36
+ ### Quick Start
37
+
38
+ ```shell
39
+ pip install 'wildtorch[full]'
40
+ ```
41
+
42
+ ```python
43
+ import wildtorch as wt
44
+
45
+ wildfire_map = wt.dataset.generate_empty_dataset()
46
+
47
+ simulator = wt.WildTorchSimulator(
48
+ wildfire_map=wildfire_map,
49
+ simulator_constants=wt.SimulatorConstants(p_continue_burn=0.7),
50
+ initial_ignition=wt.utils.create_ignition(shape=wildfire_map[0].shape),
51
+ )
52
+
53
+ logger = wt.logger.Logger()
54
+
55
+ for i in range(200):
56
+ simulator.step()
57
+ logger.log_stats(
58
+ step=i,
59
+ num_cells_on_fire=wt.metrics.cell_on_fire(simulator.fire_state).item(),
60
+ num_cells_burned_out=wt.metrics.cell_burned_out(simulator.fire_state).item(),
61
+ )
62
+ logger.snapshot_simulation(simulator)
63
+
64
+ logger.save_logs()
65
+ logger.save_snapshots()
66
+
67
+ ```
68
+
69
+ ### Demo
70
+
71
+ See Our Live Demo at [Hugging Face Space](https://xiazeyu-wildtorch.hf.space/).
72
+
73
+ ### API Documents
74
+
75
+ See at Our [Read the Docs](https://wildtorch.readthedocs.io/).
app.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import cast
2
+ import uuid
3
+ import os
4
+
5
+ import gradio as gr
6
+ import numpy as np
7
+ import pandas as pd
8
+ import torch
9
+
10
+ import wildtorch as wt
11
+
12
+ USE_OFFLINE_DATA = False
13
+ ENABLE_DOWNLOAD_SNAPSHOTS = False
14
+
15
+ if USE_OFFLINE_DATA:
16
+ wildfire_sim_maps = torch.load('wildfire_sim_maps.pt')
17
+ else:
18
+ wildfire_sim_maps = wt.dataset.load_wildfire_sim_maps()
19
+ # torch.save(wildfire_sim_maps, 'wildfire_sim_maps.pt')
20
+
21
+ DEFAULT_SHAPE = (512, 512)
22
+ DEFAULT_STATE = {
23
+ 'ds': {
24
+ 'name': None,
25
+ 'shape': None,
26
+ 'data': None,
27
+ },
28
+ 'constants': {
29
+ 'p_h': 0.58,
30
+ 'c_1': 0.045,
31
+ 'c_2': 0.131,
32
+ 'a': 0.078,
33
+ 'theta_w': 0,
34
+ 'v': 10,
35
+ 'p_firebreak': 0.9,
36
+ 'p_continue_burn': 0.6,
37
+ 'device': torch.device('cpu'),
38
+ 'dtype': torch.float32,
39
+ },
40
+ 'ignition': None,
41
+ 'out_video_path': None,
42
+ 'snapshots_path': None,
43
+ 'checkpoint': None,
44
+ 'logger': None,
45
+ }
46
+
47
+ with (gr.Blocks() as demo):
48
+ def remove_state_files(in_state):
49
+ if in_state['out_video_path'] is not None:
50
+ os.remove(in_state['out_video_path'])
51
+ if in_state['snapshots_path'] is not None:
52
+ os.remove(in_state['snapshots_path'])
53
+
54
+
55
+ state_var = gr.State(DEFAULT_STATE, delete_callback=remove_state_files)
56
+ with gr.Tabs(selected='tab_1') as tabs:
57
+ with gr.Tab("1. Datasets", interactive=True, id='tab_1') as tab_1:
58
+ sel_dataset = gr.Dropdown(cast(list, wildfire_sim_maps['name']) + ['empty'], label='Dataset')
59
+ with gr.Row() as shape_row:
60
+ sel_shape_h = gr.Number(label="Map Height", visible=False)
61
+ sel_shape_w = gr.Number(label="Map Width", visible=False)
62
+
63
+ with gr.Row() as preview_row:
64
+ canopy_img = gr.Image(label="canopy")
65
+ density_img = gr.Image(label="density")
66
+ slope_img = gr.Image(label="slope")
67
+
68
+ tab_1_confirm_btn = gr.Button("Confirm", interactive=True)
69
+
70
+
71
+ @tab_1_confirm_btn.click(inputs=[state_var], outputs=[state_var, tabs])
72
+ def jump_to_tab_2(in_state):
73
+ return in_state, gr.Tabs(selected='tab_2')
74
+
75
+ with gr.Tab("2. Simulation Constants and Initial Ignition", interactive=False, id='tab_2') as tab_2:
76
+ with gr.Row():
77
+ sel_p_h = gr.Slider(label="p_h",
78
+ info="The probability that a burnable cell adjacent to a burning cell will "
79
+ "catch fire at the next time step under normal conditions",
80
+ value=DEFAULT_STATE['constants']['p_h'], minimum=0, maximum=1, step=0.01,
81
+ interactive=True)
82
+ sel_p_continue_burn = gr.Slider(label="p_continue_burn",
83
+ info="The probability that a burning cell will continue to burn "
84
+ "at the next time step",
85
+ value=DEFAULT_STATE['constants']['p_continue_burn'], minimum=0,
86
+ maximum=1,
87
+ step=0.01,
88
+ interactive=True)
89
+ with gr.Row():
90
+ sel_a = gr.Slider(label="a",
91
+ info="The coefficient of ground elevation",
92
+ value=DEFAULT_STATE['constants']['a'], minimum=0, maximum=1, step=0.001,
93
+ interactive=True)
94
+ sel_p_firebreak = gr.Slider(label="p_firebreak",
95
+ info="The probability that a burnable cell will not catch fire even "
96
+ "if it is adjacent to a burning cell",
97
+ value=DEFAULT_STATE['constants']['p_firebreak'], minimum=0, maximum=1,
98
+ step=0.01, interactive=True)
99
+ with gr.Row():
100
+ sel_c_1 = gr.Slider(label="c_1",
101
+ info="The coefficient of wind velocity",
102
+ value=DEFAULT_STATE['constants']['c_1'], minimum=0, maximum=1, step=0.001,
103
+ interactive=True)
104
+ sel_c_2 = gr.Slider(label="c_2",
105
+ info="The coefficient of wind direction",
106
+ value=DEFAULT_STATE['constants']['c_2'], minimum=0, maximum=1, step=0.001,
107
+ interactive=True)
108
+ with gr.Row():
109
+ sel_theta_w = gr.Slider(label="theta_w",
110
+ info="The direction of the wind in degrees, measured clockwise from north",
111
+ value=DEFAULT_STATE['constants']['theta_w'], minimum=0, maximum=360, step=1,
112
+ interactive=True)
113
+ sel_v = gr.Slider(label="v",
114
+ info="The wind velocity, unit in m/s",
115
+ value=DEFAULT_STATE['constants']['v'], minimum=0, maximum=60, step=1,
116
+ interactive=True)
117
+ with gr.Row():
118
+ sel_device = gr.Dropdown(label="device", choices=['cpu', 'cuda', 'mps'],
119
+ info="The device to use",
120
+ value='cpu', allow_custom_value=True, interactive=True)
121
+ sel_dtype = gr.Dropdown(label="data type", choices=['float16', 'float32', 'float64'],
122
+ info="The data type to use",
123
+ value='float32', interactive=True)
124
+
125
+
126
+ @gr.on(triggers=[sel_p_h.input, sel_c_1.input, sel_c_2.input, sel_a.input,
127
+ sel_theta_w.input, sel_v.input, sel_p_firebreak.input,
128
+ sel_p_continue_burn.input, sel_device.input, sel_dtype.input],
129
+ inputs=[state_var, sel_p_h, sel_c_1, sel_c_2, sel_a, sel_theta_w, sel_v,
130
+ sel_p_firebreak, sel_p_continue_burn, sel_device, sel_dtype],
131
+ outputs=[state_var])
132
+ def update_constants_state(in_state, in_p_h, in_c_1, in_c_2, in_a, in_theta_w, in_v, in_p_firebreak,
133
+ in_p_continue_burn,
134
+ in_device, in_dtype):
135
+ in_state['constants']['p_h'] = in_p_h
136
+ in_state['constants']['c_1'] = in_c_1
137
+ in_state['constants']['c_2'] = in_c_2
138
+ in_state['constants']['a'] = in_a
139
+ in_state['constants']['theta_w'] = in_theta_w
140
+ in_state['constants']['v'] = in_v
141
+ in_state['constants']['p_firebreak'] = in_p_firebreak
142
+ in_state['constants']['p_continue_burn'] = in_p_continue_burn
143
+ in_state['constants']['device'] = torch.device(in_device)
144
+ in_state['constants']['dtype'] = {
145
+ 'float16': torch.float16,
146
+ 'float32': torch.float32,
147
+ 'float64': torch.float64,
148
+ }[in_dtype]
149
+ return in_state
150
+
151
+
152
+ sel_ignition_mode = gr.Dropdown(label="Initial Ignition", choices=['random', 'center', 'custom'],
153
+ interactive=True)
154
+
155
+ gr.Markdown(
156
+ 'to use custom ignition, please use the crop to fix the size, and then draw on the image. Please '
157
+ 'click on the green button once done. Drawing on the black will be good choices.')
158
+ with gr.Row():
159
+ custom_ignition_paint = gr.Paint(label="custom ignition", image_mode='L', interactive=True,
160
+ brush=gr.Brush(default_size=3, color_mode='fixed'))
161
+ ignition_img_over_map = gr.Image(label="ignition over map")
162
+
163
+
164
+ @gr.on(triggers=[sel_shape_h.input, sel_shape_w.input],
165
+ inputs=[state_var, sel_shape_h, sel_shape_w],
166
+ outputs=[state_var, canopy_img, density_img, slope_img, tab_2, custom_ignition_paint])
167
+ def update_preview_row(in_state, in_h, in_w):
168
+ shape = in_h, in_w
169
+ data = wt.dataset.generate_empty_dataset(shape)
170
+ in_state['ds']['shape'] = shape
171
+ in_state['ds']['data'] = data
172
+ in_state['ignition'] = None
173
+ return in_state, gr.Image(wt.utils.colorize_array(np.array(data[0]))), gr.Image(
174
+ wt.utils.colorize_array(np.array(data[1]))), gr.Image(
175
+ wt.utils.colorize_array(np.array(data[2]))), gr.Tab(interactive=True), gr.ImageEditor(
176
+ crop_size=(shape[1], shape[0]))
177
+
178
+
179
+ @sel_dataset.change(inputs=[state_var, sel_dataset],
180
+ outputs=[state_var, sel_shape_h, sel_shape_w, tab_2, custom_ignition_paint,
181
+ canopy_img, density_img, slope_img])
182
+ def update_shape_row(in_state, in_dataset):
183
+ if in_dataset == 'empty':
184
+ shape = DEFAULT_SHAPE
185
+ data = wt.dataset.generate_empty_dataset(shape)
186
+ editable = True
187
+ else:
188
+ idx_dict = {item['name']: index for index, item in enumerate(wildfire_sim_maps)}
189
+ shape = tuple(cast(torch.Tensor, wildfire_sim_maps[idx_dict[in_dataset]]['shape']).tolist())
190
+ data = wt.dataset.transform_wildfire_sim_map(wildfire_sim_maps[idx_dict[in_dataset]])
191
+ editable = False
192
+ in_state['ds']['name'] = in_dataset
193
+ in_state['ds']['shape'] = shape
194
+ in_state['ds']['data'] = data
195
+ return in_state, gr.Number(value=shape[0], interactive=editable, visible=True), gr.Number(
196
+ value=shape[1],
197
+ interactive=editable,
198
+ visible=True), gr.Tab(interactive=True), gr.ImageEditor(interactive=True,
199
+ crop_size=(shape[1], shape[0])), gr.Image(
200
+ wt.utils.colorize_array(np.array(data[0]))), gr.Image(
201
+ wt.utils.colorize_array(np.array(data[1]))), gr.Image(
202
+ wt.utils.colorize_array(np.array(data[2])))
203
+
204
+
205
+ tab_2_confirm_btn = gr.Button("Confirm", interactive=False)
206
+
207
+
208
+ @sel_ignition_mode.input(
209
+ inputs=[state_var, sel_ignition_mode, custom_ignition_paint],
210
+ outputs=[state_var, ignition_img_over_map, tab_2_confirm_btn])
211
+ def update_ignition_img(in_state, in_mode, in_custom):
212
+ ignition = torch.zeros(in_state['ds']['shape'], dtype=torch.bool)
213
+
214
+ if in_mode == 'random':
215
+ ignition = wt.utils.create_ignition(shape=in_state['ds']['shape'], mode='random-single')
216
+ elif in_mode == 'center':
217
+ ignition = wt.utils.create_ignition(shape=in_state['ds']['shape'], mode='center')
218
+ elif in_mode == 'custom':
219
+ if in_custom['composite'] is not None:
220
+ ignition_ndarray = in_custom['composite'] != 0
221
+ ignition = torch.tensor(ignition_ndarray)
222
+ else:
223
+ return in_state, gr.Image(
224
+ wt.utils.colorize_array(wt.utils.compose_vis_wildfire_map(in_state['ds']['data']),
225
+ cmap='grey')), gr.Button(interactive=False)
226
+
227
+ in_state['ignition'] = ignition
228
+ ignition_ndarray = wt.utils.to_ndarray(ignition)
229
+
230
+ ignition__over_map = wt.utils.overlay_arrays(
231
+ wt.utils.colorize_array(ignition_ndarray),
232
+ wt.utils.colorize_array(wt.utils.compose_vis_wildfire_map(in_state['ds']['data']),
233
+ cmap='grey'),
234
+ 0.5
235
+ )
236
+
237
+ return in_state, gr.Image((ignition__over_map * 255).astype(np.uint8)), gr.Button(interactive=True)
238
+
239
+
240
+ @custom_ignition_paint.change(
241
+ inputs=[state_var, custom_ignition_paint],
242
+ outputs=[state_var, sel_ignition_mode, ignition_img_over_map, tab_2_confirm_btn])
243
+ def update_ignition_img_over_map(in_state, in_custom):
244
+ if in_custom['composite'] is not None:
245
+ ignition_ndarray = in_custom['composite'] != 0
246
+ ignition = torch.tensor(ignition_ndarray)
247
+ else:
248
+ return in_state, gr.Dropdown(), gr.Image(), gr.Button()
249
+ in_state['ignition'] = ignition
250
+
251
+ ignition__over_map = wt.utils.overlay_arrays(
252
+ wt.utils.colorize_array(ignition_ndarray),
253
+ wt.utils.colorize_array(wt.utils.compose_vis_wildfire_map(in_state['ds']['data']),
254
+ cmap='grey'),
255
+ 0.5
256
+ )
257
+
258
+ return in_state, gr.Dropdown(value='custom'), gr.Image(
259
+ (ignition__over_map * 255).astype(np.uint8)), gr.Button(interactive=True)
260
+
261
+ with gr.Tab("3. Simulation Control", interactive=False, id='tab_3') as tab_3:
262
+ @tab_2_confirm_btn.click(inputs=[state_var], outputs=[state_var, tabs, tab_3])
263
+ def update_tab_34_components(in_state):
264
+ return in_state, gr.Tabs(selected='tab_3'), gr.Tab(interactive=True)
265
+
266
+
267
+ with gr.Row():
268
+ with gr.Column():
269
+ gr.Markdown("## Memory Control")
270
+ checkpoint_cb = gr.Checkbox(label="Checkpoint -> Memory", value=False, interactive=True)
271
+ run_from_cp_cb = gr.Checkbox(label="Begin from Memory", value=False, interactive=True)
272
+ reset_btn = gr.Button("Reset Memory", interactive=True)
273
+ with gr.Column():
274
+ gr.Markdown("## Misc Control")
275
+ sel_steps = gr.Number(label="Number of Steps", value=200, minimum=1, step=1, interactive=True)
276
+ auto_run_cb = gr.Checkbox(label="Auto Run", value=False, interactive=True)
277
+ auto_reseed_cb = gr.Checkbox(label="Auto Regenerate Seed when open Tab", value=False,
278
+ interactive=True)
279
+ track_p_burn_cb = gr.Checkbox(label="Track p(burn), slow", value=False, interactive=True)
280
+ with gr.Column():
281
+ gr.Markdown("## Random Seed Control")
282
+ sel_seed = gr.Number(label="Random Seed", value=torch.Generator().seed(), minimum=0, step=1,
283
+ interactive=True)
284
+ random_seed_btn = gr.Button("Randomize Seed", interactive=True)
285
+
286
+
287
+ @random_seed_btn.click(inputs=[state_var], outputs=[state_var, sel_seed])
288
+ def randomize_seed(in_state):
289
+ return in_state, torch.Generator().seed()
290
+
291
+ with gr.Row():
292
+ run_btn = gr.Button("Run Simulation", interactive=True)
293
+ download_snap_btn = gr.DownloadButton(label="Download Snapshots", interactive=False, visible=False)
294
+
295
+ progress_bar = gr.Progress(track_tqdm=True)
296
+
297
+ with gr.Row():
298
+ output_video = gr.Video(label="Simulation Video", interactive=False, autoplay=True)
299
+
300
+ stats_plot = gr.LinePlot(title="Simulation Stats", interactive=True, height=600,
301
+ width=600, )
302
+
303
+ with gr.Tab("4. Advanced Simulation", interactive=False, id='tab_4') as tab_4:
304
+
305
+ sel_tab4_step = gr.Slider(label='Step', minimum=0, step=1, value=0, interactive=True)
306
+ with gr.Row():
307
+ cof_tb = gr.Textbox(label='cell_on_fire', interactive=False)
308
+ cbo_tb = gr.Textbox(label='cell_burned_out', interactive=False)
309
+ with gr.Row():
310
+ fire_state_img = gr.Image(label="Fire State", interactive=False)
311
+ p_burn_plot = gr.Image(label="p(burn)", interactive=False) # gr.Plot is bad at presenting
312
+ stats_df = gr.DataFrame()
313
+
314
+
315
+ @sel_tab4_step.input(inputs=[state_var, sel_tab4_step],
316
+ outputs=[state_var, fire_state_img, p_burn_plot, cof_tb, cbo_tb])
317
+ def update_tab4_step(in_state, in_user_step):
318
+
319
+ o_fsi, o_pbp = gr.Image(), gr.Image()
320
+ o_cof_tb, o_cbo_tb = gr.Textbox(), gr.Textbox()
321
+
322
+ if in_state['logger'] is not None:
323
+ snapshot = in_state['logger'].snapshots[in_user_step]
324
+ log = in_state['logger'].logs[in_user_step]
325
+ o_fsi = gr.Image(
326
+ value=wt.utils.colorize_array(wt.utils.compose_vis_fire_state(snapshot['fire_state'])))
327
+ if len(in_state['logger'].p_burns) > 0:
328
+ p_burn_arr = in_state['logger'].p_burns[in_user_step].cpu().numpy()
329
+ o_pbp = gr.Image(
330
+ value=wt.utils.colorize_array(p_burn_arr))
331
+ o_cof_tb = gr.Textbox(value=str(log['num_cells_on_fire']))
332
+ o_cbo_tb = gr.Textbox(value=str(log['num_cells_burned_out']))
333
+
334
+ return in_state, o_fsi, o_pbp, o_cof_tb, o_cbo_tb
335
+
336
+
337
+ @tab_3.select(
338
+ inputs=[state_var, auto_run_cb, sel_steps, sel_seed, auto_reseed_cb, checkpoint_cb, run_from_cp_cb,
339
+ track_p_burn_cb],
340
+ outputs=[state_var, output_video, tab_4, stats_plot, download_snap_btn, sel_seed, sel_tab4_step,
341
+ stats_df])
342
+ def auto_run_simulation(in_state, in_auto_run, in_steps, in_seed, in_auto_reseed, in_checkpoint_cb,
343
+ in_run_from_cp_cb, in_track_p_burn_cb):
344
+ o_s = in_state
345
+ o_v = gr.Video()
346
+ o_t = gr.Tab()
347
+ o_lp = gr.LinePlot()
348
+ o_dsb = gr.DownloadButton()
349
+ o_ts = gr.Slider()
350
+ o_sdf = pd.DataFrame()
351
+ if in_auto_reseed:
352
+ in_seed = torch.Generator().seed()
353
+ if in_auto_run:
354
+ o_s, o_v, o_t, o_lp, o_dsb, o_ts, o_sdf = run_simulation(in_state,
355
+ in_steps,
356
+ in_seed,
357
+ in_checkpoint_cb,
358
+ in_run_from_cp_cb,
359
+ in_track_p_burn_cb)
360
+ return o_s, o_v, o_t, o_lp, o_dsb, in_seed, o_ts, o_sdf
361
+
362
+
363
+ @reset_btn.click(inputs=[state_var],
364
+ outputs=[state_var])
365
+ def reset_simulation(in_state):
366
+ if in_state['checkpoint'] is not None:
367
+ in_state['checkpoint'] = None
368
+ gr.Info('Checkpoint Cleared.')
369
+ return in_state
370
+
371
+
372
+ @run_btn.click(inputs=[state_var, sel_steps, sel_seed, checkpoint_cb, run_from_cp_cb, track_p_burn_cb],
373
+ outputs=[state_var, output_video, tab_4, stats_plot, download_snap_btn, sel_tab4_step,
374
+ stats_df])
375
+ def run_simulation(in_state, in_steps, in_seed, in_checkpoint_cb, in_run_from_cp_cb, in_track_p_burn_cb,
376
+ in_progress=gr.Progress(track_tqdm=True)):
377
+ if in_state['out_video_path'] is None:
378
+ in_state['out_video_path'] = f'runs/{str(uuid.uuid4())}.mp4'
379
+ simulator = wt.WildTorchSimulator(
380
+ wildfire_map=in_state['ds']['data'],
381
+ simulator_constants=wt.SimulatorConstants(
382
+ p_h=in_state['constants']['p_h'],
383
+ c_1=in_state['constants']['c_1'],
384
+ c_2=in_state['constants']['c_2'],
385
+ a=in_state['constants']['a'],
386
+ theta_w=in_state['constants']['theta_w'],
387
+ v=in_state['constants']['v'],
388
+ p_firebreak=in_state['constants']['p_firebreak'],
389
+ p_continue_burn=in_state['constants']['p_continue_burn'],
390
+ device=in_state['constants']['device'],
391
+ dtype=in_state['constants']['dtype'],
392
+ ),
393
+ maximum_step=in_steps,
394
+ initial_ignition=in_state['ignition'],
395
+ seed=in_seed,
396
+ )
397
+
398
+ if in_state['checkpoint'] is not None and in_run_from_cp_cb:
399
+ simulator.load_checkpoint(in_state['checkpoint'], restore_seed=False)
400
+
401
+ logger = wt.logger.Logger(disable_writing=True, verbose=False)
402
+
403
+ for i in in_progress.tqdm(range(in_steps)):
404
+ simulator.step()
405
+ logger.snapshot_simulation(simulator)
406
+ logger.log_stats(
407
+ step=i,
408
+ num_cells_on_fire=wt.metrics.cell_on_fire(simulator.fire_state).item(),
409
+ num_cells_burned_out=wt.metrics.cell_burned_out(simulator.fire_state).item(),
410
+ )
411
+ if in_track_p_burn_cb:
412
+ logger.log_p_burn(simulator)
413
+
414
+ gr.Info('Simulation Completed. Generating Video ...')
415
+
416
+ in_state['logger'] = logger
417
+
418
+ if in_checkpoint_cb:
419
+ in_state['checkpoint'] = simulator.checkpoint
420
+
421
+ if ENABLE_DOWNLOAD_SNAPSHOTS:
422
+ logger.snapshots_filepath = in_state['snapshots_path'] = f'runs/{str(uuid.uuid4())}.pt'
423
+ logger.save_snapshots()
424
+ can_download_snapshots = True
425
+ else:
426
+ can_download_snapshots = False
427
+
428
+ wt.utils.animate_snapshots(logger.snapshots, simulator.wildfire_map,
429
+ out_filename=in_state['out_video_path'])
430
+
431
+ m_stats_df = pd.DataFrame(logger.logs)
432
+ m_stats_df = m_stats_df.melt(id_vars=["step"], var_name="key", value_name="value")
433
+
434
+ o_stats_df = pd.DataFrame(logger.logs)
435
+ return in_state, gr.Video(value=in_state['out_video_path']), gr.Tab(interactive=True), gr.LinePlot(
436
+ m_stats_df, x='step', y='value', color="key", color_legend_position="bottom",
437
+ tooltip=["step", "key", "value"], container=False, ), gr.DownloadButton(
438
+ value=in_state['snapshots_path'], interactive=can_download_snapshots,
439
+ visible=can_download_snapshots), gr.Slider(maximum=in_steps - 1), o_stats_df
440
+
441
+ demo.queue().launch()
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ wildtorch[full]