Climate-ML-Foundation-Models / inference_utils.py
qq1990's picture
init
100edb4
raw
history blame contribute delete
595 Bytes
import torch
import streamlit as st
from aurora import rollout, Aurora
def run_inference(selected_model, model, batch, device):
if selected_model == "Prithvi":
model.eval()
with torch.no_grad():
out = model(batch)
return out
elif selected_model == "Aurora":
model.eval()
with torch.inference_mode():
# Example: Predict 2 steps ahead
out = [pred.to("cpu") for pred in rollout(model, batch, steps=2)]
return out
else:
st.error("Inference not implemented for this model.")
return None