File size: 15,921 Bytes
100edb4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 |
import streamlit as st
import torch
import random
import numpy as np
import yaml
from pathlib import Path
from io import BytesIO
import random
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
from huggingface_hub import hf_hub_download, snapshot_download
import tempfile
import traceback
import functools as ft
import os
import random
import re
from collections import defaultdict
from datetime import datetime, timedelta
from pathlib import Path
import h5py
import numpy as np
import pandas as pd
import torch
from torch import Tensor
from torch.utils.data import Dataset
import logging
from Prithvi import *
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Set page configuration
st.set_page_config(
page_title="MERRA2 Data Processor",
layout="wide",
initial_sidebar_state="expanded",
)
dataset_type = st.sidebar.selectbox(
"Select Dataset Type",
options=["MERRA2", "GEOS5"],
index=0
)
st.title("MERRA2 Data Processor with PrithviWxC Model")
# Sidebar for file uploads
st.sidebar.header("Upload MERRA2 Data Files")
# File uploader for surface data
uploaded_surface_files = st.sidebar.file_uploader(
"Upload Surface Data Files",
type=["nc", "netcdf"],
accept_multiple_files=True,
key="surface_uploader",
)
# File uploader for vertical data
uploaded_vertical_files = st.sidebar.file_uploader(
"Upload Vertical Data Files",
type=["nc", "netcdf"],
accept_multiple_files=True,
key="vertical_uploader",
)
# Optional: Upload config.yaml
uploaded_config = st.sidebar.file_uploader(
"Upload config.yaml",
type=["yaml", "yml"],
key="config_uploader",
)
# Optional: Upload model weights
uploaded_weights = st.sidebar.file_uploader(
"Upload Model Weights (.pt)",
type=["pt"],
key="weights_uploader",
)
# Other configurations
st.sidebar.header("Task Configuration")
lead_times = st.sidebar.multiselect(
"Select Lead Times",
options=[12, 24, 36, 48],
default=[12],
)
input_times = st.sidebar.multiselect(
"Select Input Times",
options=[-6, -12, -18, -24],
default=[-6],
)
time_range_start = st.sidebar.text_input(
"Start Time (e.g., 2020-01-01T00:00:00)",
value="2020-01-01T00:00:00",
)
time_range_end = st.sidebar.text_input(
"End Time (e.g., 2020-01-01T23:59:59)",
value="2020-01-01T23:59:59",
)
time_range = (time_range_start, time_range_end)
# Function to save uploaded files
def save_uploaded_files(uploaded_files, folder_name, max_size_mb=1024):
if not uploaded_files:
st.warning(f"No {folder_name} files uploaded.")
return None
# Validate file sizes
for file in uploaded_files:
if file.size > max_size_mb * 1024 * 1024:
st.error(f"File {file.name} exceeds the maximum size of {max_size_mb} MB.")
return None
temp_dir = tempfile.mkdtemp()
with st.spinner(f"Saving {folder_name} files..."):
for uploaded_file in uploaded_files:
file_path = Path(temp_dir) / uploaded_file.name
with open(file_path, "wb") as f:
f.write(uploaded_file.getbuffer())
st.success(f"Saved {len(uploaded_files)} {folder_name} files.")
return Path(temp_dir)
# Save uploaded files
surf_dir = save_uploaded_files(uploaded_surface_files, "surface")
vert_dir = save_uploaded_files(uploaded_vertical_files, "vertical")
# Display uploaded files
if surf_dir:
st.sidebar.subheader("Surface Files Uploaded:")
for file in surf_dir.iterdir():
st.sidebar.write(file.name)
if vert_dir:
st.sidebar.subheader("Vertical Files Uploaded:")
for file in vert_dir.iterdir():
st.sidebar.write(file.name)
# Handle Climatology Files
st.sidebar.header("Upload Climatology Files (If Missing)")
# Climatology files paths
default_clim_dir = Path("Prithvi-WxC/examples/climatology")
surf_in_scal_path = default_clim_dir / "musigma_surface.nc"
vert_in_scal_path = default_clim_dir / "musigma_vertical.nc"
surf_out_scal_path = default_clim_dir / "anomaly_variance_surface.nc"
vert_out_scal_path = default_clim_dir / "anomaly_variance_vertical.nc"
# Check if climatology files exist
clim_files_exist = all(
[
surf_in_scal_path.exists(),
vert_in_scal_path.exists(),
surf_out_scal_path.exists(),
vert_out_scal_path.exists(),
]
)
if not clim_files_exist:
st.sidebar.warning("Climatology files are missing.")
uploaded_clim_surface = st.sidebar.file_uploader(
"Upload Climatology Surface File",
type=["nc", "netcdf"],
key="clim_surface_uploader",
)
uploaded_clim_vertical = st.sidebar.file_uploader(
"Upload Climatology Vertical File",
type=["nc", "netcdf"],
key="clim_vertical_uploader",
)
if uploaded_clim_surface and uploaded_clim_vertical:
clim_temp_dir = tempfile.mkdtemp()
clim_surf_path = Path(clim_temp_dir) / uploaded_clim_surface.name
with open(clim_surf_path, "wb") as f:
f.write(uploaded_clim_surface.getbuffer())
clim_vert_path = Path(clim_temp_dir) / uploaded_clim_vertical.name
with open(clim_vert_path, "wb") as f:
f.write(uploaded_clim_vertical.getbuffer())
st.success("Climatology files uploaded and saved.")
else:
if not (uploaded_clim_surface and uploaded_clim_vertical):
st.warning("Please upload both climatology surface and vertical files.")
else:
clim_surf_path = surf_in_scal_path
clim_vert_path = vert_in_scal_path
# Save uploaded config.yaml
if uploaded_config:
temp_config = tempfile.mktemp(suffix=".yaml")
with open(temp_config, "wb") as f:
f.write(uploaded_config.getbuffer())
config_path = Path(temp_config)
st.sidebar.success("Config.yaml uploaded and saved.")
else:
# Use default config.yaml path
config_path = Path("Prithvi-WxC/examples/config.yaml")
if not config_path.exists():
st.sidebar.error("Default config.yaml not found. Please upload a config file.")
st.stop()
# Save uploaded model weights
if uploaded_weights:
temp_weights = tempfile.mktemp(suffix=".pt")
with open(temp_weights, "wb") as f:
f.write(uploaded_weights.getbuffer())
weights_path = Path(temp_weights)
st.sidebar.success("Model weights uploaded and saved.")
else:
# Use default weights path
weights_path = Path("Prithvi-WxC/examples/weights/prithvi.wxc.2300m.v1.pt")
if not weights_path.exists():
st.sidebar.error("Default model weights not found. Please upload model weights.")
st.stop()
# Button to run inference
if st.sidebar.button("Run Inference"):
# Initialize device
torch.jit.enable_onednn_fusion(True)
if torch.cuda.is_available():
device = torch.device("cuda")
st.write(f"Using device: {torch.cuda.get_device_name()}")
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True
else:
device = torch.device("cpu")
st.write("Using device: CPU")
# Set random seeds
random.seed(42)
if torch.cuda.is_available():
torch.cuda.manual_seed(42)
torch.manual_seed(42)
np.random.seed(42)
# Define variables and parameters
surface_vars = [
"EFLUX",
"GWETROOT",
"HFLUX",
"LAI",
"LWGAB",
"LWGEM",
"LWTUP",
"PS",
"QV2M",
"SLP",
"SWGNT",
"SWTNT",
"T2M",
"TQI",
"TQL",
"TQV",
"TS",
"U10M",
"V10M",
"Z0M",
]
static_surface_vars = ["FRACI", "FRLAND", "FROCEAN", "PHIS"]
vertical_vars = ["CLOUD", "H", "OMEGA", "PL", "QI", "QL", "QV", "T", "U", "V"]
levels = [
34.0,
39.0,
41.0,
43.0,
44.0,
45.0,
48.0,
51.0,
53.0,
56.0,
63.0,
68.0,
71.0,
72.0,
]
padding = {"level": [0, 0], "lat": [0, -1], "lon": [0, 0]}
residual = "climate"
masking_mode = "local"
decoder_shifting = True
masking_ratio = 0.99
positional_encoding = "fourier"
# Initialize Dataset
try:
with st.spinner("Initializing dataset..."):
# Validate climatology files
if not clim_files_exist and not (uploaded_clim_surface and uploaded_clim_vertical):
st.error("Climatology files are missing. Please upload both surface and vertical climatology files.")
st.stop()
dataset = Merra2Dataset(
time_range=time_range,
lead_times=lead_times,
input_times=input_times,
data_path_surface=Path("Prithvi-WxC/examples/merra-2"),
data_path_vertical=Path("Prithvi-WxC/examples/merra-2"),
climatology_path_surface=Path("Prithvi-WxC/examples/climatology"),
climatology_path_vertical=Path("Prithvi-WxC/examples/climatology"),
surface_vars=surface_vars,
static_surface_vars=static_surface_vars,
vertical_vars=vertical_vars,
levels=levels,
positional_encoding=positional_encoding,
)
assert len(dataset) > 0, "There doesn't seem to be any valid data."
st.success("Dataset initialized successfully.")
except Exception as e:
st.error("Error initializing dataset:")
st.error(traceback.format_exc())
st.stop()
# Load scalers
try:
with st.spinner("Loading scalers..."):
# Assuming the scaler paths are the same as climatology paths
surf_in_scal_path = clim_surf_path
vert_in_scal_path = clim_vert_path
surf_out_scal_path = Path(clim_surf_path.parent) / "anomaly_variance_surface.nc"
vert_out_scal_path = Path(clim_vert_path.parent) / "anomaly_variance_vertical.nc"
# Check if output scaler files exist
if not surf_out_scal_path.exists() or not vert_out_scal_path.exists():
st.error("Anomaly variance scaler files are missing.")
st.stop()
in_mu, in_sig = input_scalers(
surface_vars,
vertical_vars,
levels,
surf_in_scal_path,
vert_in_scal_path,
)
output_sig = output_scalers(
surface_vars,
vertical_vars,
levels,
surf_out_scal_path,
vert_out_scal_path,
)
static_mu, static_sig = static_input_scalers(
surf_in_scal_path,
static_surface_vars,
)
st.success("Scalers loaded successfully.")
except Exception as e:
st.error("Error loading scalers:")
st.error(traceback.format_exc())
st.stop()
# Load configuration
try:
with st.spinner("Loading configuration..."):
with open(config_path, "r") as f:
config = yaml.safe_load(f)
# Validate config
required_params = [
"in_channels", "input_size_time", "in_channels_static",
"input_scalers_epsilon", "static_input_scalers_epsilon",
"n_lats_px", "n_lons_px", "patch_size_px",
"mask_unit_size_px", "embed_dim", "n_blocks_encoder",
"n_blocks_decoder", "mlp_multiplier", "n_heads",
"dropout", "drop_path", "parameter_dropout"
]
missing_params = [param for param in required_params if param not in config.get("params", {})]
if missing_params:
st.error(f"Missing configuration parameters: {missing_params}")
st.stop()
st.success("Configuration loaded successfully.")
except Exception as e:
st.error("Error loading configuration:")
st.error(traceback.format_exc())
st.stop()
# Initialize the model
try:
with st.spinner("Initializing model..."):
model = PrithviWxC(
in_channels=config["params"]["in_channels"],
input_size_time=config["params"]["input_size_time"],
in_channels_static=config["params"]["in_channels_static"],
input_scalers_mu=in_mu,
input_scalers_sigma=in_sig,
input_scalers_epsilon=config["params"]["input_scalers_epsilon"],
static_input_scalers_mu=static_mu,
static_input_scalers_sigma=static_sig,
static_input_scalers_epsilon=config["params"]["static_input_scalers_epsilon"],
output_scalers=output_sig**0.5,
n_lats_px=config["params"]["n_lats_px"],
n_lons_px=config["params"]["n_lons_px"],
patch_size_px=config["params"]["patch_size_px"],
mask_unit_size_px=config["params"]["mask_unit_size_px"],
mask_ratio_inputs=masking_ratio,
embed_dim=config["params"]["embed_dim"],
n_blocks_encoder=config["params"]["n_blocks_encoder"],
n_blocks_decoder=config["params"]["n_blocks_decoder"],
mlp_multiplier=config["params"]["mlp_multiplier"],
n_heads=config["params"]["n_heads"],
dropout=config["params"]["dropout"],
drop_path=config["params"]["drop_path"],
parameter_dropout=config["params"]["parameter_dropout"],
residual=residual,
masking_mode=masking_mode,
decoder_shifting=decoder_shifting,
positional_encoding=positional_encoding,
checkpoint_encoder=[],
checkpoint_decoder=[],
)
st.success("Model initialized successfully.")
except Exception as e:
st.error("Error initializing model:")
st.error(traceback.format_exc())
st.stop()
# Load model weights
try:
with st.spinner("Loading model weights..."):
state_dict = torch.load(weights_path, map_location=device)
if "model_state" in state_dict:
state_dict = state_dict["model_state"]
model.load_state_dict(state_dict, strict=True)
model.to(device)
st.success("Model weights loaded successfully.")
except Exception as e:
st.error("Error loading model weights:")
st.error(traceback.format_exc())
st.stop()
# Prepare data batch
try:
with st.spinner("Preparing data batch..."):
data = next(iter(dataset))
batch = preproc([data], padding)
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch[k] = v.to(device)
st.success("Data batch prepared successfully.")
except Exception as e:
st.error("Error preparing data batch:")
st.error(traceback.format_exc())
st.stop()
# Run inference
try:
with st.spinner("Running model inference..."):
rng_state_1 = torch.get_rng_state()
with torch.no_grad():
model.eval()
out = model(batch)
st.success("Model inference completed successfully.")
except Exception as e:
st.error("Error during model inference:")
st.error(traceback.format_exc())
st.stop()
# Display output
st.header("Inference Results")
st.write(out) # Adjust based on the structure of 'out'
# Optionally, provide download links or visualizations
# For example, if 'out' contains tensors or dataframes:
# st.write("Output Tensor:", out["some_key"].cpu().numpy())
else:
st.info("Please upload the necessary files and click 'Run Inference' to start.")
|