Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
import random | |
import numpy as np | |
import yaml | |
from pathlib import Path | |
from io import BytesIO | |
import random | |
from pathlib import Path | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
from huggingface_hub import hf_hub_download, snapshot_download | |
import tempfile | |
import traceback | |
import functools as ft | |
import os | |
import random | |
import re | |
from collections import defaultdict | |
from datetime import datetime, timedelta | |
from pathlib import Path | |
import h5py | |
import numpy as np | |
import pandas as pd | |
import torch | |
from torch import Tensor | |
from torch.utils.data import Dataset | |
import logging | |
from Prithvi import * | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Set page configuration | |
st.set_page_config( | |
page_title="MERRA2 Data Processor", | |
layout="wide", | |
initial_sidebar_state="expanded", | |
) | |
dataset_type = st.sidebar.selectbox( | |
"Select Dataset Type", | |
options=["MERRA2", "GEOS5"], | |
index=0 | |
) | |
st.title("MERRA2 Data Processor with PrithviWxC Model") | |
# Sidebar for file uploads | |
st.sidebar.header("Upload MERRA2 Data Files") | |
# File uploader for surface data | |
uploaded_surface_files = st.sidebar.file_uploader( | |
"Upload Surface Data Files", | |
type=["nc", "netcdf"], | |
accept_multiple_files=True, | |
key="surface_uploader", | |
) | |
# File uploader for vertical data | |
uploaded_vertical_files = st.sidebar.file_uploader( | |
"Upload Vertical Data Files", | |
type=["nc", "netcdf"], | |
accept_multiple_files=True, | |
key="vertical_uploader", | |
) | |
# Optional: Upload config.yaml | |
uploaded_config = st.sidebar.file_uploader( | |
"Upload config.yaml", | |
type=["yaml", "yml"], | |
key="config_uploader", | |
) | |
# Optional: Upload model weights | |
uploaded_weights = st.sidebar.file_uploader( | |
"Upload Model Weights (.pt)", | |
type=["pt"], | |
key="weights_uploader", | |
) | |
# Other configurations | |
st.sidebar.header("Task Configuration") | |
lead_times = st.sidebar.multiselect( | |
"Select Lead Times", | |
options=[12, 24, 36, 48], | |
default=[12], | |
) | |
input_times = st.sidebar.multiselect( | |
"Select Input Times", | |
options=[-6, -12, -18, -24], | |
default=[-6], | |
) | |
time_range_start = st.sidebar.text_input( | |
"Start Time (e.g., 2020-01-01T00:00:00)", | |
value="2020-01-01T00:00:00", | |
) | |
time_range_end = st.sidebar.text_input( | |
"End Time (e.g., 2020-01-01T23:59:59)", | |
value="2020-01-01T23:59:59", | |
) | |
time_range = (time_range_start, time_range_end) | |
# Function to save uploaded files | |
def save_uploaded_files(uploaded_files, folder_name, max_size_mb=1024): | |
if not uploaded_files: | |
st.warning(f"No {folder_name} files uploaded.") | |
return None | |
# Validate file sizes | |
for file in uploaded_files: | |
if file.size > max_size_mb * 1024 * 1024: | |
st.error(f"File {file.name} exceeds the maximum size of {max_size_mb} MB.") | |
return None | |
temp_dir = tempfile.mkdtemp() | |
with st.spinner(f"Saving {folder_name} files..."): | |
for uploaded_file in uploaded_files: | |
file_path = Path(temp_dir) / uploaded_file.name | |
with open(file_path, "wb") as f: | |
f.write(uploaded_file.getbuffer()) | |
st.success(f"Saved {len(uploaded_files)} {folder_name} files.") | |
return Path(temp_dir) | |
# Save uploaded files | |
surf_dir = save_uploaded_files(uploaded_surface_files, "surface") | |
vert_dir = save_uploaded_files(uploaded_vertical_files, "vertical") | |
# Display uploaded files | |
if surf_dir: | |
st.sidebar.subheader("Surface Files Uploaded:") | |
for file in surf_dir.iterdir(): | |
st.sidebar.write(file.name) | |
if vert_dir: | |
st.sidebar.subheader("Vertical Files Uploaded:") | |
for file in vert_dir.iterdir(): | |
st.sidebar.write(file.name) | |
# Handle Climatology Files | |
st.sidebar.header("Upload Climatology Files (If Missing)") | |
# Climatology files paths | |
default_clim_dir = Path("Prithvi-WxC/examples/climatology") | |
surf_in_scal_path = default_clim_dir / "musigma_surface.nc" | |
vert_in_scal_path = default_clim_dir / "musigma_vertical.nc" | |
surf_out_scal_path = default_clim_dir / "anomaly_variance_surface.nc" | |
vert_out_scal_path = default_clim_dir / "anomaly_variance_vertical.nc" | |
# Check if climatology files exist | |
clim_files_exist = all( | |
[ | |
surf_in_scal_path.exists(), | |
vert_in_scal_path.exists(), | |
surf_out_scal_path.exists(), | |
vert_out_scal_path.exists(), | |
] | |
) | |
if not clim_files_exist: | |
st.sidebar.warning("Climatology files are missing.") | |
uploaded_clim_surface = st.sidebar.file_uploader( | |
"Upload Climatology Surface File", | |
type=["nc", "netcdf"], | |
key="clim_surface_uploader", | |
) | |
uploaded_clim_vertical = st.sidebar.file_uploader( | |
"Upload Climatology Vertical File", | |
type=["nc", "netcdf"], | |
key="clim_vertical_uploader", | |
) | |
if uploaded_clim_surface and uploaded_clim_vertical: | |
clim_temp_dir = tempfile.mkdtemp() | |
clim_surf_path = Path(clim_temp_dir) / uploaded_clim_surface.name | |
with open(clim_surf_path, "wb") as f: | |
f.write(uploaded_clim_surface.getbuffer()) | |
clim_vert_path = Path(clim_temp_dir) / uploaded_clim_vertical.name | |
with open(clim_vert_path, "wb") as f: | |
f.write(uploaded_clim_vertical.getbuffer()) | |
st.success("Climatology files uploaded and saved.") | |
else: | |
if not (uploaded_clim_surface and uploaded_clim_vertical): | |
st.warning("Please upload both climatology surface and vertical files.") | |
else: | |
clim_surf_path = surf_in_scal_path | |
clim_vert_path = vert_in_scal_path | |
# Save uploaded config.yaml | |
if uploaded_config: | |
temp_config = tempfile.mktemp(suffix=".yaml") | |
with open(temp_config, "wb") as f: | |
f.write(uploaded_config.getbuffer()) | |
config_path = Path(temp_config) | |
st.sidebar.success("Config.yaml uploaded and saved.") | |
else: | |
# Use default config.yaml path | |
config_path = Path("Prithvi-WxC/examples/config.yaml") | |
if not config_path.exists(): | |
st.sidebar.error("Default config.yaml not found. Please upload a config file.") | |
st.stop() | |
# Save uploaded model weights | |
if uploaded_weights: | |
temp_weights = tempfile.mktemp(suffix=".pt") | |
with open(temp_weights, "wb") as f: | |
f.write(uploaded_weights.getbuffer()) | |
weights_path = Path(temp_weights) | |
st.sidebar.success("Model weights uploaded and saved.") | |
else: | |
# Use default weights path | |
weights_path = Path("Prithvi-WxC/examples/weights/prithvi.wxc.2300m.v1.pt") | |
if not weights_path.exists(): | |
st.sidebar.error("Default model weights not found. Please upload model weights.") | |
st.stop() | |
# Button to run inference | |
if st.sidebar.button("Run Inference"): | |
# Initialize device | |
torch.jit.enable_onednn_fusion(True) | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
st.write(f"Using device: {torch.cuda.get_device_name()}") | |
torch.backends.cudnn.benchmark = True | |
torch.backends.cudnn.deterministic = True | |
else: | |
device = torch.device("cpu") | |
st.write("Using device: CPU") | |
# Set random seeds | |
random.seed(42) | |
if torch.cuda.is_available(): | |
torch.cuda.manual_seed(42) | |
torch.manual_seed(42) | |
np.random.seed(42) | |
# Define variables and parameters | |
surface_vars = [ | |
"EFLUX", | |
"GWETROOT", | |
"HFLUX", | |
"LAI", | |
"LWGAB", | |
"LWGEM", | |
"LWTUP", | |
"PS", | |
"QV2M", | |
"SLP", | |
"SWGNT", | |
"SWTNT", | |
"T2M", | |
"TQI", | |
"TQL", | |
"TQV", | |
"TS", | |
"U10M", | |
"V10M", | |
"Z0M", | |
] | |
static_surface_vars = ["FRACI", "FRLAND", "FROCEAN", "PHIS"] | |
vertical_vars = ["CLOUD", "H", "OMEGA", "PL", "QI", "QL", "QV", "T", "U", "V"] | |
levels = [ | |
34.0, | |
39.0, | |
41.0, | |
43.0, | |
44.0, | |
45.0, | |
48.0, | |
51.0, | |
53.0, | |
56.0, | |
63.0, | |
68.0, | |
71.0, | |
72.0, | |
] | |
padding = {"level": [0, 0], "lat": [0, -1], "lon": [0, 0]} | |
residual = "climate" | |
masking_mode = "local" | |
decoder_shifting = True | |
masking_ratio = 0.99 | |
positional_encoding = "fourier" | |
# Initialize Dataset | |
try: | |
with st.spinner("Initializing dataset..."): | |
# Validate climatology files | |
if not clim_files_exist and not (uploaded_clim_surface and uploaded_clim_vertical): | |
st.error("Climatology files are missing. Please upload both surface and vertical climatology files.") | |
st.stop() | |
dataset = Merra2Dataset( | |
time_range=time_range, | |
lead_times=lead_times, | |
input_times=input_times, | |
data_path_surface=Path("Prithvi-WxC/examples/merra-2"), | |
data_path_vertical=Path("Prithvi-WxC/examples/merra-2"), | |
climatology_path_surface=Path("Prithvi-WxC/examples/climatology"), | |
climatology_path_vertical=Path("Prithvi-WxC/examples/climatology"), | |
surface_vars=surface_vars, | |
static_surface_vars=static_surface_vars, | |
vertical_vars=vertical_vars, | |
levels=levels, | |
positional_encoding=positional_encoding, | |
) | |
assert len(dataset) > 0, "There doesn't seem to be any valid data." | |
st.success("Dataset initialized successfully.") | |
except Exception as e: | |
st.error("Error initializing dataset:") | |
st.error(traceback.format_exc()) | |
st.stop() | |
# Load scalers | |
try: | |
with st.spinner("Loading scalers..."): | |
# Assuming the scaler paths are the same as climatology paths | |
surf_in_scal_path = clim_surf_path | |
vert_in_scal_path = clim_vert_path | |
surf_out_scal_path = Path(clim_surf_path.parent) / "anomaly_variance_surface.nc" | |
vert_out_scal_path = Path(clim_vert_path.parent) / "anomaly_variance_vertical.nc" | |
# Check if output scaler files exist | |
if not surf_out_scal_path.exists() or not vert_out_scal_path.exists(): | |
st.error("Anomaly variance scaler files are missing.") | |
st.stop() | |
in_mu, in_sig = input_scalers( | |
surface_vars, | |
vertical_vars, | |
levels, | |
surf_in_scal_path, | |
vert_in_scal_path, | |
) | |
output_sig = output_scalers( | |
surface_vars, | |
vertical_vars, | |
levels, | |
surf_out_scal_path, | |
vert_out_scal_path, | |
) | |
static_mu, static_sig = static_input_scalers( | |
surf_in_scal_path, | |
static_surface_vars, | |
) | |
st.success("Scalers loaded successfully.") | |
except Exception as e: | |
st.error("Error loading scalers:") | |
st.error(traceback.format_exc()) | |
st.stop() | |
# Load configuration | |
try: | |
with st.spinner("Loading configuration..."): | |
with open(config_path, "r") as f: | |
config = yaml.safe_load(f) | |
# Validate config | |
required_params = [ | |
"in_channels", "input_size_time", "in_channels_static", | |
"input_scalers_epsilon", "static_input_scalers_epsilon", | |
"n_lats_px", "n_lons_px", "patch_size_px", | |
"mask_unit_size_px", "embed_dim", "n_blocks_encoder", | |
"n_blocks_decoder", "mlp_multiplier", "n_heads", | |
"dropout", "drop_path", "parameter_dropout" | |
] | |
missing_params = [param for param in required_params if param not in config.get("params", {})] | |
if missing_params: | |
st.error(f"Missing configuration parameters: {missing_params}") | |
st.stop() | |
st.success("Configuration loaded successfully.") | |
except Exception as e: | |
st.error("Error loading configuration:") | |
st.error(traceback.format_exc()) | |
st.stop() | |
# Initialize the model | |
try: | |
with st.spinner("Initializing model..."): | |
model = PrithviWxC( | |
in_channels=config["params"]["in_channels"], | |
input_size_time=config["params"]["input_size_time"], | |
in_channels_static=config["params"]["in_channels_static"], | |
input_scalers_mu=in_mu, | |
input_scalers_sigma=in_sig, | |
input_scalers_epsilon=config["params"]["input_scalers_epsilon"], | |
static_input_scalers_mu=static_mu, | |
static_input_scalers_sigma=static_sig, | |
static_input_scalers_epsilon=config["params"]["static_input_scalers_epsilon"], | |
output_scalers=output_sig**0.5, | |
n_lats_px=config["params"]["n_lats_px"], | |
n_lons_px=config["params"]["n_lons_px"], | |
patch_size_px=config["params"]["patch_size_px"], | |
mask_unit_size_px=config["params"]["mask_unit_size_px"], | |
mask_ratio_inputs=masking_ratio, | |
embed_dim=config["params"]["embed_dim"], | |
n_blocks_encoder=config["params"]["n_blocks_encoder"], | |
n_blocks_decoder=config["params"]["n_blocks_decoder"], | |
mlp_multiplier=config["params"]["mlp_multiplier"], | |
n_heads=config["params"]["n_heads"], | |
dropout=config["params"]["dropout"], | |
drop_path=config["params"]["drop_path"], | |
parameter_dropout=config["params"]["parameter_dropout"], | |
residual=residual, | |
masking_mode=masking_mode, | |
decoder_shifting=decoder_shifting, | |
positional_encoding=positional_encoding, | |
checkpoint_encoder=[], | |
checkpoint_decoder=[], | |
) | |
st.success("Model initialized successfully.") | |
except Exception as e: | |
st.error("Error initializing model:") | |
st.error(traceback.format_exc()) | |
st.stop() | |
# Load model weights | |
try: | |
with st.spinner("Loading model weights..."): | |
state_dict = torch.load(weights_path, map_location=device) | |
if "model_state" in state_dict: | |
state_dict = state_dict["model_state"] | |
model.load_state_dict(state_dict, strict=True) | |
model.to(device) | |
st.success("Model weights loaded successfully.") | |
except Exception as e: | |
st.error("Error loading model weights:") | |
st.error(traceback.format_exc()) | |
st.stop() | |
# Prepare data batch | |
try: | |
with st.spinner("Preparing data batch..."): | |
data = next(iter(dataset)) | |
batch = preproc([data], padding) | |
for k, v in batch.items(): | |
if isinstance(v, torch.Tensor): | |
batch[k] = v.to(device) | |
st.success("Data batch prepared successfully.") | |
except Exception as e: | |
st.error("Error preparing data batch:") | |
st.error(traceback.format_exc()) | |
st.stop() | |
# Run inference | |
try: | |
with st.spinner("Running model inference..."): | |
rng_state_1 = torch.get_rng_state() | |
with torch.no_grad(): | |
model.eval() | |
out = model(batch) | |
st.success("Model inference completed successfully.") | |
except Exception as e: | |
st.error("Error during model inference:") | |
st.error(traceback.format_exc()) | |
st.stop() | |
# Display output | |
st.header("Inference Results") | |
st.write(out) # Adjust based on the structure of 'out' | |
# Optionally, provide download links or visualizations | |
# For example, if 'out' contains tensors or dataframes: | |
# st.write("Output Tensor:", out["some_key"].cpu().numpy()) | |
else: | |
st.info("Please upload the necessary files and click 'Run Inference' to start.") | |