|
import json |
|
import logging |
|
import os |
|
from collections import namedtuple |
|
from datetime import datetime, timedelta |
|
|
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
from detect_peaks import detect_peaks |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def extract_picks( |
|
preds, |
|
file_names=None, |
|
begin_times=None, |
|
station_ids=None, |
|
dt=0.01, |
|
phases=["P", "S"], |
|
config=None, |
|
waveforms=None, |
|
use_amplitude=False, |
|
): |
|
"""Extract picks from prediction results. |
|
Args: |
|
preds ([type]): [Nb, Nt, Ns, Nc] "batch, time, station, channel" |
|
file_names ([type], optional): [Nb]. Defaults to None. |
|
station_ids ([type], optional): [Ns]. Defaults to None. |
|
t0 ([type], optional): [Nb]. Defaults to None. |
|
config ([type], optional): [description]. Defaults to None. |
|
|
|
Returns: |
|
picks [type]: {file_name, station_id, pick_time, pick_prob, pick_type} |
|
""" |
|
|
|
mph = {} |
|
if config is None: |
|
for x in phases: |
|
mph[x] = 0.3 |
|
mpd = 50 |
|
pre_idx = int(1 / dt) |
|
post_idx = int(4 / dt) |
|
else: |
|
mph["P"] = config.min_p_prob |
|
mph["S"] = config.min_s_prob |
|
mph["PS"] = 0.3 |
|
mpd = config.mpd |
|
pre_idx = int(config.pre_sec / dt) |
|
post_idx = int(config.post_sec / dt) |
|
|
|
Nb, Nt, Ns, Nc = preds.shape |
|
|
|
if file_names is None: |
|
file_names = [f"{i:04d}" for i in range(Nb)] |
|
elif not (isinstance(file_names, np.ndarray) or isinstance(file_names, list)): |
|
if isinstance(file_names, bytes): |
|
file_names = file_names.decode() |
|
file_names = [file_names] * Nb |
|
else: |
|
file_names = [x.decode() if isinstance(x, bytes) else x for x in file_names] |
|
|
|
if begin_times is None: |
|
begin_times = ["1970-01-01T00:00:00.000+00:00"] * Nb |
|
else: |
|
begin_times = [x.decode() if isinstance(x, bytes) else x for x in begin_times] |
|
|
|
picks = [] |
|
for i in range(Nb): |
|
file_name = file_names[i] |
|
begin_time = datetime.fromisoformat(begin_times[i]) |
|
|
|
for j in range(Ns): |
|
if (station_ids is None) or (len(station_ids[i]) == 0): |
|
station_id = f"{j:04d}" |
|
else: |
|
station_id = station_ids[i][j].decode() if isinstance(station_ids[i][j], bytes) else station_ids[i][j] |
|
|
|
if (waveforms is not None) and use_amplitude: |
|
amp = np.max(np.abs(waveforms[i, :, j, :]), axis=-1) |
|
for k in range(Nc - 1): |
|
idxs, probs = detect_peaks(preds[i, :, j, k + 1], mph=mph[phases[k]], mpd=mpd, show=False) |
|
for l, (phase_index, phase_prob) in enumerate(zip(idxs, probs)): |
|
pick_time = begin_time + timedelta(seconds=phase_index * dt) |
|
pick = { |
|
"file_name": file_name, |
|
"station_id": station_id, |
|
"begin_time": begin_time.isoformat(timespec="milliseconds"), |
|
"phase_index": int(phase_index), |
|
"phase_time": pick_time.isoformat(timespec="milliseconds"), |
|
"phase_score": round(phase_prob, 3), |
|
"phase_type": phases[k], |
|
"dt": dt, |
|
} |
|
|
|
|
|
if waveforms is not None: |
|
tmp = np.zeros((pre_idx + post_idx, 3)) |
|
lo = phase_index - pre_idx |
|
hi = phase_index + post_idx |
|
insert_idx = 0 |
|
if lo < 0: |
|
lo = 0 |
|
insert_idx = -lo |
|
if hi > Nt: |
|
hi = Nt |
|
tmp[insert_idx : insert_idx + hi - lo, :] = waveforms[i, lo:hi, j, :] |
|
if use_amplitude: |
|
next_pick = idxs[l + 1] if l < len(idxs) - 1 else (phase_index + post_idx * 3) |
|
pick["phase_amplitude"] = np.max( |
|
amp[phase_index : min(phase_index + post_idx * 3, next_pick)] |
|
).item() |
|
|
|
picks.append(pick) |
|
|
|
return picks |
|
|
|
|
|
def extract_amplitude(data, picks, window_p=10, window_s=5, config=None): |
|
record = namedtuple("amplitude", ["p_amp", "s_amp"]) |
|
dt = 0.01 if config is None else config.dt |
|
window_p = int(window_p / dt) |
|
window_s = int(window_s / dt) |
|
amps = [] |
|
for i, (da, pi) in enumerate(zip(data, picks)): |
|
p_amp, s_amp = [], [] |
|
for j in range(da.shape[1]): |
|
amp = np.max(np.abs(da[:, j, :]), axis=-1) |
|
|
|
|
|
tmp = [] |
|
for k in range(len(pi.p_idx[j]) - 1): |
|
tmp.append(np.max(amp[pi.p_idx[j][k] : min(pi.p_idx[j][k] + window_p, pi.p_idx[j][k + 1])])) |
|
if len(pi.p_idx[j]) >= 1: |
|
tmp.append(np.max(amp[pi.p_idx[j][-1] : pi.p_idx[j][-1] + window_p])) |
|
p_amp.append(tmp) |
|
tmp = [] |
|
for k in range(len(pi.s_idx[j]) - 1): |
|
tmp.append(np.max(amp[pi.s_idx[j][k] : min(pi.s_idx[j][k] + window_s, pi.s_idx[j][k + 1])])) |
|
if len(pi.s_idx[j]) >= 1: |
|
tmp.append(np.max(amp[pi.s_idx[j][-1] : pi.s_idx[j][-1] + window_s])) |
|
s_amp.append(tmp) |
|
amps.append(record(p_amp, s_amp)) |
|
return amps |
|
|
|
|
|
def save_picks(picks, output_dir, amps=None, fname=None): |
|
if fname is None: |
|
fname = "picks.csv" |
|
|
|
int2s = lambda x: ",".join(["[" + ",".join(map(str, i)) + "]" for i in x]) |
|
flt2s = lambda x: ",".join(["[" + ",".join(map("{:0.3f}".format, i)) + "]" for i in x]) |
|
sci2s = lambda x: ",".join(["[" + ",".join(map("{:0.3e}".format, i)) + "]" for i in x]) |
|
if amps is None: |
|
if hasattr(picks[0], "ps_idx"): |
|
with open(os.path.join(output_dir, fname), "w") as fp: |
|
fp.write("fname\tt0\tp_idx\tp_prob\ts_idx\ts_prob\tps_idx\tps_prob\n") |
|
for pick in picks: |
|
fp.write( |
|
f"{pick.fname}\t{pick.t0}\t{int2s(pick.p_idx)}\t{flt2s(pick.p_prob)}\t{int2s(pick.s_idx)}\t{flt2s(pick.s_prob)}\t{int2s(pick.ps_idx)}\t{flt2s(pick.ps_prob)}\n" |
|
) |
|
fp.close() |
|
else: |
|
with open(os.path.join(output_dir, fname), "w") as fp: |
|
fp.write("fname\tt0\tp_idx\tp_prob\ts_idx\ts_prob\n") |
|
for pick in picks: |
|
fp.write( |
|
f"{pick.fname}\t{pick.t0}\t{int2s(pick.p_idx)}\t{flt2s(pick.p_prob)}\t{int2s(pick.s_idx)}\t{flt2s(pick.s_prob)}\n" |
|
) |
|
fp.close() |
|
else: |
|
with open(os.path.join(output_dir, fname), "w") as fp: |
|
fp.write("fname\tt0\tp_idx\tp_prob\ts_idx\ts_prob\tp_amp\ts_amp\n") |
|
for pick, amp in zip(picks, amps): |
|
fp.write( |
|
f"{pick.fname}\t{pick.t0}\t{int2s(pick.p_idx)}\t{flt2s(pick.p_prob)}\t{int2s(pick.s_idx)}\t{flt2s(pick.s_prob)}\t{sci2s(amp.p_amp)}\t{sci2s(amp.s_amp)}\n" |
|
) |
|
fp.close() |
|
|
|
return 0 |
|
|
|
|
|
def calc_timestamp(timestamp, sec): |
|
timestamp = datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%f") + timedelta(seconds=sec) |
|
return timestamp.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] |
|
|
|
|
|
def save_picks_json(picks, output_dir, dt=0.01, amps=None, fname=None): |
|
if fname is None: |
|
fname = "picks.json" |
|
|
|
picks_ = [] |
|
if amps is None: |
|
for pick in picks: |
|
for idxs, probs in zip(pick.p_idx, pick.p_prob): |
|
for idx, prob in zip(idxs, probs): |
|
picks_.append( |
|
{ |
|
"id": pick.station_id, |
|
"timestamp": calc_timestamp(pick.t0, float(idx) * dt), |
|
"prob": prob.astype(float), |
|
"type": "p", |
|
} |
|
) |
|
for idxs, probs in zip(pick.s_idx, pick.s_prob): |
|
for idx, prob in zip(idxs, probs): |
|
picks_.append( |
|
{ |
|
"id": pick.station_id, |
|
"timestamp": calc_timestamp(pick.t0, float(idx) * dt), |
|
"prob": prob.astype(float), |
|
"type": "s", |
|
} |
|
) |
|
else: |
|
for pick, amplitude in zip(picks, amps): |
|
for idxs, probs, amps in zip(pick.p_idx, pick.p_prob, amplitude.p_amp): |
|
for idx, prob, amp in zip(idxs, probs, amps): |
|
picks_.append( |
|
{ |
|
"id": pick.station_id, |
|
"timestamp": calc_timestamp(pick.t0, float(idx) * dt), |
|
"prob": prob.astype(float), |
|
"amp": amp.astype(float), |
|
"type": "p", |
|
} |
|
) |
|
for idxs, probs, amps in zip(pick.s_idx, pick.s_prob, amplitude.s_amp): |
|
for idx, prob, amp in zip(idxs, probs, amps): |
|
picks_.append( |
|
{ |
|
"id": pick.station_id, |
|
"timestamp": calc_timestamp(pick.t0, float(idx) * dt), |
|
"prob": prob.astype(float), |
|
"amp": amp.astype(float), |
|
"type": "s", |
|
} |
|
) |
|
with open(os.path.join(output_dir, fname), "w") as fp: |
|
json.dump(picks_, fp) |
|
|
|
return 0 |
|
|
|
|
|
def convert_true_picks(fname, itp, its, itps=None): |
|
true_picks = [] |
|
if itps is None: |
|
record = namedtuple("phase", ["fname", "p_idx", "s_idx"]) |
|
for i in range(len(fname)): |
|
true_picks.append(record(fname[i].decode(), itp[i], its[i])) |
|
else: |
|
record = namedtuple("phase", ["fname", "p_idx", "s_idx", "ps_idx"]) |
|
for i in range(len(fname)): |
|
true_picks.append(record(fname[i].decode(), itp[i], its[i], itps[i])) |
|
|
|
return true_picks |
|
|
|
|
|
def calc_metrics(nTP, nP, nT): |
|
""" |
|
nTP: true positive |
|
nP: number of positive picks |
|
nT: number of true picks |
|
""" |
|
precision = nTP / nP |
|
recall = nTP / nT |
|
f1 = 2 * precision * recall / (precision + recall) |
|
return [precision, recall, f1] |
|
|
|
|
|
def calc_performance(picks, true_picks, tol=3.0, dt=1.0): |
|
assert len(picks) == len(true_picks) |
|
logging.info("Total records: {}".format(len(picks))) |
|
|
|
count = lambda picks: sum([len(x) for x in picks]) |
|
metrics = {} |
|
for phase in true_picks[0]._fields: |
|
if phase == "fname": |
|
continue |
|
true_positive, positive, true = 0, 0, 0 |
|
residual = [] |
|
for i in range(len(true_picks)): |
|
true += count(getattr(true_picks[i], phase)) |
|
positive += count(getattr(picks[i], phase)) |
|
|
|
diff = dt * ( |
|
np.array(getattr(picks[i], phase))[:, np.newaxis, :] |
|
- np.array(getattr(true_picks[i], phase))[:, :, np.newaxis] |
|
) |
|
residual.extend(list(diff[np.abs(diff) <= tol])) |
|
true_positive += np.sum(np.abs(diff) <= tol) |
|
metrics[phase] = calc_metrics(true_positive, positive, true) |
|
|
|
logging.info(f"{phase}-phase:") |
|
logging.info(f"True={true}, Positive={positive}, True Positive={true_positive}") |
|
logging.info(f"Precision={metrics[phase][0]:.3f}, Recall={metrics[phase][1]:.3f}, F1={metrics[phase][2]:.3f}") |
|
logging.info(f"Residual mean={np.mean(residual):.4f}, std={np.std(residual):.4f}") |
|
|
|
return metrics |
|
|
|
|
|
def save_prob_h5(probs, fnames, output_h5): |
|
if fnames is None: |
|
fnames = [f"{i:04d}" for i in range(len(probs))] |
|
elif type(fnames[0]) is bytes: |
|
fnames = [f.decode().rstrip(".npz") for f in fnames] |
|
else: |
|
fnames = [f.rstrip(".npz") for f in fnames] |
|
for prob, fname in zip(probs, fnames): |
|
output_h5.create_dataset(fname, data=prob, dtype="float32") |
|
return 0 |
|
|
|
|
|
def save_prob(probs, fnames, prob_dir): |
|
if fnames is None: |
|
fnames = [f"{i:04d}" for i in range(len(probs))] |
|
elif type(fnames[0]) is bytes: |
|
fnames = [f.decode().rstrip(".npz") for f in fnames] |
|
else: |
|
fnames = [f.rstrip(".npz") for f in fnames] |
|
for prob, fname in zip(probs, fnames): |
|
np.savez(os.path.join(prob_dir, fname + ".npz"), prob=prob) |
|
return 0 |
|
|