# Gradio app that takes seismic waveform as input and marks 2 phases on the waveform as output. import gradio as gr import numpy as np import pandas as pd from phasehunter.model import Onset_picker, Updated_onset_picker from phasehunter.data_preparation import prepare_waveform import torch from scipy.stats import gaussian_kde from bmi_topography import Topography import earthpy.spatial as es import obspy from obspy.clients.fdsn import Client from obspy.clients.fdsn.header import FDSNNoDataException, FDSNTimeoutException, FDSNInternalServerException from obspy.geodetics.base import locations2degrees from obspy.taup import TauPyModel from obspy.taup.helper_classes import SlownessModelError from obspy.clients.fdsn.header import URL_MAPPINGS import matplotlib.pyplot as plt import matplotlib.dates as mdates from matplotlib.colors import LightSource from glob import glob def make_prediction(waveform): waveform = np.load(waveform) processed_input = prepare_waveform(waveform) # Make prediction with torch.no_grad(): output = model(processed_input) p_phase = output[:, 0] s_phase = output[:, 1] return processed_input, p_phase, s_phase def mark_phases(waveform, uploaded_file): if uploaded_file is not None: waveform = uploaded_file.name processed_input, p_phase, s_phase = make_prediction(waveform) # Create a plot of the waveform with the phases marked if sum(processed_input[0][2] == 0): #if input is 1C fig, ax = plt.subplots(nrows=2, figsize=(10, 2), sharex=True) ax[0].plot(processed_input[0][0], color='black', lw=1) ax[0].set_ylabel('Norm. Ampl.') else: #if input is 3C fig, ax = plt.subplots(nrows=4, figsize=(10, 6), sharex=True) ax[0].plot(processed_input[0][0], color='black', lw=1) ax[1].plot(processed_input[0][1], color='black', lw=1) ax[2].plot(processed_input[0][2], color='black', lw=1) ax[0].set_ylabel('Z') ax[1].set_ylabel('N') ax[2].set_ylabel('E') p_phase_plot = p_phase*processed_input.shape[-1] p_kde = gaussian_kde(p_phase_plot) p_dist_space = np.linspace( min(p_phase_plot)-10, max(p_phase_plot)+10, 500 ) ax[-1].plot( p_dist_space, p_kde(p_dist_space), color='r') s_phase_plot = s_phase*processed_input.shape[-1] s_kde = gaussian_kde(s_phase_plot) s_dist_space = np.linspace( min(s_phase_plot)-10, max(s_phase_plot)+10, 500 ) ax[-1].plot( s_dist_space, s_kde(s_dist_space), color='b') for a in ax: a.axvline(p_phase.mean()*processed_input.shape[-1], color='r', linestyle='--', label='P') a.axvline(s_phase.mean()*processed_input.shape[-1], color='b', linestyle='--', label='S') ax[-1].set_xlabel('Time, samples') ax[-1].set_ylabel('Uncert.') ax[-1].legend() plt.subplots_adjust(hspace=0., wspace=0.) # Convert the plot to an image and return it fig.canvas.draw() image = np.array(fig.canvas.renderer.buffer_rgba()) plt.close(fig) return image def bin_distances(distances, bin_size=10): # Bin the distances into groups of `bin_size` kilometers binned_distances = {} for i, distance in enumerate(distances): bin_index = distance // bin_size if bin_index not in binned_distances: binned_distances[bin_index] = (distance, i) elif i < binned_distances[bin_index][1]: binned_distances[bin_index] = (distance, i) # Select the first distance in each bin and its index first_distances = [] for bin_index in binned_distances: first_distance, first_distance_index = binned_distances[bin_index] first_distances.append(first_distance_index) return first_distances def variance_coefficient(residuals): # calculate the variance of the residuals var = residuals.var() # scale the variance to a coefficient between 0 and 1 coeff = 1 - (var / (residuals.max() - residuals.min())) return coeff def predict_on_section(client_name, timestamp, eq_lat, eq_lon, radius_km, source_depth_km, velocity_model, max_waveforms): distances, t0s, st_lats, st_lons, waveforms, names = [], [], [], [], [], [] taup_model = TauPyModel(model=velocity_model) client = Client(client_name) window = radius_km / 111.2 max_waveforms = int(max_waveforms) assert eq_lat - window > -90 and eq_lat + window < 90, "Latitude out of bounds" assert eq_lon - window > -180 and eq_lon + window < 180, "Longitude out of bounds" starttime = obspy.UTCDateTime(timestamp) endtime = starttime + 120 try: print('Starting to download inventory') inv = client.get_stations(network="*", station="*", location="*", channel="*H*", starttime=starttime, endtime=endtime, minlatitude=(eq_lat-window), maxlatitude=(eq_lat+window), minlongitude=(eq_lon-window), maxlongitude=(eq_lon+window), level='station') print('Finished downloading inventory') except (IndexError, FDSNNoDataException, FDSNTimeoutException, FDSNInternalServerException): fig, ax = plt.subplots() ax.text(0.5,0.5,'Something is wrong with the data provider, try another') fig.canvas.draw(); image = np.array(fig.canvas.renderer.buffer_rgba()) plt.close(fig) return image waveforms = [] cached_waveforms = glob("data/cached/*.mseed") for network in inv: # Skip the SYntetic networks if network.code == 'SY': continue for station in network: print(f"Processing {network.code}.{station.code}...") distance = locations2degrees(eq_lat, eq_lon, station.latitude, station.longitude) arrivals = taup_model.get_travel_times(source_depth_in_km=source_depth_km, distance_in_degree=distance, phase_list=["P", "S"]) if len(arrivals) > 0: starttime = obspy.UTCDateTime(timestamp) + arrivals[0].time - 15 endtime = starttime + 60 try: if f"data/cached/{network.code}_{station.code}_{starttime}.mseed" not in cached_waveforms: print('Downloading waveform') waveform = client.get_waveforms(network=network.code, station=station.code, location="*", channel="*", starttime=starttime, endtime=endtime) waveform.write(f"data/cached/{network.code}_{station.code}_{starttime}.mseed", format="MSEED") print('Finished downloading and caching waveform') else: print('Reading cached waveform') waveform = obspy.read(f"data/cached/{network.code}_{station.code}_{starttime}.mseed") except (IndexError, FDSNNoDataException, FDSNTimeoutException, FDSNInternalServerException): print(f'Skipping {network.code}_{station.code}_{starttime}') continue waveform = waveform.select(channel="H[BH][ZNE]") waveform = waveform.merge(fill_value=0) waveform = waveform[:3] len_check = [len(x.data) for x in waveform] if len(set(len_check)) > 1: continue if len(waveform) == 3: try: waveform = prepare_waveform(np.stack([x.data for x in waveform])) distances.append(distance) t0s.append(starttime) st_lats.append(station.latitude) st_lons.append(station.longitude) waveforms.append(waveform) names.append(f"{network.code}.{station.code}") print(f"Added {network.code}.{station.code} to the list of waveforms") except: continue # If there are no waveforms, return an empty plot if len(waveforms) == 0: fig, ax = plt.subplots() ax.text(0.5,0.5,'No waveforms found') fig.canvas.draw(); image = np.array(fig.canvas.renderer.buffer_rgba()) plt.close(fig) return image first_distances = bin_distances(distances, bin_size=10/111.2) # Edge case when there are way too many waveforms to process selection_indexes = np.random.choice(first_distances, np.min([len(first_distances), max_waveforms]), replace=False) waveforms = np.array(waveforms)[selection_indexes] distances = np.array(distances)[selection_indexes] t0s = np.array(t0s)[selection_indexes] st_lats = np.array(st_lats)[selection_indexes] st_lons = np.array(st_lons)[selection_indexes] names = np.array(names)[selection_indexes] waveforms = [torch.tensor(waveform) for waveform in waveforms] print('Starting to run predictions') with torch.no_grad(): waveforms_torch = torch.vstack(waveforms) output = model(waveforms_torch) p_phases = output[:, 0] s_phases = output[:, 1] # Max confidence - min variance p_max_confidence = np.min([p_phases[i::len(waveforms)].std() for i in range(len(waveforms))]) s_max_confidence = np.min([s_phases[i::len(waveforms)].std() for i in range(len(waveforms))]) print(f"Starting plotting {len(waveforms)} waveforms") fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(10, 3)) # Plot topography print('Fetching topography') params = Topography.DEFAULT.copy() extra_window = 0.5 params["south"] = np.min([st_lats.min(), eq_lat])-extra_window params["north"] = np.max([st_lats.max(), eq_lat])+extra_window params["west"] = np.min([st_lons.min(), eq_lon])-extra_window params["east"] = np.max([st_lons.max(), eq_lon])+extra_window topo_map = Topography(**params) topo_map.fetch() topo_map.load() print('Plotting topo') hillshade = es.hillshade(topo_map.da[0], altitude=10) topo_map.da.plot(ax = ax[1], cmap='Greys', add_colorbar=False, add_labels=False) topo_map.da.plot(ax = ax[2], cmap='Greys', add_colorbar=False, add_labels=False) ax[1].imshow(hillshade, cmap="Greys", alpha=0.5) output_picks = pd.DataFrame({'station_name' : [], 'starttime' : [], 'p_phase' : [], 'p_uncertainty' : [], 's_phase' : [], 's_uncertainty' : [], 'velocity_p' : [], 'velocity_s' : []}) for i in range(len(waveforms)): print(f"Plotting waveform {i+1}/{len(waveforms)}") current_P = p_phases[i::len(waveforms)] current_S = s_phases[i::len(waveforms)] x = [t0s[i] + pd.Timedelta(seconds=k/100) for k in np.linspace(0,6000,6000)] x = mdates.date2num(x) # Normalize confidence for the plot p_conf = 1/(current_P.std()/p_max_confidence).item() s_conf = 1/(current_S.std()/s_max_confidence).item() ax[0].plot(x, waveforms[i][0, 0]*10+distances[i]*111.2, color='black', alpha=0.5, lw=1) ax[0].scatter(x[int(current_P.mean()*waveforms[i][0].shape[-1])], waveforms[i][0, 0].mean()+distances[i]*111.2, color='r', alpha=p_conf, marker='|') ax[0].scatter(x[int(current_S.mean()*waveforms[i][0].shape[-1])], waveforms[i][0, 0].mean()+distances[i]*111.2, color='b', alpha=s_conf, marker='|') ax[0].set_ylabel('Z') ax[0].xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S')) ax[0].xaxis.set_major_locator(mdates.SecondLocator(interval=20)) delta_t = t0s[i].timestamp - obspy.UTCDateTime(timestamp).timestamp velocity_p = (distances[i]*111.2)/(delta_t+current_P.mean()*60).item() velocity_s = (distances[i]*111.2)/(delta_t+current_S.mean()*60).item() print(f"Station {st_lats[i]}, {st_lons[i]} has P velocity {velocity_p} and S velocity {velocity_s}") output_picks = output_picks.append(pd.DataFrame({'station_name': [names[i]], 'starttime' : [str(t0s[i])], 'p_phase' : [(delta_t+current_P.mean()*60).item()], 'p_uncertainty' : [current_P.std().item()*60], 's_phase' : [(delta_t+current_S.mean()*60).item()], 's_uncertainty' : [current_S.std().item()*60], 'velocity_p' : [velocity_p], 'velocity_s' : [velocity_s]})) # Generate an array from st_lat to eq_lat and from st_lon to eq_lon x = np.linspace(st_lons[i], eq_lon, 50) y = np.linspace(st_lats[i], eq_lat, 50) # Plot the array ax[1].scatter(x, y, c=np.zeros_like(x)+velocity_p, alpha=0.5, vmin=0, vmax=8) ax[2].scatter(x, y, c=np.zeros_like(x)+velocity_s, alpha=0.5, vmin=0, vmax=8) # Add legend ax[0].scatter(None, None, color='r', marker='|', label='P') ax[0].scatter(None, None, color='b', marker='|', label='S') ax[0].legend() print('Plotting stations') for i in range(1,3): ax[i].scatter(st_lons, st_lats, color='b', label='Stations') ax[i].scatter(eq_lon, eq_lat, color='r', marker='*', label='Earthquake') # Generate colorbar for the velocity plot cbar = plt.colorbar(ax[1].scatter(None, None, c=velocity_p, alpha=0.5, vmin=0, vmax=8), ax=ax[1]) cbar.set_label('P Velocity (km/s)') ax[1].set_title('P Velocity') cbar = plt.colorbar(ax[2].scatter(None, None, c=velocity_s, alpha=0.5, vmin=0, vmax=8), ax=ax[2]) cbar.set_label('S Velocity (km/s)') ax[2].set_title('S Velocity') plt.subplots_adjust(hspace=0., wspace=0.5) fig.canvas.draw(); image = np.array(fig.canvas.renderer.buffer_rgba()) plt.close(fig) return image, output_picks model = Onset_picker.load_from_checkpoint("./weights.ckpt", picker=Updated_onset_picker(), learning_rate=3e-4) model.eval() with gr.Blocks() as demo: gr.HTML("""
This app allows one to detect P and S seismic phases along with uncertainty of the detection.
Please upload your waveform in .npy
(numpy) format.
Your waveform should be sampled at 100 samples per second and have 3 (Z, N, E) or 1 (Z) channels. If your file is longer than 60 seconds, the app will only use the first 60 seconds of the waveform.
""") with gr.Tab("Try on a single station"): with gr.Row(): # Define the input and output types for Gradio inputs = gr.Dropdown( ["data/sample/sample_0.npy", "data/sample/sample_1.npy", "data/sample/sample_2.npy"], label="Sample waveform", info="Select one of the samples", value = "data/sample/sample_0.npy" ) upload = gr.File(label="Or upload your own waveform") button = gr.Button("Predict phases") outputs = gr.Image(label='Waveform with Phases Marked', type='numpy', interactive=False) button.click(mark_phases, inputs=[inputs, upload], outputs=outputs) with gr.Tab("Select earthquake from catalogue"): gr.Markdown("""Select an earthquake from the global earthquake catalogue and the app will download the waveform from the FDSN client of your choice. """) with gr.Row(): client_inputs = gr.Dropdown( choices = list(URL_MAPPINGS.keys()), label="FDSN Client", info="Select one of the available FDSN clients", value = "IRIS", interactive=True ) velocity_inputs = gr.Dropdown( choices = ['1066a', '1066b', 'ak135', 'ak135f', 'herrin', 'iasp91', 'jb', 'prem', 'pwdk'], label="1D velocity model", info="Velocity model for station selection", value = "1066a", interactive=True ) with gr.Column(scale=4): with gr.Row(): timestamp_inputs = gr.Textbox(value='2019-07-04 17:33:49', placeholder='YYYY-MM-DD HH:MM:SS', label="Timestamp", info="Timestamp of the earthquake", max_lines=1, interactive=True) eq_lat_inputs = gr.Number(value=35.766, label="Latitude", info="Latitude of the earthquake", interactive=True) eq_lon_inputs = gr.Number(value=-117.605, label="Longitude", info="Longitude of the earthquake", interactive=True) source_depth_inputs = gr.Number(value=10, label="Source depth (km)", info="Depth of the earthquake", interactive=True) with gr.Column(scale=2): with gr.Row(): radius_inputs = gr.Slider(minimum=1, maximum=150, value=50, label="Radius (km)", step=10, info="""Select the radius around the earthquake to download data from.\n Note that the larger the radius, the longer the app will take to run.""", interactive=True) max_waveforms_inputs = gr.Slider(minimum=1, maximum=100, value=10, label="Max waveforms per section", step=1, info="Maximum number of waveforms to show per section\n (to avoid long prediction times)", interactive=True, ) button = gr.Button("Predict phases") output_image = gr.Image(label='Waveforms with Phases Marked', type='numpy', interactive=False) output_picks = gr.Dataframe(label='# Pick data', type='pandas', interactive=False) button.click(predict_on_section, inputs=[client_inputs, timestamp_inputs, eq_lat_inputs, eq_lon_inputs, radius_inputs, source_depth_inputs, velocity_inputs, max_waveforms_inputs], outputs=[output_image, output_picks]) demo.launch()