import torch import numpy as np from scipy import signal from scipy.signal import butter, lfilter, detrend # Make bandpass filter def butter_bandpass(lowcut, highcut, fs, order=5): nyq = 0.5 * fs # Nyquist frequency low = lowcut / nyq # Normalized frequency high = highcut / nyq b, a = butter(order, [low, high], btype="band") # Bandpass filter return b, a def butter_bandpass_filter(data, lowcut, highcut, fs, order=5): b, a = butter_bandpass(lowcut, highcut, fs, order=order) y = lfilter(b, a, data) return y def rotate_waveform(waveform, angle): fft_waveform = np.fft.fft(waveform) # Compute the Fourier transform of the waveform rotate_factor = np.exp( 1j * angle ) # Create a complex exponential with the specified rotation angle rotated_fft_waveform = ( fft_waveform * rotate_factor ) # Multiply the Fourier transform by the rotation factor rotated_waveform = np.fft.ifft( rotated_fft_waveform ) # Compute the inverse Fourier transform to get the rotated waveform in the time domain return rotated_waveform def augment(sample): # SET PARAMETERS: crop_length = 6000 padding = 120 test = False waveform = sample["waveform.npy"] meta = sample["meta.json"] if meta["split"] != "train": test = True target_sample_P = meta["trace_p_arrival_sample"] target_sample_S = meta["trace_s_arrival_sample"] if target_sample_P is None: target_sample_P = 0 if target_sample_S is None: target_sample_S = 0 # Randomly select a phase to start the crop current_phases = [x for x in (target_sample_P, target_sample_S) if x > 0] phase_selector = np.random.randint(0, len(current_phases)) first_phase = current_phases[phase_selector] # Shuffle if first_phase - (crop_length - padding) > padding: start_indx = int( first_phase - torch.randint(low=padding, high=(crop_length - padding), size=(1,)) ) if test == True: start_indx = int(first_phase - 2 * padding) elif int(first_phase - padding) > 0: start_indx = int( first_phase - torch.randint(low=0, high=(int(first_phase - padding)), size=(1,)) ) if test == True: start_indx = int(first_phase - padding) else: start_indx = padding end_indx = start_indx + crop_length if (waveform.shape[-1] - end_indx) < 0: start_indx += waveform.shape[-1] - end_indx end_indx = start_indx + crop_length # Update target new_target_P = target_sample_P - start_indx new_target_S = target_sample_S - start_indx # Cut waveform_cropped = waveform[:, start_indx:end_indx] # Preprocess waveform_cropped = detrend(waveform_cropped) waveform_cropped = butter_bandpass_filter( waveform_cropped, lowcut=0.2, highcut=40, fs=100, order=5 ) window = signal.windows.tukey(waveform_cropped[-1].shape[0], alpha=0.1) waveform_cropped = waveform_cropped * window waveform_cropped = detrend(waveform_cropped) if np.isnan(waveform_cropped).any() == True: waveform_cropped = np.zeros(shape=waveform_cropped.shape) new_target_P = 0 new_target_S = 0 if np.sum(waveform_cropped) == 0: new_target_P = 0 new_target_S = 0 # Normalize data max_val = np.max(np.abs(waveform_cropped)) waveform_cropped_norm = waveform_cropped / max_val # Added Z component only if len(waveform_cropped_norm) < 3: zeros = np.zeros((3, waveform_cropped_norm.shape[-1])) zeros[0] = waveform_cropped_norm waveform_cropped_norm = zeros if test == False: ##### Rotate waveform ##### probability = torch.randint(0, 2, size=(1,)).item() angle = torch.FloatTensor(size=(1,)).uniform_(0.01, 359.9).item() if probability == 1: waveform_cropped_norm = rotate_waveform(waveform_cropped_norm, angle).real #### Channel DropOUT ##### probability = torch.randint(0, 2, size=(1,)).item() channel = torch.randint(1, 3, size=(1,)).item() if probability == 1: waveform_cropped_norm[channel, :] = 1e-6 # Normalize target new_target_P = new_target_P / crop_length new_target_S = new_target_S / crop_length if (new_target_P <= 0) or (new_target_P >= 1) or (np.isnan(new_target_P)): new_target_P = 0 if (new_target_S <= 0) or (new_target_S >= 1) or (np.isnan(new_target_S)): new_target_S = 0 return waveform_cropped_norm, new_target_P, new_target_S def collation_fn(sample): waveforms = np.stack([x[0] for x in sample]) targets_P = np.stack([x[1] for x in sample]) targets_S = np.stack([x[2] for x in sample]) return ( torch.tensor(waveforms, dtype=torch.float), torch.tensor(targets_P, dtype=torch.float), torch.tensor(targets_S, dtype=torch.float), ) def my_split_by_node(urls): node_id, node_count = ( torch.distributed.get_rank(), torch.distributed.get_world_size(), ) return list(urls)[node_id::node_count] def prepare_waveform(waveform): # SET PARAMETERS: crop_length = 6000 padding = 120 assert waveform.shape[0] <= 3, "Waveform has more than 3 channels" if waveform.shape[-1] < crop_length: waveform = np.pad( waveform, ((0, 0), (0, crop_length - waveform.shape[-1])), mode="constant", constant_values=0, ) if waveform.shape[-1] > crop_length: waveform = waveform[:, :crop_length] # Preprocess waveform = detrend(waveform) waveform = butter_bandpass_filter( waveform, lowcut=0.2, highcut=40, fs=100, order=5 ) window = signal.windows.tukey(waveform[-1].shape[0], alpha=0.1) waveform = waveform * window waveform = detrend(waveform) assert np.isnan(waveform).any() != True, "Nan in waveform" assert np.sum(waveform) != 0, "Sum of waveform sample is zero" # Normalize data max_val = np.max(np.abs(waveform)) waveform = waveform / max_val # Added Z component only if len(waveform) < 3: zeros = np.zeros((3, waveform.shape[-1])) zeros[0] = waveform waveform = zeros return torch.tensor([waveform]*128, dtype=torch.float)