qq1990's picture
init
100edb4
import streamlit as st
import torch
import random
import numpy as np
import yaml
from pathlib import Path
from io import BytesIO
import random
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
from huggingface_hub import hf_hub_download, snapshot_download
import tempfile
import traceback
import functools as ft
import os
import random
import re
from collections import defaultdict
from datetime import datetime, timedelta
from pathlib import Path
import h5py
import numpy as np
import pandas as pd
import torch
from torch import Tensor
from torch.utils.data import Dataset
import logging
from Prithvi import *
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Set page configuration
st.set_page_config(
page_title="MERRA2 Data Processor",
layout="wide",
initial_sidebar_state="expanded",
)
dataset_type = st.sidebar.selectbox(
"Select Dataset Type",
options=["MERRA2", "GEOS5"],
index=0
)
st.title("MERRA2 Data Processor with PrithviWxC Model")
# Sidebar for file uploads
st.sidebar.header("Upload MERRA2 Data Files")
# File uploader for surface data
uploaded_surface_files = st.sidebar.file_uploader(
"Upload Surface Data Files",
type=["nc", "netcdf"],
accept_multiple_files=True,
key="surface_uploader",
)
# File uploader for vertical data
uploaded_vertical_files = st.sidebar.file_uploader(
"Upload Vertical Data Files",
type=["nc", "netcdf"],
accept_multiple_files=True,
key="vertical_uploader",
)
# Optional: Upload config.yaml
uploaded_config = st.sidebar.file_uploader(
"Upload config.yaml",
type=["yaml", "yml"],
key="config_uploader",
)
# Optional: Upload model weights
uploaded_weights = st.sidebar.file_uploader(
"Upload Model Weights (.pt)",
type=["pt"],
key="weights_uploader",
)
# Other configurations
st.sidebar.header("Task Configuration")
lead_times = st.sidebar.multiselect(
"Select Lead Times",
options=[12, 24, 36, 48],
default=[12],
)
input_times = st.sidebar.multiselect(
"Select Input Times",
options=[-6, -12, -18, -24],
default=[-6],
)
time_range_start = st.sidebar.text_input(
"Start Time (e.g., 2020-01-01T00:00:00)",
value="2020-01-01T00:00:00",
)
time_range_end = st.sidebar.text_input(
"End Time (e.g., 2020-01-01T23:59:59)",
value="2020-01-01T23:59:59",
)
time_range = (time_range_start, time_range_end)
# Function to save uploaded files
def save_uploaded_files(uploaded_files, folder_name, max_size_mb=1024):
if not uploaded_files:
st.warning(f"No {folder_name} files uploaded.")
return None
# Validate file sizes
for file in uploaded_files:
if file.size > max_size_mb * 1024 * 1024:
st.error(f"File {file.name} exceeds the maximum size of {max_size_mb} MB.")
return None
temp_dir = tempfile.mkdtemp()
with st.spinner(f"Saving {folder_name} files..."):
for uploaded_file in uploaded_files:
file_path = Path(temp_dir) / uploaded_file.name
with open(file_path, "wb") as f:
f.write(uploaded_file.getbuffer())
st.success(f"Saved {len(uploaded_files)} {folder_name} files.")
return Path(temp_dir)
# Save uploaded files
surf_dir = save_uploaded_files(uploaded_surface_files, "surface")
vert_dir = save_uploaded_files(uploaded_vertical_files, "vertical")
# Display uploaded files
if surf_dir:
st.sidebar.subheader("Surface Files Uploaded:")
for file in surf_dir.iterdir():
st.sidebar.write(file.name)
if vert_dir:
st.sidebar.subheader("Vertical Files Uploaded:")
for file in vert_dir.iterdir():
st.sidebar.write(file.name)
# Handle Climatology Files
st.sidebar.header("Upload Climatology Files (If Missing)")
# Climatology files paths
default_clim_dir = Path("Prithvi-WxC/examples/climatology")
surf_in_scal_path = default_clim_dir / "musigma_surface.nc"
vert_in_scal_path = default_clim_dir / "musigma_vertical.nc"
surf_out_scal_path = default_clim_dir / "anomaly_variance_surface.nc"
vert_out_scal_path = default_clim_dir / "anomaly_variance_vertical.nc"
# Check if climatology files exist
clim_files_exist = all(
[
surf_in_scal_path.exists(),
vert_in_scal_path.exists(),
surf_out_scal_path.exists(),
vert_out_scal_path.exists(),
]
)
if not clim_files_exist:
st.sidebar.warning("Climatology files are missing.")
uploaded_clim_surface = st.sidebar.file_uploader(
"Upload Climatology Surface File",
type=["nc", "netcdf"],
key="clim_surface_uploader",
)
uploaded_clim_vertical = st.sidebar.file_uploader(
"Upload Climatology Vertical File",
type=["nc", "netcdf"],
key="clim_vertical_uploader",
)
if uploaded_clim_surface and uploaded_clim_vertical:
clim_temp_dir = tempfile.mkdtemp()
clim_surf_path = Path(clim_temp_dir) / uploaded_clim_surface.name
with open(clim_surf_path, "wb") as f:
f.write(uploaded_clim_surface.getbuffer())
clim_vert_path = Path(clim_temp_dir) / uploaded_clim_vertical.name
with open(clim_vert_path, "wb") as f:
f.write(uploaded_clim_vertical.getbuffer())
st.success("Climatology files uploaded and saved.")
else:
if not (uploaded_clim_surface and uploaded_clim_vertical):
st.warning("Please upload both climatology surface and vertical files.")
else:
clim_surf_path = surf_in_scal_path
clim_vert_path = vert_in_scal_path
# Save uploaded config.yaml
if uploaded_config:
temp_config = tempfile.mktemp(suffix=".yaml")
with open(temp_config, "wb") as f:
f.write(uploaded_config.getbuffer())
config_path = Path(temp_config)
st.sidebar.success("Config.yaml uploaded and saved.")
else:
# Use default config.yaml path
config_path = Path("Prithvi-WxC/examples/config.yaml")
if not config_path.exists():
st.sidebar.error("Default config.yaml not found. Please upload a config file.")
st.stop()
# Save uploaded model weights
if uploaded_weights:
temp_weights = tempfile.mktemp(suffix=".pt")
with open(temp_weights, "wb") as f:
f.write(uploaded_weights.getbuffer())
weights_path = Path(temp_weights)
st.sidebar.success("Model weights uploaded and saved.")
else:
# Use default weights path
weights_path = Path("Prithvi-WxC/examples/weights/prithvi.wxc.2300m.v1.pt")
if not weights_path.exists():
st.sidebar.error("Default model weights not found. Please upload model weights.")
st.stop()
# Button to run inference
if st.sidebar.button("Run Inference"):
# Initialize device
torch.jit.enable_onednn_fusion(True)
if torch.cuda.is_available():
device = torch.device("cuda")
st.write(f"Using device: {torch.cuda.get_device_name()}")
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True
else:
device = torch.device("cpu")
st.write("Using device: CPU")
# Set random seeds
random.seed(42)
if torch.cuda.is_available():
torch.cuda.manual_seed(42)
torch.manual_seed(42)
np.random.seed(42)
# Define variables and parameters
surface_vars = [
"EFLUX",
"GWETROOT",
"HFLUX",
"LAI",
"LWGAB",
"LWGEM",
"LWTUP",
"PS",
"QV2M",
"SLP",
"SWGNT",
"SWTNT",
"T2M",
"TQI",
"TQL",
"TQV",
"TS",
"U10M",
"V10M",
"Z0M",
]
static_surface_vars = ["FRACI", "FRLAND", "FROCEAN", "PHIS"]
vertical_vars = ["CLOUD", "H", "OMEGA", "PL", "QI", "QL", "QV", "T", "U", "V"]
levels = [
34.0,
39.0,
41.0,
43.0,
44.0,
45.0,
48.0,
51.0,
53.0,
56.0,
63.0,
68.0,
71.0,
72.0,
]
padding = {"level": [0, 0], "lat": [0, -1], "lon": [0, 0]}
residual = "climate"
masking_mode = "local"
decoder_shifting = True
masking_ratio = 0.99
positional_encoding = "fourier"
# Initialize Dataset
try:
with st.spinner("Initializing dataset..."):
# Validate climatology files
if not clim_files_exist and not (uploaded_clim_surface and uploaded_clim_vertical):
st.error("Climatology files are missing. Please upload both surface and vertical climatology files.")
st.stop()
dataset = Merra2Dataset(
time_range=time_range,
lead_times=lead_times,
input_times=input_times,
data_path_surface=Path("Prithvi-WxC/examples/merra-2"),
data_path_vertical=Path("Prithvi-WxC/examples/merra-2"),
climatology_path_surface=Path("Prithvi-WxC/examples/climatology"),
climatology_path_vertical=Path("Prithvi-WxC/examples/climatology"),
surface_vars=surface_vars,
static_surface_vars=static_surface_vars,
vertical_vars=vertical_vars,
levels=levels,
positional_encoding=positional_encoding,
)
assert len(dataset) > 0, "There doesn't seem to be any valid data."
st.success("Dataset initialized successfully.")
except Exception as e:
st.error("Error initializing dataset:")
st.error(traceback.format_exc())
st.stop()
# Load scalers
try:
with st.spinner("Loading scalers..."):
# Assuming the scaler paths are the same as climatology paths
surf_in_scal_path = clim_surf_path
vert_in_scal_path = clim_vert_path
surf_out_scal_path = Path(clim_surf_path.parent) / "anomaly_variance_surface.nc"
vert_out_scal_path = Path(clim_vert_path.parent) / "anomaly_variance_vertical.nc"
# Check if output scaler files exist
if not surf_out_scal_path.exists() or not vert_out_scal_path.exists():
st.error("Anomaly variance scaler files are missing.")
st.stop()
in_mu, in_sig = input_scalers(
surface_vars,
vertical_vars,
levels,
surf_in_scal_path,
vert_in_scal_path,
)
output_sig = output_scalers(
surface_vars,
vertical_vars,
levels,
surf_out_scal_path,
vert_out_scal_path,
)
static_mu, static_sig = static_input_scalers(
surf_in_scal_path,
static_surface_vars,
)
st.success("Scalers loaded successfully.")
except Exception as e:
st.error("Error loading scalers:")
st.error(traceback.format_exc())
st.stop()
# Load configuration
try:
with st.spinner("Loading configuration..."):
with open(config_path, "r") as f:
config = yaml.safe_load(f)
# Validate config
required_params = [
"in_channels", "input_size_time", "in_channels_static",
"input_scalers_epsilon", "static_input_scalers_epsilon",
"n_lats_px", "n_lons_px", "patch_size_px",
"mask_unit_size_px", "embed_dim", "n_blocks_encoder",
"n_blocks_decoder", "mlp_multiplier", "n_heads",
"dropout", "drop_path", "parameter_dropout"
]
missing_params = [param for param in required_params if param not in config.get("params", {})]
if missing_params:
st.error(f"Missing configuration parameters: {missing_params}")
st.stop()
st.success("Configuration loaded successfully.")
except Exception as e:
st.error("Error loading configuration:")
st.error(traceback.format_exc())
st.stop()
# Initialize the model
try:
with st.spinner("Initializing model..."):
model = PrithviWxC(
in_channels=config["params"]["in_channels"],
input_size_time=config["params"]["input_size_time"],
in_channels_static=config["params"]["in_channels_static"],
input_scalers_mu=in_mu,
input_scalers_sigma=in_sig,
input_scalers_epsilon=config["params"]["input_scalers_epsilon"],
static_input_scalers_mu=static_mu,
static_input_scalers_sigma=static_sig,
static_input_scalers_epsilon=config["params"]["static_input_scalers_epsilon"],
output_scalers=output_sig**0.5,
n_lats_px=config["params"]["n_lats_px"],
n_lons_px=config["params"]["n_lons_px"],
patch_size_px=config["params"]["patch_size_px"],
mask_unit_size_px=config["params"]["mask_unit_size_px"],
mask_ratio_inputs=masking_ratio,
embed_dim=config["params"]["embed_dim"],
n_blocks_encoder=config["params"]["n_blocks_encoder"],
n_blocks_decoder=config["params"]["n_blocks_decoder"],
mlp_multiplier=config["params"]["mlp_multiplier"],
n_heads=config["params"]["n_heads"],
dropout=config["params"]["dropout"],
drop_path=config["params"]["drop_path"],
parameter_dropout=config["params"]["parameter_dropout"],
residual=residual,
masking_mode=masking_mode,
decoder_shifting=decoder_shifting,
positional_encoding=positional_encoding,
checkpoint_encoder=[],
checkpoint_decoder=[],
)
st.success("Model initialized successfully.")
except Exception as e:
st.error("Error initializing model:")
st.error(traceback.format_exc())
st.stop()
# Load model weights
try:
with st.spinner("Loading model weights..."):
state_dict = torch.load(weights_path, map_location=device)
if "model_state" in state_dict:
state_dict = state_dict["model_state"]
model.load_state_dict(state_dict, strict=True)
model.to(device)
st.success("Model weights loaded successfully.")
except Exception as e:
st.error("Error loading model weights:")
st.error(traceback.format_exc())
st.stop()
# Prepare data batch
try:
with st.spinner("Preparing data batch..."):
data = next(iter(dataset))
batch = preproc([data], padding)
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch[k] = v.to(device)
st.success("Data batch prepared successfully.")
except Exception as e:
st.error("Error preparing data batch:")
st.error(traceback.format_exc())
st.stop()
# Run inference
try:
with st.spinner("Running model inference..."):
rng_state_1 = torch.get_rng_state()
with torch.no_grad():
model.eval()
out = model(batch)
st.success("Model inference completed successfully.")
except Exception as e:
st.error("Error during model inference:")
st.error(traceback.format_exc())
st.stop()
# Display output
st.header("Inference Results")
st.write(out) # Adjust based on the structure of 'out'
# Optionally, provide download links or visualizations
# For example, if 'out' contains tensors or dataframes:
# st.write("Output Tensor:", out["some_key"].cpu().numpy())
else:
st.info("Please upload the necessary files and click 'Run Inference' to start.")