Spaces:
Runtime error
Runtime error
Now original waveform is displayed
Browse files- Gradio_app.ipynb +16 -20
- app.py +7 -6
Gradio_app.ipynb
CHANGED
@@ -2,14 +2,14 @@
|
|
2 |
"cells": [
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
-
"execution_count":
|
6 |
"metadata": {},
|
7 |
"outputs": [
|
8 |
{
|
9 |
"name": "stdout",
|
10 |
"output_type": "stream",
|
11 |
"text": [
|
12 |
-
"Running on local URL: http://127.0.0.1:
|
13 |
"\n",
|
14 |
"To create a public link, set `share=True` in `launch()`.\n"
|
15 |
]
|
@@ -17,7 +17,7 @@
|
|
17 |
{
|
18 |
"data": {
|
19 |
"text/html": [
|
20 |
-
"<div><iframe src=\"http://127.0.0.1:
|
21 |
],
|
22 |
"text/plain": [
|
23 |
"<IPython.core.display.HTML object>"
|
@@ -30,22 +30,16 @@
|
|
30 |
"data": {
|
31 |
"text/plain": []
|
32 |
},
|
33 |
-
"execution_count":
|
34 |
"metadata": {},
|
35 |
"output_type": "execute_result"
|
36 |
},
|
37 |
-
{
|
38 |
-
"name": "stderr",
|
39 |
-
"output_type": "stream",
|
40 |
-
"text": [
|
41 |
-
"No artists with labels found to put in legend. Note that artists whose label start with an underscore are ignored when legend() is called with no argument.\n"
|
42 |
-
]
|
43 |
-
},
|
44 |
{
|
45 |
"name": "stdout",
|
46 |
"output_type": "stream",
|
47 |
"text": [
|
48 |
-
"
|
|
|
49 |
]
|
50 |
}
|
51 |
],
|
@@ -83,8 +77,9 @@
|
|
83 |
" if len(waveform.shape) == 1:\n",
|
84 |
" waveform = waveform.reshape(1, waveform.shape[0])\n",
|
85 |
"\n",
|
|
|
86 |
" processed_input = prepare_waveform(waveform)\n",
|
87 |
-
"
|
88 |
" # Make prediction\n",
|
89 |
" with torch.inference_mode():\n",
|
90 |
" output = model(processed_input)\n",
|
@@ -92,33 +87,34 @@
|
|
92 |
" p_phase = output[:, 0]\n",
|
93 |
" s_phase = output[:, 1]\n",
|
94 |
"\n",
|
95 |
-
" return processed_input, p_phase, s_phase\n",
|
|
|
96 |
"\n",
|
97 |
"def mark_phases(waveform, uploaded_file, p_thres, s_thres):\n",
|
98 |
"\n",
|
99 |
" if uploaded_file is not None:\n",
|
100 |
" waveform = uploaded_file.name\n",
|
101 |
"\n",
|
102 |
-
" processed_input, p_phase, s_phase = make_prediction(waveform)\n",
|
103 |
"\n",
|
104 |
" # Create a plot of the waveform with the phases marked\n",
|
105 |
" if sum(processed_input[0][2] == 0): #if input is 1C\n",
|
106 |
" fig, ax = plt.subplots(nrows=2, figsize=(10, 2), sharex=True)\n",
|
107 |
"\n",
|
108 |
-
" ax[0].plot(
|
109 |
" ax[0].set_ylabel('Norm. Ampl.')\n",
|
110 |
"\n",
|
111 |
" else: #if input is 3C\n",
|
112 |
" fig, ax = plt.subplots(nrows=4, figsize=(10, 6), sharex=True)\n",
|
113 |
-
" ax[0].plot(
|
114 |
-
" ax[1].plot(
|
115 |
-
" ax[2].plot(
|
116 |
"\n",
|
117 |
" ax[0].set_ylabel('Z')\n",
|
118 |
" ax[1].set_ylabel('N')\n",
|
119 |
" ax[2].set_ylabel('E')\n",
|
120 |
"\n",
|
121 |
-
"
|
122 |
" do_we_have_p = (p_phase.std().item()*60 < p_thres)\n",
|
123 |
" if do_we_have_p:\n",
|
124 |
" p_phase_plot = p_phase*processed_input.shape[-1]\n",
|
|
|
2 |
"cells": [
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
+
"execution_count": 16,
|
6 |
"metadata": {},
|
7 |
"outputs": [
|
8 |
{
|
9 |
"name": "stdout",
|
10 |
"output_type": "stream",
|
11 |
"text": [
|
12 |
+
"Running on local URL: http://127.0.0.1:7869\n",
|
13 |
"\n",
|
14 |
"To create a public link, set `share=True` in `launch()`.\n"
|
15 |
]
|
|
|
17 |
{
|
18 |
"data": {
|
19 |
"text/html": [
|
20 |
+
"<div><iframe src=\"http://127.0.0.1:7869/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
|
21 |
],
|
22 |
"text/plain": [
|
23 |
"<IPython.core.display.HTML object>"
|
|
|
30 |
"data": {
|
31 |
"text/plain": []
|
32 |
},
|
33 |
+
"execution_count": 16,
|
34 |
"metadata": {},
|
35 |
"output_type": "execute_result"
|
36 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
{
|
38 |
"name": "stdout",
|
39 |
"output_type": "stream",
|
40 |
"text": [
|
41 |
+
"4\n",
|
42 |
+
"0.02744414610788226\n"
|
43 |
]
|
44 |
}
|
45 |
],
|
|
|
77 |
" if len(waveform.shape) == 1:\n",
|
78 |
" waveform = waveform.reshape(1, waveform.shape[0])\n",
|
79 |
"\n",
|
80 |
+
" orig_waveform = waveform[:, :6000].copy()\n",
|
81 |
" processed_input = prepare_waveform(waveform)\n",
|
82 |
+
"\n",
|
83 |
" # Make prediction\n",
|
84 |
" with torch.inference_mode():\n",
|
85 |
" output = model(processed_input)\n",
|
|
|
87 |
" p_phase = output[:, 0]\n",
|
88 |
" s_phase = output[:, 1]\n",
|
89 |
"\n",
|
90 |
+
" return processed_input, p_phase, s_phase, orig_waveform\n",
|
91 |
+
"\n",
|
92 |
"\n",
|
93 |
"def mark_phases(waveform, uploaded_file, p_thres, s_thres):\n",
|
94 |
"\n",
|
95 |
" if uploaded_file is not None:\n",
|
96 |
" waveform = uploaded_file.name\n",
|
97 |
"\n",
|
98 |
+
" processed_input, p_phase, s_phase, orig_waveform = make_prediction(waveform)\n",
|
99 |
"\n",
|
100 |
" # Create a plot of the waveform with the phases marked\n",
|
101 |
" if sum(processed_input[0][2] == 0): #if input is 1C\n",
|
102 |
" fig, ax = plt.subplots(nrows=2, figsize=(10, 2), sharex=True)\n",
|
103 |
"\n",
|
104 |
+
" ax[0].plot(orig_waveform[0], color='black', lw=1)\n",
|
105 |
" ax[0].set_ylabel('Norm. Ampl.')\n",
|
106 |
"\n",
|
107 |
" else: #if input is 3C\n",
|
108 |
" fig, ax = plt.subplots(nrows=4, figsize=(10, 6), sharex=True)\n",
|
109 |
+
" ax[0].plot(orig_waveform[0], color='black', lw=1)\n",
|
110 |
+
" ax[1].plot(orig_waveform[1], color='black', lw=1)\n",
|
111 |
+
" ax[2].plot(orig_waveform[2], color='black', lw=1)\n",
|
112 |
"\n",
|
113 |
" ax[0].set_ylabel('Z')\n",
|
114 |
" ax[1].set_ylabel('N')\n",
|
115 |
" ax[2].set_ylabel('E')\n",
|
116 |
"\n",
|
117 |
+
"\n",
|
118 |
" do_we_have_p = (p_phase.std().item()*60 < p_thres)\n",
|
119 |
" if do_we_have_p:\n",
|
120 |
" p_phase_plot = p_phase*processed_input.shape[-1]\n",
|
app.py
CHANGED
@@ -36,6 +36,7 @@ def make_prediction(waveform):
|
|
36 |
if len(waveform.shape) == 1:
|
37 |
waveform = waveform.reshape(1, waveform.shape[0])
|
38 |
|
|
|
39 |
processed_input = prepare_waveform(waveform)
|
40 |
|
41 |
# Make prediction
|
@@ -45,7 +46,7 @@ def make_prediction(waveform):
|
|
45 |
p_phase = output[:, 0]
|
46 |
s_phase = output[:, 1]
|
47 |
|
48 |
-
return processed_input, p_phase, s_phase
|
49 |
|
50 |
|
51 |
def mark_phases(waveform, uploaded_file, p_thres, s_thres):
|
@@ -53,20 +54,20 @@ def mark_phases(waveform, uploaded_file, p_thres, s_thres):
|
|
53 |
if uploaded_file is not None:
|
54 |
waveform = uploaded_file.name
|
55 |
|
56 |
-
processed_input, p_phase, s_phase = make_prediction(waveform)
|
57 |
|
58 |
# Create a plot of the waveform with the phases marked
|
59 |
if sum(processed_input[0][2] == 0): # if input is 1C
|
60 |
fig, ax = plt.subplots(nrows=2, figsize=(10, 2), sharex=True)
|
61 |
|
62 |
-
ax[0].plot(
|
63 |
ax[0].set_ylabel("Norm. Ampl.")
|
64 |
|
65 |
else: # if input is 3C
|
66 |
fig, ax = plt.subplots(nrows=4, figsize=(10, 6), sharex=True)
|
67 |
-
ax[0].plot(
|
68 |
-
ax[1].plot(
|
69 |
-
ax[2].plot(
|
70 |
|
71 |
ax[0].set_ylabel("Z")
|
72 |
ax[1].set_ylabel("N")
|
|
|
36 |
if len(waveform.shape) == 1:
|
37 |
waveform = waveform.reshape(1, waveform.shape[0])
|
38 |
|
39 |
+
orig_waveform = waveform[:, :6000].copy()
|
40 |
processed_input = prepare_waveform(waveform)
|
41 |
|
42 |
# Make prediction
|
|
|
46 |
p_phase = output[:, 0]
|
47 |
s_phase = output[:, 1]
|
48 |
|
49 |
+
return processed_input, p_phase, s_phase, orig_waveform
|
50 |
|
51 |
|
52 |
def mark_phases(waveform, uploaded_file, p_thres, s_thres):
|
|
|
54 |
if uploaded_file is not None:
|
55 |
waveform = uploaded_file.name
|
56 |
|
57 |
+
processed_input, p_phase, s_phase, orig_waveform = make_prediction(waveform)
|
58 |
|
59 |
# Create a plot of the waveform with the phases marked
|
60 |
if sum(processed_input[0][2] == 0): # if input is 1C
|
61 |
fig, ax = plt.subplots(nrows=2, figsize=(10, 2), sharex=True)
|
62 |
|
63 |
+
ax[0].plot(orig_waveform[0], color="black", lw=1)
|
64 |
ax[0].set_ylabel("Norm. Ampl.")
|
65 |
|
66 |
else: # if input is 3C
|
67 |
fig, ax = plt.subplots(nrows=4, figsize=(10, 6), sharex=True)
|
68 |
+
ax[0].plot(orig_waveform[0], color="black", lw=1)
|
69 |
+
ax[1].plot(orig_waveform[1], color="black", lw=1)
|
70 |
+
ax[2].plot(orig_waveform[2], color="black", lw=1)
|
71 |
|
72 |
ax[0].set_ylabel("Z")
|
73 |
ax[1].set_ylabel("N")
|