Spaces:
Running
Running
| import h5py | |
| import numpy as np | |
| import os | |
| import matplotlib.pyplot as plt | |
| import pickle | |
| from plot_styles import apply_physrev_style | |
| from astropy.cosmology import Planck18 | |
| from astropy.cosmology import z_at_value | |
| from astropy import units as u | |
| from scipy.stats import gumbel_r | |
| # Apply the style | |
| def get_detection_threshold(normalized, alpha, gumbel=True, list_hyp=False): | |
| """Compute detection threshold for given significance level alpha.""" | |
| if gumbel: | |
| if list_hyp: | |
| detection_threshold = [gumbel_r(*gumbel_r.fit(el)).isf(alpha) for el in normalized.T] | |
| else: | |
| detection_threshold = gumbel_r(*gumbel_r.fit(np.max(normalized, axis=1))).isf(alpha) | |
| else: | |
| if list_hyp: | |
| detection_threshold = np.quantile(normalized, 1-alpha/len(tpl_vector), axis=0) | |
| else: | |
| detection_threshold = np.quantile(np.max(normalized, axis=1), 1-alpha) | |
| return detection_threshold | |
| # New function for interactive plotting | |
| def plot_mass_vs_distance_or_redshift( | |
| snrs=[30], alpha=1e-4, y_axis="Redshift", x_axis="Primary Mass", colorbar_var="ef"): | |
| """ | |
| Interactive plot for mass vs distance/redshift. | |
| snrs: list of SNRs to include | |
| alpha: false alarm rate | |
| y_axis: 'Redshift' or 'Luminosity Distance' | |
| x_axis: 'Primary Mass' or 'Secondary Mass' | |
| colorbar_var: 'e0', 'ef', 'm1', or 'm2' | |
| Returns: matplotlib figure | |
| """ | |
| noise_file = "paper_results_tdi.h5" | |
| if not os.path.exists(noise_file): | |
| raise FileNotFoundError(f"Noise file {noise_file} not found.") | |
| with h5py.File(noise_file, 'r') as f: | |
| all_best_losses_noise = f['all_best_losses_noise'][()] | |
| tpl_vector = f['tpl_vector'][()] | |
| mean_noise = all_best_losses_noise.mean(axis=0) | |
| std_noise = all_best_losses_noise.std(axis=0) | |
| normalized = (all_best_losses_noise - mean_noise) / std_noise | |
| results_detection = [] | |
| snr_values = [] | |
| for snr in snrs: | |
| cache_file = f"paper_scatter_cache_{snr}.pkl" | |
| if not os.path.exists(cache_file): | |
| continue | |
| with open(cache_file, "rb") as f: | |
| results_, snr_ = pickle.load(f) | |
| results_detection.extend(results_) | |
| snr_values.extend(snr_) | |
| snr_values = np.array(snr_values) | |
| detection_threshold = get_detection_threshold(normalized, alpha) | |
| detected = np.array([np.max((r['losses'] - mean_noise)/std_noise) > detection_threshold for r in results_detection]) | |
| norm_ds = np.asarray([np.max((r['losses'] - mean_noise)/std_noise) for r in results_detection]) | |
| m1_values = np.array([r['m1'] for r in results_detection]) | |
| m2_values = np.array([r['m2'] for r in results_detection]) | |
| distances = np.array([r['dist'] for r in results_detection]) | |
| e0_values = np.array([r['e0'] for r in results_detection]) | |
| ef_values = np.array([r['ef'] for r in results_detection]) | |
| mask = np.isin(snr_values, snrs) | |
| det_mask = mask & detected | |
| not_det_mask = mask & ~detected | |
| filtered_distances = distances[det_mask] | |
| z_values = np.array([z_at_value(Planck18.luminosity_distance, d*u.Gpc) for d in filtered_distances]) | |
| filtered_z = z_values | |
| filtered_m1 = m1_values[det_mask]/(1 + z_values) | |
| filtered_m2 = m2_values[det_mask]/(1 + z_values) | |
| filtered_ef = ef_values[det_mask] | |
| filtered_e0 = e0_values[det_mask] | |
| # Map app.py dropdown input to variable | |
| colorbar_map = { | |
| r"Final eccentricity": (filtered_ef, 'Final Eccentricity $e_f$', 'plasma'), | |
| r"Initial eccentricity": (filtered_e0, 'Initial Eccentricity $e_0$', 'cividis'), | |
| r"Primary mass": (filtered_m1, 'Primary Mass [M$_\odot$]', 'viridis'), | |
| r"Secondary mass": (filtered_m2, 'Secondary Mass [M$_\odot$]', 'viridis'), | |
| } | |
| color_data, color_label, cmap = colorbar_map.get(colorbar_var, (filtered_ef, 'Final Eccentricity ($e_f$)', 'plasma')) | |
| fig, ax = plt.subplots(figsize=(7, 5)) | |
| if x_axis == "Primary Mass": | |
| x = filtered_m1 | |
| xlabel = r'Source frame primary mass [M$_\odot$]' | |
| else: | |
| x = filtered_m2 | |
| xlabel = r'Source frame secondary mass [M$_\odot$]' | |
| if y_axis == "Redshift": | |
| y = filtered_z | |
| ylabel = 'Redshift' | |
| else: | |
| y = filtered_distances | |
| ylabel = 'Luminosity Distance [Gpc]' | |
| scatter = ax.scatter(x, y, c=color_data, cmap=cmap, alpha=0.7, marker='o') | |
| cbar = plt.colorbar(scatter, ax=ax) | |
| cbar.set_label(color_label) | |
| ax.set_xlabel(xlabel) | |
| ax.set_ylabel(ylabel) | |
| ax.grid(True, alpha=0.3) | |
| plt.tight_layout() | |
| return fig |