{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'hi!'" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "'hi!'" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/anovosel/miniconda3/envs/phasehunter/lib/python3.11/site-packages/gradio/outputs.py:43: UserWarning: Usage of gradio.outputs is deprecated, and will not be supported in the future, please import your components from gradio.components\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Running on local URL: http://127.0.0.1:7862\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Gradio app that takes seismic waveform as input and marks 2 phases on the waveform as output.\n", "\n", "import gradio as gr\n", "import numpy as np\n", "import pandas as pd\n", "from phasehunter.model import Onset_picker, Updated_onset_picker\n", "from phasehunter.data_preparation import prepare_waveform\n", "import torch\n", "\n", "from scipy.stats import gaussian_kde\n", "\n", "import obspy\n", "from obspy.clients.fdsn import Client\n", "from obspy.clients.fdsn.header import FDSNNoDataException, FDSNTimeoutException, FDSNInternalServerException\n", "from obspy.geodetics.base import locations2degrees\n", "from obspy.taup import TauPyModel\n", "from obspy.taup.helper_classes import SlownessModelError\n", "\n", "from obspy.clients.fdsn.header import URL_MAPPINGS\n", "\n", "import matplotlib.pyplot as plt\n", "import matplotlib.dates as mdates\n", "\n", "from glob import glob\n", "\n", "def make_prediction(waveform):\n", " waveform = np.load(waveform)\n", " processed_input = prepare_waveform(waveform)\n", " \n", " # Make prediction\n", " with torch.no_grad():\n", " output = model(processed_input)\n", "\n", " p_phase = output[:, 0]\n", " s_phase = output[:, 1]\n", "\n", " return processed_input, p_phase, s_phase\n", "\n", "def mark_phases(waveform):\n", " processed_input, p_phase, s_phase = make_prediction(waveform)\n", "\n", " # Create a plot of the waveform with the phases marked\n", " if sum(processed_input[0][2] == 0): #if input is 1C\n", " fig, ax = plt.subplots(nrows=2, figsize=(10, 2), sharex=True)\n", "\n", " ax[0].plot(processed_input[0][0])\n", " ax[0].set_ylabel('Norm. Ampl.')\n", "\n", " else: #if input is 3C\n", " fig, ax = plt.subplots(nrows=4, figsize=(10, 6), sharex=True)\n", " ax[0].plot(processed_input[0][0])\n", " ax[1].plot(processed_input[0][1])\n", " ax[2].plot(processed_input[0][2])\n", "\n", " ax[0].set_ylabel('Z')\n", " ax[1].set_ylabel('N')\n", " ax[2].set_ylabel('E')\n", "\n", " p_phase_plot = p_phase*processed_input.shape[-1]\n", " p_kde = gaussian_kde(p_phase_plot)\n", " p_dist_space = np.linspace( min(p_phase_plot)-10, max(p_phase_plot)+10, 500 )\n", " ax[-1].plot( p_dist_space, p_kde(p_dist_space), color='r')\n", "\n", " s_phase_plot = s_phase*processed_input.shape[-1]\n", " s_kde = gaussian_kde(s_phase_plot)\n", " s_dist_space = np.linspace( min(s_phase_plot)-10, max(s_phase_plot)+10, 500 )\n", " ax[-1].plot( s_dist_space, s_kde(s_dist_space), color='b')\n", "\n", " for a in ax:\n", " a.axvline(p_phase.mean()*processed_input.shape[-1], color='r', linestyle='--', label='P')\n", " a.axvline(s_phase.mean()*processed_input.shape[-1], color='b', linestyle='--', label='S')\n", "\n", " ax[-1].set_xlabel('Time, samples')\n", " ax[-1].set_ylabel('Uncert.')\n", " ax[-1].legend()\n", "\n", " plt.subplots_adjust(hspace=0., wspace=0.)\n", "\n", " # Convert the plot to an image and return it\n", " fig.canvas.draw()\n", " image = np.array(fig.canvas.renderer.buffer_rgba())\n", " plt.close(fig)\n", " return image\n", "\n", "def variance_coefficient(residuals):\n", " # calculate the variance of the residuals\n", " var = residuals.var()\n", " \n", " # scale the variance to a coefficient between 0 and 1\n", " coeff = 1 - (var / (residuals.max() - residuals.min()))\n", " \n", " return coeff\n", "\n", "def predict_on_section(client_name, timestamp, eq_lat, eq_lon, radius_km, source_depth_km, velocity_model):\n", " distances, t0s, st_lats, st_lons, waveforms = [], [], [], [], []\n", " \n", " taup_model = TauPyModel(model=velocity_model)\n", " client = Client(client_name)\n", "\n", " window = radius_km / 111.2\n", "\n", " assert eq_lat - window > -90 and eq_lat + window < 90, \"Latitude out of bounds\"\n", " assert eq_lon - window > -180 and eq_lon + window < 180, \"Longitude out of bounds\"\n", "\n", " starttime = obspy.UTCDateTime(timestamp)\n", " endtime = starttime + 120\n", "\n", " inv = client.get_stations(network=\"*\", station=\"*\", location=\"*\", channel=\"*H*\", \n", " starttime=starttime, endtime=endtime, \n", " minlatitude=(eq_lat-window), maxlatitude=(eq_lat+window),\n", " minlongitude=(eq_lon-window), maxlongitude=(eq_lon+window), \n", " level='station')\n", " \n", " waveforms = []\n", " cached_waveforms = glob(\"data/cached/*.mseed\")\n", "\n", " for network in inv:\n", " for station in network:\n", " try:\n", " distance = locations2degrees(eq_lat, eq_lon, station.latitude, station.longitude)\n", "\n", " arrivals = taup_model.get_travel_times(source_depth_in_km=source_depth_km, \n", " distance_in_degree=distance, \n", " phase_list=[\"P\", \"S\"])\n", "\n", " if len(arrivals) > 0:\n", "\n", " starttime = obspy.UTCDateTime(timestamp) + arrivals[0].time - 15\n", " endtime = starttime + 60\n", "\n", " if f\"data/cached/{network.code}_{station.code}_{starttime}.mseed\" not in cached_waveforms:\n", " waveform = client.get_waveforms(network=network.code, station=station.code, location=\"*\", channel=\"*\", \n", " starttime=starttime, endtime=endtime)\n", " waveform.write(f\"data/cached/{network.code}_{station.code}_{starttime}.mseed\", format=\"MSEED\")\n", " else:\n", " waveform = obspy.read(f\"data/cached/{network.code}_{station.code}_{starttime}.mseed\")\n", " \n", " waveform = waveform.select(channel=\"H[BH][ZNE]\")\n", " waveform = waveform.merge(fill_value=0)\n", " waveform = waveform[:3]\n", " \n", " len_check = [len(x.data) for x in waveform]\n", " if len(set(len_check)) > 1:\n", " continue\n", "\n", " if len(waveform) == 3:\n", " try:\n", " waveform = prepare_waveform(np.stack([x.data for x in waveform]))\n", " except:\n", " continue\n", " \n", " distances.append(distance)\n", " t0s.append(starttime)\n", " st_lats.append(station.latitude)\n", " st_lons.append(station.longitude)\n", " waveforms.append(waveform)\n", "\n", " except (IndexError, FDSNNoDataException, FDSNTimeoutException):\n", " continue\n", "\n", " with torch.no_grad():\n", " waveforms_torch = torch.vstack(waveforms)\n", " output = model(waveforms_torch)\n", "\n", " p_phases = output[:, 0]\n", " s_phases = output[:, 1]\n", "\n", " # Max confidence - min variance \n", " p_max_confidence = np.min([p_phases[i::len(waveforms)].std() for i in range(len(waveforms))]) \n", " s_max_confidence = np.min([s_phases[i::len(waveforms)].std() for i in range(len(waveforms))])\n", "\n", " fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 3), sharex=True)\n", " for i in range(len(waveforms)):\n", " current_P = p_phases[i::len(waveforms)]\n", " current_S = s_phases[i::len(waveforms)]\n", "\n", " x = [t0s[i] + pd.Timedelta(seconds=k/100) for k in np.linspace(0,6000,6000)]\n", " x = mdates.date2num(x)\n", "\n", " # Normalize confidence for the plot\n", " p_conf = 1/(current_P.std()/p_max_confidence).item()\n", " s_conf = 1/(current_S.std()/s_max_confidence).item()\n", "\n", " ax[0].plot(x, waveforms[i][0, 0]*10+distances[i]*111.2, color='black', alpha=0.5, lw=1)\n", "\n", " ax[0].scatter(x[int(current_P.mean()*waveforms[i][0].shape[-1])], waveforms[i][0, 0].mean()+distances[i]*111.2, color='r', alpha=p_conf, marker='|')\n", " ax[0].scatter(x[int(current_S.mean()*waveforms[i][0].shape[-1])], waveforms[i][0, 0].mean()+distances[i]*111.2, color='b', alpha=s_conf, marker='|')\n", " ax[0].set_ylabel('Z')\n", "\n", " ax[0].xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S'))\n", " ax[0].xaxis.set_major_locator(mdates.SecondLocator(interval=5))\n", " \n", " ax[0].scatter(None, None, color='r', marker='|', label='P')\n", " ax[0].scatter(None, None, color='b', marker='|', label='S')\n", " ax[0].legend()\n", "\n", " ax[1].scatter(st_lats, st_lons, color='b', marker='d', label='Stations')\n", " ax[1].scatter(eq_lat, eq_lon, color='r', marker='*', label='Earthquake')\n", " ax[1].legend()\n", " plt.subplots_adjust(hspace=0., wspace=0.)\n", " \n", " fig.canvas.draw();\n", " image = np.array(fig.canvas.renderer.buffer_rgba())\n", " plt.close(fig)\n", "\n", " return image\n", "\n", "\n", "model = Onset_picker.load_from_checkpoint(\"./weights.ckpt\",\n", " picker=Updated_onset_picker(),\n", " learning_rate=3e-4)\n", "model.eval()\n", "\n", "with gr.Blocks() as demo:\n", " gr.Markdown(\"# PhaseHunter\")\n", " gr.Markdown(\"\"\"This app allows one to detect P and S seismic phases along with uncertainty of the detection. \n", " The app can be used in three ways: either by selecting one of the sample waveforms;\n", " or by selecting an earthquake from the global earthquake catalogue;\n", " or by uploading a waveform of interest.\n", " \"\"\")\n", " with gr.Tab(\"Default example\"):\n", " # Define the input and output types for Gradio\n", " inputs = gr.Dropdown(\n", " [\"data/sample/sample_0.npy\", \n", " \"data/sample/sample_1.npy\", \n", " \"data/sample/sample_2.npy\"], \n", " label=\"Sample waveform\", \n", " info=\"Select one of the samples\",\n", " value = \"data/sample/sample_0.npy\"\n", " )\n", "\n", " button = gr.Button(\"Predict phases\")\n", " outputs = gr.outputs.Image(label='Waveform with Phases Marked', type='numpy')\n", " \n", " button.click(mark_phases, inputs=inputs, outputs=outputs)\n", " \n", " with gr.Tab(\"Select earthquake from catalogue\"):\n", " gr.Markdown('TEST')\n", " \n", " client_inputs = gr.Dropdown(\n", " choices = list(URL_MAPPINGS.keys()), \n", " label=\"FDSN Client\", \n", " info=\"Select one of the available FDSN clients\",\n", " value = \"IRIS\",\n", " interactive=True\n", " )\n", " with gr.Row(): \n", "\n", " timestamp_inputs = gr.Textbox(value='2019-07-04 17:33:49',\n", " placeholder='YYYY-MM-DD HH:MM:SS',\n", " label=\"Timestamp\",\n", " info=\"Timestamp of the earthquake\",\n", " max_lines=1,\n", " interactive=True)\n", " \n", " eq_lat_inputs = gr.Number(value=35.766, \n", " label=\"Latitude\", \n", " info=\"Latitude of the earthquake\",\n", " interactive=True)\n", " \n", " eq_lon_inputs = gr.Number(value=-117.605,\n", " label=\"Longitude\",\n", " info=\"Longitude of the earthquake\",\n", " interactive=True)\n", " \n", " source_depth_inputs = gr.Number(value=10,\n", " label=\"Source depth (km)\",\n", " info=\"Depth of the earthquake\",\n", " interactive=True)\n", " \n", " radius_inputs = gr.Slider(minimum=1, \n", " maximum=150, \n", " value=50, label=\"Radius (km)\", \n", " step=10,\n", " info=\"\"\"Select the radius around the earthquake to download data from.\\n \n", " Note that the larger the radius, the longer the app will take to run.\"\"\",\n", " interactive=True)\n", " \n", " velocity_inputs = gr.Dropdown(\n", " choices = ['1066a', '1066b', 'ak135', 'ak135f', 'herrin', 'iasp91', 'jb', 'prem', 'pwdk'], \n", " label=\"1D velocity model\", \n", " info=\"Velocity model for station selection\",\n", " value = \"1066a\",\n", " interactive=True\n", " )\n", " \n", " \n", " button = gr.Button(\"Predict phases\")\n", " outputs_section = gr.Image(label='Waveforms with Phases Marked', type='numpy', interactive=False)\n", " \n", " button.click(predict_on_section, \n", " inputs=[client_inputs, timestamp_inputs, \n", " eq_lat_inputs, eq_lon_inputs, \n", " radius_inputs, source_depth_inputs, velocity_inputs],\n", " outputs=outputs_section)\n", "\n", " with gr.Tab(\"Predict on your own waveform\"):\n", " gr.Markdown(\"\"\"\n", " Please upload your waveform in .npy (numpy) format. \n", " Your waveform should be sampled at 100 sps and have 3 (Z, N, E) or 1 (Z) channels.\n", " \"\"\")\n", "\n", "demo.launch()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "phasehunter", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.2" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "6bf57068982d7b420bddaaf1d0614a7795947176033057024cf47d8ca2c1c4cd" } } }, "nbformat": 4, "nbformat_minor": 2 }