Spaces:
Runtime error
Runtime error
Updated layout
Browse files- .DS_Store +0 -0
- Gradio_app.ipynb +140 -98
- app.py +85 -83
- phasehunter/model.py +0 -313
- phasehunter/training.py +0 -104
.DS_Store
CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
|
|
Gradio_app.ipynb
CHANGED
@@ -2,29 +2,14 @@
|
|
2 |
"cells": [
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
-
"execution_count":
|
6 |
-
"metadata": {},
|
7 |
-
"outputs": [],
|
8 |
-
"source": [
|
9 |
-
"model = Onset_picker.load_from_checkpoint(\"./weights.ckpt\",\n",
|
10 |
-
" picker=Updated_onset_picker(),\n",
|
11 |
-
" learning_rate=3e-4)\n",
|
12 |
-
"model.eval()\n",
|
13 |
-
"model.freeze()\n",
|
14 |
-
"script = model.to_torchscript()\n",
|
15 |
-
"torch.jit.save(script, \"model.pt\")"
|
16 |
-
]
|
17 |
-
},
|
18 |
-
{
|
19 |
-
"cell_type": "code",
|
20 |
-
"execution_count": 32,
|
21 |
"metadata": {},
|
22 |
"outputs": [
|
23 |
{
|
24 |
"name": "stdout",
|
25 |
"output_type": "stream",
|
26 |
"text": [
|
27 |
-
"Running on local URL: http://127.0.0.1:
|
28 |
"\n",
|
29 |
"To create a public link, set `share=True` in `launch()`.\n"
|
30 |
]
|
@@ -32,7 +17,7 @@
|
|
32 |
{
|
33 |
"data": {
|
34 |
"text/html": [
|
35 |
-
"<div><iframe src=\"http://127.0.0.1:
|
36 |
],
|
37 |
"text/plain": [
|
38 |
"<IPython.core.display.HTML object>"
|
@@ -45,7 +30,7 @@
|
|
45 |
"data": {
|
46 |
"text/plain": []
|
47 |
},
|
48 |
-
"execution_count":
|
49 |
"metadata": {},
|
50 |
"output_type": "execute_result"
|
51 |
},
|
@@ -116,13 +101,69 @@
|
|
116 |
"name": "stderr",
|
117 |
"output_type": "stream",
|
118 |
"text": [
|
119 |
-
"/var/folders/_g/3q5q8_dj0ydcpktxlwxb5vrh0000gq/T/ipykernel_27324/
|
120 |
" waveforms = np.array(waveforms)[selection_indexes]\n",
|
121 |
-
"/var/folders/_g/3q5q8_dj0ydcpktxlwxb5vrh0000gq/T/ipykernel_27324/
|
122 |
" waveforms = np.array(waveforms)[selection_indexes]\n",
|
123 |
-
"/var/folders/_g/3q5q8_dj0ydcpktxlwxb5vrh0000gq/T/ipykernel_27324/
|
124 |
" waveforms = [torch.tensor(waveform) for waveform in waveforms]\n"
|
125 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
}
|
127 |
],
|
128 |
"source": [
|
@@ -149,7 +190,7 @@
|
|
149 |
"\n",
|
150 |
"import matplotlib.pyplot as plt\n",
|
151 |
"import matplotlib.dates as mdates\n",
|
152 |
-
"from
|
153 |
"\n",
|
154 |
"from glob import glob\n",
|
155 |
"\n",
|
@@ -309,8 +350,8 @@
|
|
309 |
" \n",
|
310 |
" waveform = waveform.select(channel=\"H[BH][ZNE]\")\n",
|
311 |
" waveform = waveform.merge(fill_value=0)\n",
|
312 |
-
" waveform = waveform[:3]\n",
|
313 |
-
"
|
314 |
" len_check = [len(x.data) for x in waveform]\n",
|
315 |
" if len(set(len_check)) > 1:\n",
|
316 |
" continue\n",
|
@@ -371,8 +412,8 @@
|
|
371 |
" s_max_confidence = np.min([s_phases[i::len(waveforms)].std() for i in range(len(waveforms))])\n",
|
372 |
"\n",
|
373 |
" print(f\"Starting plotting {len(waveforms)} waveforms\")\n",
|
374 |
-
" fig, ax = plt.subplots(
|
375 |
-
"\n",
|
376 |
" # Plot topography\n",
|
377 |
" print('Fetching topography')\n",
|
378 |
" params = Topography.DEFAULT.copy()\n",
|
@@ -417,9 +458,6 @@
|
|
417 |
" ax[0].scatter(x[int(current_S.mean()*waveforms[i][0].shape[-1])], waveforms[i][0, 0].mean()+distances[i]*111.2, color='b', alpha=s_conf, marker='|')\n",
|
418 |
" ax[0].set_ylabel('Z')\n",
|
419 |
"\n",
|
420 |
-
" ax[0].xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S'))\n",
|
421 |
-
" ax[0].xaxis.set_major_locator(mdates.SecondLocator(interval=20))\n",
|
422 |
-
"\n",
|
423 |
" delta_t = t0s[i].timestamp - obspy.UTCDateTime(timestamp).timestamp\n",
|
424 |
"\n",
|
425 |
" velocity_p = (distances[i]*111.2)/(delta_t+current_P.mean()*60).item()\n",
|
@@ -437,30 +475,37 @@
|
|
437 |
" y = np.linspace(st_lats[i], eq_lat, 50)\n",
|
438 |
" \n",
|
439 |
" # Plot the array\n",
|
440 |
-
" ax[1].scatter(x, y, c=np.zeros_like(x)+velocity_p, alpha=0.
|
441 |
-
" ax[2].scatter(x, y, c=np.zeros_like(x)+velocity_s, alpha=0.
|
442 |
"\n",
|
443 |
" # Add legend\n",
|
444 |
" ax[0].scatter(None, None, color='r', marker='|', label='P')\n",
|
445 |
" ax[0].scatter(None, None, color='b', marker='|', label='S')\n",
|
|
|
|
|
446 |
" ax[0].legend()\n",
|
447 |
"\n",
|
448 |
" print('Plotting stations')\n",
|
449 |
" for i in range(1,3):\n",
|
450 |
" ax[i].scatter(st_lons, st_lats, color='b', label='Stations')\n",
|
451 |
" ax[i].scatter(eq_lon, eq_lat, color='r', marker='*', label='Earthquake')\n",
|
|
|
|
|
452 |
"\n",
|
453 |
-
"
|
454 |
-
"
|
455 |
-
"
|
456 |
-
"
|
|
|
457 |
"\n",
|
458 |
-
" cbar
|
459 |
-
"
|
460 |
" ax[2].set_title('S Velocity')\n",
|
461 |
"\n",
|
|
|
|
|
|
|
462 |
" plt.subplots_adjust(hspace=0., wspace=0.5)\n",
|
463 |
-
"\n",
|
464 |
" fig.canvas.draw();\n",
|
465 |
" image = np.array(fig.canvas.renderer.buffer_rgba())\n",
|
466 |
" plt.close(fig)\n",
|
@@ -482,7 +527,6 @@
|
|
482 |
" }\n",
|
483 |
"</style></h1> \n",
|
484 |
" \n",
|
485 |
-
"\n",
|
486 |
" <p style=\"font-size: 16px; margin-bottom: 20px;\">Detect <span style=\"background-image: linear-gradient(to right, #ED213A, #93291E); \n",
|
487 |
" -webkit-background-clip: text;\n",
|
488 |
" -webkit-text-fill-color: transparent;\n",
|
@@ -531,68 +575,66 @@
|
|
531 |
" </div>\n",
|
532 |
" \"\"\")\n",
|
533 |
" with gr.Row(): \n",
|
534 |
-
"
|
535 |
-
"
|
536 |
-
"
|
537 |
-
"
|
538 |
-
"
|
539 |
-
"
|
540 |
-
"
|
541 |
-
"\n",
|
542 |
-
"
|
543 |
-
"
|
544 |
-
"
|
545 |
-
"
|
546 |
-
"
|
547 |
-
"
|
548 |
-
"
|
549 |
-
"
|
550 |
-
"
|
|
|
551 |
"\n",
|
552 |
-
" with gr.Column(scale=
|
553 |
-
"
|
554 |
-
"
|
555 |
-
"
|
556 |
-
"
|
557 |
-
"
|
558 |
-
"
|
559 |
-
"
|
560 |
-
"
|
561 |
-
"
|
562 |
-
"
|
563 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
564 |
" interactive=True)\n",
|
565 |
-
" \n",
|
566 |
-
" eq_lon_inputs = gr.Number(value=-117.605,\n",
|
567 |
-
" label=\"Longitude\",\n",
|
568 |
-
" info=\"Longitude of the earthquake\",\n",
|
569 |
-
" interactive=True)\n",
|
570 |
-
" \n",
|
571 |
-
" source_depth_inputs = gr.Number(value=10,\n",
|
572 |
-
" label=\"Source depth (km)\",\n",
|
573 |
-
" info=\"Depth of the earthquake\",\n",
|
574 |
-
" interactive=True)\n",
|
575 |
" \n",
|
576 |
-
"\n",
|
577 |
-
" \n",
|
578 |
" with gr.Column(scale=2):\n",
|
579 |
-
"
|
580 |
-
"
|
581 |
-
"
|
582 |
-
"
|
583 |
-
"
|
584 |
-
"
|
585 |
-
"
|
586 |
-
"
|
587 |
-
"
|
588 |
-
"
|
589 |
-
"
|
590 |
-
"
|
591 |
-
"
|
592 |
-
"
|
593 |
-
"
|
594 |
-
"
|
595 |
-
" )\n",
|
596 |
" \n",
|
597 |
" button = gr.Button(\"Predict phases\")\n",
|
598 |
" output_image = gr.Image(label='Waveforms with Phases Marked', type='numpy', interactive=False)\n",
|
|
|
2 |
"cells": [
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
+
"execution_count": 51,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
"metadata": {},
|
7 |
"outputs": [
|
8 |
{
|
9 |
"name": "stdout",
|
10 |
"output_type": "stream",
|
11 |
"text": [
|
12 |
+
"Running on local URL: http://127.0.0.1:7897\n",
|
13 |
"\n",
|
14 |
"To create a public link, set `share=True` in `launch()`.\n"
|
15 |
]
|
|
|
17 |
{
|
18 |
"data": {
|
19 |
"text/html": [
|
20 |
+
"<div><iframe src=\"http://127.0.0.1:7897/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
|
21 |
],
|
22 |
"text/plain": [
|
23 |
"<IPython.core.display.HTML object>"
|
|
|
30 |
"data": {
|
31 |
"text/plain": []
|
32 |
},
|
33 |
+
"execution_count": 51,
|
34 |
"metadata": {},
|
35 |
"output_type": "execute_result"
|
36 |
},
|
|
|
101 |
"name": "stderr",
|
102 |
"output_type": "stream",
|
103 |
"text": [
|
104 |
+
"/var/folders/_g/3q5q8_dj0ydcpktxlwxb5vrh0000gq/T/ipykernel_27324/1938231065.py:224: FutureWarning: The input object of type 'Tensor' is an array-like implementing one of the corresponding protocols (`__array__`, `__array_interface__` or `__array_struct__`); but not a sequence (or 0-D). In the future, this object will be coerced as if it was first converted using `np.array(obj)`. To retain the old behaviour, you have to either modify the type 'Tensor', or assign to an empty array created with `np.empty(correct_shape, dtype=object)`.\n",
|
105 |
" waveforms = np.array(waveforms)[selection_indexes]\n",
|
106 |
+
"/var/folders/_g/3q5q8_dj0ydcpktxlwxb5vrh0000gq/T/ipykernel_27324/1938231065.py:224: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.\n",
|
107 |
" waveforms = np.array(waveforms)[selection_indexes]\n",
|
108 |
+
"/var/folders/_g/3q5q8_dj0ydcpktxlwxb5vrh0000gq/T/ipykernel_27324/1938231065.py:231: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
109 |
" waveforms = [torch.tensor(waveform) for waveform in waveforms]\n"
|
110 |
]
|
111 |
+
},
|
112 |
+
{
|
113 |
+
"name": "stdout",
|
114 |
+
"output_type": "stream",
|
115 |
+
"text": [
|
116 |
+
"Starting plotting 3 waveforms\n",
|
117 |
+
"Fetching topography\n",
|
118 |
+
"Plotting topo\n"
|
119 |
+
]
|
120 |
+
},
|
121 |
+
{
|
122 |
+
"name": "stderr",
|
123 |
+
"output_type": "stream",
|
124 |
+
"text": [
|
125 |
+
"/Users/anovosel/miniconda3/envs/phasehunter/lib/python3.11/site-packages/bmi_topography/api_key.py:49: UserWarning: You are using a demo key to fetch data from OpenTopography, functionality will be limited. See https://bmi-topography.readthedocs.io/en/latest/#api-key for more information.\n",
|
126 |
+
" warnings.warn(\n"
|
127 |
+
]
|
128 |
+
},
|
129 |
+
{
|
130 |
+
"name": "stdout",
|
131 |
+
"output_type": "stream",
|
132 |
+
"text": [
|
133 |
+
"Plotting waveform 1/3\n",
|
134 |
+
"Station 36.11758, -117.85486 has P velocity 4.987805380766392 and S velocity 2.9782985042350987\n",
|
135 |
+
"Plotting waveform 2/3\n"
|
136 |
+
]
|
137 |
+
},
|
138 |
+
{
|
139 |
+
"name": "stderr",
|
140 |
+
"output_type": "stream",
|
141 |
+
"text": [
|
142 |
+
"/var/folders/_g/3q5q8_dj0ydcpktxlwxb5vrh0000gq/T/ipykernel_27324/1938231065.py:299: FutureWarning: The frame.append method is deprecated and will be removed from pandas in a future version. Use pandas.concat instead.\n",
|
143 |
+
" output_picks = output_picks.append(pd.DataFrame({'station_name': [names[i]], 'starttime' : [str(t0s[i])],\n",
|
144 |
+
"/var/folders/_g/3q5q8_dj0ydcpktxlwxb5vrh0000gq/T/ipykernel_27324/1938231065.py:299: FutureWarning: The frame.append method is deprecated and will be removed from pandas in a future version. Use pandas.concat instead.\n",
|
145 |
+
" output_picks = output_picks.append(pd.DataFrame({'station_name': [names[i]], 'starttime' : [str(t0s[i])],\n"
|
146 |
+
]
|
147 |
+
},
|
148 |
+
{
|
149 |
+
"name": "stdout",
|
150 |
+
"output_type": "stream",
|
151 |
+
"text": [
|
152 |
+
"Station 35.98249, -117.80885 has P velocity 4.255522557803516 and S velocity 2.2929437916670583\n",
|
153 |
+
"Plotting waveform 3/3\n",
|
154 |
+
"Station 35.69235, -117.75051 has P velocity 2.979034174961547 and S velocity 1.3728192788753049\n",
|
155 |
+
"Plotting stations\n"
|
156 |
+
]
|
157 |
+
},
|
158 |
+
{
|
159 |
+
"name": "stderr",
|
160 |
+
"output_type": "stream",
|
161 |
+
"text": [
|
162 |
+
"/var/folders/_g/3q5q8_dj0ydcpktxlwxb5vrh0000gq/T/ipykernel_27324/1938231065.py:299: FutureWarning: The frame.append method is deprecated and will be removed from pandas in a future version. Use pandas.concat instead.\n",
|
163 |
+
" output_picks = output_picks.append(pd.DataFrame({'station_name': [names[i]], 'starttime' : [str(t0s[i])],\n",
|
164 |
+
"/var/folders/_g/3q5q8_dj0ydcpktxlwxb5vrh0000gq/T/ipykernel_27324/1938231065.py:324: UserWarning: FixedFormatter should only be used together with FixedLocator\n",
|
165 |
+
" ax[i].set_xticklabels(ax[i].get_xticks(), rotation = 50)\n"
|
166 |
+
]
|
167 |
}
|
168 |
],
|
169 |
"source": [
|
|
|
190 |
"\n",
|
191 |
"import matplotlib.pyplot as plt\n",
|
192 |
"import matplotlib.dates as mdates\n",
|
193 |
+
"from mpl_toolkits.axes_grid1 import ImageGrid\n",
|
194 |
"\n",
|
195 |
"from glob import glob\n",
|
196 |
"\n",
|
|
|
350 |
" \n",
|
351 |
" waveform = waveform.select(channel=\"H[BH][ZNE]\")\n",
|
352 |
" waveform = waveform.merge(fill_value=0)\n",
|
353 |
+
" waveform = waveform[:3].sort(keys=['channel'], reverse=True)\n",
|
354 |
+
"\n",
|
355 |
" len_check = [len(x.data) for x in waveform]\n",
|
356 |
" if len(set(len_check)) > 1:\n",
|
357 |
" continue\n",
|
|
|
412 |
" s_max_confidence = np.min([s_phases[i::len(waveforms)].std() for i in range(len(waveforms))])\n",
|
413 |
"\n",
|
414 |
" print(f\"Starting plotting {len(waveforms)} waveforms\")\n",
|
415 |
+
" fig, ax = plt.subplots(ncols=3, figsize=(10, 3))\n",
|
416 |
+
" \n",
|
417 |
" # Plot topography\n",
|
418 |
" print('Fetching topography')\n",
|
419 |
" params = Topography.DEFAULT.copy()\n",
|
|
|
458 |
" ax[0].scatter(x[int(current_S.mean()*waveforms[i][0].shape[-1])], waveforms[i][0, 0].mean()+distances[i]*111.2, color='b', alpha=s_conf, marker='|')\n",
|
459 |
" ax[0].set_ylabel('Z')\n",
|
460 |
"\n",
|
|
|
|
|
|
|
461 |
" delta_t = t0s[i].timestamp - obspy.UTCDateTime(timestamp).timestamp\n",
|
462 |
"\n",
|
463 |
" velocity_p = (distances[i]*111.2)/(delta_t+current_P.mean()*60).item()\n",
|
|
|
475 |
" y = np.linspace(st_lats[i], eq_lat, 50)\n",
|
476 |
" \n",
|
477 |
" # Plot the array\n",
|
478 |
+
" ax[1].scatter(x, y, c=np.zeros_like(x)+velocity_p, alpha=0.1, vmin=0, vmax=8)\n",
|
479 |
+
" ax[2].scatter(x, y, c=np.zeros_like(x)+velocity_s, alpha=0.1, vmin=0, vmax=8)\n",
|
480 |
"\n",
|
481 |
" # Add legend\n",
|
482 |
" ax[0].scatter(None, None, color='r', marker='|', label='P')\n",
|
483 |
" ax[0].scatter(None, None, color='b', marker='|', label='S')\n",
|
484 |
+
" ax[0].xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S'))\n",
|
485 |
+
" ax[0].xaxis.set_major_locator(mdates.SecondLocator(interval=20))\n",
|
486 |
" ax[0].legend()\n",
|
487 |
"\n",
|
488 |
" print('Plotting stations')\n",
|
489 |
" for i in range(1,3):\n",
|
490 |
" ax[i].scatter(st_lons, st_lats, color='b', label='Stations')\n",
|
491 |
" ax[i].scatter(eq_lon, eq_lat, color='r', marker='*', label='Earthquake')\n",
|
492 |
+
" ax[i].set_aspect('equal')\n",
|
493 |
+
" ax[i].set_xticklabels(ax[i].get_xticks(), rotation = 50)\n",
|
494 |
"\n",
|
495 |
+
" fig.subplots_adjust(bottom=0.1, top=0.9, left=0.1, right=0.8,\n",
|
496 |
+
" wspace=0.02, hspace=0.02)\n",
|
497 |
+
" \n",
|
498 |
+
" cb_ax = fig.add_axes([0.83, 0.1, 0.02, 0.8])\n",
|
499 |
+
" cbar = fig.colorbar(ax[2].scatter(None, None, c=velocity_p, alpha=0.5, vmin=0, vmax=8), cax=cb_ax)\n",
|
500 |
"\n",
|
501 |
+
" cbar.set_label('Velocity (km/s)')\n",
|
502 |
+
" ax[1].set_title('P Velocity')\n",
|
503 |
" ax[2].set_title('S Velocity')\n",
|
504 |
"\n",
|
505 |
+
" for a in ax:\n",
|
506 |
+
" a.tick_params(axis='both', which='major', labelsize=8)\n",
|
507 |
+
" \n",
|
508 |
" plt.subplots_adjust(hspace=0., wspace=0.5)\n",
|
|
|
509 |
" fig.canvas.draw();\n",
|
510 |
" image = np.array(fig.canvas.renderer.buffer_rgba())\n",
|
511 |
" plt.close(fig)\n",
|
|
|
527 |
" }\n",
|
528 |
"</style></h1> \n",
|
529 |
" \n",
|
|
|
530 |
" <p style=\"font-size: 16px; margin-bottom: 20px;\">Detect <span style=\"background-image: linear-gradient(to right, #ED213A, #93291E); \n",
|
531 |
" -webkit-background-clip: text;\n",
|
532 |
" -webkit-text-fill-color: transparent;\n",
|
|
|
575 |
" </div>\n",
|
576 |
" \"\"\")\n",
|
577 |
" with gr.Row(): \n",
|
578 |
+
" with gr.Column(scale=2):\n",
|
579 |
+
" client_inputs = gr.Dropdown(\n",
|
580 |
+
" choices = list(URL_MAPPINGS.keys()), \n",
|
581 |
+
" label=\"FDSN Client\", \n",
|
582 |
+
" info=\"Select one of the available FDSN clients\",\n",
|
583 |
+
" value = \"IRIS\",\n",
|
584 |
+
" interactive=True\n",
|
585 |
+
" )\n",
|
586 |
+
"\n",
|
587 |
+
" velocity_inputs = gr.Dropdown(\n",
|
588 |
+
" choices = ['1066a', '1066b', 'ak135', \n",
|
589 |
+
" 'ak135f', 'herrin', 'iasp91', \n",
|
590 |
+
" 'jb', 'prem', 'pwdk'], \n",
|
591 |
+
" label=\"1D velocity model\", \n",
|
592 |
+
" info=\"Velocity model for station selection\",\n",
|
593 |
+
" value = \"1066a\",\n",
|
594 |
+
" interactive=True\n",
|
595 |
+
" )\n",
|
596 |
"\n",
|
597 |
+
" with gr.Column(scale=2):\n",
|
598 |
+
" timestamp_inputs = gr.Textbox(value='2019-07-04 17:33:49',\n",
|
599 |
+
" placeholder='YYYY-MM-DD HH:MM:SS',\n",
|
600 |
+
" label=\"Timestamp\",\n",
|
601 |
+
" info=\"Timestamp of the earthquake\",\n",
|
602 |
+
" max_lines=1,\n",
|
603 |
+
" interactive=True)\n",
|
604 |
+
" \n",
|
605 |
+
" source_depth_inputs = gr.Number(value=10,\n",
|
606 |
+
" label=\"Source depth (km)\",\n",
|
607 |
+
" info=\"Depth of the earthquake\",\n",
|
608 |
+
" interactive=True)\n",
|
609 |
+
" \n",
|
610 |
+
" with gr.Column(scale=2):\n",
|
611 |
+
" eq_lat_inputs = gr.Number(value=35.766, \n",
|
612 |
+
" label=\"Latitude\", \n",
|
613 |
+
" info=\"Latitude of the earthquake\",\n",
|
614 |
+
" interactive=True)\n",
|
615 |
+
" \n",
|
616 |
+
" eq_lon_inputs = gr.Number(value=-117.605,\n",
|
617 |
+
" label=\"Longitude\",\n",
|
618 |
+
" info=\"Longitude of the earthquake\",\n",
|
619 |
" interactive=True)\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
620 |
" \n",
|
|
|
|
|
621 |
" with gr.Column(scale=2):\n",
|
622 |
+
" radius_inputs = gr.Slider(minimum=1, \n",
|
623 |
+
" maximum=200, \n",
|
624 |
+
" value=50, label=\"Radius (km)\", \n",
|
625 |
+
" step=10,\n",
|
626 |
+
" info=\"\"\"Select the radius around the earthquake to download data from.\\n \n",
|
627 |
+
" Note that the larger the radius, the longer the app will take to run.\"\"\",\n",
|
628 |
+
" interactive=True)\n",
|
629 |
+
" \n",
|
630 |
+
" max_waveforms_inputs = gr.Slider(minimum=1,\n",
|
631 |
+
" maximum=100,\n",
|
632 |
+
" value=10,\n",
|
633 |
+
" label=\"Max waveforms per section\",\n",
|
634 |
+
" step=1,\n",
|
635 |
+
" info=\"Maximum number of waveforms to show per section\\n (to avoid long prediction times)\",\n",
|
636 |
+
" interactive=True,\n",
|
637 |
+
" )\n",
|
|
|
638 |
" \n",
|
639 |
" button = gr.Button(\"Predict phases\")\n",
|
640 |
" output_image = gr.Image(label='Waveforms with Phases Marked', type='numpy', interactive=False)\n",
|
app.py
CHANGED
@@ -21,7 +21,7 @@ from obspy.clients.fdsn.header import URL_MAPPINGS
|
|
21 |
|
22 |
import matplotlib.pyplot as plt
|
23 |
import matplotlib.dates as mdates
|
24 |
-
from
|
25 |
|
26 |
from glob import glob
|
27 |
|
@@ -181,8 +181,8 @@ def predict_on_section(client_name, timestamp, eq_lat, eq_lon, radius_km, source
|
|
181 |
|
182 |
waveform = waveform.select(channel="H[BH][ZNE]")
|
183 |
waveform = waveform.merge(fill_value=0)
|
184 |
-
waveform = waveform[:3]
|
185 |
-
|
186 |
len_check = [len(x.data) for x in waveform]
|
187 |
if len(set(len_check)) > 1:
|
188 |
continue
|
@@ -243,8 +243,8 @@ def predict_on_section(client_name, timestamp, eq_lat, eq_lon, radius_km, source
|
|
243 |
s_max_confidence = np.min([s_phases[i::len(waveforms)].std() for i in range(len(waveforms))])
|
244 |
|
245 |
print(f"Starting plotting {len(waveforms)} waveforms")
|
246 |
-
fig, ax = plt.subplots(
|
247 |
-
|
248 |
# Plot topography
|
249 |
print('Fetching topography')
|
250 |
params = Topography.DEFAULT.copy()
|
@@ -289,9 +289,6 @@ def predict_on_section(client_name, timestamp, eq_lat, eq_lon, radius_km, source
|
|
289 |
ax[0].scatter(x[int(current_S.mean()*waveforms[i][0].shape[-1])], waveforms[i][0, 0].mean()+distances[i]*111.2, color='b', alpha=s_conf, marker='|')
|
290 |
ax[0].set_ylabel('Z')
|
291 |
|
292 |
-
ax[0].xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S'))
|
293 |
-
ax[0].xaxis.set_major_locator(mdates.SecondLocator(interval=20))
|
294 |
-
|
295 |
delta_t = t0s[i].timestamp - obspy.UTCDateTime(timestamp).timestamp
|
296 |
|
297 |
velocity_p = (distances[i]*111.2)/(delta_t+current_P.mean()*60).item()
|
@@ -309,30 +306,37 @@ def predict_on_section(client_name, timestamp, eq_lat, eq_lon, radius_km, source
|
|
309 |
y = np.linspace(st_lats[i], eq_lat, 50)
|
310 |
|
311 |
# Plot the array
|
312 |
-
ax[1].scatter(x, y, c=np.zeros_like(x)+velocity_p, alpha=0.
|
313 |
-
ax[2].scatter(x, y, c=np.zeros_like(x)+velocity_s, alpha=0.
|
314 |
|
315 |
# Add legend
|
316 |
ax[0].scatter(None, None, color='r', marker='|', label='P')
|
317 |
ax[0].scatter(None, None, color='b', marker='|', label='S')
|
|
|
|
|
318 |
ax[0].legend()
|
319 |
|
320 |
print('Plotting stations')
|
321 |
for i in range(1,3):
|
322 |
ax[i].scatter(st_lons, st_lats, color='b', label='Stations')
|
323 |
ax[i].scatter(eq_lon, eq_lat, color='r', marker='*', label='Earthquake')
|
|
|
|
|
324 |
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
|
|
329 |
|
330 |
-
cbar
|
331 |
-
|
332 |
ax[2].set_title('S Velocity')
|
333 |
|
|
|
|
|
|
|
334 |
plt.subplots_adjust(hspace=0., wspace=0.5)
|
335 |
-
|
336 |
fig.canvas.draw();
|
337 |
image = np.array(fig.canvas.renderer.buffer_rgba())
|
338 |
plt.close(fig)
|
@@ -354,7 +358,6 @@ with gr.Blocks() as demo:
|
|
354 |
}
|
355 |
</style></h1>
|
356 |
|
357 |
-
|
358 |
<p style="font-size: 16px; margin-bottom: 20px;">Detect <span style="background-image: linear-gradient(to right, #ED213A, #93291E);
|
359 |
-webkit-background-clip: text;
|
360 |
-webkit-text-fill-color: transparent;
|
@@ -393,77 +396,76 @@ with gr.Blocks() as demo:
|
|
393 |
button.click(mark_phases, inputs=[inputs, upload], outputs=outputs)
|
394 |
|
395 |
with gr.Tab("Select earthquake from catalogue"):
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
|
|
403 |
""")
|
404 |
with gr.Row():
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
|
|
422 |
|
423 |
-
with gr.Column(scale=
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
max_lines=1,
|
430 |
-
interactive=True)
|
431 |
-
|
432 |
-
eq_lat_inputs = gr.Number(value=35.766,
|
433 |
-
label="Latitude",
|
434 |
-
info="Latitude of the earthquake",
|
435 |
interactive=True)
|
436 |
-
|
437 |
-
eq_lon_inputs = gr.Number(value=-117.605,
|
438 |
-
label="Longitude",
|
439 |
-
info="Longitude of the earthquake",
|
440 |
-
interactive=True)
|
441 |
-
|
442 |
-
source_depth_inputs = gr.Number(value=10,
|
443 |
-
label="Source depth (km)",
|
444 |
-
info="Depth of the earthquake",
|
445 |
-
interactive=True)
|
446 |
|
447 |
-
|
448 |
-
|
|
|
|
|
|
|
449 |
with gr.Column(scale=2):
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
467 |
|
468 |
button = gr.Button("Predict phases")
|
469 |
output_image = gr.Image(label='Waveforms with Phases Marked', type='numpy', interactive=False)
|
|
|
21 |
|
22 |
import matplotlib.pyplot as plt
|
23 |
import matplotlib.dates as mdates
|
24 |
+
from mpl_toolkits.axes_grid1 import ImageGrid
|
25 |
|
26 |
from glob import glob
|
27 |
|
|
|
181 |
|
182 |
waveform = waveform.select(channel="H[BH][ZNE]")
|
183 |
waveform = waveform.merge(fill_value=0)
|
184 |
+
waveform = waveform[:3].sort(keys=['channel'], reverse=True)
|
185 |
+
|
186 |
len_check = [len(x.data) for x in waveform]
|
187 |
if len(set(len_check)) > 1:
|
188 |
continue
|
|
|
243 |
s_max_confidence = np.min([s_phases[i::len(waveforms)].std() for i in range(len(waveforms))])
|
244 |
|
245 |
print(f"Starting plotting {len(waveforms)} waveforms")
|
246 |
+
fig, ax = plt.subplots(ncols=3, figsize=(10, 3))
|
247 |
+
|
248 |
# Plot topography
|
249 |
print('Fetching topography')
|
250 |
params = Topography.DEFAULT.copy()
|
|
|
289 |
ax[0].scatter(x[int(current_S.mean()*waveforms[i][0].shape[-1])], waveforms[i][0, 0].mean()+distances[i]*111.2, color='b', alpha=s_conf, marker='|')
|
290 |
ax[0].set_ylabel('Z')
|
291 |
|
|
|
|
|
|
|
292 |
delta_t = t0s[i].timestamp - obspy.UTCDateTime(timestamp).timestamp
|
293 |
|
294 |
velocity_p = (distances[i]*111.2)/(delta_t+current_P.mean()*60).item()
|
|
|
306 |
y = np.linspace(st_lats[i], eq_lat, 50)
|
307 |
|
308 |
# Plot the array
|
309 |
+
ax[1].scatter(x, y, c=np.zeros_like(x)+velocity_p, alpha=0.1, vmin=0, vmax=8)
|
310 |
+
ax[2].scatter(x, y, c=np.zeros_like(x)+velocity_s, alpha=0.1, vmin=0, vmax=8)
|
311 |
|
312 |
# Add legend
|
313 |
ax[0].scatter(None, None, color='r', marker='|', label='P')
|
314 |
ax[0].scatter(None, None, color='b', marker='|', label='S')
|
315 |
+
ax[0].xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S'))
|
316 |
+
ax[0].xaxis.set_major_locator(mdates.SecondLocator(interval=20))
|
317 |
ax[0].legend()
|
318 |
|
319 |
print('Plotting stations')
|
320 |
for i in range(1,3):
|
321 |
ax[i].scatter(st_lons, st_lats, color='b', label='Stations')
|
322 |
ax[i].scatter(eq_lon, eq_lat, color='r', marker='*', label='Earthquake')
|
323 |
+
ax[i].set_aspect('equal')
|
324 |
+
ax[i].set_xticklabels(ax[i].get_xticks(), rotation = 50)
|
325 |
|
326 |
+
fig.subplots_adjust(bottom=0.1, top=0.9, left=0.1, right=0.8,
|
327 |
+
wspace=0.02, hspace=0.02)
|
328 |
+
|
329 |
+
cb_ax = fig.add_axes([0.83, 0.1, 0.02, 0.8])
|
330 |
+
cbar = fig.colorbar(ax[2].scatter(None, None, c=velocity_p, alpha=0.5, vmin=0, vmax=8), cax=cb_ax)
|
331 |
|
332 |
+
cbar.set_label('Velocity (km/s)')
|
333 |
+
ax[1].set_title('P Velocity')
|
334 |
ax[2].set_title('S Velocity')
|
335 |
|
336 |
+
for a in ax:
|
337 |
+
a.tick_params(axis='both', which='major', labelsize=8)
|
338 |
+
|
339 |
plt.subplots_adjust(hspace=0., wspace=0.5)
|
|
|
340 |
fig.canvas.draw();
|
341 |
image = np.array(fig.canvas.renderer.buffer_rgba())
|
342 |
plt.close(fig)
|
|
|
358 |
}
|
359 |
</style></h1>
|
360 |
|
|
|
361 |
<p style="font-size: 16px; margin-bottom: 20px;">Detect <span style="background-image: linear-gradient(to right, #ED213A, #93291E);
|
362 |
-webkit-background-clip: text;
|
363 |
-webkit-text-fill-color: transparent;
|
|
|
396 |
button.click(mark_phases, inputs=[inputs, upload], outputs=outputs)
|
397 |
|
398 |
with gr.Tab("Select earthquake from catalogue"):
|
399 |
+
|
400 |
+
gr.HTML("""
|
401 |
+
<div style="padding: 20px; border-radius: 10px; font-size: 16px;">
|
402 |
+
<p style="font-weight: bold; font-size: 24px; margin-bottom: 20px;">Using PhaseHunter to Analyze Seismic Waveforms</p>
|
403 |
+
<p>Select an earthquake from the global earthquake catalogue and the app will download the waveform from the FDSN client of your choice. The app will use a velocity model of your choice to select appropriate time windows for each station within a specified radius of the earthquake.</p>
|
404 |
+
<p>The app will then analyze the waveforms and mark the detected phases on the waveform. Pick data for each waveform is reported in seconds from the start of the waveform.</p>
|
405 |
+
<p>Velocities are derived from distance and travel time determined by PhaseHunter picks (<span style="font-style: italic;">v = distance/predicted_pick_time</span>). The background of the velocity plot is colored by DEM.</p>
|
406 |
+
</div>
|
407 |
""")
|
408 |
with gr.Row():
|
409 |
+
with gr.Column(scale=2):
|
410 |
+
client_inputs = gr.Dropdown(
|
411 |
+
choices = list(URL_MAPPINGS.keys()),
|
412 |
+
label="FDSN Client",
|
413 |
+
info="Select one of the available FDSN clients",
|
414 |
+
value = "IRIS",
|
415 |
+
interactive=True
|
416 |
+
)
|
417 |
+
|
418 |
+
velocity_inputs = gr.Dropdown(
|
419 |
+
choices = ['1066a', '1066b', 'ak135',
|
420 |
+
'ak135f', 'herrin', 'iasp91',
|
421 |
+
'jb', 'prem', 'pwdk'],
|
422 |
+
label="1D velocity model",
|
423 |
+
info="Velocity model for station selection",
|
424 |
+
value = "1066a",
|
425 |
+
interactive=True
|
426 |
+
)
|
427 |
|
428 |
+
with gr.Column(scale=2):
|
429 |
+
timestamp_inputs = gr.Textbox(value='2019-07-04 17:33:49',
|
430 |
+
placeholder='YYYY-MM-DD HH:MM:SS',
|
431 |
+
label="Timestamp",
|
432 |
+
info="Timestamp of the earthquake",
|
433 |
+
max_lines=1,
|
|
|
|
|
|
|
|
|
|
|
|
|
434 |
interactive=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
435 |
|
436 |
+
source_depth_inputs = gr.Number(value=10,
|
437 |
+
label="Source depth (km)",
|
438 |
+
info="Depth of the earthquake",
|
439 |
+
interactive=True)
|
440 |
+
|
441 |
with gr.Column(scale=2):
|
442 |
+
eq_lat_inputs = gr.Number(value=35.766,
|
443 |
+
label="Latitude",
|
444 |
+
info="Latitude of the earthquake",
|
445 |
+
interactive=True)
|
446 |
+
|
447 |
+
eq_lon_inputs = gr.Number(value=-117.605,
|
448 |
+
label="Longitude",
|
449 |
+
info="Longitude of the earthquake",
|
450 |
+
interactive=True)
|
451 |
+
|
452 |
+
with gr.Column(scale=2):
|
453 |
+
radius_inputs = gr.Slider(minimum=1,
|
454 |
+
maximum=200,
|
455 |
+
value=50, label="Radius (km)",
|
456 |
+
step=10,
|
457 |
+
info="""Select the radius around the earthquake to download data from.\n
|
458 |
+
Note that the larger the radius, the longer the app will take to run.""",
|
459 |
+
interactive=True)
|
460 |
+
|
461 |
+
max_waveforms_inputs = gr.Slider(minimum=1,
|
462 |
+
maximum=100,
|
463 |
+
value=10,
|
464 |
+
label="Max waveforms per section",
|
465 |
+
step=1,
|
466 |
+
info="Maximum number of waveforms to show per section\n (to avoid long prediction times)",
|
467 |
+
interactive=True,
|
468 |
+
)
|
469 |
|
470 |
button = gr.Button("Predict phases")
|
471 |
output_image = gr.Image(label='Waveforms with Phases Marked', type='numpy', interactive=False)
|
phasehunter/model.py
DELETED
@@ -1,313 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
import torch
|
3 |
-
import torch.nn.functional as F
|
4 |
-
from torch import nn
|
5 |
-
from torchmetrics import MeanAbsoluteError
|
6 |
-
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
7 |
-
|
8 |
-
import lightning as pl
|
9 |
-
|
10 |
-
class BlurPool1D(nn.Module):
|
11 |
-
def __init__(self, channels, pad_type="reflect", filt_size=3, stride=2, pad_off=0):
|
12 |
-
super(BlurPool1D, self).__init__()
|
13 |
-
self.filt_size = filt_size
|
14 |
-
self.pad_off = pad_off
|
15 |
-
self.pad_sizes = [
|
16 |
-
int(1.0 * (filt_size - 1) / 2),
|
17 |
-
int(np.ceil(1.0 * (filt_size - 1) / 2)),
|
18 |
-
]
|
19 |
-
self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes]
|
20 |
-
self.stride = stride
|
21 |
-
self.off = int((self.stride - 1) / 2.0)
|
22 |
-
self.channels = channels
|
23 |
-
|
24 |
-
# print('Filter size [%i]' % filt_size)
|
25 |
-
if self.filt_size == 1:
|
26 |
-
a = np.array(
|
27 |
-
[
|
28 |
-
1.0,
|
29 |
-
]
|
30 |
-
)
|
31 |
-
elif self.filt_size == 2:
|
32 |
-
a = np.array([1.0, 1.0])
|
33 |
-
elif self.filt_size == 3:
|
34 |
-
a = np.array([1.0, 2.0, 1.0])
|
35 |
-
elif self.filt_size == 4:
|
36 |
-
a = np.array([1.0, 3.0, 3.0, 1.0])
|
37 |
-
elif self.filt_size == 5:
|
38 |
-
a = np.array([1.0, 4.0, 6.0, 4.0, 1.0])
|
39 |
-
elif self.filt_size == 6:
|
40 |
-
a = np.array([1.0, 5.0, 10.0, 10.0, 5.0, 1.0])
|
41 |
-
elif self.filt_size == 7:
|
42 |
-
a = np.array([1.0, 6.0, 15.0, 20.0, 15.0, 6.0, 1.0])
|
43 |
-
|
44 |
-
filt = torch.Tensor(a)
|
45 |
-
filt = filt / torch.sum(filt)
|
46 |
-
self.register_buffer("filt", filt[None, None, :].repeat((self.channels, 1, 1)))
|
47 |
-
|
48 |
-
self.pad = get_pad_layer_1d(pad_type)(self.pad_sizes)
|
49 |
-
|
50 |
-
def forward(self, inp):
|
51 |
-
if self.filt_size == 1:
|
52 |
-
if self.pad_off == 0:
|
53 |
-
return inp[:, :, :: self.stride]
|
54 |
-
else:
|
55 |
-
return self.pad(inp)[:, :, :: self.stride]
|
56 |
-
else:
|
57 |
-
return F.conv1d(
|
58 |
-
self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1]
|
59 |
-
)
|
60 |
-
|
61 |
-
|
62 |
-
def get_pad_layer_1d(pad_type):
|
63 |
-
if pad_type in ["refl", "reflect"]:
|
64 |
-
PadLayer = nn.ReflectionPad1d
|
65 |
-
elif pad_type in ["repl", "replicate"]:
|
66 |
-
PadLayer = nn.ReplicationPad1d
|
67 |
-
elif pad_type == "zero":
|
68 |
-
PadLayer = nn.ZeroPad1d
|
69 |
-
else:
|
70 |
-
print("Pad type [%s] not recognized" % pad_type)
|
71 |
-
return PadLayer
|
72 |
-
|
73 |
-
|
74 |
-
from masksembles import common
|
75 |
-
|
76 |
-
|
77 |
-
class Masksembles1D(nn.Module):
|
78 |
-
def __init__(self, channels: int, n: int, scale: float):
|
79 |
-
super().__init__()
|
80 |
-
|
81 |
-
self.channels = channels
|
82 |
-
self.n = n
|
83 |
-
self.scale = scale
|
84 |
-
|
85 |
-
masks = common.generation_wrapper(channels, n, scale)
|
86 |
-
masks = torch.from_numpy(masks)
|
87 |
-
|
88 |
-
self.masks = torch.nn.Parameter(masks, requires_grad=False)
|
89 |
-
|
90 |
-
def forward(self, inputs):
|
91 |
-
batch = inputs.shape[0]
|
92 |
-
x = torch.split(inputs.unsqueeze(1), batch // self.n, dim=0)
|
93 |
-
x = torch.cat(x, dim=1).permute([1, 0, 2, 3])
|
94 |
-
x = x * self.masks.unsqueeze(1).unsqueeze(-1)
|
95 |
-
x = torch.cat(torch.split(x, 1, dim=0), dim=1)
|
96 |
-
|
97 |
-
return x.squeeze(0).type(inputs.dtype)
|
98 |
-
|
99 |
-
|
100 |
-
class BasicBlock(nn.Module):
|
101 |
-
expansion = 1
|
102 |
-
|
103 |
-
def __init__(self, in_planes, planes, stride=1, kernel_size=7, groups=1):
|
104 |
-
super(BasicBlock, self).__init__()
|
105 |
-
self.conv1 = nn.Conv1d(
|
106 |
-
in_planes,
|
107 |
-
planes,
|
108 |
-
kernel_size=kernel_size,
|
109 |
-
stride=stride,
|
110 |
-
padding="same",
|
111 |
-
bias=False,
|
112 |
-
)
|
113 |
-
self.bn1 = nn.BatchNorm1d(planes)
|
114 |
-
self.conv2 = nn.Conv1d(
|
115 |
-
planes,
|
116 |
-
planes,
|
117 |
-
kernel_size=kernel_size,
|
118 |
-
stride=1,
|
119 |
-
padding="same",
|
120 |
-
bias=False,
|
121 |
-
)
|
122 |
-
self.bn2 = nn.BatchNorm1d(planes)
|
123 |
-
|
124 |
-
self.shortcut = nn.Sequential(
|
125 |
-
nn.Conv1d(
|
126 |
-
in_planes,
|
127 |
-
self.expansion * planes,
|
128 |
-
kernel_size=1,
|
129 |
-
stride=stride,
|
130 |
-
padding="same",
|
131 |
-
bias=False,
|
132 |
-
),
|
133 |
-
nn.BatchNorm1d(self.expansion * planes),
|
134 |
-
)
|
135 |
-
|
136 |
-
def forward(self, x):
|
137 |
-
out = F.relu(self.bn1(self.conv1(x)))
|
138 |
-
out = self.bn2(self.conv2(out))
|
139 |
-
out += self.shortcut(x)
|
140 |
-
out = F.relu(out)
|
141 |
-
return out
|
142 |
-
|
143 |
-
|
144 |
-
class Updated_onset_picker(nn.Module):
|
145 |
-
def __init__(
|
146 |
-
self,
|
147 |
-
):
|
148 |
-
super().__init__()
|
149 |
-
|
150 |
-
# self.activation = nn.ReLU()
|
151 |
-
# self.maxpool = nn.MaxPool1d(2)
|
152 |
-
|
153 |
-
self.n_masks = 128
|
154 |
-
|
155 |
-
self.block1 = nn.Sequential(
|
156 |
-
BasicBlock(3, 8, kernel_size=7, groups=1),
|
157 |
-
nn.GELU(),
|
158 |
-
BlurPool1D(8, filt_size=3, stride=2),
|
159 |
-
nn.GroupNorm(2, 8),
|
160 |
-
)
|
161 |
-
|
162 |
-
self.block2 = nn.Sequential(
|
163 |
-
BasicBlock(8, 16, kernel_size=7, groups=8),
|
164 |
-
nn.GELU(),
|
165 |
-
BlurPool1D(16, filt_size=3, stride=2),
|
166 |
-
nn.GroupNorm(2, 16),
|
167 |
-
)
|
168 |
-
|
169 |
-
self.block3 = nn.Sequential(
|
170 |
-
BasicBlock(16, 32, kernel_size=7, groups=16),
|
171 |
-
nn.GELU(),
|
172 |
-
BlurPool1D(32, filt_size=3, stride=2),
|
173 |
-
nn.GroupNorm(2, 32),
|
174 |
-
)
|
175 |
-
|
176 |
-
self.block4 = nn.Sequential(
|
177 |
-
BasicBlock(32, 64, kernel_size=7, groups=32),
|
178 |
-
nn.GELU(),
|
179 |
-
BlurPool1D(64, filt_size=3, stride=2),
|
180 |
-
nn.GroupNorm(2, 64),
|
181 |
-
)
|
182 |
-
|
183 |
-
self.block5 = nn.Sequential(
|
184 |
-
BasicBlock(64, 128, kernel_size=7, groups=64),
|
185 |
-
nn.GELU(),
|
186 |
-
BlurPool1D(128, filt_size=3, stride=2),
|
187 |
-
nn.GroupNorm(2, 128),
|
188 |
-
)
|
189 |
-
|
190 |
-
self.block6 = nn.Sequential(
|
191 |
-
Masksembles1D(128, self.n_masks, 2.0),
|
192 |
-
BasicBlock(128, 256, kernel_size=7, groups=128),
|
193 |
-
nn.GELU(),
|
194 |
-
BlurPool1D(256, filt_size=3, stride=2),
|
195 |
-
nn.GroupNorm(2, 256),
|
196 |
-
)
|
197 |
-
|
198 |
-
self.block7 = nn.Sequential(
|
199 |
-
Masksembles1D(256, self.n_masks, 2.0),
|
200 |
-
BasicBlock(256, 512, kernel_size=7, groups=256),
|
201 |
-
BlurPool1D(512, filt_size=3, stride=2),
|
202 |
-
nn.GELU(),
|
203 |
-
nn.GroupNorm(2, 512),
|
204 |
-
)
|
205 |
-
|
206 |
-
self.block8 = nn.Sequential(
|
207 |
-
Masksembles1D(512, self.n_masks, 2.0),
|
208 |
-
BasicBlock(512, 1024, kernel_size=7, groups=512),
|
209 |
-
BlurPool1D(1024, filt_size=3, stride=2),
|
210 |
-
nn.GELU(),
|
211 |
-
nn.GroupNorm(2, 1024),
|
212 |
-
)
|
213 |
-
|
214 |
-
self.block9 = nn.Sequential(
|
215 |
-
Masksembles1D(1024, self.n_masks, 2.0),
|
216 |
-
BasicBlock(1024, 128, kernel_size=7, groups=128),
|
217 |
-
# BlurPool1D(512, filt_size=3, stride=2),
|
218 |
-
# nn.GELU(),
|
219 |
-
# nn.GroupNorm(2,512),
|
220 |
-
)
|
221 |
-
|
222 |
-
self.out = nn.Sequential(nn.Linear(3072, 2), nn.Sigmoid())
|
223 |
-
|
224 |
-
def forward(self, x):
|
225 |
-
# Feature extraction
|
226 |
-
|
227 |
-
x = self.block1(x)
|
228 |
-
x = self.block2(x)
|
229 |
-
|
230 |
-
x = self.block3(x)
|
231 |
-
x = self.block4(x)
|
232 |
-
|
233 |
-
x = self.block5(x)
|
234 |
-
x = self.block6(x)
|
235 |
-
|
236 |
-
x = self.block7(x)
|
237 |
-
x = self.block8(x)
|
238 |
-
|
239 |
-
x = self.block9(x)
|
240 |
-
|
241 |
-
# Regressor
|
242 |
-
x = x.flatten(start_dim=1)
|
243 |
-
x = self.out(x)
|
244 |
-
|
245 |
-
return x
|
246 |
-
|
247 |
-
class Onset_picker(pl.LightningModule):
|
248 |
-
def __init__(self, picker, learning_rate):
|
249 |
-
super().__init__()
|
250 |
-
self.picker = picker
|
251 |
-
self.learning_rate = learning_rate
|
252 |
-
self.save_hyperparameters(ignore=['picker'])
|
253 |
-
self.mae = MeanAbsoluteError()
|
254 |
-
|
255 |
-
def compute_loss(self, y, pick, mae_name=False):
|
256 |
-
y_filt = y[y != 0]
|
257 |
-
pick_filt = pick[y != 0]
|
258 |
-
if len(y_filt) > 0:
|
259 |
-
loss = F.l1_loss(y_filt, pick_filt.flatten())
|
260 |
-
if mae_name != False:
|
261 |
-
mae_phase = self.mae(y_filt, pick_filt.flatten())*60
|
262 |
-
self.log(f'MAE/{mae_name}_val', mae_phase, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
|
263 |
-
else:
|
264 |
-
loss = 0
|
265 |
-
return loss
|
266 |
-
|
267 |
-
def training_step(self, batch, batch_idx):
|
268 |
-
# training_step defines the train loop.
|
269 |
-
x, y_p, y_s = batch
|
270 |
-
# x, y_p, y_s, y_pg, y_sg, y_pn, y_sn = batch
|
271 |
-
|
272 |
-
picks = self.picker(x)
|
273 |
-
|
274 |
-
p_pick = picks[:,0]
|
275 |
-
s_pick = picks[:,1]
|
276 |
-
|
277 |
-
p_loss = self.compute_loss(y_p, p_pick)
|
278 |
-
s_loss = self.compute_loss(y_s, s_pick)
|
279 |
-
|
280 |
-
loss = (p_loss+s_loss)/2
|
281 |
-
|
282 |
-
self.log('Loss/train', loss, on_step=True, on_epoch=False, prog_bar=True, sync_dist=True)
|
283 |
-
|
284 |
-
return loss
|
285 |
-
|
286 |
-
def validation_step(self, batch, batch_idx):
|
287 |
-
|
288 |
-
x, y_p, y_s = batch
|
289 |
-
|
290 |
-
picks = self.picker(x)
|
291 |
-
|
292 |
-
p_pick = picks[:,0]
|
293 |
-
s_pick = picks[:,1]
|
294 |
-
|
295 |
-
p_loss = self.compute_loss(y_p, p_pick, mae_name='P')
|
296 |
-
s_loss = self.compute_loss(y_s, s_pick, mae_name='S')
|
297 |
-
|
298 |
-
loss = (p_loss+s_loss)/2
|
299 |
-
|
300 |
-
self.log('Loss/val', loss, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
|
301 |
-
|
302 |
-
return loss
|
303 |
-
|
304 |
-
def configure_optimizers(self):
|
305 |
-
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
|
306 |
-
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10, cooldown=10, threshold=1e-3)
|
307 |
-
# scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, 3e-4, epochs=300, steps_per_epoch=len(train_loader))
|
308 |
-
monitor = 'Loss/train'
|
309 |
-
return {"optimizer": optimizer, "lr_scheduler": scheduler, 'monitor': monitor}
|
310 |
-
|
311 |
-
def forward(self, x):
|
312 |
-
picks = self.picker(x)
|
313 |
-
return picks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
phasehunter/training.py
DELETED
@@ -1,104 +0,0 @@
|
|
1 |
-
|
2 |
-
import torch
|
3 |
-
|
4 |
-
from data_preparation import augment, collation_fn, my_split_by_node
|
5 |
-
from model import Onset_picker, Updated_onset_picker
|
6 |
-
|
7 |
-
import webdataset as wds
|
8 |
-
|
9 |
-
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
|
10 |
-
from lightning.pytorch.loggers.tensorboard import TensorBoardLogger
|
11 |
-
from lightning.pytorch.strategies import DDPStrategy
|
12 |
-
from lightning import seed_everything
|
13 |
-
import lightning as pl
|
14 |
-
|
15 |
-
seed_everything(42, workers=False)
|
16 |
-
torch.set_float32_matmul_precision('medium')
|
17 |
-
|
18 |
-
batch_size = 256
|
19 |
-
num_workers = 16 #int(os.cpu_count())
|
20 |
-
n_iters_in_epoch = 5000
|
21 |
-
|
22 |
-
train_dataset = (
|
23 |
-
wds.WebDataset("data/sample/shard-00{0000..0001}.tar",
|
24 |
-
# splitter=my_split_by_worker,
|
25 |
-
nodesplitter=my_split_by_node)
|
26 |
-
.decode()
|
27 |
-
.map(augment)
|
28 |
-
.shuffle(5000)
|
29 |
-
.batched(batchsize=batch_size,
|
30 |
-
collation_fn=collation_fn,
|
31 |
-
partial=False
|
32 |
-
)
|
33 |
-
).with_epoch(n_iters_in_epoch//num_workers)
|
34 |
-
|
35 |
-
|
36 |
-
val_dataset = (
|
37 |
-
wds.WebDataset("data/sample/shard-00{0000..0000}.tar",
|
38 |
-
# splitter=my_split_by_worker,
|
39 |
-
nodesplitter=my_split_by_node)
|
40 |
-
.decode()
|
41 |
-
.map(augment)
|
42 |
-
.repeat()
|
43 |
-
.batched(batchsize=batch_size,
|
44 |
-
collation_fn=collation_fn,
|
45 |
-
partial=False
|
46 |
-
)
|
47 |
-
).with_epoch(100)
|
48 |
-
|
49 |
-
|
50 |
-
train_loader = wds.WebLoader(train_dataset,
|
51 |
-
num_workers=num_workers,
|
52 |
-
shuffle=False,
|
53 |
-
pin_memory=True,
|
54 |
-
batch_size=None)
|
55 |
-
|
56 |
-
val_loader = wds.WebLoader(val_dataset,
|
57 |
-
num_workers=0,
|
58 |
-
shuffle=False,
|
59 |
-
pin_memory=True,
|
60 |
-
batch_size=None)
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
# model
|
65 |
-
model = Onset_picker(picker=Updated_onset_picker(),
|
66 |
-
learning_rate=3e-4)
|
67 |
-
# model = torch.compile(model, mode="reduce-overhead")
|
68 |
-
|
69 |
-
logger = TensorBoardLogger("tensorboard_logdir", name="FAST")
|
70 |
-
|
71 |
-
checkpoint_callback = ModelCheckpoint(save_top_k=1, monitor="Loss/val", filename="chkp-{epoch:02d}")
|
72 |
-
lr_callback = LearningRateMonitor(logging_interval='epoch')
|
73 |
-
# swa_callback = StochasticWeightAveraging(swa_lrs=0.05)
|
74 |
-
|
75 |
-
# # train model
|
76 |
-
trainer = pl.Trainer(
|
77 |
-
precision='16-mixed',
|
78 |
-
|
79 |
-
callbacks=[checkpoint_callback, lr_callback],
|
80 |
-
|
81 |
-
devices='auto',
|
82 |
-
accelerator='auto',
|
83 |
-
|
84 |
-
strategy=DDPStrategy(find_unused_parameters=False,
|
85 |
-
static_graph=True,
|
86 |
-
gradient_as_bucket_view=True),
|
87 |
-
benchmark=True,
|
88 |
-
|
89 |
-
gradient_clip_val=0.5,
|
90 |
-
# ckpt_path='path/to/saved/checkpoints/chkp.ckpt',
|
91 |
-
|
92 |
-
# fast_dev_run=True,
|
93 |
-
|
94 |
-
logger=logger,
|
95 |
-
log_every_n_steps=50,
|
96 |
-
enable_progress_bar=True,
|
97 |
-
|
98 |
-
max_epochs=300,
|
99 |
-
)
|
100 |
-
|
101 |
-
trainer.fit(model=model,
|
102 |
-
train_dataloaders=train_loader,
|
103 |
-
val_dataloaders=val_loader,
|
104 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|