Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import AutoModelForImageSegmentation, AutoProcessor
|
| 3 |
+
from sunpy.net import Fido, attrs as a
|
| 4 |
+
import astropy.units as u
|
| 5 |
+
import numpy as np
|
| 6 |
+
from fastapi import FastAPI
|
| 7 |
+
import gradio as gr
|
| 8 |
+
from datetime import datetime, timedelta
|
| 9 |
+
from sklearn.linear_model import LogisticRegression
|
| 10 |
+
import xarray
|
| 11 |
+
import matplotlib.pyplot as plt
|
| 12 |
+
import os
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
|
| 15 |
+
# FastAPI app
|
| 16 |
+
app = FastAPI()
|
| 17 |
+
|
| 18 |
+
# Surya model setup
|
| 19 |
+
MODEL_ID = "nasa-ibm-ai4science/ar_segmentation_surya"
|
| 20 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 21 |
+
model = AutoModelForImageSegmentation.from_pretrained(MODEL_ID).to(device)
|
| 22 |
+
processor = AutoProcessor.from_pretrained(MODEL_ID)
|
| 23 |
+
|
| 24 |
+
# Cache directory for SDO data
|
| 25 |
+
CACHE_DIR = Path("/persistent-storage/sdo_data")
|
| 26 |
+
CACHE_DIR.mkdir(exist_ok=True)
|
| 27 |
+
|
| 28 |
+
# Simplified NRLMSISE-00 model for neutral density (placeholder for WAM-IPE)
|
| 29 |
+
def nrlmsise_density(altitude_km, flare_factor=1.0):
|
| 30 |
+
"""Estimate neutral density using a simplified NRLMSISE-00 model."""
|
| 31 |
+
base_density = 1e-12 # kg/m^3 at 550 km (approximate)
|
| 32 |
+
density = base_density * flare_factor # Scale by flare/CME impact
|
| 33 |
+
return density
|
| 34 |
+
|
| 35 |
+
# Simplified flare/CME probability model
|
| 36 |
+
def train_flare_model():
|
| 37 |
+
"""Train a simple logistic regression model for flare probability."""
|
| 38 |
+
# Dummy data: active region size (MmΒ²) vs. flare probability
|
| 39 |
+
X = np.array([[100], [500], [1000], [2000]]) # Size in MmΒ²
|
| 40 |
+
y = np.array([0, 0.5, 0.8, 0.95]) # Flare probability
|
| 41 |
+
model = LogisticRegression().fit(X, y)
|
| 42 |
+
return model
|
| 43 |
+
|
| 44 |
+
flare_model = train_flare_model()
|
| 45 |
+
|
| 46 |
+
# Drag calculation
|
| 47 |
+
def calculate_drag(density, altitude_km=550, velocity=7.5e3, cd=2.2, area=10):
|
| 48 |
+
"""Calculate drag force and orbital decay for a Starlink satellite."""
|
| 49 |
+
drag_force = 0.5 * density * velocity**2 * cd * area # N
|
| 50 |
+
mass = 260 # kg (Starlink satellite mass, approximate)
|
| 51 |
+
acceleration = drag_force / mass # m/s^2
|
| 52 |
+
decay_rate = acceleration * 86400 / 1000 # km/day
|
| 53 |
+
return drag_force, decay_rate
|
| 54 |
+
|
| 55 |
+
# Fetch and preprocess SDO data
|
| 56 |
+
def fetch_sdo_data(start_time, end_time):
|
| 57 |
+
"""Fetch SDO AIA 171Γ
data for the given time range."""
|
| 58 |
+
try:
|
| 59 |
+
query = Fido.search(
|
| 60 |
+
a.Time(start_time, end_time),
|
| 61 |
+
a.Instrument("AIA"),
|
| 62 |
+
a.Wavelength(171 * u.angstrom)
|
| 63 |
+
)
|
| 64 |
+
files = Fido.fetch(query, path=str(CACHE_DIR / "aia_{file}"), progress=True)
|
| 65 |
+
return files
|
| 66 |
+
except Exception as e:
|
| 67 |
+
print(f"Error fetching SDO data: {e}")
|
| 68 |
+
return None
|
| 69 |
+
|
| 70 |
+
def preprocess_sdo_data(files):
|
| 71 |
+
"""Preprocess SDO data for Surya model."""
|
| 72 |
+
from sunpy.map import Map
|
| 73 |
+
if not files:
|
| 74 |
+
return None
|
| 75 |
+
sdo_map = Map(files[0])
|
| 76 |
+
data = sdo_map.data # 4096x4096 image
|
| 77 |
+
data = data / np.max(data) # Normalize
|
| 78 |
+
data = np.expand_dims(data, axis=(0, 1)) # Shape: (1, 1, H, W)
|
| 79 |
+
return data
|
| 80 |
+
|
| 81 |
+
# Run Surya segmentation
|
| 82 |
+
def run_surya_segmentation(sdo_data):
|
| 83 |
+
"""Run Surya AR segmentation on SDO data."""
|
| 84 |
+
if sdo_data is None:
|
| 85 |
+
return None
|
| 86 |
+
inputs = processor(images=sdo_data, return_tensors="pt").to(device)
|
| 87 |
+
with torch.no_grad():
|
| 88 |
+
outputs = model(**inputs)
|
| 89 |
+
masks = torch.sigmoid(outputs.logits).cpu().numpy() # Segmentation masks
|
| 90 |
+
return masks
|
| 91 |
+
|
| 92 |
+
# Analyze active regions
|
| 93 |
+
def analyze_active_regions(masks):
|
| 94 |
+
"""Extract active region properties from segmentation masks."""
|
| 95 |
+
if masks is None:
|
| 96 |
+
return []
|
| 97 |
+
mask = masks[0, 0] # First image, first channel
|
| 98 |
+
ar_size = np.sum(mask > 0.5) * 0.1 # Approximate size in MmΒ² (pixel-to-Mm conversion)
|
| 99 |
+
return [{"id": "AR1", "size_mm2": ar_size}]
|
| 100 |
+
|
| 101 |
+
# Predict flare/CME and density impact
|
| 102 |
+
def predict_events(active_regions):
|
| 103 |
+
"""Predict flare/CME probabilities and density impact."""
|
| 104 |
+
if not active_regions:
|
| 105 |
+
return {"flare_prob": 0.0, "cme_prob": 0.0, "flare_factor": 1.0}
|
| 106 |
+
ar_size = active_regions[0]["size_mm2"]
|
| 107 |
+
flare_prob = flare_model.predict_proba([[ar_size]])[0, 1]
|
| 108 |
+
cme_prob = flare_prob * 0.5 # Simplified: CMEs less likely than flares
|
| 109 |
+
flare_factor = 1.0 + flare_prob * 2.0 # Scale density by flare intensity
|
| 110 |
+
return {"flare_prob": flare_prob, "cme_prob": cme_prob, "flare_factor": flare_factor}
|
| 111 |
+
|
| 112 |
+
# Main prediction function
|
| 113 |
+
def predict_drag(start_time_str):
|
| 114 |
+
"""Predict drag based on SDO data and active region analysis."""
|
| 115 |
+
try:
|
| 116 |
+
start_time = datetime.fromisoformat(start_time_str.replace("Z", "+00:00"))
|
| 117 |
+
end_time = start_time + timedelta(minutes=15)
|
| 118 |
+
sdo_files = fetch_sdo_data(start_time.isoformat(), end_time.isoformat())
|
| 119 |
+
sdo_data = preprocess_sdo_data(sdo_files)
|
| 120 |
+
masks = run_surya_segmentation(sdo_data)
|
| 121 |
+
active_regions = analyze_active_regions(masks)
|
| 122 |
+
event_probs = predict_events(active_regions)
|
| 123 |
+
|
| 124 |
+
# Estimate density and drag
|
| 125 |
+
density = nrlmsise_density(altitude_km=550, flare_factor=event_probs["flare_factor"])
|
| 126 |
+
drag_force, decay_rate = calculate_drag(density)
|
| 127 |
+
|
| 128 |
+
# Generate recommendations
|
| 129 |
+
recommendations = []
|
| 130 |
+
if event_probs["flare_prob"] > 0.7:
|
| 131 |
+
recommendations.append("Consider raising orbit by 5-10 km.")
|
| 132 |
+
if decay_rate > 0.5:
|
| 133 |
+
recommendations.append("Switch satellites to edge-on safe mode.")
|
| 134 |
+
|
| 135 |
+
# Save visualization
|
| 136 |
+
if masks is not None:
|
| 137 |
+
plt.imshow(masks[0, 0], cmap="hot")
|
| 138 |
+
plt.title("Active Region Segmentation")
|
| 139 |
+
plt.savefig(CACHE_DIR / "ar_mask.png")
|
| 140 |
+
plt.close()
|
| 141 |
+
|
| 142 |
+
return {
|
| 143 |
+
"active_regions": active_regions,
|
| 144 |
+
"flare_probability": event_probs["flare_prob"],
|
| 145 |
+
"cme_probability": event_probs["cme_prob"],
|
| 146 |
+
"neutral_density": density,
|
| 147 |
+
"drag_force": drag_force,
|
| 148 |
+
"orbital_decay": decay_rate,
|
| 149 |
+
"recommendations": recommendations,
|
| 150 |
+
"visualization": str(CACHE_DIR / "ar_mask.png")
|
| 151 |
+
}
|
| 152 |
+
except Exception as e:
|
| 153 |
+
return {"error": str(e)}
|
| 154 |
+
|
| 155 |
+
# FastAPI endpoint
|
| 156 |
+
@app.post("/predict_drag")
|
| 157 |
+
async def api_predict_drag(start_time: str):
|
| 158 |
+
"""API endpoint for drag prediction."""
|
| 159 |
+
result = predict_drag(start_time)
|
| 160 |
+
return result
|
| 161 |
+
|
| 162 |
+
# Gradio interface
|
| 163 |
+
def gradio_predict(start_time):
|
| 164 |
+
"""Gradio interface for drag prediction."""
|
| 165 |
+
result = predict_drag(start_time)
|
| 166 |
+
if "error" in result:
|
| 167 |
+
return f"Error: {result['error']}"
|
| 168 |
+
|
| 169 |
+
output = f"""
|
| 170 |
+
**Active Regions**: {result['active_regions']}
|
| 171 |
+
**Flare Probability**: {result['flare_probability']:.2f}
|
| 172 |
+
**CME Probability**: {result['cme_probability']:.2f}
|
| 173 |
+
**Neutral Density**: {result['neutral_density']:.2e} kg/mΒ³
|
| 174 |
+
**Drag Force**: {result['drag_force']:.2e} N
|
| 175 |
+
**Orbital Decay**: {result['orbital_decay']:.2f} km/day
|
| 176 |
+
**Recommendations**: {', '.join(result['recommendations'])}
|
| 177 |
+
"""
|
| 178 |
+
return output, result["visualization"]
|
| 179 |
+
|
| 180 |
+
# Gradio app
|
| 181 |
+
iface = gr.Interface(
|
| 182 |
+
fn=gradio_predict,
|
| 183 |
+
inputs=gr.Textbox(label="Start Time (ISO format, e.g., 2025-08-21T12:00:00Z)"),
|
| 184 |
+
outputs=[gr.Textbox(label="Prediction Results"), gr.Image(label="Active Region Visualization")],
|
| 185 |
+
title="Starlink Drag Prediction App",
|
| 186 |
+
description="Predict satellite drag based on solar active regions using NASA-IBM Surya model."
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
# Run Gradio app (Hugging Face Spaces handles this automatically)
|
| 190 |
+
if __name__ == "__main__":
|
| 191 |
+
iface.launch()
|