phase-hunter / app.py
crimeacs's picture
section plot now works
15dbd99
raw
history blame
11.2 kB
# Gradio app that takes seismic waveform as input and marks 2 phases on the waveform as output.
import gradio as gr
import numpy as np
import pandas as pd
from phasehunter.model import Onset_picker, Updated_onset_picker
from phasehunter.data_preparation import prepare_waveform
import torch
from scipy.stats import gaussian_kde
import obspy
from obspy.clients.fdsn import Client
from obspy.clients.fdsn.header import FDSNNoDataException, FDSNTimeoutException, FDSNInternalServerException
from obspy.geodetics.base import locations2degrees
from obspy.taup import TauPyModel
from obspy.taup.helper_classes import SlownessModelError
from obspy.clients.fdsn.header import URL_MAPPINGS
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
def make_prediction(waveform):
waveform = np.load(waveform)
processed_input = prepare_waveform(waveform)
# Make prediction
with torch.no_grad():
output = model(processed_input)
p_phase = output[:, 0]
s_phase = output[:, 1]
return processed_input, p_phase, s_phase
def mark_phases(waveform):
processed_input, p_phase, s_phase = make_prediction(waveform)
# Create a plot of the waveform with the phases marked
if sum(processed_input[0][2] == 0): #if input is 1C
fig, ax = plt.subplots(nrows=2, figsize=(10, 2), sharex=True)
ax[0].plot(processed_input[0][0])
ax[0].set_ylabel('Norm. Ampl.')
else: #if input is 3C
fig, ax = plt.subplots(nrows=4, figsize=(10, 6), sharex=True)
ax[0].plot(processed_input[0][0])
ax[1].plot(processed_input[0][1])
ax[2].plot(processed_input[0][2])
ax[0].set_ylabel('Z')
ax[1].set_ylabel('N')
ax[2].set_ylabel('E')
p_phase_plot = p_phase*processed_input.shape[-1]
p_kde = gaussian_kde(p_phase_plot)
p_dist_space = np.linspace( min(p_phase_plot)-10, max(p_phase_plot)+10, 500 )
ax[-1].plot( p_dist_space, p_kde(p_dist_space), color='r')
s_phase_plot = s_phase*processed_input.shape[-1]
s_kde = gaussian_kde(s_phase_plot)
s_dist_space = np.linspace( min(s_phase_plot)-10, max(s_phase_plot)+10, 500 )
ax[-1].plot( s_dist_space, s_kde(s_dist_space), color='b')
for a in ax:
a.axvline(p_phase.mean()*processed_input.shape[-1], color='r', linestyle='--', label='P')
a.axvline(s_phase.mean()*processed_input.shape[-1], color='b', linestyle='--', label='S')
ax[-1].set_xlabel('Time, samples')
ax[-1].set_ylabel('Uncert.')
ax[-1].legend()
plt.subplots_adjust(hspace=0., wspace=0.)
# Convert the plot to an image and return it
fig.canvas.draw()
image = np.array(fig.canvas.renderer.buffer_rgba())
plt.close(fig)
return image
def predict_on_section(client_name, timestamp, eq_lat, eq_lon, radius_km, source_depth_km, velocity_model):
distances, t0s, st_lats, st_lons, waveforms = [], [], [], [], []
taup_model = TauPyModel(model=velocity_model)
client = Client(client_name)
window = radius_km / 111.2
assert eq_lat - window > -90 and eq_lat + window < 90, "Latitude out of bounds"
assert eq_lon - window > -180 and eq_lon + window < 180, "Longitude out of bounds"
starttime = obspy.UTCDateTime(timestamp)
endtime = starttime + 120
inv = client.get_stations(network="*", station="*", location="*", channel="*H*",
starttime=starttime, endtime=endtime,
minlatitude=(eq_lat-window), maxlatitude=(eq_lat+window),
minlongitude=(eq_lon-window), maxlongitude=(eq_lon+window),
level='station')
waveforms = []
for network in inv:
for station in network:
try:
distance = locations2degrees(eq_lat, eq_lon, station.latitude, station.longitude)
arrivals = taup_model.get_travel_times(source_depth_in_km=source_depth_km,
distance_in_degree=distance,
phase_list=["P", "S"])
if len(arrivals) > 0:
starttime = obspy.UTCDateTime(timestamp) + arrivals[0].time - 15
endtime = starttime + 60
waveform = client.get_waveforms(network=network.code, station=station.code, location="*", channel="*",
starttime=starttime, endtime=endtime)
waveform = waveform.select(channel="H[BH][ZNE]")
waveform = waveform.merge(fill_value=0)
waveform = waveform[:3]
len_check = [len(x.data) for x in waveform]
if len(set(len_check)) > 1:
continue
if len(waveform) == 3:
try:
waveform = prepare_waveform(np.stack([x.data for x in waveform]))
except:
continue
distances.append(distance)
t0s.append(starttime)
st_lats.append(station.latitude)
st_lons.append(station.longitude)
waveforms.append(waveform)
except (IndexError, FDSNNoDataException, FDSNTimeoutException):
continue
with torch.no_grad():
waveforms_torch = torch.vstack(waveforms)
output = model(waveforms_torch)
p_phases = output[:, 0]
s_phases = output[:, 1]
fig, ax = plt.subplots(nrows=1, figsize=(10, 3), sharex=True)
for i in range(len(waveforms)):
current_P = p_phases[i::len(waveforms)]
current_S = s_phases[i::len(waveforms)]
x = [t0s[i] + pd.Timedelta(seconds=k/100) for k in np.linspace(0,6000,6000)]
x = mdates.date2num(x)
ax.plot(x, waveforms[i][0, 0]+distances[i]*111.2, color='black', alpha=0.5)
ax.scatter(x[int(current_P.mean()*waveforms[i][0].shape[-1])], waveforms[i][0, 0].mean()+distances[i]*111.2, color='r')
ax.scatter(x[int(current_S.mean()*waveforms[i][0].shape[-1])], waveforms[i][0, 0].mean()+distances[i]*111.2, color='b')
ax.set_ylabel('Z')
ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S'))
ax.xaxis.set_major_locator(mdates.SecondLocator(interval=10))
# for a in ax:
# a.axvline(current_P.mean()*waveforms[i][0].shape[-1], color='r', linestyle='--', label='P')
# a.axvline(current_S.mean()*waveforms[i][0].shape[-1], color='b', linestyle='--', label='S')
# ax[-1].set_xlabel('Time, samples')
# ax[-1].set_ylabel('Uncert.')
# ax[-1].legend()
plt.subplots_adjust(hspace=0., wspace=0.)
fig.canvas.draw();
image = np.array(fig.canvas.renderer.buffer_rgba())
plt.close(fig)
return image
model = Onset_picker.load_from_checkpoint("./weights.ckpt",
picker=Updated_onset_picker(),
learning_rate=3e-4)
model.eval()
# # Create the Gradio interface
# gr.Interface(mark_phases, inputs, outputs, title='PhaseHunter').launch()
with gr.Blocks() as demo:
gr.Markdown("# PhaseHunter")
gr.Markdown("""This app allows one to detect P and S seismic phases along with uncertainty of the detection.
The app can be used in three ways: either by selecting one of the sample waveforms;
or by selecting an earthquake from the global earthquake catalogue;
or by uploading a waveform of interest.
""")
with gr.Tab("Default example"):
# Define the input and output types for Gradio
inputs = gr.Dropdown(
["data/sample/sample_0.npy",
"data/sample/sample_1.npy",
"data/sample/sample_2.npy"],
label="Sample waveform",
info="Select one of the samples",
value = "data/sample/sample_0.npy"
)
button = gr.Button("Predict phases")
outputs = gr.outputs.Image(label='Waveform with Phases Marked', type='numpy')
button.click(mark_phases, inputs=inputs, outputs=outputs)
with gr.Tab("Select earthquake from catalogue"):
gr.Markdown('TEST')
client_inputs = gr.Dropdown(
choices = list(URL_MAPPINGS.keys()),
label="FDSN Client",
info="Select one of the available FDSN clients",
value = "IRIS",
interactive=True
)
with gr.Row():
timestamp_inputs = gr.Textbox(value='2019-07-04 17:33:49',
placeholder='YYYY-MM-DD HH:MM:SS',
label="Timestamp",
info="Timestamp of the earthquake",
max_lines=1,
interactive=True)
eq_lat_inputs = gr.Number(value=35.766,
label="Latitude",
info="Latitude of the earthquake",
interactive=True)
eq_lon_inputs = gr.Number(value=-117.605,
label="Longitude",
info="Longitude of the earthquake",
interactive=True)
source_depth_inputs = gr.Number(value=10,
label="Source depth (km)",
info="Depth of the earthquake",
interactive=True)
radius_inputs = gr.Slider(minimum=1,
maximum=150,
value=50, label="Radius (km)",
info="Select the radius around the earthquake to download data from",
interactive=True)
velocity_inputs = gr.Dropdown(
choices = ['1066a', '1066b', 'ak135', 'ak135f', 'herrin', 'iasp91', 'jb', 'prem', 'pwdk'],
label="1D velocity model",
info="Velocity model for station selection",
value = "1066a",
interactive=True
)
button = gr.Button("Predict phases")
outputs_section = gr.outputs.Image(label='Waveforms with Phases Marked', type='numpy')
button.click(predict_on_section,
inputs=[client_inputs, timestamp_inputs,
eq_lat_inputs, eq_lon_inputs,
radius_inputs, source_depth_inputs, velocity_inputs],
outputs=outputs_section)
with gr.Tab("Predict on your own waveform"):
gr.Markdown("""
Please upload your waveform in .npy (numpy) format.
Your waveform should be sampled at 100 sps and have 3 (Z, N, E) or 1 (Z) channels.
""")
demo.launch()