FassikaF commited on
Commit
b7055ff
Β·
verified Β·
1 Parent(s): 91e95c7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +191 -0
app.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForImageSegmentation, AutoProcessor
3
+ from sunpy.net import Fido, attrs as a
4
+ import astropy.units as u
5
+ import numpy as np
6
+ from fastapi import FastAPI
7
+ import gradio as gr
8
+ from datetime import datetime, timedelta
9
+ from sklearn.linear_model import LogisticRegression
10
+ import xarray
11
+ import matplotlib.pyplot as plt
12
+ import os
13
+ from pathlib import Path
14
+
15
+ # FastAPI app
16
+ app = FastAPI()
17
+
18
+ # Surya model setup
19
+ MODEL_ID = "nasa-ibm-ai4science/ar_segmentation_surya"
20
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+ model = AutoModelForImageSegmentation.from_pretrained(MODEL_ID).to(device)
22
+ processor = AutoProcessor.from_pretrained(MODEL_ID)
23
+
24
+ # Cache directory for SDO data
25
+ CACHE_DIR = Path("/persistent-storage/sdo_data")
26
+ CACHE_DIR.mkdir(exist_ok=True)
27
+
28
+ # Simplified NRLMSISE-00 model for neutral density (placeholder for WAM-IPE)
29
+ def nrlmsise_density(altitude_km, flare_factor=1.0):
30
+ """Estimate neutral density using a simplified NRLMSISE-00 model."""
31
+ base_density = 1e-12 # kg/m^3 at 550 km (approximate)
32
+ density = base_density * flare_factor # Scale by flare/CME impact
33
+ return density
34
+
35
+ # Simplified flare/CME probability model
36
+ def train_flare_model():
37
+ """Train a simple logistic regression model for flare probability."""
38
+ # Dummy data: active region size (MmΒ²) vs. flare probability
39
+ X = np.array([[100], [500], [1000], [2000]]) # Size in MmΒ²
40
+ y = np.array([0, 0.5, 0.8, 0.95]) # Flare probability
41
+ model = LogisticRegression().fit(X, y)
42
+ return model
43
+
44
+ flare_model = train_flare_model()
45
+
46
+ # Drag calculation
47
+ def calculate_drag(density, altitude_km=550, velocity=7.5e3, cd=2.2, area=10):
48
+ """Calculate drag force and orbital decay for a Starlink satellite."""
49
+ drag_force = 0.5 * density * velocity**2 * cd * area # N
50
+ mass = 260 # kg (Starlink satellite mass, approximate)
51
+ acceleration = drag_force / mass # m/s^2
52
+ decay_rate = acceleration * 86400 / 1000 # km/day
53
+ return drag_force, decay_rate
54
+
55
+ # Fetch and preprocess SDO data
56
+ def fetch_sdo_data(start_time, end_time):
57
+ """Fetch SDO AIA 171Γ… data for the given time range."""
58
+ try:
59
+ query = Fido.search(
60
+ a.Time(start_time, end_time),
61
+ a.Instrument("AIA"),
62
+ a.Wavelength(171 * u.angstrom)
63
+ )
64
+ files = Fido.fetch(query, path=str(CACHE_DIR / "aia_{file}"), progress=True)
65
+ return files
66
+ except Exception as e:
67
+ print(f"Error fetching SDO data: {e}")
68
+ return None
69
+
70
+ def preprocess_sdo_data(files):
71
+ """Preprocess SDO data for Surya model."""
72
+ from sunpy.map import Map
73
+ if not files:
74
+ return None
75
+ sdo_map = Map(files[0])
76
+ data = sdo_map.data # 4096x4096 image
77
+ data = data / np.max(data) # Normalize
78
+ data = np.expand_dims(data, axis=(0, 1)) # Shape: (1, 1, H, W)
79
+ return data
80
+
81
+ # Run Surya segmentation
82
+ def run_surya_segmentation(sdo_data):
83
+ """Run Surya AR segmentation on SDO data."""
84
+ if sdo_data is None:
85
+ return None
86
+ inputs = processor(images=sdo_data, return_tensors="pt").to(device)
87
+ with torch.no_grad():
88
+ outputs = model(**inputs)
89
+ masks = torch.sigmoid(outputs.logits).cpu().numpy() # Segmentation masks
90
+ return masks
91
+
92
+ # Analyze active regions
93
+ def analyze_active_regions(masks):
94
+ """Extract active region properties from segmentation masks."""
95
+ if masks is None:
96
+ return []
97
+ mask = masks[0, 0] # First image, first channel
98
+ ar_size = np.sum(mask > 0.5) * 0.1 # Approximate size in MmΒ² (pixel-to-Mm conversion)
99
+ return [{"id": "AR1", "size_mm2": ar_size}]
100
+
101
+ # Predict flare/CME and density impact
102
+ def predict_events(active_regions):
103
+ """Predict flare/CME probabilities and density impact."""
104
+ if not active_regions:
105
+ return {"flare_prob": 0.0, "cme_prob": 0.0, "flare_factor": 1.0}
106
+ ar_size = active_regions[0]["size_mm2"]
107
+ flare_prob = flare_model.predict_proba([[ar_size]])[0, 1]
108
+ cme_prob = flare_prob * 0.5 # Simplified: CMEs less likely than flares
109
+ flare_factor = 1.0 + flare_prob * 2.0 # Scale density by flare intensity
110
+ return {"flare_prob": flare_prob, "cme_prob": cme_prob, "flare_factor": flare_factor}
111
+
112
+ # Main prediction function
113
+ def predict_drag(start_time_str):
114
+ """Predict drag based on SDO data and active region analysis."""
115
+ try:
116
+ start_time = datetime.fromisoformat(start_time_str.replace("Z", "+00:00"))
117
+ end_time = start_time + timedelta(minutes=15)
118
+ sdo_files = fetch_sdo_data(start_time.isoformat(), end_time.isoformat())
119
+ sdo_data = preprocess_sdo_data(sdo_files)
120
+ masks = run_surya_segmentation(sdo_data)
121
+ active_regions = analyze_active_regions(masks)
122
+ event_probs = predict_events(active_regions)
123
+
124
+ # Estimate density and drag
125
+ density = nrlmsise_density(altitude_km=550, flare_factor=event_probs["flare_factor"])
126
+ drag_force, decay_rate = calculate_drag(density)
127
+
128
+ # Generate recommendations
129
+ recommendations = []
130
+ if event_probs["flare_prob"] > 0.7:
131
+ recommendations.append("Consider raising orbit by 5-10 km.")
132
+ if decay_rate > 0.5:
133
+ recommendations.append("Switch satellites to edge-on safe mode.")
134
+
135
+ # Save visualization
136
+ if masks is not None:
137
+ plt.imshow(masks[0, 0], cmap="hot")
138
+ plt.title("Active Region Segmentation")
139
+ plt.savefig(CACHE_DIR / "ar_mask.png")
140
+ plt.close()
141
+
142
+ return {
143
+ "active_regions": active_regions,
144
+ "flare_probability": event_probs["flare_prob"],
145
+ "cme_probability": event_probs["cme_prob"],
146
+ "neutral_density": density,
147
+ "drag_force": drag_force,
148
+ "orbital_decay": decay_rate,
149
+ "recommendations": recommendations,
150
+ "visualization": str(CACHE_DIR / "ar_mask.png")
151
+ }
152
+ except Exception as e:
153
+ return {"error": str(e)}
154
+
155
+ # FastAPI endpoint
156
+ @app.post("/predict_drag")
157
+ async def api_predict_drag(start_time: str):
158
+ """API endpoint for drag prediction."""
159
+ result = predict_drag(start_time)
160
+ return result
161
+
162
+ # Gradio interface
163
+ def gradio_predict(start_time):
164
+ """Gradio interface for drag prediction."""
165
+ result = predict_drag(start_time)
166
+ if "error" in result:
167
+ return f"Error: {result['error']}"
168
+
169
+ output = f"""
170
+ **Active Regions**: {result['active_regions']}
171
+ **Flare Probability**: {result['flare_probability']:.2f}
172
+ **CME Probability**: {result['cme_probability']:.2f}
173
+ **Neutral Density**: {result['neutral_density']:.2e} kg/mΒ³
174
+ **Drag Force**: {result['drag_force']:.2e} N
175
+ **Orbital Decay**: {result['orbital_decay']:.2f} km/day
176
+ **Recommendations**: {', '.join(result['recommendations'])}
177
+ """
178
+ return output, result["visualization"]
179
+
180
+ # Gradio app
181
+ iface = gr.Interface(
182
+ fn=gradio_predict,
183
+ inputs=gr.Textbox(label="Start Time (ISO format, e.g., 2025-08-21T12:00:00Z)"),
184
+ outputs=[gr.Textbox(label="Prediction Results"), gr.Image(label="Active Region Visualization")],
185
+ title="Starlink Drag Prediction App",
186
+ description="Predict satellite drag based on solar active regions using NASA-IBM Surya model."
187
+ )
188
+
189
+ # Run Gradio app (Hugging Face Spaces handles this automatically)
190
+ if __name__ == "__main__":
191
+ iface.launch()