Spaces:
Sleeping
Sleeping
import torch | |
import streamlit as st | |
from aurora import rollout, Aurora | |
def run_inference(selected_model, model, batch, device): | |
if selected_model == "Prithvi": | |
model.eval() | |
with torch.no_grad(): | |
out = model(batch) | |
return out | |
elif selected_model == "Aurora": | |
model.eval() | |
with torch.inference_mode(): | |
# Example: Predict 2 steps ahead | |
out = [pred.to("cpu") for pred in rollout(model, batch, steps=2)] | |
return out | |
else: | |
st.error("Inference not implemented for this model.") | |
return None | |