File size: 5,564 Bytes
100edb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import streamlit as st
import tempfile
from pathlib import Path
import torch
import traceback
import yaml
# from Prithvi import PrithviWxC, Merra2Dataset, input_scalers, output_scalers, static_input_scalers, preproc

def prithvi_config_ui():
    st.subheader("Prithvi Model Configuration")
    param1 = st.number_input("Prithvi Parameter 1", value=10, step=1)
    param2 = st.text_input("Prithvi Parameter 2", value="default_prithvi")

    config = {"param1": param1, "param2": param2}

    st.markdown("### Upload Data Files for Prithvi Model")
    uploaded_surface_files = st.file_uploader(
        "Upload Surface Data Files",
        type=["nc", "netcdf"],
        accept_multiple_files=True,
        key="surface_uploader",
    )

    uploaded_vertical_files = st.file_uploader(
        "Upload Vertical Data Files",
        type=["nc", "netcdf"],
        accept_multiple_files=True,
        key="vertical_uploader",
    )

    st.markdown("### Upload Climatology Files (If Missing)")
    default_clim_dir = Path("Prithvi-WxC/examples/climatology")
    surf_in_scal_path = default_clim_dir / "musigma_surface.nc"
    vert_in_scal_path = default_clim_dir / "musigma_vertical.nc"
    surf_out_scal_path = default_clim_dir / "anomaly_variance_surface.nc"
    vert_out_scal_path = default_clim_dir / "anomaly_variance_vertical.nc"
    clim_files_exist = all([
        surf_in_scal_path.exists(),
        vert_in_scal_path.exists(),
        surf_out_scal_path.exists(),
        vert_out_scal_path.exists(),
    ])

    if not clim_files_exist:
        st.warning("Climatology files are missing.")
        uploaded_clim_surface = st.file_uploader(
            "Upload Climatology Surface File",
            type=["nc", "netcdf"],
            key="clim_surface_uploader",
        )
        uploaded_clim_vertical = st.file_uploader(
            "Upload Climatology Vertical File",
            type=["nc", "netcdf"],
            key="clim_vertical_uploader",
        )
        if uploaded_clim_surface and uploaded_clim_vertical:
            clim_temp_dir = tempfile.mkdtemp()
            clim_surf_path = Path(clim_temp_dir) / uploaded_clim_surface.name
            with open(clim_surf_path, "wb") as f:
                f.write(uploaded_clim_surface.getbuffer())
            clim_vert_path = Path(clim_temp_dir) / uploaded_clim_vertical.name
            with open(clim_vert_path, "wb") as f:
                f.write(uploaded_clim_vertical.getbuffer())
            st.success("Climatology files uploaded and saved.")
        else:
            st.warning("Please upload both climatology surface and vertical files.")
            clim_surf_path, clim_vert_path = None, None
    else:
        clim_surf_path = surf_in_scal_path
        clim_vert_path = vert_in_scal_path

    uploaded_config = st.file_uploader(
        "Upload config.yaml",
        type=["yaml", "yml"],
        key="config_uploader",
    )

    if uploaded_config:
        temp_config = tempfile.mktemp(suffix=".yaml")
        with open(temp_config, "wb") as f:
            f.write(uploaded_config.getbuffer())
        config_path = Path(temp_config)
        st.success("Config.yaml uploaded and saved.")
    else:
        config_path = Path("Prithvi-WxC/examples/config.yaml")
        if not config_path.exists():
            st.error("Default config.yaml not found. Please upload a config file.")
            st.stop()

    uploaded_weights = st.file_uploader(
        "Upload Model Weights (.pt)",
        type=["pt"],
        key="weights_uploader",
    )

    if uploaded_weights:
        temp_weights = tempfile.mktemp(suffix=".pt")
        with open(temp_weights, "wb") as f:
            f.write(uploaded_weights.getbuffer())
        weights_path = Path(temp_weights)
        st.success("Model weights uploaded and saved.")
    else:
        weights_path = Path("Prithvi-WxC/examples/weights/prithvi.wxc.2300m.v1.pt")
        if not weights_path.exists():
            st.error("Default model weights not found. Please upload model weights.")
            st.stop()

    return config, uploaded_surface_files, uploaded_vertical_files, clim_surf_path, clim_vert_path, config_path, weights_path


def initialize_prithvi_model(config, config_path, weights_path, device):
    # Load the configuration
    with open(config_path, "r") as f:
        cfg = yaml.safe_load(f)

    # Validate and load scalers, etc.
    # Insert your logic here (loading scalers, etc.)
    # Example (pseudo-code):
    # in_mu, in_sig = input_scalers(...)
    # output_sig = output_scalers(...)
    # static_mu, static_sig = static_input_scalers(...)

    # from Prithvi import PrithviWxC
    # model = PrithviWxC(**cfg["params"], ...)
    # state_dict = torch.load(weights_path, map_location=device)
    # model.load_state_dict(state_dict["model_state"] if "model_state" in state_dict else state_dict, strict=True)
    # model.to(device)

    # Placeholder returns until actual logic is implemented
    model = None
    in_mu, in_sig, output_sig, static_mu, static_sig = None, None, None, None, None
    return model, in_mu, in_sig, output_sig, static_mu, static_sig


def prepare_prithvi_batch(uploaded_surface_files, uploaded_vertical_files, clim_surf_path, clim_vert_path, device):
    # Prepare your dataset and batch for Prithvi inference
    # dataset = Merra2Dataset(...)
    # data = next(iter(dataset))
    # batch = preproc([data], padding={...})
    # for k,v in batch.items():
    #     if isinstance(v, torch.Tensor):
    #         batch[k] = v.to(device)

    # Placeholder until implemented
    return None