Spaces:
Runtime error
Runtime error
section plot updated
Browse files- Gradio_app.ipynb +21 -0
- app.py +47 -26
Gradio_app.ipynb
CHANGED
@@ -61,6 +61,27 @@
|
|
61 |
"execution_count": 4,
|
62 |
"metadata": {},
|
63 |
"output_type": "execute_result"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
}
|
65 |
],
|
66 |
"source": [
|
|
|
61 |
"execution_count": 4,
|
62 |
"metadata": {},
|
63 |
"output_type": "execute_result"
|
64 |
+
},
|
65 |
+
{
|
66 |
+
"name": "stdout",
|
67 |
+
"output_type": "stream",
|
68 |
+
"text": [
|
69 |
+
"Error in callback <function _draw_all_if_interactive at 0x1774d0ea0> (for post_execute):\n"
|
70 |
+
]
|
71 |
+
},
|
72 |
+
{
|
73 |
+
"ename": "KeyboardInterrupt",
|
74 |
+
"evalue": "",
|
75 |
+
"output_type": "error",
|
76 |
+
"traceback": [
|
77 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
78 |
+
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
79 |
+
"File \u001b[0;32m~/miniconda3/envs/phasehunter/lib/python3.11/site-packages/matplotlib/pyplot.py:120\u001b[0m, in \u001b[0;36m_draw_all_if_interactive\u001b[0;34m()\u001b[0m\n\u001b[1;32m 118\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m_draw_all_if_interactive\u001b[39m():\n\u001b[1;32m 119\u001b[0m \u001b[39mif\u001b[39;00m matplotlib\u001b[39m.\u001b[39mis_interactive():\n\u001b[0;32m--> 120\u001b[0m draw_all()\n",
|
80 |
+
"File \u001b[0;32m~/miniconda3/envs/phasehunter/lib/python3.11/site-packages/matplotlib/_pylab_helpers.py:132\u001b[0m, in \u001b[0;36mGcf.draw_all\u001b[0;34m(cls, force)\u001b[0m\n\u001b[1;32m 130\u001b[0m \u001b[39mfor\u001b[39;00m manager \u001b[39min\u001b[39;00m \u001b[39mcls\u001b[39m\u001b[39m.\u001b[39mget_all_fig_managers():\n\u001b[1;32m 131\u001b[0m \u001b[39mif\u001b[39;00m force \u001b[39mor\u001b[39;00m manager\u001b[39m.\u001b[39mcanvas\u001b[39m.\u001b[39mfigure\u001b[39m.\u001b[39mstale:\n\u001b[0;32m--> 132\u001b[0m manager\u001b[39m.\u001b[39;49mcanvas\u001b[39m.\u001b[39;49mdraw_idle()\n",
|
81 |
+
"File \u001b[0;32m~/miniconda3/envs/phasehunter/lib/python3.11/site-packages/matplotlib/backend_bases.py:2082\u001b[0m, in \u001b[0;36mFigureCanvasBase.draw_idle\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 2080\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_is_idle_drawing:\n\u001b[1;32m 2081\u001b[0m \u001b[39mwith\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_idle_draw_cntx():\n\u001b[0;32m-> 2082\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mdraw(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n",
|
82 |
+
"File \u001b[0;32m~/miniconda3/envs/phasehunter/lib/python3.11/site-packages/matplotlib/backends/backend_agg.py:397\u001b[0m, in \u001b[0;36mFigureCanvasAgg.draw\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 395\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mrenderer\u001b[39m.\u001b[39mclear()\n\u001b[1;32m 396\u001b[0m \u001b[39m# Acquire a lock on the shared font cache.\u001b[39;00m\n\u001b[0;32m--> 397\u001b[0m \u001b[39mwith\u001b[39;49;00m RendererAgg\u001b[39m.\u001b[39;49mlock, \\\n\u001b[1;32m 398\u001b[0m (\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mtoolbar\u001b[39m.\u001b[39;49m_wait_cursor_for_draw_cm() \u001b[39mif\u001b[39;49;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mtoolbar\n\u001b[1;32m 399\u001b[0m \u001b[39melse\u001b[39;49;00m nullcontext()):\n\u001b[1;32m 400\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mfigure\u001b[39m.\u001b[39;49mdraw(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mrenderer)\n\u001b[1;32m 401\u001b[0m \u001b[39m# A GUI class may be need to update a window using this draw, so\u001b[39;49;00m\n\u001b[1;32m 402\u001b[0m \u001b[39m# don't forget to call the superclass.\u001b[39;49;00m\n",
|
83 |
+
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
|
84 |
+
]
|
85 |
}
|
86 |
],
|
87 |
"source": [
|
app.py
CHANGED
@@ -21,6 +21,8 @@ from obspy.clients.fdsn.header import URL_MAPPINGS
|
|
21 |
import matplotlib.pyplot as plt
|
22 |
import matplotlib.dates as mdates
|
23 |
|
|
|
|
|
24 |
def make_prediction(waveform):
|
25 |
waveform = np.load(waveform)
|
26 |
processed_input = prepare_waveform(waveform)
|
@@ -80,6 +82,15 @@ def mark_phases(waveform):
|
|
80 |
plt.close(fig)
|
81 |
return image
|
82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
def predict_on_section(client_name, timestamp, eq_lat, eq_lon, radius_km, source_depth_km, velocity_model):
|
84 |
distances, t0s, st_lats, st_lons, waveforms = [], [], [], [], []
|
85 |
|
@@ -101,6 +112,8 @@ def predict_on_section(client_name, timestamp, eq_lat, eq_lon, radius_km, source
|
|
101 |
level='station')
|
102 |
|
103 |
waveforms = []
|
|
|
|
|
104 |
for network in inv:
|
105 |
for station in network:
|
106 |
try:
|
@@ -115,8 +128,12 @@ def predict_on_section(client_name, timestamp, eq_lat, eq_lon, radius_km, source
|
|
115 |
starttime = obspy.UTCDateTime(timestamp) + arrivals[0].time - 15
|
116 |
endtime = starttime + 60
|
117 |
|
118 |
-
|
119 |
-
|
|
|
|
|
|
|
|
|
120 |
|
121 |
waveform = waveform.select(channel="H[BH][ZNE]")
|
122 |
waveform = waveform.merge(fill_value=0)
|
@@ -148,29 +165,39 @@ def predict_on_section(client_name, timestamp, eq_lat, eq_lon, radius_km, source
|
|
148 |
p_phases = output[:, 0]
|
149 |
s_phases = output[:, 1]
|
150 |
|
151 |
-
|
|
|
|
|
|
|
|
|
152 |
for i in range(len(waveforms)):
|
153 |
current_P = p_phases[i::len(waveforms)]
|
154 |
current_S = s_phases[i::len(waveforms)]
|
|
|
155 |
x = [t0s[i] + pd.Timedelta(seconds=k/100) for k in np.linspace(0,6000,6000)]
|
156 |
x = mdates.date2num(x)
|
157 |
-
ax.plot(x, waveforms[i][0, 0]+distances[i]*111.2, color='black', alpha=0.5)
|
158 |
-
ax.scatter(x[int(current_P.mean()*waveforms[i][0].shape[-1])], waveforms[i][0, 0].mean()+distances[i]*111.2, color='r')
|
159 |
-
ax.scatter(x[int(current_S.mean()*waveforms[i][0].shape[-1])], waveforms[i][0, 0].mean()+distances[i]*111.2, color='b')
|
160 |
-
ax.set_ylabel('Z')
|
161 |
|
162 |
-
|
163 |
-
|
|
|
164 |
|
165 |
-
|
166 |
-
# a.axvline(current_P.mean()*waveforms[i][0].shape[-1], color='r', linestyle='--', label='P')
|
167 |
-
# a.axvline(current_S.mean()*waveforms[i][0].shape[-1], color='b', linestyle='--', label='S')
|
168 |
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
|
173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
|
175 |
fig.canvas.draw();
|
176 |
image = np.array(fig.canvas.renderer.buffer_rgba())
|
@@ -184,12 +211,6 @@ model = Onset_picker.load_from_checkpoint("./weights.ckpt",
|
|
184 |
learning_rate=3e-4)
|
185 |
model.eval()
|
186 |
|
187 |
-
|
188 |
-
|
189 |
-
# # Create the Gradio interface
|
190 |
-
# gr.Interface(mark_phases, inputs, outputs, title='PhaseHunter').launch()
|
191 |
-
|
192 |
-
|
193 |
with gr.Blocks() as demo:
|
194 |
gr.Markdown("# PhaseHunter")
|
195 |
gr.Markdown("""This app allows one to detect P and S seismic phases along with uncertainty of the detection.
|
@@ -250,7 +271,9 @@ with gr.Blocks() as demo:
|
|
250 |
radius_inputs = gr.Slider(minimum=1,
|
251 |
maximum=150,
|
252 |
value=50, label="Radius (km)",
|
253 |
-
|
|
|
|
|
254 |
interactive=True)
|
255 |
|
256 |
velocity_inputs = gr.Dropdown(
|
@@ -263,7 +286,7 @@ with gr.Blocks() as demo:
|
|
263 |
|
264 |
|
265 |
button = gr.Button("Predict phases")
|
266 |
-
outputs_section = gr.
|
267 |
|
268 |
button.click(predict_on_section,
|
269 |
inputs=[client_inputs, timestamp_inputs,
|
@@ -277,6 +300,4 @@ with gr.Blocks() as demo:
|
|
277 |
Your waveform should be sampled at 100 sps and have 3 (Z, N, E) or 1 (Z) channels.
|
278 |
""")
|
279 |
|
280 |
-
|
281 |
-
|
282 |
demo.launch()
|
|
|
21 |
import matplotlib.pyplot as plt
|
22 |
import matplotlib.dates as mdates
|
23 |
|
24 |
+
from glob import glob
|
25 |
+
|
26 |
def make_prediction(waveform):
|
27 |
waveform = np.load(waveform)
|
28 |
processed_input = prepare_waveform(waveform)
|
|
|
82 |
plt.close(fig)
|
83 |
return image
|
84 |
|
85 |
+
def variance_coefficient(residuals):
|
86 |
+
# calculate the variance of the residuals
|
87 |
+
var = residuals.var()
|
88 |
+
|
89 |
+
# scale the variance to a coefficient between 0 and 1
|
90 |
+
coeff = 1 - (var / (residuals.max() - residuals.min()))
|
91 |
+
|
92 |
+
return coeff
|
93 |
+
|
94 |
def predict_on_section(client_name, timestamp, eq_lat, eq_lon, radius_km, source_depth_km, velocity_model):
|
95 |
distances, t0s, st_lats, st_lons, waveforms = [], [], [], [], []
|
96 |
|
|
|
112 |
level='station')
|
113 |
|
114 |
waveforms = []
|
115 |
+
cached_waveforms = glob("data/cached/*.mseed")
|
116 |
+
|
117 |
for network in inv:
|
118 |
for station in network:
|
119 |
try:
|
|
|
128 |
starttime = obspy.UTCDateTime(timestamp) + arrivals[0].time - 15
|
129 |
endtime = starttime + 60
|
130 |
|
131 |
+
if f"data/cached/{network.code}_{station.code}_{starttime}.mseed" not in cached_waveforms:
|
132 |
+
waveform = client.get_waveforms(network=network.code, station=station.code, location="*", channel="*",
|
133 |
+
starttime=starttime, endtime=endtime)
|
134 |
+
waveform.write(f"data/cached/{network.code}_{station.code}_{starttime}.mseed", format="MSEED")
|
135 |
+
else:
|
136 |
+
waveform = obspy.read(f"data/cached/{network.code}_{station.code}_{starttime}.mseed")
|
137 |
|
138 |
waveform = waveform.select(channel="H[BH][ZNE]")
|
139 |
waveform = waveform.merge(fill_value=0)
|
|
|
165 |
p_phases = output[:, 0]
|
166 |
s_phases = output[:, 1]
|
167 |
|
168 |
+
# Max confidence - min variance
|
169 |
+
p_max_confidence = np.min([p_phases[i::len(waveforms)].std() for i in range(len(waveforms))])
|
170 |
+
s_max_confidence = np.min([s_phases[i::len(waveforms)].std() for i in range(len(waveforms))])
|
171 |
+
|
172 |
+
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 3), sharex=True)
|
173 |
for i in range(len(waveforms)):
|
174 |
current_P = p_phases[i::len(waveforms)]
|
175 |
current_S = s_phases[i::len(waveforms)]
|
176 |
+
|
177 |
x = [t0s[i] + pd.Timedelta(seconds=k/100) for k in np.linspace(0,6000,6000)]
|
178 |
x = mdates.date2num(x)
|
|
|
|
|
|
|
|
|
179 |
|
180 |
+
# Normalize confidence for the plot
|
181 |
+
p_conf = 1/(current_P.std()/p_max_confidence).item()
|
182 |
+
s_conf = 1/(current_S.std()/s_max_confidence).item()
|
183 |
|
184 |
+
ax[0].plot(x, waveforms[i][0, 0]*10+distances[i]*111.2, color='black', alpha=0.5, lw=1)
|
|
|
|
|
185 |
|
186 |
+
ax[0].scatter(x[int(current_P.mean()*waveforms[i][0].shape[-1])], waveforms[i][0, 0].mean()+distances[i]*111.2, color='r', alpha=p_conf, marker='|')
|
187 |
+
ax[0].scatter(x[int(current_S.mean()*waveforms[i][0].shape[-1])], waveforms[i][0, 0].mean()+distances[i]*111.2, color='b', alpha=s_conf, marker='|')
|
188 |
+
ax[0].set_ylabel('Z')
|
189 |
|
190 |
+
ax[0].xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S'))
|
191 |
+
ax[0].xaxis.set_major_locator(mdates.SecondLocator(interval=5))
|
192 |
+
|
193 |
+
ax[0].scatter(None, None, color='r', marker='|', label='P')
|
194 |
+
ax[0].scatter(None, None, color='b', marker='|', label='S')
|
195 |
+
ax[0].legend()
|
196 |
+
|
197 |
+
ax[1].scatter(st_lats, st_lons, color='b', marker='d', label='Stations')
|
198 |
+
ax[1].scatter(eq_lat, eq_lon, color='r', marker='*', label='Earthquake')
|
199 |
+
ax[1].legend()
|
200 |
+
plt.subplots_adjust(hspace=0., wspace=0.)
|
201 |
|
202 |
fig.canvas.draw();
|
203 |
image = np.array(fig.canvas.renderer.buffer_rgba())
|
|
|
211 |
learning_rate=3e-4)
|
212 |
model.eval()
|
213 |
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
with gr.Blocks() as demo:
|
215 |
gr.Markdown("# PhaseHunter")
|
216 |
gr.Markdown("""This app allows one to detect P and S seismic phases along with uncertainty of the detection.
|
|
|
271 |
radius_inputs = gr.Slider(minimum=1,
|
272 |
maximum=150,
|
273 |
value=50, label="Radius (km)",
|
274 |
+
step=10,
|
275 |
+
info="""Select the radius around the earthquake to download data from.\n
|
276 |
+
Note that the larger the radius, the longer the app will take to run.""",
|
277 |
interactive=True)
|
278 |
|
279 |
velocity_inputs = gr.Dropdown(
|
|
|
286 |
|
287 |
|
288 |
button = gr.Button("Predict phases")
|
289 |
+
outputs_section = gr.Image(label='Waveforms with Phases Marked', type='numpy', interactive=False)
|
290 |
|
291 |
button.click(predict_on_section,
|
292 |
inputs=[client_inputs, timestamp_inputs,
|
|
|
300 |
Your waveform should be sampled at 100 sps and have 3 (Z, N, E) or 1 (Z) channels.
|
301 |
""")
|
302 |
|
|
|
|
|
303 |
demo.launch()
|