Spaces:
Runtime error
Runtime error
Merge pull request #1 from lauracabayol/clean_code
Browse files- .gitignore +4 -3
- notebooks/{Fig7_colourspace.py β Colourspace.py} +36 -37
- notebooks/Comparison_methodology.py +517 -0
- notebooks/Feature_space.py +24 -256
- notebooks/Fig6_qualitycut.py +0 -164
- notebooks/{Fig2_NMAD.py β NMAD.py} +30 -45
- notebooks/{Fig3_PIT_CRPS.py β PIT_CRPS.py} +31 -39
- notebooks/Qualitycut.py +241 -0
- notebooks/Table_metrics.py +21 -23
- notebooks/nz.py +215 -0
- notebooks/{Fig4_pz_examples.py β pz_examples.py} +23 -42
- temps/archive.py +91 -25
- temps/plots.py +41 -62
- temps/temps.py +225 -216
- temps/temps_arch.py +11 -17
- temps/utils.py +58 -144
.gitignore
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
/
|
|
|
4 |
/notebooks/developer_notebooks
|
5 |
temps/.ipynb_checkpoints/
|
6 |
*.ipynb
|
|
|
1 |
+
temps/__pycache__/*
|
2 |
+
notebooks/.ipynb_checkpoints/
|
3 |
+
notebooks/cache
|
4 |
+
notebooks/*.ipynb
|
5 |
/notebooks/developer_notebooks
|
6 |
temps/.ipynb_checkpoints/
|
7 |
*.ipynb
|
notebooks/{Fig7_colourspace.py β Colourspace.py}
RENAMED
@@ -5,11 +5,11 @@
|
|
5 |
# extension: .py
|
6 |
# format_name: light
|
7 |
# format_version: '1.5'
|
8 |
-
# jupytext_version: 1.
|
9 |
# kernelspec:
|
10 |
-
# display_name:
|
11 |
# language: python
|
12 |
-
# name:
|
13 |
# ---
|
14 |
|
15 |
# # FIGURE COLOURSPACE IN THE PAPER
|
@@ -23,6 +23,7 @@ import os
|
|
23 |
from astropy.io import fits
|
24 |
from astropy.table import Table
|
25 |
import torch
|
|
|
26 |
|
27 |
#matplotlib settings
|
28 |
from matplotlib import rcParams
|
@@ -30,19 +31,11 @@ import matplotlib.pyplot as plt
|
|
30 |
rcParams["mathtext.fontset"] = "stix"
|
31 |
rcParams["font.family"] = "STIXGeneral"
|
32 |
|
33 |
-
|
34 |
-
|
35 |
-
import
|
36 |
-
|
37 |
-
|
38 |
-
from archive import archive
|
39 |
-
from utils import nmad
|
40 |
-
from temps_arch import EncoderPhotometry, MeasureZ
|
41 |
-
from temps import Temps_module
|
42 |
-
from plots import plot_nz
|
43 |
-
|
44 |
|
45 |
-
# -
|
46 |
|
47 |
def estimate_som_map(df, plot_arg='z', nx=40, ny=40):
|
48 |
"""
|
@@ -98,10 +91,14 @@ def plot_som_map(som_data, plot_arg = 'z', vmin=0, vmax=1):
|
|
98 |
|
99 |
# ### LOAD DATA
|
100 |
|
|
|
|
|
|
|
|
|
101 |
# +
|
102 |
filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
|
103 |
|
104 |
-
hdu_list = fits.open(
|
105 |
cat = Table(hdu_list[1].data).to_pandas()
|
106 |
cat = cat[cat['FLAG_PHOT']==0]
|
107 |
cat = cat[cat['mu_class_L07']==1]
|
@@ -116,27 +113,29 @@ ID = cat['ID']
|
|
116 |
VISmag = cat['MAG_VIS']
|
117 |
zsflag = cat['reliable_S15']
|
118 |
|
119 |
-
photoz_archive =
|
120 |
f, ferr = photoz_archive._extract_fluxes(catalogue= cat)
|
121 |
col, colerr = photoz_archive._to_colors(f, ferr)
|
122 |
|
|
|
|
|
123 |
# +
|
124 |
dfs = {}
|
125 |
|
126 |
for il, lab in enumerate(['z','L15','DA']):
|
127 |
|
128 |
nn_features = EncoderPhotometry()
|
129 |
-
nn_features.load_state_dict(torch.load(
|
130 |
nn_z = MeasureZ(num_gauss=6)
|
131 |
-
nn_z.load_state_dict(torch.load(
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
z,
|
136 |
return_pz=True)
|
137 |
# Create a DataFrame with the desired columns
|
138 |
-
df = pd.DataFrame(np.c_[ID, VISmag,z,
|
139 |
-
columns=['ID','VISmag','z','
|
140 |
|
141 |
# Calculate additional columns or operations if needed
|
142 |
df['zwerr'] = (df.z - df.ztarget) / (1 + df.ztarget)
|
@@ -152,15 +151,15 @@ for il, lab in enumerate(['z','L15','DA']):
|
|
152 |
# ### LOAD TRAINED MODELS AND EVALUATE PDFs AND REDSHIFT
|
153 |
|
154 |
#define here the directory containing the photometric catalogues
|
155 |
-
parent_dir = '/data/astro/
|
156 |
-
modules_dir = '../data/models/'
|
157 |
|
158 |
df_z = dfs['z']
|
159 |
df_z_DA = dfs['DA']
|
160 |
|
161 |
# ##### LOAD TRAIN SOM ON TRAINING DATA
|
162 |
|
163 |
-
df_som = pd.read_csv(
|
164 |
df_z = df_z.merge(df_som, on = 'ID')
|
165 |
df_z_DA = df_z_DA.merge(df_som, on = 'ID')
|
166 |
|
@@ -171,10 +170,10 @@ df_l15 = df_z[(df_z.ztarget>0)]
|
|
171 |
df_l15_DA = df_z_DA[(df_z_DA.ztarget>0)]
|
172 |
|
173 |
df_l15_euclid = df_z[(df_z.VISmag <24.5) & (df_z.z > 0.2) & (df_z.z < 2.6)]
|
174 |
-
df_l15_euclid_cut= df_l15_euclid[df_l15_euclid.
|
175 |
|
176 |
df_l15_euclid_da = df_z_DA[(df_z_DA.VISmag <24.5) & (df_z_DA.z > 0.2) & (df_z_DA.z < 2.6)]
|
177 |
-
df_l15_euclid_cut_da= df_l15_euclid_da[df_l15_euclid_da.
|
178 |
|
179 |
# ## MAKE SOM PLOT
|
180 |
|
@@ -186,7 +185,7 @@ fig, axs = plt.subplots(6, 4, figsize=(13, 15), sharex=True, sharey=True, gridsp
|
|
186 |
# Plot in the top row (axs[0, i])
|
187 |
#top row, spectroscopic sample
|
188 |
columns = ['ztarget','z','zwerr','count']
|
189 |
-
titles = [r'$z_{true}$',r'$z$',r'$z_{\rm error}$','Counts']
|
190 |
limits = [[0,4],[0,4],[-0.5,0.5],[0,50]]
|
191 |
for ii in range(4):
|
192 |
som_data = estimate_som_map(df_zspec, plot_arg=columns[ii], nx=40, ny=40)
|
@@ -245,13 +244,13 @@ axs[4, 0].set_ylabel(r'$y$', fontsize=14)
|
|
245 |
axs[5, 0].set_ylabel(r'$y$', fontsize=14)
|
246 |
|
247 |
|
248 |
-
fig.text(0.09, 0.815, r'$z_{\rm s}$
|
249 |
-
fig.text(0.09, 0.69, r'L15
|
250 |
-
fig.text(0.09, 0.56, r'L15
|
251 |
-
fig.text(0.09, 0.44, r'$Euclid$
|
252 |
-
fig.text(0.09, 0.3, r'$Euclid$
|
253 |
|
254 |
-
fig.text(0.09, 0.17, r'
|
255 |
|
256 |
|
257 |
plt.savefig('SOM_colourspace.pdf', format='pdf', bbox_inches='tight', dpi=300)
|
|
|
5 |
# extension: .py
|
6 |
# format_name: light
|
7 |
# format_version: '1.5'
|
8 |
+
# jupytext_version: 1.16.2
|
9 |
# kernelspec:
|
10 |
+
# display_name: temps
|
11 |
# language: python
|
12 |
+
# name: temps
|
13 |
# ---
|
14 |
|
15 |
# # FIGURE COLOURSPACE IN THE PAPER
|
|
|
23 |
from astropy.io import fits
|
24 |
from astropy.table import Table
|
25 |
import torch
|
26 |
+
from pathlib import Path
|
27 |
|
28 |
#matplotlib settings
|
29 |
from matplotlib import rcParams
|
|
|
31 |
rcParams["mathtext.fontset"] = "stix"
|
32 |
rcParams["font.family"] = "STIXGeneral"
|
33 |
|
34 |
+
from temps.archive import Archive
|
35 |
+
from temps.utils import nmad
|
36 |
+
from temps.temps_arch import EncoderPhotometry, MeasureZ
|
37 |
+
from temps.temps import TempsModule
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
|
|
39 |
|
40 |
def estimate_som_map(df, plot_arg='z', nx=40, ny=40):
|
41 |
"""
|
|
|
91 |
|
92 |
# ### LOAD DATA
|
93 |
|
94 |
+
#define here the directory containing the photometric catalogues
|
95 |
+
parent_dir = Path('/data/astro/scratch/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5')
|
96 |
+
modules_dir = Path('../data/models/')
|
97 |
+
|
98 |
# +
|
99 |
filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
|
100 |
|
101 |
+
hdu_list = fits.open(parent_dir/filename_valid)
|
102 |
cat = Table(hdu_list[1].data).to_pandas()
|
103 |
cat = cat[cat['FLAG_PHOT']==0]
|
104 |
cat = cat[cat['mu_class_L07']==1]
|
|
|
113 |
VISmag = cat['MAG_VIS']
|
114 |
zsflag = cat['reliable_S15']
|
115 |
|
116 |
+
photoz_archive = Archive(path = parent_dir,only_zspec=False)
|
117 |
f, ferr = photoz_archive._extract_fluxes(catalogue= cat)
|
118 |
col, colerr = photoz_archive._to_colors(f, ferr)
|
119 |
|
120 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
121 |
+
|
122 |
# +
|
123 |
dfs = {}
|
124 |
|
125 |
for il, lab in enumerate(['z','L15','DA']):
|
126 |
|
127 |
nn_features = EncoderPhotometry()
|
128 |
+
nn_features.load_state_dict(torch.load(modules_dir / f'modelF_{lab}.pt',map_location=torch.device('cpu')))
|
129 |
nn_z = MeasureZ(num_gauss=6)
|
130 |
+
nn_z.load_state_dict(torch.load(modules_dir / f'modelZ_{lab}.pt',map_location=torch.device('cpu')))
|
131 |
+
|
132 |
+
temps_module = TempsModule(nn_features, nn_z)
|
133 |
+
|
134 |
+
z, pz, odds = temps_module.get_pz(input_data=torch.Tensor(col),
|
135 |
return_pz=True)
|
136 |
# Create a DataFrame with the desired columns
|
137 |
+
df = pd.DataFrame(np.c_[ID, VISmag,z,odds, ztarget,zsflag, specz_or_photo],
|
138 |
+
columns=['ID','VISmag','z', 'odds','ztarget','zsflag','S15_L15_flag'])
|
139 |
|
140 |
# Calculate additional columns or operations if needed
|
141 |
df['zwerr'] = (df.z - df.ztarget) / (1 + df.ztarget)
|
|
|
151 |
# ### LOAD TRAINED MODELS AND EVALUATE PDFs AND REDSHIFT
|
152 |
|
153 |
#define here the directory containing the photometric catalogues
|
154 |
+
parent_dir = Path('/data/astro/scratch/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5')
|
155 |
+
modules_dir = Path('../data/models/')
|
156 |
|
157 |
df_z = dfs['z']
|
158 |
df_z_DA = dfs['DA']
|
159 |
|
160 |
# ##### LOAD TRAIN SOM ON TRAINING DATA
|
161 |
|
162 |
+
df_som = pd.read_csv(parent_dir/'som_dataframe.csv', header = 0, sep =',')
|
163 |
df_z = df_z.merge(df_som, on = 'ID')
|
164 |
df_z_DA = df_z_DA.merge(df_som, on = 'ID')
|
165 |
|
|
|
170 |
df_l15_DA = df_z_DA[(df_z_DA.ztarget>0)]
|
171 |
|
172 |
df_l15_euclid = df_z[(df_z.VISmag <24.5) & (df_z.z > 0.2) & (df_z.z < 2.6)]
|
173 |
+
df_l15_euclid_cut= df_l15_euclid[df_l15_euclid.odds>df_l15_euclid['odds'].quantile(0.2)]
|
174 |
|
175 |
df_l15_euclid_da = df_z_DA[(df_z_DA.VISmag <24.5) & (df_z_DA.z > 0.2) & (df_z_DA.z < 2.6)]
|
176 |
+
df_l15_euclid_cut_da= df_l15_euclid_da[df_l15_euclid_da.odds>df_l15_euclid['odds'].quantile(0.2)]
|
177 |
|
178 |
# ## MAKE SOM PLOT
|
179 |
|
|
|
185 |
# Plot in the top row (axs[0, i])
|
186 |
#top row, spectroscopic sample
|
187 |
columns = ['ztarget','z','zwerr','count']
|
188 |
+
titles = [r'$z_{true}$ (A)',r'$z$ (B)',r'$z_{\rm error}$ (C)','Counts']
|
189 |
limits = [[0,4],[0,4],[-0.5,0.5],[0,50]]
|
190 |
for ii in range(4):
|
191 |
som_data = estimate_som_map(df_zspec, plot_arg=columns[ii], nx=40, ny=40)
|
|
|
244 |
axs[5, 0].set_ylabel(r'$y$', fontsize=14)
|
245 |
|
246 |
|
247 |
+
fig.text(0.09, 0.815, r'$z_{\rm s}$ samp. (1)', va='center', rotation='vertical', fontsize=16)
|
248 |
+
fig.text(0.09, 0.69, r'L15 samp. (2)', va='center', rotation='vertical', fontsize=16)
|
249 |
+
fig.text(0.09, 0.56, r'L15 samp. + DA (3)', va='center', rotation='vertical', fontsize=14)
|
250 |
+
fig.text(0.09, 0.44, r'$Euclid$ samp. + DA (4)', va='center', rotation='vertical', fontsize=14)
|
251 |
+
fig.text(0.09, 0.3, r'$Euclid$ samp. + QC (5)', va='center', rotation='vertical', fontsize=14)
|
252 |
|
253 |
+
fig.text(0.09, 0.17, r'(5) + DA ', va='center', rotation='vertical', fontsize=13)
|
254 |
|
255 |
|
256 |
plt.savefig('SOM_colourspace.pdf', format='pdf', bbox_inches='tight', dpi=300)
|
notebooks/Comparison_methodology.py
ADDED
@@ -0,0 +1,517 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ---
|
2 |
+
# jupyter:
|
3 |
+
# jupytext:
|
4 |
+
# text_representation:
|
5 |
+
# extension: .py
|
6 |
+
# format_name: light
|
7 |
+
# format_version: '1.5'
|
8 |
+
# jupytext_version: 1.16.2
|
9 |
+
# kernelspec:
|
10 |
+
# display_name: temps
|
11 |
+
# language: python
|
12 |
+
# name: temps
|
13 |
+
# ---
|
14 |
+
|
15 |
+
# +
|
16 |
+
import pandas as pd
|
17 |
+
import numpy as np
|
18 |
+
import matplotlib.pyplot as plt
|
19 |
+
from astropy.io import fits
|
20 |
+
import os
|
21 |
+
from astropy.table import Table
|
22 |
+
|
23 |
+
from temps.utils import nmad
|
24 |
+
from scipy import stats
|
25 |
+
from pathlib import Path
|
26 |
+
# -
|
27 |
+
|
28 |
+
#define here the directory containing the photometric catalogues
|
29 |
+
parent_dir = '/data/astro/scratch/lcabayol/EUCLID/DAz/DC2_results_to_share/'
|
30 |
+
|
31 |
+
|
32 |
+
# +
|
33 |
+
# List of FITS files to be processed
|
34 |
+
fits_files = [
|
35 |
+
'GDE_RF_full.fits',
|
36 |
+
'GDE_PHOSPHOROS_V2_full.fits',
|
37 |
+
'OIL_LEPHARE_full.fits',
|
38 |
+
'JDV_DNF_A_full.fits',
|
39 |
+
'JSP_FRANKENZ_full.fits',
|
40 |
+
'MBR_METAPHOR_full.fits',
|
41 |
+
'GDE_ADABOOST_full.fits',
|
42 |
+
'CSC_GPZ_best_full.fits',
|
43 |
+
'SFO_CPZ_full.fits',
|
44 |
+
'AAL_NNPZ_V3_full.fits'
|
45 |
+
]
|
46 |
+
|
47 |
+
# Corresponding redshift column names
|
48 |
+
redshift_columns = [
|
49 |
+
'REDSHIFT_RF',
|
50 |
+
'REDSHIFT_PHOSPHOROS',
|
51 |
+
'REDSHIFT_LEPHARE',
|
52 |
+
'REDSHIFT_DNF',
|
53 |
+
'REDSHIFT_FRANKENZ',
|
54 |
+
'REDSHIFT_METAPHOR',
|
55 |
+
'REDSHIFT_ADABOOST',
|
56 |
+
'REDSHIFT_GPZ',
|
57 |
+
'REDSHIFT_CPZ',
|
58 |
+
'REDSHIFT_NNPZ'
|
59 |
+
]
|
60 |
+
|
61 |
+
# Initialize an empty DataFrame for merging
|
62 |
+
merged_df = pd.DataFrame()
|
63 |
+
|
64 |
+
# Process each FITS file
|
65 |
+
for fits_file, redshift_col in zip(fits_files, redshift_columns):
|
66 |
+
print(fits_file)
|
67 |
+
# Open the FITS file
|
68 |
+
hdu_list = fits.open(os.path.join(parent_dir,fits_file))
|
69 |
+
df = Table(hdu_list[1].data).to_pandas()
|
70 |
+
df = df[df.REDSHIFT!=0]
|
71 |
+
df = df[['ID', 'VIS','SPECZ', 'REDSHIFT']].rename(columns={'REDSHIFT': redshift_col})
|
72 |
+
# Merge with the main DataFrame
|
73 |
+
if merged_df.empty:
|
74 |
+
merged_df = df
|
75 |
+
else:
|
76 |
+
merged_df = pd.merge(merged_df, df, on=['ID', 'VIS', 'SPECZ'], how='outer')
|
77 |
+
|
78 |
+
|
79 |
+
# -
|
80 |
+
|
81 |
+
# ## OPEN DATA
|
82 |
+
|
83 |
+
# +
|
84 |
+
modules_dir = Path('/data/astro/scratch/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5')
|
85 |
+
filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
|
86 |
+
|
87 |
+
hdu_list = fits.open(modules_dir/filename_valid)
|
88 |
+
cat_full = Table(hdu_list[1].data).to_pandas()
|
89 |
+
|
90 |
+
cat_full = cat_full[['ID','z_spec_S15','reliable_S15','mu_class_L07']]
|
91 |
+
|
92 |
+
merged_df['reliable_S15'] = cat_full.reliable_S15
|
93 |
+
merged_df['z_spec_S15'] = cat_full.z_spec_S15
|
94 |
+
merged_df['mu_class_L07'] = cat_full.mu_class_L07
|
95 |
+
merged_df['ID_catfull'] = cat_full.ID
|
96 |
+
# -
|
97 |
+
|
98 |
+
merged_df_specz = merged_df[(merged_df.z_spec_S15>0)&(merged_df.SPECZ>0)&(merged_df.reliable_S15==1)&(merged_df.mu_class_L07==1)&(merged_df.VIS!=np.inf)]
|
99 |
+
|
100 |
+
# ##Β ONLY SPECZ SAMPLE
|
101 |
+
|
102 |
+
scatter, outliers =[],[]
|
103 |
+
for im, method in enumerate(redshift_columns):
|
104 |
+
print(method)
|
105 |
+
df_method = merged_df_specz.dropna(subset=method)
|
106 |
+
zerr = (df_method.SPECZ - df_method[method] ) / (1 + df_method.SPECZ)
|
107 |
+
print(len(zerr[np.abs(zerr)>0.15]) /len(zerr))
|
108 |
+
scatter.append(nmad(zerr))
|
109 |
+
outliers.append(len(zerr[np.abs(zerr)>0.15]) / len(df_method))
|
110 |
+
|
111 |
+
|
112 |
+
# +
|
113 |
+
labs = [
|
114 |
+
'RF',
|
115 |
+
'PHOSPHOROS',
|
116 |
+
'LEPHARE',
|
117 |
+
'DNF',
|
118 |
+
'FRANKENZ',
|
119 |
+
'METAPHOR',
|
120 |
+
'ADABOOST',
|
121 |
+
'GPZ',
|
122 |
+
'CPZ',
|
123 |
+
'NNPZ',
|
124 |
+
]
|
125 |
+
|
126 |
+
# Colors from colormap
|
127 |
+
cmap = plt.get_cmap('tab20')
|
128 |
+
colors = [cmap(i / len(labs)) for i in range(len(labs))]
|
129 |
+
|
130 |
+
# Plotting
|
131 |
+
plt.figure(figsize=(10, 6))
|
132 |
+
for i in range(len(labs)):
|
133 |
+
plt.scatter(outliers[i]*100, scatter[i], color=colors[i], label=labs[i], marker = '^')
|
134 |
+
|
135 |
+
# Adding legend
|
136 |
+
plt.legend(fontsize=12)
|
137 |
+
plt.ylabel(r'NMAD $[\Delta z]$', fontsize=14)
|
138 |
+
plt.xlabel('Outlier fraction [%]', fontsize=14)
|
139 |
+
plt.xticks(fontsize=14)
|
140 |
+
plt.yticks(fontsize=14)
|
141 |
+
|
142 |
+
plt.xlim(5,35)
|
143 |
+
plt.ylim(0,0.14)
|
144 |
+
|
145 |
+
# Display plot
|
146 |
+
plt.show()
|
147 |
+
# -
|
148 |
+
|
149 |
+
# ### ADD TEMPS PREDICTIONS
|
150 |
+
|
151 |
+
import torch
|
152 |
+
from temps.archive import Archive
|
153 |
+
from temps.temps_arch import EncoderPhotometry, MeasureZ
|
154 |
+
from temps.temps import TempsModule
|
155 |
+
|
156 |
+
# +
|
157 |
+
data_dir = Path('/data/astro/scratch/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5')
|
158 |
+
filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
|
159 |
+
|
160 |
+
hdu_list = fits.open(data_dir/filename_valid)
|
161 |
+
cat_phot = Table(hdu_list[1].data).to_pandas()
|
162 |
+
# -
|
163 |
+
|
164 |
+
cat_phot = cat_phot[cat_phot.ID.isin(merged_df_specz.ID_catfull)]
|
165 |
+
|
166 |
+
# +
|
167 |
+
photoz_archive = Archive(path = '/data/astro/scratch/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5',
|
168 |
+
only_zspec=True)
|
169 |
+
f, ferr = photoz_archive._extract_fluxes(catalogue= cat_phot)
|
170 |
+
col, colerr = photoz_archive._to_colors(f, ferr)
|
171 |
+
|
172 |
+
ID = cat_phot.ID
|
173 |
+
|
174 |
+
# +
|
175 |
+
modules_dir = Path('/nfs/pic.es/user/l/lcabayol/EUCLID/TEMPS/data/models')
|
176 |
+
|
177 |
+
nn_features = EncoderPhotometry()
|
178 |
+
nn_features.load_state_dict(torch.load(modules_dir / f'modelF_DA.pt',map_location=torch.device('cpu')))
|
179 |
+
nn_z = MeasureZ(num_gauss=6)
|
180 |
+
nn_z.load_state_dict(torch.load(modules_dir / f'modelZ_DA.pt', map_location=torch.device('cpu')))
|
181 |
+
|
182 |
+
temps_module = TempsModule(nn_features, nn_z)
|
183 |
+
|
184 |
+
z, pz, odds = temps_module.get_pz(input_data=torch.Tensor(col),
|
185 |
+
return_pz=True)
|
186 |
+
df = pd.DataFrame(np.c_[ID, z],
|
187 |
+
columns=['ID','TEMPS'])
|
188 |
+
|
189 |
+
df = df.dropna()
|
190 |
+
# -
|
191 |
+
|
192 |
+
merged_df_specz= merged_df_specz.merge(df, left_on='ID_catfull', right_on='ID')
|
193 |
+
|
194 |
+
# Corresponding redshift column names
|
195 |
+
redshift_columns = redshift_columns + ['TEMPS']
|
196 |
+
|
197 |
+
scatter, outliers =[],[]
|
198 |
+
for im, method in enumerate(redshift_columns):
|
199 |
+
print(method)
|
200 |
+
df_method = merged_df_specz.dropna(subset=method)
|
201 |
+
zerr = (df_method.SPECZ - df_method[method] ) / (1 + df_method.SPECZ)
|
202 |
+
print(len(zerr[np.abs(zerr)>0.15]) /len(zerr))
|
203 |
+
scatter.append(nmad(zerr))
|
204 |
+
outliers.append(len(zerr[np.abs(zerr)>0.15]) / len(df_method))
|
205 |
+
|
206 |
+
|
207 |
+
# +
|
208 |
+
labs = [
|
209 |
+
'RF',
|
210 |
+
'PHOSPHOROS',
|
211 |
+
'LEPHARE',
|
212 |
+
'DNF',
|
213 |
+
'FRANKENZ',
|
214 |
+
'METAPHOR',
|
215 |
+
'ADABOOST',
|
216 |
+
'GPZ',
|
217 |
+
'CPZ',
|
218 |
+
'NNPZ',
|
219 |
+
'TEMPS'
|
220 |
+
]
|
221 |
+
|
222 |
+
# Colors from colormap
|
223 |
+
cmap = plt.get_cmap('tab20')
|
224 |
+
colors = [cmap(i / len(labs)) for i in range(len(labs))]
|
225 |
+
|
226 |
+
# Plotting
|
227 |
+
plt.figure(figsize=(10, 6))
|
228 |
+
for i in range(len(labs)):
|
229 |
+
plt.scatter(outliers[i]*100, scatter[i], color=colors[i], label=labs[i], marker = '^')
|
230 |
+
|
231 |
+
# Adding legend
|
232 |
+
plt.legend(fontsize=12)
|
233 |
+
plt.ylabel(r'NMAD $[\Delta z]$', fontsize=14)
|
234 |
+
plt.xlabel('Outlier fraction [%]', fontsize=14)
|
235 |
+
plt.xticks(fontsize=14)
|
236 |
+
plt.yticks(fontsize=14)
|
237 |
+
|
238 |
+
plt.xlim(5,35)
|
239 |
+
plt.ylim(0,0.14)
|
240 |
+
|
241 |
+
# Display plot
|
242 |
+
plt.show()
|
243 |
+
# -
|
244 |
+
|
245 |
+
# ## ANOTHER SELECTION
|
246 |
+
|
247 |
+
# +
|
248 |
+
# List of FITS files to be processed
|
249 |
+
fits_files = [
|
250 |
+
'GDE_RF_full.fits',
|
251 |
+
'GDE_PHOSPHOROS_V2_full.fits',
|
252 |
+
'OIL_LEPHARE_full.fits',
|
253 |
+
'JDV_DNF_A_full.fits',
|
254 |
+
'JSP_FRANKENZ_full.fits',
|
255 |
+
'MBR_METAPHOR_full.fits',
|
256 |
+
'GDE_ADABOOST_full.fits',
|
257 |
+
'CSC_GPZ_best_full.fits',
|
258 |
+
'SFO_CPZ_full.fits',
|
259 |
+
'AAL_NNPZ_V3_full.fits'
|
260 |
+
]
|
261 |
+
|
262 |
+
# Corresponding redshift column names
|
263 |
+
redshift_columns = [
|
264 |
+
'REDSHIFT_RF',
|
265 |
+
'REDSHIFT_PHOSPHOROS',
|
266 |
+
'REDSHIFT_LEPHARE',
|
267 |
+
'REDSHIFT_DNF',
|
268 |
+
'REDSHIFT_FRANKENZ',
|
269 |
+
'REDSHIFT_METAPHOR',
|
270 |
+
'REDSHIFT_ADABOOST',
|
271 |
+
'REDSHIFT_GPZ',
|
272 |
+
'REDSHIFT_CPZ',
|
273 |
+
'REDSHIFT_NNPZ'
|
274 |
+
]
|
275 |
+
|
276 |
+
use_columns = [
|
277 |
+
'USE_RF',
|
278 |
+
'USE_PHOSPHOROS',
|
279 |
+
'USE_LEPHARE',
|
280 |
+
'USE_DNF',
|
281 |
+
'USE_FRANKENZ',
|
282 |
+
'USE_METAPHOR',
|
283 |
+
'USE_ADABOOST',
|
284 |
+
'USE_GPZ',
|
285 |
+
'USE_CPZ',
|
286 |
+
'USE_NNPZ'
|
287 |
+
]
|
288 |
+
|
289 |
+
# Initialize an empty DataFrame for merging
|
290 |
+
merged_df = pd.DataFrame()
|
291 |
+
|
292 |
+
# Process each FITS file
|
293 |
+
for fits_file, redshift_col,use_col in zip(fits_files, redshift_columns,use_columns):
|
294 |
+
print(fits_file)
|
295 |
+
# Open the FITS file
|
296 |
+
hdu_list = fits.open(os.path.join(parent_dir,fits_file))
|
297 |
+
df = Table(hdu_list[1].data).to_pandas()
|
298 |
+
df = df[df.REDSHIFT!=0]
|
299 |
+
df = df[['ID', 'VIS', 'SPECZ', 'REDSHIFT', 'L15PHZ', 'USE']].rename(columns={'REDSHIFT': redshift_col, 'USE': use_col})
|
300 |
+
# Merge with the main DataFrame
|
301 |
+
if merged_df.empty:
|
302 |
+
merged_df = df
|
303 |
+
else:
|
304 |
+
merged_df = pd.merge(merged_df, df, on=['ID', 'VIS', 'SPECZ','L15PHZ'], how='outer')
|
305 |
+
|
306 |
+
|
307 |
+
# -
|
308 |
+
|
309 |
+
merged_df['comp_z'] = np.where(merged_df['SPECZ'] > 0, merged_df['SPECZ'], merged_df['L15PHZ'])
|
310 |
+
#merged_df = merged_df[(merged_df.comp_z>0)&(merged_df.comp_z<4)&(merged_df.VIS>23.5)]
|
311 |
+
merged_df = merged_df[(merged_df.comp_z>0)&(merged_df.comp_z<4)&(merged_df.VIS<25)]
|
312 |
+
|
313 |
+
# +
|
314 |
+
modules_dir = Path('/data/astro/scratch/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5')
|
315 |
+
filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
|
316 |
+
|
317 |
+
hdu_list = fits.open(modules_dir/filename_valid)
|
318 |
+
cat_full = Table(hdu_list[1].data).to_pandas()
|
319 |
+
|
320 |
+
merged_df['ID_catfull'] = cat_full.ID
|
321 |
+
|
322 |
+
# +
|
323 |
+
data_dir = Path('/data/astro/scratch/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5')
|
324 |
+
filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
|
325 |
+
|
326 |
+
hdu_list = fits.open(data_dir/filename_valid)
|
327 |
+
cat_phot = Table(hdu_list[1].data).to_pandas()
|
328 |
+
# -
|
329 |
+
|
330 |
+
cat_phot = cat_phot[cat_phot.ID.isin(merged_df.ID_catfull)]
|
331 |
+
|
332 |
+
# +
|
333 |
+
photoz_archive = Archive(path = '/data/astro/scratch/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5',
|
334 |
+
only_zspec=False)
|
335 |
+
f, ferr = photoz_archive._extract_fluxes(catalogue= cat_phot)
|
336 |
+
col, colerr = photoz_archive._to_colors(f, ferr)
|
337 |
+
|
338 |
+
ID = cat_phot.ID
|
339 |
+
|
340 |
+
# +
|
341 |
+
modules_dir = Path('/nfs/pic.es/user/l/lcabayol/EUCLID/TEMPS/data/models')
|
342 |
+
|
343 |
+
nn_features = EncoderPhotometry()
|
344 |
+
nn_features.load_state_dict(torch.load(modules_dir/f'modelF_DA.pt',map_location=torch.device('cpu')))
|
345 |
+
nn_z = MeasureZ(num_gauss=6)
|
346 |
+
nn_z.load_state_dict(torch.load(modules_dir/f'modelZ_DA.pt',map_location=torch.device('cpu')))
|
347 |
+
|
348 |
+
temps_module = TempsModule(nn_features, nn_z)
|
349 |
+
|
350 |
+
z, pz, odds = temps_module.get_pz(input_data=torch.Tensor(col),
|
351 |
+
return_pz=True)
|
352 |
+
|
353 |
+
nn_features = EncoderPhotometry()
|
354 |
+
nn_features.load_state_dict(torch.load(modules_dir/f'modelF_z.pt',map_location=torch.device('cpu')))
|
355 |
+
nn_z = MeasureZ(num_gauss=6)
|
356 |
+
nn_z.load_state_dict(torch.load(modules_dir/f'modelZ_z.pt',map_location=torch.device('cpu')))
|
357 |
+
|
358 |
+
temps_module = TempsModule(nn_features, nn_z)
|
359 |
+
znoda, pz, odds_noda = temps_module.get_pz(input_data=torch.Tensor(col),
|
360 |
+
return_pz=True)
|
361 |
+
|
362 |
+
nn_features = EncoderPhotometry()
|
363 |
+
nn_features.load_state_dict(torch.load(modules_dir/f'modelF_L15.pt',map_location=torch.device('cpu')))
|
364 |
+
nn_z = MeasureZ(num_gauss=6)
|
365 |
+
nn_z.load_state_dict(torch.load(modules_dir/f'modelZ_L15.pt',map_location=torch.device('cpu')))
|
366 |
+
|
367 |
+
temps_module = TempsModule(nn_features, nn_z)
|
368 |
+
z_L15, pz, odds_L15 = temps_module.get_pz(input_data=torch.Tensor(col),
|
369 |
+
return_pz=True)
|
370 |
+
|
371 |
+
df = pd.DataFrame(np.c_[ID, z, odds, znoda, odds_noda,z_L15, odds_L15],
|
372 |
+
columns=['ID','TEMPS', 'flag_TEMPS', 'TEMPS_noda', 'flag_TEMPSnoda', 'TEMPS_L15', 'flag_L15'])
|
373 |
+
|
374 |
+
df = df.dropna()
|
375 |
+
|
376 |
+
# +
|
377 |
+
percent=0.3
|
378 |
+
df['USE_TEMPS'] = np.zeros(shape=len(df))
|
379 |
+
# Calculate the 50th percentile (median) value of 'Flag_temps'
|
380 |
+
threshold = df['flag_TEMPS'].quantile(percent)
|
381 |
+
|
382 |
+
# Set 'USE_TEMPS' to 1 if 'Flag_temps' is in the top 50% (greater than or equal to the threshold)
|
383 |
+
df['USE_TEMPS'] = np.where(df['flag_TEMPS'] >= threshold, 1, 0)
|
384 |
+
|
385 |
+
# +
|
386 |
+
percent=0.3
|
387 |
+
df['USE_TEMPS_noda'] = np.zeros(shape=len(df))
|
388 |
+
# Calculate the 50th percentile (median) value of 'Flag_temps'
|
389 |
+
threshold = df['flag_TEMPSnoda'].quantile(percent)
|
390 |
+
|
391 |
+
# Set 'USE_TEMPS' to 1 if 'Flag_temps' is in the top 50% (greater than or equal to the threshold)
|
392 |
+
df['USE_TEMPS_noda'] = np.where(df['flag_TEMPSnoda'] >= threshold, 1, 0)
|
393 |
+
|
394 |
+
# +
|
395 |
+
percent=0.3
|
396 |
+
df['USE_TEMPS_L15'] = np.zeros(shape=len(df))
|
397 |
+
# Calculate the 50th percentile (median) value of 'Flag_temps'
|
398 |
+
threshold = df['flag_L15'].quantile(percent)
|
399 |
+
|
400 |
+
# Set 'USE_TEMPS' to 1 if 'Flag_temps' is in the top 50% (greater than or equal to the threshold)
|
401 |
+
df['USE_TEMPS_L15'] = np.where(df['flag_L15'] >= threshold, 1, 0)
|
402 |
+
# -
|
403 |
+
|
404 |
+
merged_df_temps = merged_df.merge(df, left_on='ID_catfull', right_on='ID')
|
405 |
+
|
406 |
+
# Corresponding redshift column names
|
407 |
+
redshift_columns = [
|
408 |
+
'REDSHIFT_RF',
|
409 |
+
'REDSHIFT_PHOSPHOROS',
|
410 |
+
'REDSHIFT_LEPHARE',
|
411 |
+
'REDSHIFT_DNF',
|
412 |
+
'REDSHIFT_FRANKENZ',
|
413 |
+
'REDSHIFT_METAPHOR',
|
414 |
+
'REDSHIFT_ADABOOST',
|
415 |
+
'REDSHIFT_GPZ',
|
416 |
+
'REDSHIFT_CPZ',
|
417 |
+
'REDSHIFT_NNPZ'
|
418 |
+
]
|
419 |
+
|
420 |
+
redshift_columns = redshift_columns + ['TEMPS', 'TEMPS_noda', 'TEMPS_L15']
|
421 |
+
use_columns = use_columns + ['USE_TEMPS','USE_TEMPS_noda', 'USE_TEMPS_L15']
|
422 |
+
|
423 |
+
merged_df_temps = merged_df_temps[merged_df_temps.VIS <25]
|
424 |
+
|
425 |
+
|
426 |
+
scatter, outliers, size =[],[], []
|
427 |
+
for method, use in(zip(redshift_columns, use_columns)):
|
428 |
+
print(method)
|
429 |
+
#df_method = merged_df_temps.dropna(subset=method)
|
430 |
+
df_method = merged_df_temps[(merged_df_temps.loc[:, method]>0.2)&(merged_df_temps.loc[:, method]<2.6)]
|
431 |
+
df_method = df_method[df_method.VIS<24.5]
|
432 |
+
norm_size = len(df_method)
|
433 |
+
df_method = df_method[df_method.loc[:, use]==1]
|
434 |
+
zerr = (df_method.comp_z - df_method[method] ) / (1 + df_method.comp_z)
|
435 |
+
scatter.append(nmad(zerr))
|
436 |
+
outliers.append(len(zerr[np.abs(zerr)>0.15]) / len(df_method))
|
437 |
+
size.append(len(df_method)/norm_size)
|
438 |
+
print(nmad(zerr),len(zerr[np.abs(zerr)>0.15]) / len(df_method), len(df_method) /norm_size )
|
439 |
+
|
440 |
+
|
441 |
+
scatter_faint, outliers_faint, size_faint =[],[], []
|
442 |
+
for method, use in(zip(redshift_columns, use_columns)):
|
443 |
+
print(method)
|
444 |
+
#df_method = merged_df_temps.dropna(subset=method)
|
445 |
+
df_method = merged_df_temps[(merged_df_temps.loc[:,'VIS']>23.5)&(merged_df_temps.loc[:,'VIS']<25)]
|
446 |
+
#df_method = df_method[df_method.loc[:, use]==1]
|
447 |
+
#df_method = merged_df_temps[(merged_df_temps.loc[:,'VIS']>23.5)&(merged_df_temps.loc[:,'VIS']<24.5)]
|
448 |
+
zerr = (df_method.comp_z - df_method[method] ) / (1 + df_method.comp_z)
|
449 |
+
scatter_faint.append(nmad(zerr))
|
450 |
+
outliers_faint.append(len(zerr[np.abs(zerr)>0.15]) / len(df_method))
|
451 |
+
size_faint.append(len(df_method))
|
452 |
+
print(nmad(zerr),len(zerr[np.abs(zerr)>0.15]) / len(df_method), len(df_method))
|
453 |
+
|
454 |
+
|
455 |
+
# +
|
456 |
+
import matplotlib.pyplot as plt
|
457 |
+
import numpy as np
|
458 |
+
from pastamarkers import markers
|
459 |
+
|
460 |
+
# Define labels for the models
|
461 |
+
labs = [
|
462 |
+
'RF', 'PHOSPHOROS', 'LEPHARE', 'DNF', 'FRANKENZ', 'METAPHOR',
|
463 |
+
'ADABOOST', 'GPZ', 'CPZ', 'NNPZ', 'TEMPS', 'TEMPS - no DA', 'TEMPS - L15'
|
464 |
+
]
|
465 |
+
|
466 |
+
markers_pasta = [markers.penne, markers.conchiglie, markers.tortellini, markers.creste, markers.spaghetti, markers.ravioli, markers.tagliatelle, markers.mezzelune,markers.puntine, markers.stelline , 's', 'o', '^']
|
467 |
+
|
468 |
+
labs_faint = [f"{lab}_faint" for lab in labs] # Labels for the faint data
|
469 |
+
|
470 |
+
|
471 |
+
# Colors from colormap
|
472 |
+
cmap = plt.get_cmap('tab20')
|
473 |
+
colors = [cmap(i / len(labs)) for i in range(len(labs))]
|
474 |
+
|
475 |
+
# Create subplots with 2 panels stacked vertically
|
476 |
+
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 12), sharex=False)
|
477 |
+
|
478 |
+
# Plotting for the top panel
|
479 |
+
for i in range(len(labs)):
|
480 |
+
if labs[i] == 'TEMPS - no DA' or labs[i] == 'TEMPS - L15':
|
481 |
+
ax1.scatter(np.nan, np.nan, color=colors[i], label=labs[i], marker=markers_pasta[i], s=300)
|
482 |
+
elif labs[i]=='CPZ':
|
483 |
+
ax1.scatter(outliers[i] * 100, scatter[i], color=colors[i], label=labs[i], marker=markers_pasta[i], s=300)
|
484 |
+
ax1.text(outliers[i] * 100 -0.2, scatter[i] + 0.001, f'{int(np.around(size[i] * 100))}', fontsize=12, verticalalignment='bottom')
|
485 |
+
|
486 |
+
elif labs[i]=='ADABOOST':
|
487 |
+
ax1.scatter(outliers[i] * 100, scatter[i], color=colors[i], label=labs[i], marker=markers_pasta[i], s=300)
|
488 |
+
ax1.text(outliers[i] * 100 - 0.5, scatter[i] - 0.004, f'{int(np.around(size[i] * 100))}', fontsize=12, verticalalignment='bottom')
|
489 |
+
|
490 |
+
else:
|
491 |
+
ax1.scatter(outliers[i] * 100, scatter[i], color=colors[i], label=labs[i], marker=markers_pasta[i], s=300)
|
492 |
+
ax1.text(outliers[i] * 100 - 0.5, scatter[i] + 0.001, f'{int(np.around(size[i] * 100))}', fontsize=12, verticalalignment='bottom')
|
493 |
+
|
494 |
+
# Customizations for the top plot
|
495 |
+
ax1.set_ylabel(r'NMAD $[\Delta z]$', fontsize=24)
|
496 |
+
ax1.legend(fontsize=14)
|
497 |
+
ax1.tick_params(axis='both', which='major', labelsize=20)
|
498 |
+
|
499 |
+
# Plotting for the bottom panel (faint data)
|
500 |
+
for i in range(len(labs)):
|
501 |
+
ax2.scatter(outliers_faint[i] * 100, scatter_faint[i], color=colors[i], label=labs[i], marker=markers_pasta[i], s=300)
|
502 |
+
|
503 |
+
# Customizations for the bottom plot
|
504 |
+
ax2.set_ylabel(r'NMAD $[\Delta z]$', fontsize=24)
|
505 |
+
ax2.set_xlabel('Outlier fraction [%]', fontsize=24)
|
506 |
+
ax2.tick_params(axis='both', which='major', labelsize=20)
|
507 |
+
|
508 |
+
# Display the plot
|
509 |
+
plt.tight_layout()
|
510 |
+
#plt.savefig('Comparison_paper.pdf', bbox_inches='tight')
|
511 |
+
plt.show()
|
512 |
+
|
513 |
+
# -
|
514 |
+
|
515 |
+
cat_val_z = cat_val[['RA','DEC']].merge(cat_all[['RA','DEC','z_spec_S15','photo_z_L15','reliable_S15','mu_class_L07']], on = ['RA','DEC'])
|
516 |
+
|
517 |
+
merged_df = merged_df.merge(cat_val_z, on = ['RA','DEC'])
|
notebooks/Feature_space.py
CHANGED
@@ -5,11 +5,11 @@
|
|
5 |
# extension: .py
|
6 |
# format_name: light
|
7 |
# format_version: '1.5'
|
8 |
-
# jupytext_version: 1.
|
9 |
# kernelspec:
|
10 |
-
# display_name:
|
11 |
# language: python
|
12 |
-
# name:
|
13 |
# ---
|
14 |
|
15 |
# # DOMAIN ADAPTATION INTUITION
|
@@ -23,6 +23,9 @@ import os
|
|
23 |
from astropy.io import fits
|
24 |
from astropy.table import Table
|
25 |
import torch
|
|
|
|
|
|
|
26 |
|
27 |
#matplotlib settings
|
28 |
from matplotlib import rcParams
|
@@ -30,28 +33,22 @@ import matplotlib.pyplot as plt
|
|
30 |
rcParams["mathtext.fontset"] = "stix"
|
31 |
rcParams["font.family"] = "STIXGeneral"
|
32 |
|
33 |
-
|
34 |
-
|
35 |
-
import
|
36 |
-
|
37 |
-
|
38 |
-
from archive import archive
|
39 |
-
from utils import nmad
|
40 |
-
from temps_arch import EncoderPhotometry, MeasureZ
|
41 |
-
from temps import Temps_module
|
42 |
-
from plots import plot_nz
|
43 |
-
# -
|
44 |
|
45 |
# ## LOAD DATA
|
46 |
|
47 |
#define here the directory containing the photometric catalogues
|
48 |
-
parent_dir = '/data/astro/
|
49 |
-
modules_dir = '../data/models/'
|
50 |
|
51 |
# +
|
52 |
filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
|
53 |
|
54 |
-
hdu_list = fits.open(
|
55 |
cat = Table(hdu_list[1].data).to_pandas()
|
56 |
cat = cat[cat['FLAG_PHOT']==0]
|
57 |
cat = cat[cat['mu_class_L07']==1]
|
@@ -70,7 +67,7 @@ cat['specz_or_photo']=specz_or_photo
|
|
70 |
|
71 |
# ### EXTRACT PHOTOMETRY
|
72 |
|
73 |
-
photoz_archive =
|
74 |
f, ferr = photoz_archive._extract_fluxes(catalogue= cat)
|
75 |
col, colerr = photoz_archive._to_colors(f, ferr)
|
76 |
|
@@ -80,7 +77,7 @@ features_all = np.zeros((3,len(cat),10))
|
|
80 |
for il, lab in enumerate(['z','L15','DA']):
|
81 |
|
82 |
nn_features = EncoderPhotometry()
|
83 |
-
nn_features.load_state_dict(torch.load(
|
84 |
|
85 |
features = nn_features(torch.Tensor(col))
|
86 |
features = features.detach().cpu().numpy()
|
@@ -132,7 +129,7 @@ autoencoder = Autoencoder(input_dim=10,
|
|
132 |
criterion = nn.L1Loss()
|
133 |
optimizer = optim.Adam(autoencoder.parameters(), lr=0.0001)
|
134 |
|
135 |
-
# +
|
136 |
# Define the number of epochs
|
137 |
num_epochs = 100
|
138 |
for epoch in range(num_epochs):
|
@@ -162,9 +159,7 @@ print('Training finished')
|
|
162 |
|
163 |
# #### EVALUTATE AUTOENCODER
|
164 |
|
165 |
-
# + [markdown] jupyter={"source_hidden": true}
|
166 |
# cat.to_csv('features_cat.csv', header=True, sep=',')
|
167 |
-
# -
|
168 |
|
169 |
indexes_specz = cat[(cat.specz_or_photo==0)&(cat.reliable_S15>0)].reset_index().index
|
170 |
|
@@ -173,6 +168,8 @@ for i in range(3):
|
|
173 |
_, features = autoencoder(torch.Tensor(features_all[i]))
|
174 |
features_all_reduced[i] = features.detach().cpu().numpy()
|
175 |
|
|
|
|
|
176 |
# ### Plot the features
|
177 |
|
178 |
start = 0
|
@@ -182,7 +179,6 @@ values_not_in_indexes_specz = all_values - set(indexes_specz)
|
|
182 |
indexes_nospecz = sorted(values_not_in_indexes_specz)
|
183 |
|
184 |
# +
|
185 |
-
import seaborn as sns
|
186 |
|
187 |
# Create subplots with three panels
|
188 |
fig, axs = plt.subplots(1, 3, figsize=(15, 5))
|
@@ -223,14 +219,14 @@ axs[1].set_title('Trained on L15')
|
|
223 |
|
224 |
# Third subplot
|
225 |
features_all_reduced_nospecz = pd.DataFrame(features_all_reduced[2, indexes_nospecz, :]).drop_duplicates().values
|
226 |
-
sns.kdeplot(x=
|
227 |
-
y=
|
228 |
clip=(-1, 5),
|
229 |
ax=axs[2],
|
230 |
color='salmon',
|
231 |
label='Wide-field sample')
|
232 |
-
sns.kdeplot(x=
|
233 |
-
y=
|
234 |
clip=(-1, 5),
|
235 |
ax=axs[2],
|
236 |
color='lightskyblue',
|
@@ -252,200 +248,7 @@ axs[2].legend(legend_handles, legend_labels, loc='upper right', fontsize=16)
|
|
252 |
# Adjust layout
|
253 |
plt.tight_layout()
|
254 |
|
255 |
-
plt.savefig('Contourplot.pdf', bbox_inches='tight')
|
256 |
-
plt.show()
|
257 |
-
|
258 |
-
# -
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
np.savetxt('features.txt',features_all_reduced.reshape(3*164816, 2))
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
# +
|
279 |
-
photoz_archive = archive(path = parent_dir,only_zspec=False)
|
280 |
-
|
281 |
-
fig, ax = plt.subplots(ncols = 3, figsize=(15,4), sharex=True, sharey=True)
|
282 |
-
colors = ['navy', 'goldenrod']
|
283 |
-
titles = [r'Training: $z_s$', r'Training: L15',r'Training: $z_s$ + DA']
|
284 |
-
x_min, x_max = -5,5
|
285 |
-
y_min, y_max = -5,5
|
286 |
-
x_grid, y_grid = np.meshgrid(np.linspace(x_min, x_max, 10), np.linspace(y_min, y_max, 10))
|
287 |
-
xy_grid = np.vstack([x_grid.ravel(), y_grid.ravel()])
|
288 |
-
density_grid = density_estimation(xy_grid).reshape(x_grid.shape)
|
289 |
-
for il, lab in enumerate(['z','L15','DA']):
|
290 |
-
|
291 |
-
|
292 |
-
nn_features = EncoderPhotometry()
|
293 |
-
nn_features.load_state_dict(torch.load(os.path.join(modules_dir,f'modelF_{lab}.pt')))
|
294 |
-
|
295 |
-
for it, target_type in enumerate(['L15','zs']):
|
296 |
-
if target_type=='zs':
|
297 |
-
cat_sub = photoz_archive._select_only_zspec(cat)
|
298 |
-
cat_sub = photoz_archive._clean_zspec_sample(cat_sub)
|
299 |
-
|
300 |
-
elif target_type=='L15':
|
301 |
-
cat_sub = photoz_archive._exclude_only_zspec(cat)
|
302 |
-
else:
|
303 |
-
assert False
|
304 |
-
|
305 |
-
cat_sub = photoz_archive._clean_photometry(cat_sub)
|
306 |
-
print(cat_sub.shape)
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
f, ferr = photoz_archive._extract_fluxes(cat_sub)
|
311 |
-
col, colerr = photoz_archive._to_colors(f, ferr)
|
312 |
-
|
313 |
-
features = nn_features(torch.Tensor(col))
|
314 |
-
features = features.detach().cpu().numpy()
|
315 |
-
|
316 |
-
|
317 |
-
#xy = np.vstack([features[:1000,0], features[:1000,1]])
|
318 |
-
#zd = gaussian_kde(xy)(xy)
|
319 |
-
#ax[il].scatter(features[:1000,0], features[:1000,1],c=zd, s=3)
|
320 |
-
|
321 |
-
xy = np.vstack([features[:,0], features[:,1]])
|
322 |
-
density_estimation = gaussian_kde(xy)
|
323 |
-
|
324 |
-
# Define grid for plotting density lines
|
325 |
-
|
326 |
-
xy_grid = np.vstack([x_grid.ravel(), y_grid.ravel()])
|
327 |
-
density_grid = density_estimation(xy_grid).reshape(x_grid.shape)
|
328 |
-
|
329 |
-
# Plot contour lines representing density
|
330 |
-
ax[il].contour(x_grid, y_grid, density_grid, colors=colors[it], label = f'{target_type}')
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
ax[il].set_title(titles[il])
|
335 |
-
ax[il].set_xlim(-5,5)
|
336 |
-
ax[il].set_ylim(-5,5)
|
337 |
-
|
338 |
-
|
339 |
-
ax[0].set_ylabel('Feature 1', fontsize=14)
|
340 |
-
#plt.ylabel('Feature 2', fontsize=14)
|
341 |
-
|
342 |
-
#assert False
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
# -
|
350 |
-
|
351 |
-
H
|
352 |
-
|
353 |
-
H
|
354 |
-
|
355 |
-
xedges
|
356 |
-
|
357 |
-
yedges
|
358 |
-
|
359 |
-
# +
|
360 |
-
import matplotlib.colors as colors
|
361 |
-
from matplotlib import path
|
362 |
-
import numpy as np
|
363 |
-
from matplotlib import pyplot as plt
|
364 |
-
try:
|
365 |
-
from astropy.convolution import Gaussian2DKernel, convolve
|
366 |
-
astro_smooth = True
|
367 |
-
except ImportError as IE:
|
368 |
-
astro_smooth = False
|
369 |
-
|
370 |
-
np.random.seed(123)
|
371 |
-
#t = np.linspace(-5,1.2,1000)
|
372 |
-
x = features[:1000,0]
|
373 |
-
y = features[:1000,1]
|
374 |
-
|
375 |
-
H, xedges, yedges = np.histogram2d(x,y, bins=(10,10))
|
376 |
-
xmesh, ymesh = np.meshgrid(xedges[:-1], yedges[:-1])
|
377 |
-
|
378 |
-
# Smooth the contours (if astropy is installed)
|
379 |
-
if astro_smooth:
|
380 |
-
kernel = Gaussian2DKernel(x_stddev=1.)
|
381 |
-
H=convolve(H,kernel)
|
382 |
-
|
383 |
-
fig,ax = plt.subplots(1, figsize=(7,6))
|
384 |
-
clevels = ax.contour(xmesh,ymesh,H.T,lw=.9,cmap='winter')#,zorder=90)
|
385 |
-
ax.scatter(x,y,s=3)
|
386 |
-
#ax.set_xlim(-20,5)
|
387 |
-
#ax.set_ylim(-20,5)
|
388 |
-
|
389 |
-
# Identify points within contours
|
390 |
-
#p = clevels.collections[0].get_paths()
|
391 |
-
#inside = np.full_like(x,False,dtype=bool)
|
392 |
-
#for level in p:
|
393 |
-
# inside |= level.contains_points(zip(*(x,y)))
|
394 |
-
|
395 |
-
#ax.plot(x[~inside],y[~inside],'kx')
|
396 |
-
#plt.show(block=False)
|
397 |
-
# -
|
398 |
-
|
399 |
-
density_grid
|
400 |
-
|
401 |
-
features.shape, zd.shape
|
402 |
-
|
403 |
-
# + jupyter={"outputs_hidden": true}
|
404 |
-
xy = np.vstack([features[:,0], features[:,1]])
|
405 |
-
zd = gaussian_kde(xy)(xy)
|
406 |
-
plt.scatter(features[:,0], features[:,1],c=zd)
|
407 |
-
|
408 |
-
|
409 |
-
# +
|
410 |
-
# Make the base corner plot
|
411 |
-
figure = corner.corner(features[:,:2], quantiles=[0.16, 0.84], show_titles=False, color ='crimson')
|
412 |
-
corner.corner(samples2, fig=fig)
|
413 |
-
ndim=2
|
414 |
-
# Extract the axes
|
415 |
-
axes = np.array(figure.axes).reshape((ndim, ndim))
|
416 |
-
|
417 |
-
|
418 |
-
for a in axes[np.triu_indices(ndim)]:
|
419 |
-
a.remove()
|
420 |
-
|
421 |
-
# +
|
422 |
-
import numpy as np
|
423 |
-
import matplotlib.pyplot as plt
|
424 |
-
from scipy.stats import gaussian_kde
|
425 |
-
|
426 |
-
# Assuming 'features' is your data array with shape (n_samples, 2)
|
427 |
-
|
428 |
-
# Calculate the density estimate
|
429 |
-
xy = np.vstack([features[:,0], features[:,1]])
|
430 |
-
density_estimation = gaussian_kde(xy)
|
431 |
-
|
432 |
-
# Define grid for plotting density lines
|
433 |
-
|
434 |
-
xy_grid = np.vstack([x_grid.ravel(), y_grid.ravel()])
|
435 |
-
density_grid = density_estimation(xy_grid).reshape(x_grid.shape)
|
436 |
-
|
437 |
-
# Plot contour lines representing density
|
438 |
-
plt.contour(x_grid, y_grid, density_grid, colors='black')
|
439 |
-
|
440 |
-
# Optionally, you can add a scatter plot on top of the density lines for better visualization
|
441 |
-
#plt.scatter(features[:,0], features[:,1], color='blue', alpha=0.5)
|
442 |
-
|
443 |
-
# Set labels and title
|
444 |
-
plt.xlabel('Feature 1')
|
445 |
-
plt.ylabel('Feature 2')
|
446 |
-
plt.title('Density Lines Plot')
|
447 |
-
|
448 |
-
# Show plot
|
449 |
plt.show()
|
450 |
|
451 |
# -
|
@@ -454,41 +257,6 @@ plt.show()
|
|
454 |
|
455 |
|
456 |
|
457 |
-
corner_plot = corner.corner(Arinyo_preds,
|
458 |
-
labels=[r'$b$', r'$\beta$', '$q_1$', '$k_{vav}$','$a_v$','$b_v$','$k_p$','$q_2$'],
|
459 |
-
truths=Arinyo_coeffs_central[test_snap],
|
460 |
-
truth_color='crimson')
|
461 |
-
|
462 |
-
import corner
|
463 |
-
figure = corner.corner(features, quantiles=[0.16, 0.5, 0.84], show_titles=False)
|
464 |
-
axes = np.array(fig.axes).reshape((ndim, ndim))
|
465 |
-
for a in axes[np.triu_indices(ndim)]:
|
466 |
-
a.remove()
|
467 |
|
468 |
|
469 |
|
470 |
-
# +
|
471 |
-
# My data
|
472 |
-
x = features[:,0]
|
473 |
-
y = features[:,1]
|
474 |
-
|
475 |
-
# Peform the kernel density estimate
|
476 |
-
k = stats.gaussian_kde(np.vstack([x, y]))
|
477 |
-
xi, yi = np.mgrid[-5:5,-5:5]
|
478 |
-
zi = k(np.vstack([xi.flatten(), yi.flatten()]))
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
fig = plt.figure()
|
483 |
-
ax = fig.gca()
|
484 |
-
|
485 |
-
|
486 |
-
CS = ax.contour(xi, yi, zi.reshape(xi.shape), colors='crimson')
|
487 |
-
|
488 |
-
ax.set_xlim(-5, 5)
|
489 |
-
ax.set_ylim(-5, 5)
|
490 |
-
|
491 |
-
plt.show()
|
492 |
-
# -
|
493 |
-
|
494 |
-
|
|
|
5 |
# extension: .py
|
6 |
# format_name: light
|
7 |
# format_version: '1.5'
|
8 |
+
# jupytext_version: 1.16.2
|
9 |
# kernelspec:
|
10 |
+
# display_name: temps
|
11 |
# language: python
|
12 |
+
# name: temps
|
13 |
# ---
|
14 |
|
15 |
# # DOMAIN ADAPTATION INTUITION
|
|
|
23 |
from astropy.io import fits
|
24 |
from astropy.table import Table
|
25 |
import torch
|
26 |
+
from pathlib import Path
|
27 |
+
import seaborn as sns
|
28 |
+
|
29 |
|
30 |
#matplotlib settings
|
31 |
from matplotlib import rcParams
|
|
|
33 |
rcParams["mathtext.fontset"] = "stix"
|
34 |
rcParams["font.family"] = "STIXGeneral"
|
35 |
|
36 |
+
from temps.archive import Archive
|
37 |
+
from temps.utils import nmad
|
38 |
+
from temps.temps_arch import EncoderPhotometry, MeasureZ
|
39 |
+
from temps.temps import TempsModule
|
40 |
+
from temps.plots import plot_nz
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
# ## LOAD DATA
|
43 |
|
44 |
#define here the directory containing the photometric catalogues
|
45 |
+
parent_dir = Path('/data/astro/scratch/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5')
|
46 |
+
modules_dir = Path('../data/models/')
|
47 |
|
48 |
# +
|
49 |
filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
|
50 |
|
51 |
+
hdu_list = fits.open(parent_dir/filename_valid)
|
52 |
cat = Table(hdu_list[1].data).to_pandas()
|
53 |
cat = cat[cat['FLAG_PHOT']==0]
|
54 |
cat = cat[cat['mu_class_L07']==1]
|
|
|
67 |
|
68 |
# ### EXTRACT PHOTOMETRY
|
69 |
|
70 |
+
photoz_archive = Archive(path = parent_dir,only_zspec=False)
|
71 |
f, ferr = photoz_archive._extract_fluxes(catalogue= cat)
|
72 |
col, colerr = photoz_archive._to_colors(f, ferr)
|
73 |
|
|
|
77 |
for il, lab in enumerate(['z','L15','DA']):
|
78 |
|
79 |
nn_features = EncoderPhotometry()
|
80 |
+
nn_features.load_state_dict(torch.load(modules_dir/f'modelF_{lab}.pt',map_location=torch.device('cpu')))
|
81 |
|
82 |
features = nn_features(torch.Tensor(col))
|
83 |
features = features.detach().cpu().numpy()
|
|
|
129 |
criterion = nn.L1Loss()
|
130 |
optimizer = optim.Adam(autoencoder.parameters(), lr=0.0001)
|
131 |
|
132 |
+
# + jupyter={"outputs_hidden": true}
|
133 |
# Define the number of epochs
|
134 |
num_epochs = 100
|
135 |
for epoch in range(num_epochs):
|
|
|
159 |
|
160 |
# #### EVALUTATE AUTOENCODER
|
161 |
|
|
|
162 |
# cat.to_csv('features_cat.csv', header=True, sep=',')
|
|
|
163 |
|
164 |
indexes_specz = cat[(cat.specz_or_photo==0)&(cat.reliable_S15>0)].reset_index().index
|
165 |
|
|
|
168 |
_, features = autoencoder(torch.Tensor(features_all[i]))
|
169 |
features_all_reduced[i] = features.detach().cpu().numpy()
|
170 |
|
171 |
+
features_all.shape
|
172 |
+
|
173 |
# ### Plot the features
|
174 |
|
175 |
start = 0
|
|
|
179 |
indexes_nospecz = sorted(values_not_in_indexes_specz)
|
180 |
|
181 |
# +
|
|
|
182 |
|
183 |
# Create subplots with three panels
|
184 |
fig, axs = plt.subplots(1, 3, figsize=(15, 5))
|
|
|
219 |
|
220 |
# Third subplot
|
221 |
features_all_reduced_nospecz = pd.DataFrame(features_all_reduced[2, indexes_nospecz, :]).drop_duplicates().values
|
222 |
+
sns.kdeplot(x=features_all_reduced[2, indexes_nospecz, 0],
|
223 |
+
y=features_all_reduced[2, indexes_nospecz, 1],
|
224 |
clip=(-1, 5),
|
225 |
ax=axs[2],
|
226 |
color='salmon',
|
227 |
label='Wide-field sample')
|
228 |
+
sns.kdeplot(x=features_all_reduced[2, indexes_specz, 0],
|
229 |
+
y=features_all_reduced[2, indexes_specz,1],
|
230 |
clip=(-1, 5),
|
231 |
ax=axs[2],
|
232 |
color='lightskyblue',
|
|
|
248 |
# Adjust layout
|
249 |
plt.tight_layout()
|
250 |
|
251 |
+
#plt.savefig('Contourplot.pdf', bbox_inches='tight')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
252 |
plt.show()
|
253 |
|
254 |
# -
|
|
|
257 |
|
258 |
|
259 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
260 |
|
261 |
|
262 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
notebooks/Fig6_qualitycut.py
DELETED
@@ -1,164 +0,0 @@
|
|
1 |
-
# ---
|
2 |
-
# jupyter:
|
3 |
-
# jupytext:
|
4 |
-
# text_representation:
|
5 |
-
# extension: .py
|
6 |
-
# format_name: light
|
7 |
-
# format_version: '1.5'
|
8 |
-
# jupytext_version: 1.14.5
|
9 |
-
# kernelspec:
|
10 |
-
# display_name: insight
|
11 |
-
# language: python
|
12 |
-
# name: insight
|
13 |
-
# ---
|
14 |
-
|
15 |
-
# # FIGURE 6 IN THE PAPER
|
16 |
-
|
17 |
-
# ## QUALITY CUTS
|
18 |
-
|
19 |
-
# %load_ext autoreload
|
20 |
-
# %autoreload 2
|
21 |
-
|
22 |
-
import pandas as pd
|
23 |
-
import numpy as np
|
24 |
-
import os
|
25 |
-
import torch
|
26 |
-
from scipy import stats
|
27 |
-
|
28 |
-
#matplotlib settings
|
29 |
-
from matplotlib import rcParams
|
30 |
-
import matplotlib.pyplot as plt
|
31 |
-
rcParams["mathtext.fontset"] = "stix"
|
32 |
-
rcParams["font.family"] = "STIXGeneral"
|
33 |
-
|
34 |
-
#insight modules
|
35 |
-
import sys
|
36 |
-
sys.path.append('../temps')
|
37 |
-
#from insight_arch import EncoderPhotometry, MeasureZ
|
38 |
-
#from insight import Insight_module
|
39 |
-
from archive import archive
|
40 |
-
from utils import nmad
|
41 |
-
from temps_arch import EncoderPhotometry, MeasureZ
|
42 |
-
from temps import Temps_module
|
43 |
-
|
44 |
-
|
45 |
-
# ### LOAD DATA (ONLY SPECZ)
|
46 |
-
|
47 |
-
#define here the directory containing the photometric catalogues
|
48 |
-
parent_dir = '/data/astro/scratch2/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5'
|
49 |
-
modules_dir = '../data/models/'
|
50 |
-
|
51 |
-
photoz_archive = archive(path = parent_dir,only_zspec=True,flags_kept=[1. , 1.1, 1.4, 1.5, 2,2.1,2.4,2.5,3., 3.1, 3.4, 3.5, 4., 9. , 9.1, 9.3, 9.4, 9.5,11.1, 11.5, 12.1, 12.5, 13. , 13.1, 13.5, 14, ])
|
52 |
-
f_test_specz, ferr_test_specz, specz_test ,VIS_mag_test = photoz_archive.get_testing_data()
|
53 |
-
|
54 |
-
|
55 |
-
# ### LOAD TRAINED MODELS AND EVALUATE PDF OF RANDOM EXAMPLES
|
56 |
-
|
57 |
-
# +
|
58 |
-
# Initialize an empty dictionary to store DataFrames
|
59 |
-
dfs = {}
|
60 |
-
|
61 |
-
for il, lab in enumerate(['z','L15','DA']):
|
62 |
-
|
63 |
-
nn_features = EncoderPhotometry()
|
64 |
-
nn_features.load_state_dict(torch.load(os.path.join(modules_dir,f'modelF_{lab}.pt')))
|
65 |
-
nn_z = MeasureZ(num_gauss=6)
|
66 |
-
nn_z.load_state_dict(torch.load(os.path.join(modules_dir,f'modelZ_{lab}.pt')))
|
67 |
-
|
68 |
-
temps = Temps_module(nn_features, nn_z)
|
69 |
-
|
70 |
-
z,zerr, pz, flag, odds = temps.get_pz(input_data=torch.Tensor(f_test_specz),
|
71 |
-
return_pz=True)
|
72 |
-
|
73 |
-
|
74 |
-
# Create a DataFrame with the desired columns
|
75 |
-
df = pd.DataFrame(np.c_[z, flag, odds, specz_test],
|
76 |
-
columns=['z','zflag', 'odds' ,'ztarget'])
|
77 |
-
|
78 |
-
# Calculate additional columns or operations if needed
|
79 |
-
df['zwerr'] = (df.z - df.ztarget) / (1 + df.ztarget)
|
80 |
-
|
81 |
-
# Drop any rows with NaN values
|
82 |
-
df = df.dropna()
|
83 |
-
|
84 |
-
# Assign the DataFrame to a key in the dictionary
|
85 |
-
dfs[lab] = df
|
86 |
-
|
87 |
-
# -
|
88 |
-
|
89 |
-
# ### STATISTICS BASED ON OUR QUALITY CUT
|
90 |
-
|
91 |
-
# +
|
92 |
-
bin_edges = stats.mstats.mquantiles(df.zflag, np.arange(0,1.01,0.05))
|
93 |
-
scatter, eta, xlab, xmag, xzs, flagmean = [],[],[], [], [], []
|
94 |
-
|
95 |
-
for k in range(len(bin_edges)-1):
|
96 |
-
edge_min = bin_edges[k]
|
97 |
-
edge_max = bin_edges[k+1]
|
98 |
-
|
99 |
-
df_bin = df[(df.zflag > edge_min)]
|
100 |
-
|
101 |
-
|
102 |
-
xlab.append(np.round(len(df_bin)/len(df),2)*100)
|
103 |
-
xzs.append(0.5*(df_bin.ztarget.min()+df_bin.ztarget.max()))
|
104 |
-
flagmean.append(np.mean(df_bin.zflag))
|
105 |
-
scatter.append(nmad(df_bin.zwerr))
|
106 |
-
eta.append(len(df_bin[np.abs(df_bin.zwerr)>0.15])/len(df)*100)
|
107 |
-
|
108 |
-
|
109 |
-
# -
|
110 |
-
|
111 |
-
# ### STATISTICS BASED ON ODDS
|
112 |
-
|
113 |
-
# +
|
114 |
-
bin_edges = stats.mstats.mquantiles(df.odds, np.arange(0,1.01,0.05))
|
115 |
-
scatter_odds, eta_odds,xlab_odds, oddsmean = [],[],[], []
|
116 |
-
|
117 |
-
for k in range(len(bin_edges)-1):
|
118 |
-
edge_min = bin_edges[k]
|
119 |
-
edge_max = bin_edges[k+1]
|
120 |
-
|
121 |
-
df_bin = df[(df.odds > edge_min)]
|
122 |
-
|
123 |
-
|
124 |
-
xlab_odds.append(np.round(len(df_bin)/len(df),2)*100)
|
125 |
-
oddsmean.append(np.mean(df_bin.zflag))
|
126 |
-
scatter_odds.append(nmad(df_bin.zwerr))
|
127 |
-
eta_odds.append(len(df_bin[np.abs(df_bin.zwerr)>0.15])/len(df)*100)
|
128 |
-
|
129 |
-
|
130 |
-
# -
|
131 |
-
|
132 |
-
# ### PLOTS
|
133 |
-
|
134 |
-
# +
|
135 |
-
plt.plot(xlab_odds,scatter_odds, marker = '.', color ='crimson', label=r'$\theta(\Delta z)$', ls='--', alpha=0.5)
|
136 |
-
plt.plot(xlab,scatter, marker = '.', color ='navy',label=r'$\xi = \theta(\Delta z)$')
|
137 |
-
|
138 |
-
|
139 |
-
plt.ylabel(r'NMAD [$\Delta z\ /\ (1 + z_{\rm s})$]', fontsize=16)
|
140 |
-
plt.xlabel('Completeness', fontsize=16)
|
141 |
-
|
142 |
-
plt.yticks(fontsize=12)
|
143 |
-
plt.xticks(np.arange(5,101,10), fontsize=12)
|
144 |
-
plt.legend(fontsize=14)
|
145 |
-
|
146 |
-
plt.savefig('Flag_nmad_zspec.pdf', bbox_inches='tight')
|
147 |
-
plt.show()
|
148 |
-
|
149 |
-
# +
|
150 |
-
plt.plot(xlab_odds,eta_odds, marker='.', color ='crimson', label=r'$\theta(\Delta z)$', ls='--', alpha=0.5)
|
151 |
-
plt.plot(xlab,eta, marker='.', color ='navy',label=r'$\xi = \theta(\Delta z)$')
|
152 |
-
|
153 |
-
plt.yticks(fontsize=12)
|
154 |
-
plt.xticks(np.arange(5,101,10), fontsize=12)
|
155 |
-
plt.ylabel(r'$\eta$ [%]', fontsize=16)
|
156 |
-
plt.xlabel('Completeness', fontsize=16)
|
157 |
-
plt.legend()
|
158 |
-
|
159 |
-
plt.savefig('Flag_eta_zspec.pdf', bbox_inches='tight')
|
160 |
-
|
161 |
-
plt.show()
|
162 |
-
# -
|
163 |
-
|
164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
notebooks/{Fig2_NMAD.py β NMAD.py}
RENAMED
@@ -6,15 +6,15 @@
|
|
6 |
# extension: .py
|
7 |
# format_name: percent
|
8 |
# format_version: '1.3'
|
9 |
-
# jupytext_version: 1.
|
10 |
# kernelspec:
|
11 |
-
# display_name:
|
12 |
# language: python
|
13 |
-
# name:
|
14 |
# ---
|
15 |
|
16 |
# %% [markdown]
|
17 |
-
# # FIGURE
|
18 |
|
19 |
# %% [markdown]
|
20 |
# ## METRICS FOR THE DIFFERENT METHODS ON THE WIDE FIELD SAMPLE
|
@@ -43,15 +43,14 @@ rcParams["font.family"] = "STIXGeneral"
|
|
43 |
|
44 |
|
45 |
# %%
|
46 |
-
|
47 |
-
import sys
|
48 |
-
sys.path.append('../temps')
|
49 |
-
|
50 |
-
from archive import archive
|
51 |
-
from utils import nmad
|
52 |
-
from temps_arch import EncoderPhotometry, MeasureZ
|
53 |
-
from temps import Temps_module
|
54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
|
57 |
# %%
|
@@ -62,15 +61,13 @@ eval_methods=True
|
|
62 |
|
63 |
# %%
|
64 |
#define here the directory containing the photometric catalogues
|
65 |
-
parent_dir = '/data/astro/
|
66 |
-
modules_dir = '../data/models/'
|
67 |
|
68 |
# %%
|
69 |
-
#load catalogue and apply cuts
|
70 |
-
|
71 |
filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
|
72 |
-
|
73 |
-
hdu_list = fits.open(
|
74 |
cat = Table(hdu_list[1].data).to_pandas()
|
75 |
cat = cat[cat['FLAG_PHOT']==0]
|
76 |
cat = cat[cat['mu_class_L07']==1]
|
@@ -78,7 +75,6 @@ cat = cat[(cat['z_spec_S15'] > 0) | (cat['photo_z_L15'] > 0)]
|
|
78 |
cat = cat[cat['MAG_VIS']<25]
|
79 |
|
80 |
|
81 |
-
|
82 |
# %%
|
83 |
ztarget = [cat['z_spec_S15'].values[ii] if cat['z_spec_S15'].values[ii]> 0 else cat['photo_z_L15'].values[ii] for ii in range(len(cat))]
|
84 |
specz_or_photo = [0 if cat['z_spec_S15'].values[ii]> 0 else 1 for ii in range(len(cat))]
|
@@ -87,7 +83,7 @@ VISmag = cat['MAG_VIS']
|
|
87 |
zsflag = cat['reliable_S15']
|
88 |
|
89 |
# %%
|
90 |
-
photoz_archive =
|
91 |
f, ferr = photoz_archive._extract_fluxes(catalogue= cat)
|
92 |
col, colerr = photoz_archive._to_colors(f, ferr)
|
93 |
|
@@ -101,20 +97,21 @@ if eval_methods:
|
|
101 |
for il, lab in enumerate(['z','L15','DA']):
|
102 |
|
103 |
nn_features = EncoderPhotometry()
|
104 |
-
nn_features.load_state_dict(
|
105 |
nn_z = MeasureZ(num_gauss=6)
|
106 |
-
nn_z.load_state_dict(
|
107 |
|
108 |
-
|
109 |
|
110 |
-
z,
|
111 |
-
|
|
|
112 |
# Create a DataFrame with the desired columns
|
113 |
-
df = pd.DataFrame(np.c_[ID, VISmag,z,
|
114 |
-
columns=['ID','VISmag','z',
|
115 |
|
116 |
# Calculate additional columns or operations if needed
|
117 |
-
df['zwerr'] = (df.
|
118 |
|
119 |
# Drop any rows with NaN values
|
120 |
df = df.dropna()
|
@@ -135,36 +132,24 @@ dfs['DA']['zwerr'] = (dfs['DA'].z - dfs['DA'].ztarget) / (1 + dfs['DA'].ztarget)
|
|
135 |
# %%
|
136 |
if not eval_methods:
|
137 |
dfs = {}
|
138 |
-
dfs['z'] = pd.read_csv(
|
139 |
-
dfs['L15'] = pd.read_csv(
|
140 |
-
dfs['DA'] = pd.read_csv(
|
141 |
|
142 |
|
143 |
# %% [markdown]
|
144 |
# ### MAKE PLOT
|
145 |
|
146 |
# %%
|
147 |
-
|
148 |
-
nbins=8,
|
149 |
-
xvariable='VISmag',
|
150 |
-
metric='nmad',
|
151 |
-
type_bin='bin',
|
152 |
-
label_list = ['zs','zs+L15',r'TEMPS'],
|
153 |
-
save=False,
|
154 |
-
samp='L15'
|
155 |
-
)
|
156 |
|
157 |
# %%
|
158 |
plot_photoz(df_list,
|
159 |
nbins=8,
|
160 |
xvariable='VISmag',
|
161 |
-
metric='
|
162 |
type_bin='bin',
|
163 |
label_list = ['zs','zs+L15',r'TEMPS'],
|
164 |
save=False,
|
165 |
samp='L15'
|
166 |
)
|
167 |
-
|
168 |
-
# %%
|
169 |
-
|
170 |
-
# %%
|
|
|
6 |
# extension: .py
|
7 |
# format_name: percent
|
8 |
# format_version: '1.3'
|
9 |
+
# jupytext_version: 1.16.2
|
10 |
# kernelspec:
|
11 |
+
# display_name: temps
|
12 |
# language: python
|
13 |
+
# name: temps
|
14 |
# ---
|
15 |
|
16 |
# %% [markdown]
|
17 |
+
# # FIGURE METRICS
|
18 |
|
19 |
# %% [markdown]
|
20 |
# ## METRICS FOR THE DIFFERENT METHODS ON THE WIDE FIELD SAMPLE
|
|
|
43 |
|
44 |
|
45 |
# %%
|
46 |
+
import temps
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
+
# %%
|
49 |
+
from temps.archive import Archive
|
50 |
+
from temps.utils import nmad
|
51 |
+
from temps.temps_arch import EncoderPhotometry, MeasureZ
|
52 |
+
from temps.temps import TempsModule
|
53 |
+
from temps.plots import plot_photoz
|
54 |
|
55 |
|
56 |
# %%
|
|
|
61 |
|
62 |
# %%
|
63 |
#define here the directory containing the photometric catalogues
|
64 |
+
parent_dir = Path('/data/astro/scratch/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5')
|
65 |
+
modules_dir = Path('../data/models/')
|
66 |
|
67 |
# %%
|
|
|
|
|
68 |
filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
|
69 |
+
path_file = parent_dir / filename_valid # Creating the path to the file
|
70 |
+
hdu_list = fits.open(path_file)
|
71 |
cat = Table(hdu_list[1].data).to_pandas()
|
72 |
cat = cat[cat['FLAG_PHOT']==0]
|
73 |
cat = cat[cat['mu_class_L07']==1]
|
|
|
75 |
cat = cat[cat['MAG_VIS']<25]
|
76 |
|
77 |
|
|
|
78 |
# %%
|
79 |
ztarget = [cat['z_spec_S15'].values[ii] if cat['z_spec_S15'].values[ii]> 0 else cat['photo_z_L15'].values[ii] for ii in range(len(cat))]
|
80 |
specz_or_photo = [0 if cat['z_spec_S15'].values[ii]> 0 else 1 for ii in range(len(cat))]
|
|
|
83 |
zsflag = cat['reliable_S15']
|
84 |
|
85 |
# %%
|
86 |
+
photoz_archive = Archive(path = parent_dir,only_zspec=False)
|
87 |
f, ferr = photoz_archive._extract_fluxes(catalogue= cat)
|
88 |
col, colerr = photoz_archive._to_colors(f, ferr)
|
89 |
|
|
|
97 |
for il, lab in enumerate(['z','L15','DA']):
|
98 |
|
99 |
nn_features = EncoderPhotometry()
|
100 |
+
nn_features.load_state_dict(modules_dir / f'modelF_{lab}.pt',map_location=torch.device('cpu')))
|
101 |
nn_z = MeasureZ(num_gauss=6)
|
102 |
+
nn_z.load_state_dict(modules_dir / f'modelZ_{lab}.pt',map_location=torch.device('cpu')))
|
103 |
|
104 |
+
temps_module = TempsModule(nn_features, nn_z)
|
105 |
|
106 |
+
z, pz, odds = temps_module.get_pz(input_data=torch.Tensor(col),
|
107 |
+
return_pz=True,
|
108 |
+
return_flag=True)
|
109 |
# Create a DataFrame with the desired columns
|
110 |
+
df = pd.DataFrame(np.c_[ID, VISmag,z, odds, ztarget,zsflag, specz_or_photo],
|
111 |
+
columns=['ID','VISmag','z','odds', 'ztarget','zsflag','S15_L15_flag'])
|
112 |
|
113 |
# Calculate additional columns or operations if needed
|
114 |
+
df['zwerr'] = (df.z - df.ztarget) / (1 + df.ztarget)
|
115 |
|
116 |
# Drop any rows with NaN values
|
117 |
df = df.dropna()
|
|
|
132 |
# %%
|
133 |
if not eval_methods:
|
134 |
dfs = {}
|
135 |
+
dfs['z'] = pd.read_csv(parent_dir / 'predictions_specztraining.csv', header=0)
|
136 |
+
dfs['L15'] = pd.read_csv(parent_dir / 'predictions_speczL15training.csv', header=0)
|
137 |
+
dfs['DA'] = pd.read_csv(parent_dir / 'predictions_speczDAtraining.csv', header=0)
|
138 |
|
139 |
|
140 |
# %% [markdown]
|
141 |
# ### MAKE PLOT
|
142 |
|
143 |
# %%
|
144 |
+
df_list = [dfs['z'], dfs['L15'], dfs['DA']]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
|
146 |
# %%
|
147 |
plot_photoz(df_list,
|
148 |
nbins=8,
|
149 |
xvariable='VISmag',
|
150 |
+
metric='nmad',
|
151 |
type_bin='bin',
|
152 |
label_list = ['zs','zs+L15',r'TEMPS'],
|
153 |
save=False,
|
154 |
samp='L15'
|
155 |
)
|
|
|
|
|
|
|
|
notebooks/{Fig3_PIT_CRPS.py β PIT_CRPS.py}
RENAMED
@@ -1,93 +1,84 @@
|
|
1 |
# ---
|
2 |
# jupyter:
|
3 |
# jupytext:
|
4 |
-
# formats: ipynb,py:percent
|
5 |
# text_representation:
|
6 |
# extension: .py
|
7 |
-
# format_name:
|
8 |
-
# format_version: '1.
|
9 |
-
# jupytext_version: 1.
|
10 |
# kernelspec:
|
11 |
-
# display_name:
|
12 |
# language: python
|
13 |
-
# name:
|
14 |
# ---
|
15 |
|
16 |
-
#
|
17 |
-
# # FIGURE 3 IN THE PAPER
|
18 |
|
19 |
-
# %% [markdown]
|
20 |
# ## PIT AND CRPS FOR THE THREE METHODS
|
21 |
|
22 |
-
# %% [markdown]
|
23 |
# ### LOAD PYTHON MODULES
|
24 |
|
25 |
-
# %%
|
26 |
# %load_ext autoreload
|
27 |
# %autoreload 2
|
28 |
|
29 |
-
|
|
|
30 |
import pandas as pd
|
31 |
import numpy as np
|
32 |
import os
|
33 |
from astropy.io import fits
|
34 |
from astropy.table import Table
|
35 |
import torch
|
|
|
36 |
|
37 |
-
|
38 |
-
# %%
|
39 |
#matplotlib settings
|
40 |
from matplotlib import rcParams
|
41 |
import matplotlib.pyplot as plt
|
42 |
rcParams["mathtext.fontset"] = "stix"
|
43 |
rcParams["font.family"] = "STIXGeneral"
|
44 |
|
45 |
-
#
|
46 |
-
|
47 |
-
import
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
from utils import nmad
|
53 |
-
from plots import plot_PIT, plot_crps
|
54 |
-
from temps_arch import EncoderPhotometry, MeasureZ
|
55 |
-
from temps import Temps_module
|
56 |
|
|
|
57 |
|
58 |
-
# %% [markdown]
|
59 |
# ### LOAD DATA
|
60 |
|
61 |
-
#
|
62 |
-
|
|
|
|
|
|
|
63 |
only_zspec=False,
|
64 |
flags_kept=[1. , 1.1, 1.4, 1.5, 2,2.1,2.4,2.5,3., 3.1, 3.4, 3.5, 4., 9. , 9.1, 9.3, 9.4, 9.5,11.1, 11.5, 12.1, 12.5, 13. , 13.1, 13.5, 14, ],
|
65 |
target_test='L15')
|
66 |
f_test, ferr_test, specz_test ,VIS_mag_test = photoz_archive.get_testing_data()
|
67 |
|
68 |
|
69 |
-
# %% [markdown]
|
70 |
# ## CREATE PIT; CRPS; SPECTROSCOPIC SAMPLE
|
71 |
|
72 |
-
# %% [markdown]
|
73 |
# This loads pre-trained models (for the sake of time). You can learn how to train the models in the Tutorial notebook.
|
74 |
|
75 |
-
# %%
|
76 |
# Initialize an empty dictionary to store DataFrames
|
77 |
crps_dict = {}
|
78 |
pit_dict = {}
|
79 |
for il, lab in enumerate(['z','L15','DA']):
|
80 |
|
81 |
nn_features = EncoderPhotometry()
|
82 |
-
nn_features.load_state_dict(torch.load(
|
83 |
nn_z = MeasureZ(num_gauss=6)
|
84 |
-
nn_z.load_state_dict(torch.load(
|
85 |
|
86 |
-
|
87 |
|
88 |
|
89 |
-
pit_list =
|
90 |
-
crps_list =
|
91 |
|
92 |
|
93 |
# Assign the DataFrame to a key in the dictionary
|
@@ -95,7 +86,7 @@ for il, lab in enumerate(['z','L15','DA']):
|
|
95 |
pit_dict[lab] = pit_list
|
96 |
|
97 |
|
98 |
-
#
|
99 |
plot_PIT(pit_dict['z'],
|
100 |
pit_dict['L15'],
|
101 |
pit_dict['DA'],
|
@@ -106,7 +97,7 @@ plot_PIT(pit_dict['z'],
|
|
106 |
|
107 |
|
108 |
|
109 |
-
#
|
110 |
plot_crps(crps_dict['z'],
|
111 |
crps_dict['L15'],
|
112 |
crps_dict['DA'],
|
@@ -116,5 +107,6 @@ plot_crps(crps_dict['z'],
|
|
116 |
|
117 |
|
118 |
|
|
|
|
|
119 |
|
120 |
-
# %%
|
|
|
1 |
# ---
|
2 |
# jupyter:
|
3 |
# jupytext:
|
|
|
4 |
# text_representation:
|
5 |
# extension: .py
|
6 |
+
# format_name: light
|
7 |
+
# format_version: '1.5'
|
8 |
+
# jupytext_version: 1.16.2
|
9 |
# kernelspec:
|
10 |
+
# display_name: temps
|
11 |
# language: python
|
12 |
+
# name: temps
|
13 |
# ---
|
14 |
|
15 |
+
# # $p(z)$ DISTRIBUTIONS
|
|
|
16 |
|
|
|
17 |
# ## PIT AND CRPS FOR THE THREE METHODS
|
18 |
|
|
|
19 |
# ### LOAD PYTHON MODULES
|
20 |
|
|
|
21 |
# %load_ext autoreload
|
22 |
# %autoreload 2
|
23 |
|
24 |
+
import temps
|
25 |
+
|
26 |
import pandas as pd
|
27 |
import numpy as np
|
28 |
import os
|
29 |
from astropy.io import fits
|
30 |
from astropy.table import Table
|
31 |
import torch
|
32 |
+
from pathlib import Path
|
33 |
|
|
|
|
|
34 |
#matplotlib settings
|
35 |
from matplotlib import rcParams
|
36 |
import matplotlib.pyplot as plt
|
37 |
rcParams["mathtext.fontset"] = "stix"
|
38 |
rcParams["font.family"] = "STIXGeneral"
|
39 |
|
40 |
+
# +
|
41 |
+
from temps.temps import TempsModule
|
42 |
+
from temps.archive import Archive
|
43 |
+
from temps.utils import nmad
|
44 |
+
from temps.temps_arch import EncoderPhotometry, MeasureZ
|
45 |
+
from temps.plots import plot_photoz, plot_PIT, plot_crps
|
46 |
+
|
|
|
|
|
|
|
|
|
47 |
|
48 |
+
# -
|
49 |
|
|
|
50 |
# ### LOAD DATA
|
51 |
|
52 |
+
#define here the directory containing the photometric catalogues
|
53 |
+
parent_dir = Path('/data/astro/scratch/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5')
|
54 |
+
modules_dir = Path('../data/models/')
|
55 |
+
|
56 |
+
photoz_archive = Archive(path = parent_dir,
|
57 |
only_zspec=False,
|
58 |
flags_kept=[1. , 1.1, 1.4, 1.5, 2,2.1,2.4,2.5,3., 3.1, 3.4, 3.5, 4., 9. , 9.1, 9.3, 9.4, 9.5,11.1, 11.5, 12.1, 12.5, 13. , 13.1, 13.5, 14, ],
|
59 |
target_test='L15')
|
60 |
f_test, ferr_test, specz_test ,VIS_mag_test = photoz_archive.get_testing_data()
|
61 |
|
62 |
|
|
|
63 |
# ## CREATE PIT; CRPS; SPECTROSCOPIC SAMPLE
|
64 |
|
|
|
65 |
# This loads pre-trained models (for the sake of time). You can learn how to train the models in the Tutorial notebook.
|
66 |
|
|
|
67 |
# Initialize an empty dictionary to store DataFrames
|
68 |
crps_dict = {}
|
69 |
pit_dict = {}
|
70 |
for il, lab in enumerate(['z','L15','DA']):
|
71 |
|
72 |
nn_features = EncoderPhotometry()
|
73 |
+
nn_features.load_state_dict(torch.load(modules_dir / f'modelF_{lab}.pt',map_location=torch.device('cpu')))
|
74 |
nn_z = MeasureZ(num_gauss=6)
|
75 |
+
nn_z.load_state_dict(torch.load(modules_dir / f'modelZ_{lab}.pt',map_location=torch.device('cpu')))
|
76 |
|
77 |
+
temps_module = TempsModule(nn_features, nn_z)
|
78 |
|
79 |
|
80 |
+
pit_list = temps_module.calculate_pit(input_data=torch.Tensor(f_test), target_data=torch.Tensor(specz_test))
|
81 |
+
crps_list = temps_module.calculate_crps(input_data=torch.Tensor(f_test), target_data=specz_test)
|
82 |
|
83 |
|
84 |
# Assign the DataFrame to a key in the dictionary
|
|
|
86 |
pit_dict[lab] = pit_list
|
87 |
|
88 |
|
89 |
+
# +
|
90 |
plot_PIT(pit_dict['z'],
|
91 |
pit_dict['L15'],
|
92 |
pit_dict['DA'],
|
|
|
97 |
|
98 |
|
99 |
|
100 |
+
# +
|
101 |
plot_crps(crps_dict['z'],
|
102 |
crps_dict['L15'],
|
103 |
crps_dict['DA'],
|
|
|
107 |
|
108 |
|
109 |
|
110 |
+
# -
|
111 |
+
|
112 |
|
|
notebooks/Qualitycut.py
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ---
|
2 |
+
# jupyter:
|
3 |
+
# jupytext:
|
4 |
+
# text_representation:
|
5 |
+
# extension: .py
|
6 |
+
# format_name: light
|
7 |
+
# format_version: '1.5'
|
8 |
+
# jupytext_version: 1.16.2
|
9 |
+
# kernelspec:
|
10 |
+
# display_name: temps
|
11 |
+
# language: python
|
12 |
+
# name: temps
|
13 |
+
# ---
|
14 |
+
|
15 |
+
# # QUALITY CUTS
|
16 |
+
|
17 |
+
# %load_ext autoreload
|
18 |
+
# %autoreload 2
|
19 |
+
|
20 |
+
import pandas as pd
|
21 |
+
import numpy as np
|
22 |
+
import os
|
23 |
+
import torch
|
24 |
+
from scipy import stats
|
25 |
+
from pathlib import Path
|
26 |
+
|
27 |
+
#matplotlib settings
|
28 |
+
from matplotlib import rcParams
|
29 |
+
import matplotlib.pyplot as plt
|
30 |
+
rcParams["mathtext.fontset"] = "stix"
|
31 |
+
rcParams["font.family"] = "STIXGeneral"
|
32 |
+
|
33 |
+
from temps.archive import Archive
|
34 |
+
from temps.utils import nmad, caluclate_eta
|
35 |
+
from temps.temps_arch import EncoderPhotometry, MeasureZ
|
36 |
+
from temps.temps import TempsModule
|
37 |
+
|
38 |
+
|
39 |
+
# ### LOAD DATA (ONLY SPECZ)
|
40 |
+
|
41 |
+
#define here the directory containing the photometric catalogues
|
42 |
+
parent_dir = Path('/data/astro/scratch/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5')
|
43 |
+
modules_dir = Path('../data/models/')
|
44 |
+
|
45 |
+
photoz_archive = Archive(path = parent_dir,only_zspec=True,flags_kept=[1. , 1.1, 1.4, 1.5, 2,2.1,2.4,2.5,3., 3.1, 3.4, 3.5, 4., 9. , 9.1, 9.3, 9.4, 9.5,11.1, 11.5, 12.1, 12.5, 13. , 13.1, 13.5, 14, ])
|
46 |
+
f_test_specz, ferr_test_specz, specz_test ,VIS_mag_test = photoz_archive.get_testing_data()
|
47 |
+
|
48 |
+
|
49 |
+
# ### LOAD TRAINED MODELS AND EVALUATE PDF OF RANDOM EXAMPLES
|
50 |
+
|
51 |
+
# Initialize an empty dictionary to store DataFrames
|
52 |
+
dfs = {}
|
53 |
+
pzs = np.zeros(shape = (3,11016,1000))
|
54 |
+
for il, lab in enumerate(['z','L15','DA']):
|
55 |
+
|
56 |
+
nn_features = EncoderPhotometry()
|
57 |
+
nn_features.load_state_dict(torch.load(modules_dir / f'modelF_{lab}.pt',map_location=torch.device('cpu')))
|
58 |
+
nn_z = MeasureZ(num_gauss=6)
|
59 |
+
nn_z.load_state_dict(torch.load(modules_dir / f'modelZ_{lab}.pt', map_location=torch.device('cpu')))
|
60 |
+
|
61 |
+
temps_module = TempsModule(nn_features, nn_z)
|
62 |
+
|
63 |
+
z, pz, odds = temps_module.get_pz(input_data=torch.Tensor(f_test_specz),
|
64 |
+
return_pz=True)
|
65 |
+
|
66 |
+
pzs[il] = pz
|
67 |
+
|
68 |
+
# Create a DataFrame with the desired columns
|
69 |
+
df = pd.DataFrame(np.c_[z, odds, specz_test],
|
70 |
+
columns=['z', 'odds' ,'ztarget'])
|
71 |
+
|
72 |
+
# Calculate additional columns or operations if needed
|
73 |
+
df['zwerr'] = (df.z - df.ztarget) / (1 + df.ztarget)
|
74 |
+
|
75 |
+
# Drop any rows with NaN values
|
76 |
+
df = df.dropna()
|
77 |
+
|
78 |
+
# Assign the DataFrame to a key in the dictionary
|
79 |
+
dfs[lab] = df
|
80 |
+
|
81 |
+
|
82 |
+
# ### STATS
|
83 |
+
|
84 |
+
# +
|
85 |
+
#odds_test = [0, 0.01, 0.03, 0.05, 0.07, 0.1, 0.13, 0.15]
|
86 |
+
odds_test = np.arange(0,0.15,0.01)
|
87 |
+
|
88 |
+
df = dfs['DA'].copy()
|
89 |
+
zgrid = np.linspace(0, 5, 1000)
|
90 |
+
pz = pzs[2]
|
91 |
+
# -
|
92 |
+
|
93 |
+
diff_matrix = np.abs(df.z.values[:,None] - zgrid[None,:])
|
94 |
+
idx_peak = np.argmax(pz,1)
|
95 |
+
idx = np.argmin(diff_matrix,1)
|
96 |
+
|
97 |
+
odds_cat = np.zeros(shape = (len(odds_test),len(df)))
|
98 |
+
for ii, odds_ in enumerate(odds_test):
|
99 |
+
diff_matrix_upper = np.abs((df.z.values+odds_)[:,None] - zgrid[None,:])
|
100 |
+
diff_matrix_lower = np.abs((df.z.values-odds_)[:,None] - zgrid[None,:])
|
101 |
+
|
102 |
+
idx = np.argmin(diff_matrix,1)
|
103 |
+
idx_upper = np.argmin(diff_matrix_upper,1)
|
104 |
+
idx_lower = np.argmin(diff_matrix_lower,1)
|
105 |
+
|
106 |
+
odds = []
|
107 |
+
for jj in range(len(pz)):
|
108 |
+
odds.append(pz[jj,idx_lower[jj]:(idx_upper[jj]+1)].sum())
|
109 |
+
|
110 |
+
odds_cat[ii] = np.array(odds)
|
111 |
+
|
112 |
+
|
113 |
+
odds_df = pd.DataFrame(odds_cat.T, columns=[f'odds_{x}' for x in odds_test])
|
114 |
+
df = pd.concat([df, odds_df], axis=1)
|
115 |
+
|
116 |
+
|
117 |
+
# ## statistics on ODDS
|
118 |
+
|
119 |
+
# +
|
120 |
+
scatter_odds, eta_odds,xlab_odds, oddsmean = [],[],[], []
|
121 |
+
|
122 |
+
for c in complenteness:
|
123 |
+
percentile_cutoff = df['odds'].quantile(c)
|
124 |
+
|
125 |
+
df_bin = df[(df.odds > percentile_cutoff)]
|
126 |
+
|
127 |
+
xlab_odds.append((1-c)*100)
|
128 |
+
oddsmean.append(np.mean(df_bin.odds))
|
129 |
+
scatter_odds.append(nmad(df_bin.zwerr))
|
130 |
+
eta_odds.append(caluclate_eta(df_bin))
|
131 |
+
if np.round(c,1) ==0.3:
|
132 |
+
percentiles_cutoff = [df[f'odds_{col}'].quantile(c) for col in odds_test]
|
133 |
+
scatters_odds = [nmad(df[df[f'odds_{col}'] > percentile_cutoff].zwerr) for (col, percentile_cutoff) in zip(odds_test,percentiles_cutoff)]
|
134 |
+
etas_odds = [caluclate_eta(df[df[f'odds_{col}'] > percentile_cutoff]) for (col, percentile_cutoff) in zip(odds_test,percentiles_cutoff)]
|
135 |
+
|
136 |
+
|
137 |
+
|
138 |
+
|
139 |
+
# -
|
140 |
+
|
141 |
+
df_completeness = pd.DataFrame(np.c_[xlab_odds,scatter_odds, eta_odds],
|
142 |
+
columns = ['completeness', 'sigma_odds', 'eta_odds'])
|
143 |
+
|
144 |
+
# ## PLOTS
|
145 |
+
|
146 |
+
# +
|
147 |
+
# Initialize the figure and axis
|
148 |
+
fig, ax1 = plt.subplots(figsize=(7, 5))
|
149 |
+
|
150 |
+
# First plot (Sigma) - using the left y-axis
|
151 |
+
color = 'crimson'
|
152 |
+
ax1.plot(df_completeness.completeness,
|
153 |
+
df_completeness.sigma_odds,
|
154 |
+
marker='.',
|
155 |
+
color=color,
|
156 |
+
label=r'NMAD',
|
157 |
+
ls='-',
|
158 |
+
alpha=0.5,
|
159 |
+
)
|
160 |
+
|
161 |
+
|
162 |
+
ax1.set_xlabel('Completeness', fontsize=16)
|
163 |
+
ax1.set_ylabel(r'NMAD [$\Delta z$]', color=color, fontsize=16)
|
164 |
+
ax1.tick_params(axis='x', labelsize=14)
|
165 |
+
ax1.tick_params(axis='y', which='major', labelsize = 14, width=2.5, length=3, labelcolor=color)
|
166 |
+
ax1.set_xticks(np.arange(5, 101, 10))
|
167 |
+
|
168 |
+
ax2 = ax1.twinx() # Create another y-axis that shares the same x-axis
|
169 |
+
color = 'navy'
|
170 |
+
ax2.plot(df_completeness.completeness,
|
171 |
+
df_completeness.eta_odds,
|
172 |
+
marker='.',
|
173 |
+
color=color,
|
174 |
+
label=r'$\eta$ [%]',
|
175 |
+
ls='--',
|
176 |
+
alpha=0.5)
|
177 |
+
|
178 |
+
ax2.set_ylabel(r'$\eta$ [%]', color=color, fontsize=16)
|
179 |
+
|
180 |
+
# Adjust notation to allow comparison
|
181 |
+
ax1.yaxis.get_major_formatter().set_powerlimits((0, 0)) # Adjust scientific notation for Sigma
|
182 |
+
ax2.yaxis.get_major_formatter().set_powerlimits((0, 0)) # Adjust scientific notation for Eta
|
183 |
+
ax2.tick_params(axis='x', labelsize=14)
|
184 |
+
ax2.tick_params(axis='y', which='major', labelsize = 14, width=2.5, length=3, labelcolor=color)
|
185 |
+
|
186 |
+
# Final adjustments
|
187 |
+
fig.tight_layout()
|
188 |
+
fig.legend(bbox_to_anchor = [-0.18,0.75,0.5,0.2], fontsize = 14)
|
189 |
+
#plt.savefig('Flag_nmad_eta_sigma_comparison.pdf', bbox_inches='tight')
|
190 |
+
plt.show()
|
191 |
+
|
192 |
+
|
193 |
+
# +
|
194 |
+
# Initialize the figure and axis
|
195 |
+
fig, ax1 = plt.subplots(figsize=(7, 5))
|
196 |
+
|
197 |
+
# First plot (Sigma) - using the left y-axis
|
198 |
+
color = 'crimson'
|
199 |
+
ax1.plot(odds_test,
|
200 |
+
scatters_odds,
|
201 |
+
marker='.',
|
202 |
+
color=color,
|
203 |
+
label=r'NMAD',
|
204 |
+
ls='-',
|
205 |
+
alpha=0.5,
|
206 |
+
)
|
207 |
+
|
208 |
+
|
209 |
+
ax1.set_xlabel(r'$\delta z$ (ODDS)', fontsize=16)
|
210 |
+
ax1.set_ylabel(r'NMAD [$\Delta z$]', color=color, fontsize=16)
|
211 |
+
ax1.tick_params(axis='x', labelsize=14)
|
212 |
+
ax1.tick_params(axis='y', which='major', labelsize = 14, width=2.5, length=3, labelcolor=color)
|
213 |
+
ax1.set_xticks(np.arange(0,0.16,0.02))
|
214 |
+
|
215 |
+
ax2 = ax1.twinx() # Create another y-axis that shares the same x-axis
|
216 |
+
color = 'navy'
|
217 |
+
ax2.plot(odds_test,
|
218 |
+
etas_odds,
|
219 |
+
marker='.',
|
220 |
+
color=color,
|
221 |
+
label=r'$\eta$ [%]',
|
222 |
+
ls='--',
|
223 |
+
alpha=0.5)
|
224 |
+
|
225 |
+
ax2.set_ylabel(r'$\eta$ [%]', color=color, fontsize=16)
|
226 |
+
|
227 |
+
# Adjust notation to allow comparison
|
228 |
+
ax1.yaxis.get_major_formatter().set_powerlimits((0, 0)) # Adjust scientific notation for Sigma
|
229 |
+
ax2.yaxis.get_major_formatter().set_powerlimits((0, 0)) # Adjust scientific notation for Eta
|
230 |
+
ax2.tick_params(axis='x', labelsize=14)
|
231 |
+
ax2.tick_params(axis='y', which='major', labelsize = 14, width=2.5, length=3, labelcolor=color)
|
232 |
+
|
233 |
+
# Final adjustments
|
234 |
+
fig.tight_layout()
|
235 |
+
fig.legend(bbox_to_anchor = [0.10,0.75,0.5,0.2], fontsize = 14)
|
236 |
+
#plt.savefig('ODDS_study.pdf', bbox_inches='tight')
|
237 |
+
plt.show()
|
238 |
+
|
239 |
+
# -
|
240 |
+
|
241 |
+
|
notebooks/Table_metrics.py
CHANGED
@@ -5,11 +5,11 @@
|
|
5 |
# extension: .py
|
6 |
# format_name: light
|
7 |
# format_version: '1.5'
|
8 |
-
# jupytext_version: 1.
|
9 |
# kernelspec:
|
10 |
-
# display_name:
|
11 |
# language: python
|
12 |
-
# name:
|
13 |
# ---
|
14 |
|
15 |
# # TABLE METRICS
|
@@ -24,6 +24,7 @@ import torch
|
|
24 |
from scipy import stats
|
25 |
from astropy.io import fits
|
26 |
from astropy.table import Table
|
|
|
27 |
|
28 |
#matplotlib settings
|
29 |
from matplotlib import rcParams
|
@@ -31,27 +32,22 @@ import matplotlib.pyplot as plt
|
|
31 |
rcParams["mathtext.fontset"] = "stix"
|
32 |
rcParams["font.family"] = "STIXGeneral"
|
33 |
|
34 |
-
|
35 |
-
import
|
36 |
-
|
37 |
-
|
38 |
-
#from insight import Insight_module
|
39 |
-
from archive import archive
|
40 |
-
from utils import nmad, select_cut
|
41 |
-
from temps_arch import EncoderPhotometry, MeasureZ
|
42 |
-
from temps import Temps_module
|
43 |
|
44 |
|
45 |
# ## LOAD DATA
|
46 |
|
47 |
#define here the directory containing the photometric catalogues
|
48 |
-
parent_dir = '/data/astro/
|
49 |
-
modules_dir = '../data/models/'
|
50 |
|
51 |
# +
|
52 |
filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
|
53 |
|
54 |
-
hdu_list = fits.open(
|
55 |
cat = Table(hdu_list[1].data).to_pandas()
|
56 |
cat = cat[cat['FLAG_PHOT']==0]
|
57 |
cat = cat[cat['mu_class_L07']==1]
|
@@ -74,7 +70,7 @@ cat = cat[cat.ztarget>0]
|
|
74 |
|
75 |
# ### EXTRACT PHOTOMETRY
|
76 |
|
77 |
-
photoz_archive =
|
78 |
f, ferr = photoz_archive._extract_fluxes(catalogue= cat)
|
79 |
col, colerr = photoz_archive._to_colors(f, ferr)
|
80 |
|
@@ -84,19 +80,19 @@ col, colerr = photoz_archive._to_colors(f, ferr)
|
|
84 |
# Initialize an empty dictionary to store DataFrames
|
85 |
lab='DA'
|
86 |
nn_features = EncoderPhotometry()
|
87 |
-
nn_features.load_state_dict(torch.load(
|
88 |
nn_z = MeasureZ(num_gauss=6)
|
89 |
-
nn_z.load_state_dict(torch.load(
|
90 |
|
91 |
-
|
92 |
|
93 |
-
z,
|
94 |
return_pz=True)
|
95 |
|
96 |
|
97 |
# Create a DataFrame with the desired columns
|
98 |
-
df = pd.DataFrame(np.c_[z,
|
99 |
-
columns=['z',
|
100 |
|
101 |
# Calculate additional columns or operations if needed
|
102 |
df['zwerr'] = (df.z - df.ztarget) / (1 + df.ztarget)
|
@@ -130,10 +126,12 @@ print(dfcuts.to_latex(float_format="%.3f",
|
|
130 |
|
131 |
df_euclid = df[(df.z >0.2)&(df.z < 2.6)]
|
132 |
|
|
|
|
|
133 |
# +
|
134 |
df_selected, cut, dfcuts = select_cut(df_euclid,
|
135 |
completenss_lim=None,
|
136 |
-
nmad_lim=0.
|
137 |
outliers_lim=None,
|
138 |
return_df=True)
|
139 |
|
|
|
5 |
# extension: .py
|
6 |
# format_name: light
|
7 |
# format_version: '1.5'
|
8 |
+
# jupytext_version: 1.16.2
|
9 |
# kernelspec:
|
10 |
+
# display_name: temps
|
11 |
# language: python
|
12 |
+
# name: temps
|
13 |
# ---
|
14 |
|
15 |
# # TABLE METRICS
|
|
|
24 |
from scipy import stats
|
25 |
from astropy.io import fits
|
26 |
from astropy.table import Table
|
27 |
+
from pathlib import Path
|
28 |
|
29 |
#matplotlib settings
|
30 |
from matplotlib import rcParams
|
|
|
32 |
rcParams["mathtext.fontset"] = "stix"
|
33 |
rcParams["font.family"] = "STIXGeneral"
|
34 |
|
35 |
+
from temps.archive import Archive
|
36 |
+
from temps.utils import nmad, select_cut
|
37 |
+
from temps.temps_arch import EncoderPhotometry, MeasureZ
|
38 |
+
from temps.temps import TempsModule
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
|
41 |
# ## LOAD DATA
|
42 |
|
43 |
#define here the directory containing the photometric catalogues
|
44 |
+
parent_dir = Path('/data/astro/scratch/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5')
|
45 |
+
modules_dir = Path('../data/models/')
|
46 |
|
47 |
# +
|
48 |
filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
|
49 |
|
50 |
+
hdu_list = fits.open(parent_dir / filename_valid)
|
51 |
cat = Table(hdu_list[1].data).to_pandas()
|
52 |
cat = cat[cat['FLAG_PHOT']==0]
|
53 |
cat = cat[cat['mu_class_L07']==1]
|
|
|
70 |
|
71 |
# ### EXTRACT PHOTOMETRY
|
72 |
|
73 |
+
photoz_archive = Archive(path = parent_dir,only_zspec=False)
|
74 |
f, ferr = photoz_archive._extract_fluxes(catalogue= cat)
|
75 |
col, colerr = photoz_archive._to_colors(f, ferr)
|
76 |
|
|
|
80 |
# Initialize an empty dictionary to store DataFrames
|
81 |
lab='DA'
|
82 |
nn_features = EncoderPhotometry()
|
83 |
+
nn_features.load_state_dict(torch.load(modules_dir / f'modelF_{lab}.pt', map_location=torch.device('cpu')))
|
84 |
nn_z = MeasureZ(num_gauss=6)
|
85 |
+
nn_z.load_state_dict(torch.load(modules_dir / f'modelZ_{lab}.pt', map_location=torch.device('cpu')))
|
86 |
|
87 |
+
temps_module = TempsModule(nn_features, nn_z)
|
88 |
|
89 |
+
z, pz, odds = temps_module.get_pz(input_data=torch.Tensor(col),
|
90 |
return_pz=True)
|
91 |
|
92 |
|
93 |
# Create a DataFrame with the desired columns
|
94 |
+
df = pd.DataFrame(np.c_[z, odds, cat.ztarget, cat.reliable_S15, cat.specz_or_photo],
|
95 |
+
columns=['z', 'odds' ,'ztarget','reliable_S15', 'specz_or_photo'])
|
96 |
|
97 |
# Calculate additional columns or operations if needed
|
98 |
df['zwerr'] = (df.z - df.ztarget) / (1 + df.ztarget)
|
|
|
126 |
|
127 |
df_euclid = df[(df.z >0.2)&(df.z < 2.6)]
|
128 |
|
129 |
+
df_euclid
|
130 |
+
|
131 |
# +
|
132 |
df_selected, cut, dfcuts = select_cut(df_euclid,
|
133 |
completenss_lim=None,
|
134 |
+
nmad_lim= 0.05,
|
135 |
outliers_lim=None,
|
136 |
return_df=True)
|
137 |
|
notebooks/nz.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ---
|
2 |
+
# jupyter:
|
3 |
+
# jupytext:
|
4 |
+
# text_representation:
|
5 |
+
# extension: .py
|
6 |
+
# format_name: light
|
7 |
+
# format_version: '1.5'
|
8 |
+
# jupytext_version: 1.16.2
|
9 |
+
# kernelspec:
|
10 |
+
# display_name: temps
|
11 |
+
# language: python
|
12 |
+
# name: temps
|
13 |
+
# ---
|
14 |
+
|
15 |
+
# # FIGURE 5 IN THE PAPER
|
16 |
+
|
17 |
+
# ## n(z) distributions
|
18 |
+
|
19 |
+
# %load_ext autoreload
|
20 |
+
# %autoreload 2
|
21 |
+
|
22 |
+
import pandas as pd
|
23 |
+
import numpy as np
|
24 |
+
from astropy.io import fits
|
25 |
+
from astropy.table import Table
|
26 |
+
import torch
|
27 |
+
from pathlib import Path
|
28 |
+
|
29 |
+
#matplotlib settings
|
30 |
+
from matplotlib import rcParams
|
31 |
+
import matplotlib.pyplot as plt
|
32 |
+
rcParams["mathtext.fontset"] = "stix"
|
33 |
+
rcParams["font.family"] = "STIXGeneral"
|
34 |
+
|
35 |
+
from temps.archive import Archive
|
36 |
+
from temps.utils import nmad
|
37 |
+
from temps.temps_arch import EncoderPhotometry, MeasureZ
|
38 |
+
from temps.temps import TempsModule
|
39 |
+
from temps.plots import plot_nz
|
40 |
+
|
41 |
+
eval_methods=False
|
42 |
+
|
43 |
+
# ### LOAD DATA
|
44 |
+
|
45 |
+
#define here the directory containing the photometric catalogues
|
46 |
+
parent_dir = Path('/data/astro/scratch/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5')
|
47 |
+
modules_dir = Path('../data/models/')
|
48 |
+
|
49 |
+
# +
|
50 |
+
filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
|
51 |
+
|
52 |
+
hdu_list = fits.open(parent_dir / filename_valid)
|
53 |
+
cat = Table(hdu_list[1].data).to_pandas()
|
54 |
+
cat = cat[cat['FLAG_PHOT']==0]
|
55 |
+
cat = cat[cat['mu_class_L07']==1]
|
56 |
+
cat = cat[(cat['z_spec_S15'] > 0) | (cat['photo_z_L15'] > 0)]
|
57 |
+
cat = cat[cat['MAG_VIS']<25]
|
58 |
+
|
59 |
+
# -
|
60 |
+
|
61 |
+
ztarget = [cat['z_spec_S15'].values[ii] if cat['z_spec_S15'].values[ii]> 0 else cat['photo_z_L15'].values[ii] for ii in range(len(cat))]
|
62 |
+
specz_or_photo = [0 if cat['z_spec_S15'].values[ii]> 0 else 1 for ii in range(len(cat))]
|
63 |
+
ID = cat['ID']
|
64 |
+
VISmag = cat['MAG_VIS']
|
65 |
+
zsflag = cat['reliable_S15']
|
66 |
+
|
67 |
+
photoz_archive = Archive(path = parent_dir,only_zspec=False)
|
68 |
+
f, ferr = photoz_archive._extract_fluxes(catalogue= cat)
|
69 |
+
col, colerr = photoz_archive._to_colors(f, ferr)
|
70 |
+
|
71 |
+
# ### LOAD TRAINED MODELS AND EVALUATE PDFs AND REDSHIFT
|
72 |
+
|
73 |
+
if eval_methods:
|
74 |
+
dfs = {}
|
75 |
+
|
76 |
+
for il, lab in enumerate(['z','L15','DA']):
|
77 |
+
|
78 |
+
nn_features = EncoderPhotometry()
|
79 |
+
nn_features.load_state_dict(torch.load(modules_dir / f'modelF_{lab}.pt',map_location=torch.device('cpu')))
|
80 |
+
nn_z = MeasureZ(num_gauss=6)
|
81 |
+
nn_z.load_state_dict(torch.load(modules_dir / f'modelZ_{lab}.pt',map_location=torch.device('cpu')))
|
82 |
+
|
83 |
+
temps_module = TempsModule(nn_features, nn_z)
|
84 |
+
|
85 |
+
z, pz, odds = temps_module.get_pz(input_data=torch.Tensor(col),
|
86 |
+
return_pz=True)
|
87 |
+
# Create a DataFrame with the desired columns
|
88 |
+
df = pd.DataFrame(np.c_[ID, VISmag,z, odds, ztarget,zsflag, specz_or_photo],
|
89 |
+
columns=['ID','VISmag','z','odds', 'ztarget','zsflag','S15_L15_flag'])
|
90 |
+
|
91 |
+
# Calculate additional columns or operations if needed
|
92 |
+
df['zwerr'] = (df.z - df.ztarget) / (1 + df.ztarget)
|
93 |
+
|
94 |
+
# Drop any rows with NaN values
|
95 |
+
df = df.dropna()
|
96 |
+
|
97 |
+
# Assign the DataFrame to a key in the dictionary
|
98 |
+
dfs[lab] = df
|
99 |
+
|
100 |
+
|
101 |
+
# ### LOAD CATALOGUES IF AVAILABLE
|
102 |
+
|
103 |
+
if not eval_methods:
|
104 |
+
|
105 |
+
df_zs = pd.read_csv(parent_dir / 'predictions_specztraining.csv', header=0)
|
106 |
+
df_zsL15 = pd.read_csv(parent_dir / 'predictions_speczL15training.csv', header=0)
|
107 |
+
df_DA = pd.read_csv(parent_dir / 'predictions_speczDAtraining.csv', header=0)
|
108 |
+
|
109 |
+
|
110 |
+
dfs = {}
|
111 |
+
dfs['z'] = df_zs
|
112 |
+
dfs['L15'] = df_zsL15
|
113 |
+
dfs['DA'] = df_DA
|
114 |
+
|
115 |
+
# +
|
116 |
+
import matplotlib.pyplot as plt
|
117 |
+
from matplotlib import gridspec
|
118 |
+
|
119 |
+
# Create figure and grid specification
|
120 |
+
fig = plt.figure(figsize=(8, 10))
|
121 |
+
gs = gridspec.GridSpec(5, 1, height_ratios=[0.1, 1, 1,1,1])
|
122 |
+
|
123 |
+
# Upper panel (very thin) with shaded areas
|
124 |
+
ax1 = plt.subplot(gs[0])
|
125 |
+
ax1.set_yticks([])
|
126 |
+
|
127 |
+
ax1.set_ylabel('Bins', fontsize=10)
|
128 |
+
|
129 |
+
# Define the ranges for shaded areas
|
130 |
+
#z_ranges = [[0.15, 0.35], [0.35, 0.55], [0.55, 0.85], [0.85, 1.05], [1.05, 1.35],
|
131 |
+
# [1.35, 1.55],# [1.55, 1.85], [1.85, 2], [2, 2.5], [2.5, 3], [3, 4]]
|
132 |
+
|
133 |
+
z_ranges = [[0.15, 0.5], [0.5, 1], [1, 1.5], [1.5,2]]#, [2, 3], [3,4]]#,
|
134 |
+
#[1.35, 1.55],# [1.55, 1.85], [1.85, 2], [2, 2.5], [2.5, 3], [3, 4]]
|
135 |
+
|
136 |
+
colors = ['deepskyblue', 'forestgreen', 'coral', 'grey', 'pink', 'goldenrod',
|
137 |
+
'cyan', 'seagreen', 'salmon', 'steelblue', 'orange']
|
138 |
+
|
139 |
+
# Plot shaded areas
|
140 |
+
x_values = [0, 1, 2] # Example x values, adjust as needed
|
141 |
+
for i, (start, end) in enumerate(z_ranges):
|
142 |
+
ax1.fill_betweenx(x_values, start, end, color=colors[i], alpha=0.5)
|
143 |
+
|
144 |
+
# Middle panel (equally thick)
|
145 |
+
ax2 = plt.subplot(gs[1])
|
146 |
+
for i, (start, end) in enumerate(z_ranges):
|
147 |
+
dfplot_z = dfs['z'][(dfs['z']['ztarget'] > start) & (dfs['z']['ztarget'] < end)]
|
148 |
+
ax2.hist(dfplot_z.ztarget, bins=50, color=colors[i], histtype='step', linestyle='-', density=True, range=(0, 4))
|
149 |
+
|
150 |
+
# Bottom panel (equally thick)
|
151 |
+
ax3 = plt.subplot(gs[2])
|
152 |
+
for i, (start, end) in enumerate(z_ranges):
|
153 |
+
dfplot_z = dfs['z'][(dfs['z']['z'] > start) & (dfs['z']['z'] < end)]
|
154 |
+
ax3.hist(dfplot_z.ztarget, bins=50, color=colors[i], histtype='step', linestyle='-', density=True, range=(0, 4))
|
155 |
+
|
156 |
+
# Bottom panel (equally thick)
|
157 |
+
ax4 = plt.subplot(gs[3])
|
158 |
+
for i, (start, end) in enumerate(z_ranges):
|
159 |
+
dfplot_z = dfs['L15'][(dfs['L15']['z'] > start) & (dfs['L15']['z'] < end)]
|
160 |
+
print(len(dfplot_z))
|
161 |
+
ax4.hist(dfplot_z.ztarget, bins=50, color=colors[i], histtype='step', linestyle='-', density=True, range=(0, 4))
|
162 |
+
|
163 |
+
ax5 = plt.subplot(gs[4])
|
164 |
+
for i, (start, end) in enumerate(z_ranges):
|
165 |
+
dfplot_z = dfs['DA'][(dfs['DA']['z'] > start) & (dfs['DA']['z'] < end)]
|
166 |
+
ax5.hist(dfplot_z.ztarget, bins=50, color=colors[i], histtype='step', linestyle='-', density=True, range=(0, 4))
|
167 |
+
|
168 |
+
plt.tight_layout()
|
169 |
+
plt.show()
|
170 |
+
|
171 |
+
# -
|
172 |
+
|
173 |
+
def plot_nz(df_list,
|
174 |
+
zcuts = [0.1, 0.5, 1, 1.5, 2, 3, 4],
|
175 |
+
save=False):
|
176 |
+
# Plot properties
|
177 |
+
plt.rcParams['font.family'] = 'serif'
|
178 |
+
plt.rcParams['font.size'] = 16
|
179 |
+
|
180 |
+
cmap = plt.get_cmap('Dark2') # Choose a colormap for coloring lines
|
181 |
+
|
182 |
+
# Create subplots
|
183 |
+
fig, axs = plt.subplots(3, 1, figsize=(20, 8), sharex=True)
|
184 |
+
|
185 |
+
for i, df in enumerate(df_list):
|
186 |
+
dfplot = df_list[i].copy() # Assuming df_list contains dataframes
|
187 |
+
ax = axs[i] # Selecting the appropriate subplot
|
188 |
+
|
189 |
+
for iz in range(len(zcuts)-1):
|
190 |
+
dfplot_z = dfplot[(dfplot['ztarget'] > zcuts[iz]) & (dfplot['ztarget'] < zcuts[iz + 1])]
|
191 |
+
color = cmap(iz) # Get a different color for each redshift
|
192 |
+
|
193 |
+
zt_mean = np.median(dfplot_z.ztarget.values)
|
194 |
+
zp_mean = np.median(dfplot_z.z.values)
|
195 |
+
|
196 |
+
|
197 |
+
# Plot histogram on the selected subplot
|
198 |
+
ax.hist(dfplot_z.z, bins=50, color=color, histtype='step', linestyle='-', density=True, range=(0, 4))
|
199 |
+
ax.axvline(zt_mean, color=color, linestyle='-', lw=2)
|
200 |
+
ax.axvline(zp_mean, color=color, linestyle='--', lw=2)
|
201 |
+
|
202 |
+
ax.set_ylabel(f'Frequency', fontsize=14)
|
203 |
+
ax.grid(False)
|
204 |
+
ax.set_xlim(0, 3.5)
|
205 |
+
|
206 |
+
axs[-1].set_xlabel(f'$z$', fontsize=18)
|
207 |
+
|
208 |
+
if save:
|
209 |
+
plt.savefig(f'nz_hist.pdf', dpi=300, bbox_inches='tight')
|
210 |
+
|
211 |
+
plt.show()
|
212 |
+
|
213 |
+
plot_nz(df_list)
|
214 |
+
|
215 |
+
|
notebooks/{Fig4_pz_examples.py β pz_examples.py}
RENAMED
@@ -1,70 +1,55 @@
|
|
1 |
# ---
|
2 |
# jupyter:
|
3 |
# jupytext:
|
4 |
-
# formats: ipynb,py:percent
|
5 |
# text_representation:
|
6 |
# extension: .py
|
7 |
-
# format_name:
|
8 |
-
# format_version: '1.
|
9 |
-
# jupytext_version: 1.
|
10 |
# kernelspec:
|
11 |
-
# display_name:
|
12 |
# language: python
|
13 |
-
# name:
|
14 |
# ---
|
15 |
|
16 |
-
#
|
17 |
-
# # FIGURE 4 IN THE PAPER
|
18 |
|
19 |
-
# %% [markdown]
|
20 |
# ## IMPACT OF TEMPS ON CONCRETE P(Z) EXAMPLES
|
21 |
|
22 |
-
# %% [markdown]
|
23 |
# ### LOAD PYTHON MODULES
|
24 |
|
25 |
-
# %%
|
26 |
# %load_ext autoreload
|
27 |
# %autoreload 2
|
28 |
|
29 |
-
# %%
|
30 |
import pandas as pd
|
31 |
import numpy as np
|
32 |
import os
|
33 |
from astropy.io import fits
|
34 |
from astropy.table import Table
|
35 |
import torch
|
|
|
36 |
|
37 |
-
# %%
|
38 |
#matplotlib settings
|
39 |
from matplotlib import rcParams
|
40 |
import matplotlib.pyplot as plt
|
41 |
rcParams["mathtext.fontset"] = "stix"
|
42 |
rcParams["font.family"] = "STIXGeneral"
|
43 |
|
44 |
-
|
45 |
-
|
46 |
-
import
|
47 |
-
|
48 |
-
#from insight_arch import EncoderPhotometry, MeasureZ
|
49 |
-
#from insight import Insight_module
|
50 |
-
from archive import archive
|
51 |
-
from utils import nmad
|
52 |
-
from temps_arch import EncoderPhotometry, MeasureZ
|
53 |
-
from temps import Temps_module
|
54 |
|
55 |
|
56 |
-
# %% [markdown]
|
57 |
# ### LOAD DATA
|
58 |
|
59 |
-
# %%
|
60 |
#define here the directory containing the photometric catalogues
|
61 |
-
parent_dir = '/data/astro/
|
62 |
-
modules_dir = '../data/models/'
|
63 |
|
64 |
-
# %%
|
65 |
filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
|
66 |
-
|
67 |
-
hdu_list = fits.open(
|
68 |
cat = Table(hdu_list[1].data).to_pandas()
|
69 |
cat = cat[cat['FLAG_PHOT']==0]
|
70 |
cat = cat[cat['mu_class_L07']==1]
|
@@ -72,46 +57,41 @@ cat = cat[(cat['z_spec_S15'] > 0) | (cat['photo_z_L15'] > 0)]
|
|
72 |
cat = cat[cat['MAG_VIS']<25]
|
73 |
|
74 |
|
75 |
-
# %%
|
76 |
ztarget = [cat['z_spec_S15'].values[ii] if cat['z_spec_S15'].values[ii]> 0 else cat['photo_z_L15'].values[ii] for ii in range(len(cat))]
|
77 |
specz_or_photo = [0 if cat['z_spec_S15'].values[ii]> 0 else 1 for ii in range(len(cat))]
|
78 |
ID = cat['ID']
|
79 |
VISmag = cat['MAG_VIS']
|
80 |
zsflag = cat['reliable_S15']
|
81 |
|
82 |
-
|
83 |
-
photoz_archive = archive(path = parent_dir,only_zspec=False)
|
84 |
f, ferr = photoz_archive._extract_fluxes(catalogue= cat)
|
85 |
col, colerr = photoz_archive._to_colors(f, ferr)
|
86 |
|
87 |
-
# %% [markdown]
|
88 |
# ### LOAD TRAINED MODELS AND EVALUATE PDF OF RANDOM EXAMPLES
|
89 |
|
90 |
-
# %% [markdown]
|
91 |
# The notebook 'Tutorial_temps' gives an example of how to train and save models.
|
92 |
|
93 |
-
# %%
|
94 |
# Initialize an empty dictionary to store DataFrames
|
95 |
ii = np.random.randint(0,len(col),1)
|
96 |
pz_dict = {}
|
97 |
for il, lab in enumerate(['z','L15','DA']):
|
98 |
|
99 |
nn_features = EncoderPhotometry()
|
100 |
-
nn_features.load_state_dict(torch.load(
|
101 |
nn_z = MeasureZ(num_gauss=6)
|
102 |
-
nn_z.load_state_dict(torch.load(
|
103 |
|
104 |
-
|
105 |
|
106 |
|
107 |
-
z,
|
108 |
|
109 |
|
110 |
# Assign the DataFrame to a key in the dictionary
|
111 |
pz_dict[lab] = pz
|
112 |
|
113 |
|
114 |
-
#
|
115 |
cmap = plt.get_cmap('Dark2')
|
116 |
|
117 |
plt.plot(np.linspace(0,5,1000),pz_dict['z'][0],label='z', color = cmap(0), ls ='--')
|
@@ -124,5 +104,6 @@ plt.legend()
|
|
124 |
plt.xlabel(r'$z$', fontsize=14)
|
125 |
plt.ylabel('Probability', fontsize=14)
|
126 |
#plt.savefig(f'pz_{ii[0]}.pdf', bbox_inches='tight')
|
|
|
|
|
127 |
|
128 |
-
# %%
|
|
|
1 |
# ---
|
2 |
# jupyter:
|
3 |
# jupytext:
|
|
|
4 |
# text_representation:
|
5 |
# extension: .py
|
6 |
+
# format_name: light
|
7 |
+
# format_version: '1.5'
|
8 |
+
# jupytext_version: 1.16.2
|
9 |
# kernelspec:
|
10 |
+
# display_name: temps
|
11 |
# language: python
|
12 |
+
# name: temps
|
13 |
# ---
|
14 |
|
15 |
+
# # $p(z)$ examples
|
|
|
16 |
|
|
|
17 |
# ## IMPACT OF TEMPS ON CONCRETE P(Z) EXAMPLES
|
18 |
|
|
|
19 |
# ### LOAD PYTHON MODULES
|
20 |
|
|
|
21 |
# %load_ext autoreload
|
22 |
# %autoreload 2
|
23 |
|
|
|
24 |
import pandas as pd
|
25 |
import numpy as np
|
26 |
import os
|
27 |
from astropy.io import fits
|
28 |
from astropy.table import Table
|
29 |
import torch
|
30 |
+
from pathlib import Path
|
31 |
|
|
|
32 |
#matplotlib settings
|
33 |
from matplotlib import rcParams
|
34 |
import matplotlib.pyplot as plt
|
35 |
rcParams["mathtext.fontset"] = "stix"
|
36 |
rcParams["font.family"] = "STIXGeneral"
|
37 |
|
38 |
+
from temps.archive import Archive
|
39 |
+
from temps.utils import nmad
|
40 |
+
from temps.temps_arch import EncoderPhotometry, MeasureZ
|
41 |
+
from temps.temps import TempsModule
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
|
|
|
44 |
# ### LOAD DATA
|
45 |
|
|
|
46 |
#define here the directory containing the photometric catalogues
|
47 |
+
parent_dir = Path('/data/astro/scratch/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5')
|
48 |
+
modules_dir = Path('../data/models/')
|
49 |
|
|
|
50 |
filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
|
51 |
+
path_file = parent_dir / filename_valid # Creating the path to the file
|
52 |
+
hdu_list = fits.open(path_file)
|
53 |
cat = Table(hdu_list[1].data).to_pandas()
|
54 |
cat = cat[cat['FLAG_PHOT']==0]
|
55 |
cat = cat[cat['mu_class_L07']==1]
|
|
|
57 |
cat = cat[cat['MAG_VIS']<25]
|
58 |
|
59 |
|
|
|
60 |
ztarget = [cat['z_spec_S15'].values[ii] if cat['z_spec_S15'].values[ii]> 0 else cat['photo_z_L15'].values[ii] for ii in range(len(cat))]
|
61 |
specz_or_photo = [0 if cat['z_spec_S15'].values[ii]> 0 else 1 for ii in range(len(cat))]
|
62 |
ID = cat['ID']
|
63 |
VISmag = cat['MAG_VIS']
|
64 |
zsflag = cat['reliable_S15']
|
65 |
|
66 |
+
photoz_archive = Archive(path = parent_dir,only_zspec=False)
|
|
|
67 |
f, ferr = photoz_archive._extract_fluxes(catalogue= cat)
|
68 |
col, colerr = photoz_archive._to_colors(f, ferr)
|
69 |
|
|
|
70 |
# ### LOAD TRAINED MODELS AND EVALUATE PDF OF RANDOM EXAMPLES
|
71 |
|
|
|
72 |
# The notebook 'Tutorial_temps' gives an example of how to train and save models.
|
73 |
|
|
|
74 |
# Initialize an empty dictionary to store DataFrames
|
75 |
ii = np.random.randint(0,len(col),1)
|
76 |
pz_dict = {}
|
77 |
for il, lab in enumerate(['z','L15','DA']):
|
78 |
|
79 |
nn_features = EncoderPhotometry()
|
80 |
+
nn_features.load_state_dict(torch.load(modules_dir / f'modelF_{lab}.pt',map_location=torch.device('cpu')))
|
81 |
nn_z = MeasureZ(num_gauss=6)
|
82 |
+
nn_z.load_state_dict(torch.load(modules_dir / f'modelZ_{lab}.pt',map_location=torch.device('cpu')))
|
83 |
|
84 |
+
temps_module = TempsModule(nn_features, nn_z)
|
85 |
|
86 |
|
87 |
+
z, pz, fodds = temps_module.get_pz(input_data=torch.Tensor(col[ii]),return_pz=True)
|
88 |
|
89 |
|
90 |
# Assign the DataFrame to a key in the dictionary
|
91 |
pz_dict[lab] = pz
|
92 |
|
93 |
|
94 |
+
# +
|
95 |
cmap = plt.get_cmap('Dark2')
|
96 |
|
97 |
plt.plot(np.linspace(0,5,1000),pz_dict['z'][0],label='z', color = cmap(0), ls ='--')
|
|
|
104 |
plt.xlabel(r'$z$', fontsize=14)
|
105 |
plt.ylabel('Probability', fontsize=14)
|
106 |
#plt.savefig(f'pz_{ii[0]}.pdf', bbox_inches='tight')
|
107 |
+
# -
|
108 |
+
|
109 |
|
|
temps/archive.py
CHANGED
@@ -1,42 +1,62 @@
|
|
1 |
import numpy as np
|
2 |
import pandas as pd
|
3 |
from astropy.io import fits
|
4 |
-
import os
|
5 |
from astropy.table import Table
|
6 |
from scipy.spatial import KDTree
|
|
|
|
|
|
|
|
|
7 |
|
8 |
-
import matplotlib.pyplot as plt
|
9 |
|
10 |
-
from matplotlib import rcParams
|
11 |
rcParams["mathtext.fontset"] = "stix"
|
12 |
rcParams["font.family"] = "STIXGeneral"
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
-
class archive():
|
16 |
-
def __init__(self, path, aperture=2, drop_stars=True, clean_photometry=True, convert_colors=True, extinction_corr=True, only_zspec=True, target_test='specz', flags_kept=[3,3.1,3.4,3.5,4]):
|
17 |
|
|
|
18 |
self.aperture = aperture
|
19 |
-
self.
|
|
|
20 |
|
|
|
|
|
21 |
|
|
|
|
|
|
|
22 |
|
23 |
-
|
24 |
-
|
|
|
|
|
25 |
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
|
31 |
-
|
32 |
-
|
|
|
33 |
|
34 |
|
35 |
if drop_stars==True:
|
|
|
36 |
cat = cat[cat.mu_class_L07==1]
|
37 |
cat_test = cat_test[cat_test.mu_class_L07==1]
|
38 |
|
39 |
if clean_photometry==True:
|
|
|
40 |
cat = self._clean_photometry(cat)
|
41 |
cat_test = self._clean_photometry(cat_test)
|
42 |
|
@@ -55,6 +75,7 @@ class archive():
|
|
55 |
|
56 |
|
57 |
self._set_training_data(cat,
|
|
|
58 |
only_zspec=only_zspec,
|
59 |
extinction_corr=extinction_corr,
|
60 |
convert_colors=convert_colors)
|
@@ -65,17 +86,51 @@ class archive():
|
|
65 |
|
66 |
|
67 |
def _extract_fluxes(self,catalogue):
|
68 |
-
|
69 |
-
|
|
|
|
|
|
|
|
|
70 |
|
71 |
f = catalogue[columns_f].values
|
72 |
ferr = catalogue[columns_ferr].values
|
73 |
return f, ferr
|
74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
def _to_colors(self, flux, fluxerr):
|
76 |
""" Convert fluxes to colors"""
|
77 |
-
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
return color,color_err
|
80 |
|
81 |
def _set_combiend_target(self, catalogue):
|
@@ -92,13 +147,20 @@ class archive():
|
|
92 |
|
93 |
return catalogue
|
94 |
|
95 |
-
def _correct_extinction(self,catalogue, f):
|
96 |
"""Corrects for extinction"""
|
97 |
ext_correction_cols = [f'EB_V_corr_FLUX_{x}' for x in ['G','R','I','Z','Y','J','H']]
|
98 |
-
|
|
|
|
|
|
|
|
|
99 |
|
100 |
f = f * ext_correction
|
101 |
-
|
|
|
|
|
|
|
102 |
|
103 |
def _select_only_zspec(self,catalogue,cat_flag=None):
|
104 |
"""Selects only galaxies with spectroscopic redshift"""
|
@@ -158,22 +220,24 @@ class archive():
|
|
158 |
return catalogue_valid
|
159 |
|
160 |
|
161 |
-
def _set_training_data(self,catalogue, only_zspec=True, extinction_corr=True, convert_colors=True):
|
162 |
|
163 |
-
cat_da = self._exclude_only_zspec(
|
164 |
target_z_train_DA = cat_da['photo_z_L15'].values
|
165 |
|
166 |
|
167 |
if only_zspec:
|
|
|
168 |
catalogue = self._select_only_zspec(catalogue, cat_flag='Calib')
|
169 |
catalogue = self._clean_zspec_sample(catalogue, flags_kept=self.flags_kept)
|
170 |
else:
|
|
|
171 |
catalogue = self._take_zspec_and_photoz(catalogue, cat_flag='Calib')
|
172 |
|
173 |
|
174 |
self.cat_train=catalogue
|
175 |
f, ferr = self._extract_fluxes(catalogue)
|
176 |
-
|
177 |
f_DA, ferr_DA = self._extract_fluxes(cat_da)
|
178 |
idx = np.random.randint(0, len(f_DA), len(f))
|
179 |
f_DA, ferr_DA = f_DA[idx], ferr_DA[idx]
|
@@ -182,9 +246,11 @@ class archive():
|
|
182 |
|
183 |
|
184 |
if extinction_corr==True:
|
|
|
185 |
f = self._correct_extinction(catalogue,f)
|
186 |
-
|
187 |
if convert_colors==True:
|
|
|
188 |
col, colerr = self._to_colors(f, ferr)
|
189 |
col_DA, colerr_DA = self._to_colors(f_DA, ferr_DA)
|
190 |
|
|
|
1 |
import numpy as np
|
2 |
import pandas as pd
|
3 |
from astropy.io import fits
|
|
|
4 |
from astropy.table import Table
|
5 |
from scipy.spatial import KDTree
|
6 |
+
from matplotlib import pyplot as plt
|
7 |
+
from matplotlib import rcParams
|
8 |
+
from pathlib import Path
|
9 |
+
from loguru import logger
|
10 |
|
|
|
11 |
|
|
|
12 |
rcParams["mathtext.fontset"] = "stix"
|
13 |
rcParams["font.family"] = "STIXGeneral"
|
14 |
|
15 |
+
class Archive:
|
16 |
+
def __init__(self, path,
|
17 |
+
aperture=2,
|
18 |
+
drop_stars=True,
|
19 |
+
clean_photometry=True,
|
20 |
+
convert_colors=True,
|
21 |
+
extinction_corr=True,
|
22 |
+
only_zspec=True,
|
23 |
+
all_apertures=False,
|
24 |
+
target_test='specz', flags_kept=[3, 3.1, 3.4, 3.5, 4]):
|
25 |
|
|
|
|
|
26 |
|
27 |
+
logger.info("Starting archive")
|
28 |
self.aperture = aperture
|
29 |
+
self.all_apertures = all_apertures
|
30 |
+
self.flags_kept = flags_kept
|
31 |
|
32 |
+
filename_calib = 'euclid_cosmos_DC2_S1_v2.1_calib_clean.fits'
|
33 |
+
filename_valid = 'euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
|
34 |
|
35 |
+
# Use Path for file handling
|
36 |
+
path_calib = Path(path) / filename_calib
|
37 |
+
path_valid = Path(path) / filename_valid
|
38 |
|
39 |
+
# Open the calibration FITS file
|
40 |
+
with fits.open(path_calib) as hdu_list:
|
41 |
+
cat = Table(hdu_list[1].data).to_pandas()
|
42 |
+
cat = cat[(cat['z_spec_S15'] > 0) | (cat['photo_z_L15'] > 0)]
|
43 |
|
44 |
+
# Open the validation FITS file
|
45 |
+
with fits.open(path_valid) as hdu_list:
|
46 |
+
cat_test = Table(hdu_list[1].data).to_pandas()
|
|
|
47 |
|
48 |
+
# Store the catalogs for later use
|
49 |
+
self.cat = cat
|
50 |
+
self.cat_test = cat_test
|
51 |
|
52 |
|
53 |
if drop_stars==True:
|
54 |
+
logger.info("dropping stars...")
|
55 |
cat = cat[cat.mu_class_L07==1]
|
56 |
cat_test = cat_test[cat_test.mu_class_L07==1]
|
57 |
|
58 |
if clean_photometry==True:
|
59 |
+
logger.info("cleaning stars...")
|
60 |
cat = self._clean_photometry(cat)
|
61 |
cat_test = self._clean_photometry(cat_test)
|
62 |
|
|
|
75 |
|
76 |
|
77 |
self._set_training_data(cat,
|
78 |
+
cat_test,
|
79 |
only_zspec=only_zspec,
|
80 |
extinction_corr=extinction_corr,
|
81 |
convert_colors=convert_colors)
|
|
|
86 |
|
87 |
|
88 |
def _extract_fluxes(self,catalogue):
|
89 |
+
if self.all_apertures:
|
90 |
+
columns_f = [f'FLUX_{x}_{a}' for a in [1,2,3] for x in ['G','R','I','Z','Y','J','H']]
|
91 |
+
columns_ferr = [f'FLUXERR_{x}_{a}' for a in [1,2,3] for x in ['G','R','I','Z','Y','J','H'] ]
|
92 |
+
else:
|
93 |
+
columns_f = [f'FLUX_{x}_{self.aperture}' for x in ['G','R','I','Z','Y','J','H']]
|
94 |
+
columns_ferr = [f'FLUXERR_{x}_{self.aperture}' for x in ['G','R','I','Z','Y','J','H']]
|
95 |
|
96 |
f = catalogue[columns_f].values
|
97 |
ferr = catalogue[columns_ferr].values
|
98 |
return f, ferr
|
99 |
|
100 |
+
def _extract_magnitudes(self,catalogue):
|
101 |
+
if self.all_apertures:
|
102 |
+
columns_m = [f'MAG_{x}_{a}' for a in [1,2,3] for x in ['G','R','I','Z','Y','J','H']]
|
103 |
+
columns_merr = [f'MAGERR_{x}_{a}' for a in [1,2,3] for x in ['G','R','I','Z','Y','J','H'] ]
|
104 |
+
else:
|
105 |
+
columns_m = [f'MAG_{x}_{self.aperture}' for x in ['G','R','I','Z','Y','J','H']]
|
106 |
+
columns_merr = [f'MAGERR_{x}_{self.aperture}' for x in ['G','R','I','Z','Y','J','H']]
|
107 |
+
|
108 |
+
m = catalogue[columns_m].values
|
109 |
+
merr = catalogue[columns_merr].values
|
110 |
+
return m, merr
|
111 |
+
|
112 |
def _to_colors(self, flux, fluxerr):
|
113 |
""" Convert fluxes to colors"""
|
114 |
+
|
115 |
+
if self.all_apertures:
|
116 |
+
|
117 |
+
for a in range(3):
|
118 |
+
lim1 = 7*a
|
119 |
+
lim2 = 7*(a+1)
|
120 |
+
c = flux[:,lim1:(lim2-1)] / flux[:,(lim1+1):lim2]
|
121 |
+
cerr = np.sqrt((fluxerr[:,lim1:(lim2-1)]/ flux[:,(lim1+1):lim2])**2 + (flux[:,lim1:(lim2-1)] / flux[:,(lim1+1):lim2]**2)**2 * fluxerr[:,(lim1+1):lim2]**2)
|
122 |
+
|
123 |
+
if a==0:
|
124 |
+
color = c
|
125 |
+
color_err = cerr
|
126 |
+
else:
|
127 |
+
color = np.concatenate((color,c),axis=1)
|
128 |
+
color_err = np.concatenate((color_err,cerr),axis=1)
|
129 |
+
|
130 |
+
else:
|
131 |
+
color = flux[:,:-1] / flux[:,1:]
|
132 |
+
|
133 |
+
color_err = np.sqrt((fluxerr[:,:-1]/ flux[:,1:])**2 + (flux[:,:-1] / flux[:,1:]**2)**2 * fluxerr[:,1:]**2)
|
134 |
return color,color_err
|
135 |
|
136 |
def _set_combiend_target(self, catalogue):
|
|
|
147 |
|
148 |
return catalogue
|
149 |
|
150 |
+
def _correct_extinction(self,catalogue, f, return_ext_corr=False):
|
151 |
"""Corrects for extinction"""
|
152 |
ext_correction_cols = [f'EB_V_corr_FLUX_{x}' for x in ['G','R','I','Z','Y','J','H']]
|
153 |
+
if self.all_apertures:
|
154 |
+
ext_correction = catalogue[ext_correction_cols].values
|
155 |
+
ext_correction = np.concatenate((ext_correction,ext_correction,ext_correction),axis=1)
|
156 |
+
else:
|
157 |
+
ext_correction = catalogue[ext_correction_cols].values
|
158 |
|
159 |
f = f * ext_correction
|
160 |
+
if return_ext_corr:
|
161 |
+
return f, ext_correction
|
162 |
+
else:
|
163 |
+
return f
|
164 |
|
165 |
def _select_only_zspec(self,catalogue,cat_flag=None):
|
166 |
"""Selects only galaxies with spectroscopic redshift"""
|
|
|
220 |
return catalogue_valid
|
221 |
|
222 |
|
223 |
+
def _set_training_data(self,catalogue, catalogue_da, only_zspec=True, extinction_corr=True, convert_colors=True):
|
224 |
|
225 |
+
cat_da = self._exclude_only_zspec(catalogue_da)
|
226 |
target_z_train_DA = cat_da['photo_z_L15'].values
|
227 |
|
228 |
|
229 |
if only_zspec:
|
230 |
+
logger.info("Selecting only galaxies with spectroscopic redshift")
|
231 |
catalogue = self._select_only_zspec(catalogue, cat_flag='Calib')
|
232 |
catalogue = self._clean_zspec_sample(catalogue, flags_kept=self.flags_kept)
|
233 |
else:
|
234 |
+
logger.info("Selecting galaxies with spectroscopic redshift and high-precision photo-z")
|
235 |
catalogue = self._take_zspec_and_photoz(catalogue, cat_flag='Calib')
|
236 |
|
237 |
|
238 |
self.cat_train=catalogue
|
239 |
f, ferr = self._extract_fluxes(catalogue)
|
240 |
+
|
241 |
f_DA, ferr_DA = self._extract_fluxes(cat_da)
|
242 |
idx = np.random.randint(0, len(f_DA), len(f))
|
243 |
f_DA, ferr_DA = f_DA[idx], ferr_DA[idx]
|
|
|
246 |
|
247 |
|
248 |
if extinction_corr==True:
|
249 |
+
logger.info("Correcting MW extinction")
|
250 |
f = self._correct_extinction(catalogue,f)
|
251 |
+
|
252 |
if convert_colors==True:
|
253 |
+
logger.info("Converting to colors")
|
254 |
col, colerr = self._to_colors(f, ferr)
|
255 |
col_DA, colerr_DA = self._to_colors(f_DA, ferr_DA)
|
256 |
|
temps/plots.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import numpy as np
|
2 |
import pandas as pd
|
3 |
import matplotlib.pyplot as plt
|
4 |
-
from utils import nmad
|
5 |
|
6 |
import numpy as np
|
7 |
import matplotlib.pyplot as plt
|
@@ -181,68 +181,7 @@ def plot_PIT(pit_list_1, pit_list_2 = None, pit_list_3=None, sample='specz', lab
|
|
181 |
# Show the plot
|
182 |
plt.show()
|
183 |
|
184 |
-
|
185 |
-
import numpy as np
|
186 |
-
import matplotlib.pyplot as plt
|
187 |
-
from scipy import stats
|
188 |
-
|
189 |
-
def plot_photoz(df_list, nbins, xvariable, metric, type_bin='bin',label_list=None, samp='zs', save=False):
|
190 |
-
#plot properties
|
191 |
-
plt.rcParams['font.family'] = 'serif'
|
192 |
-
plt.rcParams['font.size'] = 12
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
bin_edges = stats.mstats.mquantiles(df_list[0][xvariable].values, np.linspace(0.05, 1, nbins))
|
198 |
-
print(bin_edges)
|
199 |
-
cmap = plt.get_cmap('Dark2') # Choose a colormap for coloring lines
|
200 |
-
plt.figure(figsize=(6, 5))
|
201 |
-
ls = ['--',':','-']
|
202 |
-
|
203 |
-
for i, df in enumerate(df_list):
|
204 |
-
ydata, xlab = [], []
|
205 |
-
|
206 |
-
for k in range(len(bin_edges)-1):
|
207 |
-
edge_min = bin_edges[k]
|
208 |
-
edge_max = bin_edges[k+1]
|
209 |
-
|
210 |
-
mean_mag = (edge_max + edge_min) / 2
|
211 |
-
|
212 |
-
if type_bin == 'bin':
|
213 |
-
df_plot = df[(df[xvariable] > edge_min) & (df[xvariable] < edge_max)]
|
214 |
-
elif type_bin == 'cum':
|
215 |
-
df_plot = df[(df[xvariable] < edge_max)]
|
216 |
-
else:
|
217 |
-
raise ValueError("Only type_bin=='bin' for binned and 'cum' for cumulative are supported")
|
218 |
-
|
219 |
-
xlab.append(mean_mag)
|
220 |
-
if metric == 'sig68':
|
221 |
-
ydata.append(sigma68(df_plot.zwerr))
|
222 |
-
elif metric == 'bias':
|
223 |
-
ydata.append(np.mean(df_plot.zwerr))
|
224 |
-
elif metric == 'nmad':
|
225 |
-
ydata.append(nmad(df_plot.zwerr))
|
226 |
-
elif metric == 'outliers':
|
227 |
-
ydata.append(len(df_plot[np.abs(df_plot.zwerr) > 0.15]) / len(df_plot)*100)
|
228 |
-
|
229 |
-
print(ydata)
|
230 |
-
color = cmap(i) # Get a different color for each dataframe
|
231 |
-
plt.plot(xlab, ydata,marker='.', lw=1, label=f'{label_list[i]}', color=color, ls=ls[i])
|
232 |
-
|
233 |
-
if xvariable == 'VISmag':
|
234 |
-
xvariable_lab = 'VIS'
|
235 |
-
|
236 |
-
|
237 |
|
238 |
-
plt.ylabel(f'{metric} $[\\Delta z]$', fontsize=18)
|
239 |
-
plt.xlabel(f'{xvariable_lab}', fontsize=16)
|
240 |
-
plt.grid(False)
|
241 |
-
plt.legend()
|
242 |
-
|
243 |
-
if save==True:
|
244 |
-
plt.savefig(f'{metric}_{xvariable}_{samp}.pdf', dpi=300, bbox_inches='tight')
|
245 |
-
plt.show()
|
246 |
|
247 |
|
248 |
def plot_nz(df_list,
|
@@ -336,3 +275,43 @@ def plot_crps(crps_list_1, crps_list_2 = None, crps_list_3=None, labels=None, s
|
|
336 |
# Show the plot
|
337 |
plt.show()
|
338 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import numpy as np
|
2 |
import pandas as pd
|
3 |
import matplotlib.pyplot as plt
|
4 |
+
from temps.utils import nmad
|
5 |
|
6 |
import numpy as np
|
7 |
import matplotlib.pyplot as plt
|
|
|
181 |
# Show the plot
|
182 |
plt.show()
|
183 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
185 |
|
186 |
|
187 |
def plot_nz(df_list,
|
|
|
275 |
# Show the plot
|
276 |
plt.show()
|
277 |
|
278 |
+
|
279 |
+
|
280 |
+
def plot_nz(df, bins=np.arange(0,5,0.2)):
|
281 |
+
kwargs=dict( bins=bins,alpha=0.5)
|
282 |
+
plt.hist(df.zs.values, color='grey', ls='-' ,**kwargs)
|
283 |
+
counts, _, =np.histogram(df.z.values, bins=bins)
|
284 |
+
|
285 |
+
plt.plot((bins[:-1]+bins[1:])*0.5,counts, color ='purple')
|
286 |
+
|
287 |
+
#plt.legend(fontsize=14)
|
288 |
+
plt.xlabel(r'Redshift', fontsize=14)
|
289 |
+
plt.ylabel(r'Counts', fontsize=14)
|
290 |
+
plt.yscale('log')
|
291 |
+
|
292 |
+
plt.show()
|
293 |
+
|
294 |
+
return
|
295 |
+
|
296 |
+
|
297 |
+
def plot_scatter(df, sample='specz', save=True):
|
298 |
+
# Calculate the point density
|
299 |
+
xy = np.vstack([df.zs.values,df.z.values])
|
300 |
+
zd = gaussian_kde(xy)(xy)
|
301 |
+
|
302 |
+
fig, ax = plt.subplots()
|
303 |
+
plt.scatter(df.zs.values, df.z.values,c=zd, s=1)
|
304 |
+
plt.xlim(0,5)
|
305 |
+
plt.ylim(0,5)
|
306 |
+
|
307 |
+
plt.xlabel(r'$z_{\rm s}$', fontsize = 14)
|
308 |
+
plt.ylabel('$z$', fontsize = 14)
|
309 |
+
|
310 |
+
plt.xticks(fontsize = 12)
|
311 |
+
plt.yticks(fontsize = 12)
|
312 |
+
|
313 |
+
if save==True:
|
314 |
+
plt.savefig(f'{sample}_scatter.pdf', dpi = 300, bbox_inches='tight')
|
315 |
+
|
316 |
+
plt.show()
|
317 |
+
|
temps/temps.py
CHANGED
@@ -1,257 +1,267 @@
|
|
|
|
|
|
1 |
import torch
|
2 |
-
from torch.utils.data import DataLoader, dataset, TensorDataset
|
3 |
from torch import nn, optim
|
|
|
4 |
from torch.optim import lr_scheduler
|
5 |
-
import numpy as np
|
6 |
-
import pandas as pd
|
7 |
-
from astropy.io import fits
|
8 |
-
import os
|
9 |
-
from astropy.table import Table
|
10 |
-
from scipy.spatial import KDTree
|
11 |
-
from scipy.special import erf
|
12 |
from scipy.stats import norm
|
13 |
-
import
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
self
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
input_data = torch.Tensor(input_data)
|
36 |
target_data = torch.Tensor(target_data)
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
|
|
|
|
|
|
|
|
48 |
|
49 |
return loader_train, loader_val
|
50 |
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
log_prob =
|
57 |
-
log_prob = torch.logsumexp(log_prob, 1)
|
58 |
loss = -log_prob.mean()
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
kl_loss = nn.KLDivLoss(reduction="batchmean",log_target=True)
|
64 |
loss = kl_loss(f1, f2)
|
65 |
-
|
66 |
-
#print('f1',f1)
|
67 |
-
#print('f2',f2)
|
68 |
-
|
69 |
-
return loss
|
70 |
|
71 |
-
def _to_numpy(self,x):
|
|
|
72 |
return x.detach().cpu().numpy()
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
def train(self,input_data,
|
77 |
-
input_data_DA,
|
78 |
-
target_data,
|
79 |
-
target_data_DA,
|
80 |
-
nepochs=10,
|
81 |
-
step_size = 100,
|
82 |
-
val_fraction=0.1,
|
83 |
-
lr=1e-3,
|
84 |
-
weight_decay=0):
|
85 |
-
self.modelZ = self.modelZ.train()
|
86 |
-
self.modelF = self.modelF.train()
|
87 |
-
|
88 |
-
loader_train, loader_val = self._get_dataloaders(input_data, target_data, input_data_DA, target_data_DA, val_fraction=0.1)
|
89 |
-
optimizerZ = optim.Adam(self.modelZ.parameters(), lr=lr, weight_decay=weight_decay)
|
90 |
-
optimizerF = optim.Adam(self.modelF.parameters(), lr=lr, weight_decay=weight_decay)
|
91 |
-
|
92 |
-
schedulerZ = torch.optim.lr_scheduler.StepLR(optimizerZ, step_size=step_size, gamma =0.1)
|
93 |
-
schedulerF = torch.optim.lr_scheduler.StepLR(optimizerF, step_size=step_size, gamma =0.1)
|
94 |
-
|
95 |
-
self.modelZ = self.modelZ.to(self.device)
|
96 |
-
self.modelF = self.modelF.to(self.device)
|
97 |
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
-
|
105 |
-
|
106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
if self.da:
|
108 |
input_data_da = input_data_da.to(self.device)
|
109 |
-
target_data_DA = target_data_DA.to(self.device)
|
110 |
|
111 |
-
|
112 |
-
|
113 |
|
114 |
-
features = self.
|
115 |
-
if self.da
|
116 |
-
features_DA = self.modelF(input_data_da)
|
117 |
|
118 |
-
mu, logsig, logmix_coeff = self.
|
119 |
-
logsig = torch.clamp(logsig
|
120 |
sig = torch.exp(logsig)
|
121 |
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
loss = lossZ +1e3*lossDA
|
134 |
-
else:
|
135 |
-
loss = lossZ
|
136 |
-
|
137 |
-
_loss_train.append(lossZ.item())
|
138 |
-
|
139 |
loss.backward()
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
|
|
|
|
|
|
|
|
|
|
147 |
|
148 |
-
|
|
|
|
|
|
|
|
|
149 |
|
|
|
|
|
|
|
|
|
150 |
input_data = input_data.to(self.device)
|
151 |
target_data = target_data.to(self.device)
|
152 |
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
logsig = torch.clamp(logsig,-6,2)
|
158 |
sig = torch.exp(logsig)
|
159 |
|
160 |
loss_val = self._loss_function(mu, sig, logmix_coeff, target_data)
|
161 |
_loss_validation.append(loss_val.item())
|
162 |
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
if self.verbose:
|
167 |
-
|
168 |
-
print(f'training_loss:{loss}',f'testing_loss:{loss_val}')
|
169 |
-
|
170 |
|
171 |
def get_features(self, input_data):
|
172 |
-
|
173 |
-
self.
|
174 |
-
|
175 |
input_data = input_data.to(self.device)
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
return features.detach().cpu().numpy()
|
180 |
-
|
181 |
|
182 |
-
def get_pz(self,input_data, return_pz=True, return_flag=True,
|
183 |
-
|
184 |
-
|
185 |
-
self.
|
186 |
-
self.
|
187 |
|
188 |
input_data = input_data.to(self.device)
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
mu, logsig, logmix_coeff = self.modelZ(features)
|
193 |
-
logsig = torch.clamp(logsig,-6,2)
|
194 |
sig = torch.exp(logsig)
|
195 |
|
196 |
mix_coeff = torch.exp(logmix_coeff)
|
|
|
|
|
|
|
|
|
|
|
197 |
|
198 |
-
|
199 |
-
zerr = torch.sqrt( (mix_coeff * sig**2).sum(1) + (mix_coeff * (mu - mu.mean(1)[:,None])**2).sum(1))
|
200 |
-
|
201 |
-
mu, mix_coeff, sig = mu.detach().cpu().numpy(), mix_coeff.detach().cpu().numpy(), sig.detach().cpu().numpy()
|
202 |
-
|
203 |
-
|
204 |
-
if return_pz==True:
|
205 |
-
zgrid = np.linspace(0, 5, 1000)
|
206 |
-
pdf_mixture = np.zeros(shape=(len(input_data), len(zgrid)))
|
207 |
-
for ii in range(len(input_data)):
|
208 |
-
for i in range(self.ngaussians):
|
209 |
-
pdf_mixture[ii] += mix_coeff[ii,i] * norm.pdf(zgrid, mu[ii,i], sig[ii,i])
|
210 |
-
if return_flag==True:
|
211 |
-
#narrow peak
|
212 |
-
pdf_mixture = pdf_mixture / pdf_mixture.sum(1)[:,None]
|
213 |
-
diff_matrix = np.abs(self._to_numpy(z)[:,None] - zgrid[None,:])
|
214 |
-
#odds
|
215 |
-
idx_peak = np.argmax(pdf_mixture,1)
|
216 |
-
zpeak = zgrid[idx_peak]
|
217 |
-
diff_matrix_upper = np.abs((zpeak+0.05)[:,None] - zgrid[None,:])
|
218 |
-
diff_matrix_lower = np.abs((zpeak-0.05)[:,None] - zgrid[None,:])
|
219 |
-
|
220 |
-
idx = np.argmin(diff_matrix,1)
|
221 |
-
idx_upper = np.argmin(diff_matrix_upper,1)
|
222 |
-
idx_lower = np.argmin(diff_matrix_lower,1)
|
223 |
-
|
224 |
-
p_z_x = np.zeros(shape=(len(z)))
|
225 |
-
odds = np.zeros(shape=(len(z)))
|
226 |
-
|
227 |
-
for ii in range(len(z)):
|
228 |
-
p_z_x[ii] = pdf_mixture[ii,idx[ii]]
|
229 |
-
odds[ii] = pdf_mixture[ii,:idx_upper[ii]].sum() - pdf_mixture[ii,:idx_lower[ii]].sum()
|
230 |
-
|
231 |
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
return self._to_numpy(z),self._to_numpy(zerr), pdf_mixture
|
237 |
-
|
238 |
else:
|
239 |
-
return self._to_numpy(z),self._to_numpy(zerr)
|
240 |
-
|
241 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
242 |
|
243 |
pit_list = []
|
244 |
|
245 |
-
self.
|
246 |
-
self.
|
247 |
-
self.
|
248 |
-
self.
|
249 |
|
250 |
input_data = input_data.to(self.device)
|
251 |
|
252 |
|
253 |
-
features = self.
|
254 |
-
mu, logsig, logmix_coeff = self.
|
255 |
|
256 |
logsig = torch.clamp(logsig,-6,2)
|
257 |
sig = torch.exp(logsig)
|
@@ -267,7 +277,8 @@ class Temps_module():
|
|
267 |
|
268 |
return pit_list
|
269 |
|
270 |
-
def
|
|
|
271 |
|
272 |
def measure_crps(cdf, t):
|
273 |
zgrid = np.linspace(0,4,1000)
|
@@ -281,16 +292,16 @@ class Temps_module():
|
|
281 |
|
282 |
crps_list = []
|
283 |
|
284 |
-
self.
|
285 |
-
self.
|
286 |
-
self.
|
287 |
-
self.
|
288 |
|
289 |
input_data = input_data.to(self.device)
|
290 |
|
291 |
|
292 |
-
features = self.
|
293 |
-
mu, logsig, logmix_coeff = self.
|
294 |
logsig = torch.clamp(logsig,-6,2)
|
295 |
sig = torch.exp(logsig)
|
296 |
|
@@ -302,21 +313,19 @@ class Temps_module():
|
|
302 |
z = (mix_coeff * mu).sum(1)
|
303 |
|
304 |
x = np.linspace(0, 4, 1000)
|
305 |
-
|
306 |
for ii in range(len(input_data)):
|
307 |
for i in range(6):
|
308 |
-
|
309 |
|
310 |
-
|
311 |
|
312 |
|
313 |
-
|
314 |
|
315 |
-
crps_value = measure_crps(
|
316 |
|
317 |
|
318 |
|
319 |
return crps_value
|
320 |
|
321 |
-
|
322 |
-
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pandas as pd
|
3 |
import torch
|
|
|
4 |
from torch import nn, optim
|
5 |
+
from torch.utils.data import DataLoader, TensorDataset
|
6 |
from torch.optim import lr_scheduler
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
from scipy.stats import norm
|
8 |
+
from loguru import logger
|
9 |
+
from tqdm import tqdm # Import tqdm for progress bars
|
10 |
+
|
11 |
+
# Local imports
|
12 |
+
from temps.utils import maximum_mean_discrepancy
|
13 |
+
|
14 |
+
|
15 |
+
class TempsModule:
|
16 |
+
"""Class for managing temperature-related models and training."""
|
17 |
+
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
model_f,
|
21 |
+
model_z,
|
22 |
+
batch_size=100,
|
23 |
+
rejection_param=1,
|
24 |
+
da=True,
|
25 |
+
verbose=False,
|
26 |
+
):
|
27 |
+
self.model_z = model_z
|
28 |
+
self.model_f = model_f
|
29 |
+
self.da = da
|
30 |
+
self.verbose = verbose
|
31 |
+
self.ngaussians = model_z.ngaussians
|
32 |
+
|
33 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
34 |
+
self.batch_size = batch_size
|
35 |
+
self.rejection_parameter = rejection_param
|
36 |
+
|
37 |
+
def _get_dataloaders(
|
38 |
+
self, input_data, target_data, input_data_da=None, val_fraction=0.1
|
39 |
+
):
|
40 |
+
"""Create training and validation dataloaders."""
|
41 |
input_data = torch.Tensor(input_data)
|
42 |
target_data = torch.Tensor(target_data)
|
43 |
+
input_data_da = (
|
44 |
+
torch.Tensor(input_data_da)
|
45 |
+
if input_data_da is not None
|
46 |
+
else input_data.clone()
|
47 |
+
)
|
48 |
+
|
49 |
+
dataset = TensorDataset(input_data, input_data_da, target_data)
|
50 |
+
train_dataset, val_dataset = torch.utils.data.random_split(
|
51 |
+
dataset,
|
52 |
+
[int(len(dataset) * (1 - val_fraction)), int(len(dataset) * val_fraction)],
|
53 |
+
)
|
54 |
+
loader_train = DataLoader(
|
55 |
+
train_dataset, batch_size=self.batch_size, shuffle=True
|
56 |
+
)
|
57 |
+
loader_val = DataLoader(val_dataset, batch_size=64, shuffle=True)
|
58 |
|
59 |
return loader_train, loader_val
|
60 |
|
61 |
+
def _loss_function(self, mean, std, logmix, true):
|
62 |
+
"""Compute the loss function."""
|
63 |
+
log_prob = (
|
64 |
+
logmix - 0.5 * (mean - true[:, None]).pow(2) / std.pow(2) - torch.log(std)
|
65 |
+
)
|
66 |
+
log_prob = torch.logsumexp(log_prob, dim=1)
|
|
|
67 |
loss = -log_prob.mean()
|
68 |
+
return loss
|
69 |
+
|
70 |
+
def _loss_function_da(self, f1, f2):
|
71 |
+
"""Compute the KL divergence loss for domain adaptation."""
|
72 |
+
kl_loss = nn.KLDivLoss(reduction="batchmean", log_target=True)
|
73 |
loss = kl_loss(f1, f2)
|
74 |
+
return torch.log(loss)
|
|
|
|
|
|
|
|
|
75 |
|
76 |
+
def _to_numpy(self, x):
|
77 |
+
"""Convert a tensor to a NumPy array."""
|
78 |
return x.detach().cpu().numpy()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
+
def train(
|
81 |
+
self,
|
82 |
+
input_data,
|
83 |
+
input_data_da,
|
84 |
+
target_data,
|
85 |
+
nepochs=10,
|
86 |
+
step_size=100,
|
87 |
+
val_fraction=0.1,
|
88 |
+
lr=1e-3,
|
89 |
+
weight_decay=0,
|
90 |
+
):
|
91 |
+
"""Train the models using provided data."""
|
92 |
+
self.model_z.train()
|
93 |
+
self.model_f.train()
|
94 |
+
|
95 |
+
loader_train, loader_val = self._get_dataloaders(
|
96 |
+
input_data, target_data, input_data_da, val_fraction
|
97 |
+
)
|
98 |
+
optimizer_z = optim.Adam(
|
99 |
+
self.model_z.parameters(), lr=lr, weight_decay=weight_decay
|
100 |
+
)
|
101 |
+
optimizer_f = optim.Adam(
|
102 |
+
self.model_f.parameters(), lr=lr, weight_decay=weight_decay
|
103 |
+
)
|
104 |
+
|
105 |
+
scheduler_z = lr_scheduler.StepLR(optimizer_z, step_size=step_size, gamma=0.1)
|
106 |
+
scheduler_f = lr_scheduler.StepLR(optimizer_f, step_size=step_size, gamma=0.1)
|
107 |
+
|
108 |
+
self.model_z.to(self.device)
|
109 |
+
self.model_f.to(self.device)
|
110 |
+
|
111 |
+
loss_train, loss_validation = [], []
|
112 |
|
113 |
+
for epoch in range(nepochs):
|
114 |
+
_loss_train, _loss_validation = [], []
|
115 |
+
logger.info(f"Epoch {epoch + 1}/{nepochs} starting...")
|
116 |
+
for input_data, input_data_da, target_data in tqdm(
|
117 |
+
loader_train, desc="Training", unit="batch"
|
118 |
+
):
|
119 |
+
input_data, target_data = input_data.to(self.device), target_data.to(
|
120 |
+
self.device
|
121 |
+
)
|
122 |
if self.da:
|
123 |
input_data_da = input_data_da.to(self.device)
|
|
|
124 |
|
125 |
+
optimizer_f.zero_grad()
|
126 |
+
optimizer_z.zero_grad()
|
127 |
|
128 |
+
features = self.model_f(input_data)
|
129 |
+
features_da = self.model_f(input_data_da) if self.da else None
|
|
|
130 |
|
131 |
+
mu, logsig, logmix_coeff = self.model_z(features)
|
132 |
+
logsig = torch.clamp(logsig, -6, 2)
|
133 |
sig = torch.exp(logsig)
|
134 |
|
135 |
+
loss_z = self._loss_function(mu, sig, logmix_coeff, target_data)
|
136 |
+
loss = loss_z + (
|
137 |
+
1e3
|
138 |
+
* maximum_mean_discrepancy(
|
139 |
+
features, features_da, kernel_type="rbf"
|
140 |
+
).sum()
|
141 |
+
if self.da
|
142 |
+
else 0
|
143 |
+
)
|
144 |
+
|
145 |
+
_loss_train.append(loss_z.item())
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
loss.backward()
|
147 |
+
optimizer_f.step()
|
148 |
+
optimizer_z.step()
|
149 |
+
|
150 |
+
scheduler_f.step()
|
151 |
+
scheduler_z.step()
|
152 |
+
|
153 |
+
loss_train.append(np.mean(_loss_train))
|
154 |
+
_loss_validation = self._validate(loader_val, target_data)
|
155 |
+
|
156 |
+
logger.info(
|
157 |
+
f"Epoch {epoch + 1}: Training Loss: {np.mean(_loss_train):.4f}, Validation Loss: {np.mean(_loss_validation):.4f}"
|
158 |
+
)
|
159 |
|
160 |
+
def _validate(self, loader_val, target_data):
|
161 |
+
"""Validate the model on the validation dataset."""
|
162 |
+
self.model_z.eval()
|
163 |
+
self.model_f.eval()
|
164 |
+
_loss_validation = []
|
165 |
|
166 |
+
with torch.no_grad():
|
167 |
+
for input_data, _, target_data in tqdm(
|
168 |
+
loader_val, desc="Validating", unit="batch"
|
169 |
+
):
|
170 |
input_data = input_data.to(self.device)
|
171 |
target_data = target_data.to(self.device)
|
172 |
|
173 |
+
features = self.model_f(input_data)
|
174 |
+
mu, logsig, logmix_coeff = self.model_z(features)
|
175 |
+
logsig = torch.clamp(logsig, -6, 2)
|
|
|
|
|
176 |
sig = torch.exp(logsig)
|
177 |
|
178 |
loss_val = self._loss_function(mu, sig, logmix_coeff, target_data)
|
179 |
_loss_validation.append(loss_val.item())
|
180 |
|
181 |
+
return _loss_validation
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
|
183 |
def get_features(self, input_data):
|
184 |
+
"""Get features from the model."""
|
185 |
+
self.model_f.eval()
|
|
|
186 |
input_data = input_data.to(self.device)
|
187 |
+
features = self.model_f(input_data)
|
188 |
+
return self._to_numpy(features)
|
|
|
|
|
|
|
189 |
|
190 |
+
def get_pz(self, input_data, return_pz=True, return_flag=True, return_odds=False):
|
191 |
+
"""Get the predicted z values and their uncertainties."""
|
192 |
+
logger.info("Predicting photo-z for the input galaxies...")
|
193 |
+
self.model_z.eval()
|
194 |
+
self.model_f.eval()
|
195 |
|
196 |
input_data = input_data.to(self.device)
|
197 |
+
features = self.model_f(input_data)
|
198 |
+
mu, logsig, logmix_coeff = self.model_z(features)
|
199 |
+
logsig = torch.clamp(logsig, -6, 2)
|
|
|
|
|
200 |
sig = torch.exp(logsig)
|
201 |
|
202 |
mix_coeff = torch.exp(logmix_coeff)
|
203 |
+
z = (mix_coeff * mu).sum(dim=1)
|
204 |
+
zerr = torch.sqrt(
|
205 |
+
(mix_coeff * sig**2).sum(dim=1)
|
206 |
+
+ (mix_coeff * (mu - mu.mean(dim=1, keepdim=True)) ** 2).sum(dim=1)
|
207 |
+
)
|
208 |
|
209 |
+
mu, mix_coeff, sig = map(self._to_numpy, (mu, mix_coeff, sig))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
210 |
|
211 |
+
if return_pz:
|
212 |
+
logger.info("Returning p(z)")
|
213 |
+
return self._calculate_pdf(z, mu, sig, mix_coeff, return_flag)
|
|
|
|
|
|
|
214 |
else:
|
215 |
+
return self._to_numpy(z), self._to_numpy(zerr)
|
216 |
+
|
217 |
+
def _calculate_pdf(self, z, mu, sig, mix_coeff, return_flag):
|
218 |
+
"""Calculate the probability density function."""
|
219 |
+
zgrid = np.linspace(0, 5, 1000)
|
220 |
+
pz = np.zeros((len(z), len(zgrid)))
|
221 |
+
|
222 |
+
for ii in range(len(z)):
|
223 |
+
for i in range(self.ngaussians):
|
224 |
+
pz[ii] += mix_coeff[ii, i] * norm.pdf(
|
225 |
+
zgrid, mu[ii, i], sig[ii, i]
|
226 |
+
)
|
227 |
+
|
228 |
+
if return_flag:
|
229 |
+
logger.info("Calculating and returning ODDS")
|
230 |
+
pz /= pz.sum(axis=1, keepdims=True)
|
231 |
+
return self._calculate_odds(z, pz, zgrid)
|
232 |
+
return self._to_numpy(z), pz
|
233 |
+
|
234 |
+
def _calculate_odds(self, z, pz, zgrid):
|
235 |
+
"""Calculate odds based on the PDF."""
|
236 |
+
logger.info('Calculating ODDS values')
|
237 |
+
diff_matrix = np.abs(self._to_numpy(z)[:, None] - zgrid[None, :])
|
238 |
+
idx_peak = np.argmax(pz, axis=1)
|
239 |
+
zpeak = zgrid[idx_peak]
|
240 |
+
idx_upper = np.argmin(np.abs((zpeak + 0.05)[:, None] - zgrid[None, :]), axis=1)
|
241 |
+
idx_lower = np.argmin(np.abs((zpeak - 0.05)[:, None] - zgrid[None, :]), axis=1)
|
242 |
+
|
243 |
+
odds = []
|
244 |
+
for jj in range(len(pz)):
|
245 |
+
odds.append(pz[jj,idx_lower[jj]:(idx_upper[jj]+1)].sum())
|
246 |
+
|
247 |
+
odds = np.array(odds)
|
248 |
+
return self._to_numpy(z), pz, odds
|
249 |
+
|
250 |
+
def calculate_pit(self, input_data, target_data):
|
251 |
+
logger.info('Calculating PIT values')
|
252 |
|
253 |
pit_list = []
|
254 |
|
255 |
+
self.model_f = self.model_f.eval()
|
256 |
+
self.model_f = self.model_f.to(self.device)
|
257 |
+
self.model_z = self.model_z.eval()
|
258 |
+
self.model_z = self.model_z.to(self.device)
|
259 |
|
260 |
input_data = input_data.to(self.device)
|
261 |
|
262 |
|
263 |
+
features = self.model_f(input_data)
|
264 |
+
mu, logsig, logmix_coeff = self.model_z(features)
|
265 |
|
266 |
logsig = torch.clamp(logsig,-6,2)
|
267 |
sig = torch.exp(logsig)
|
|
|
277 |
|
278 |
return pit_list
|
279 |
|
280 |
+
def calculate_crps(self, input_data, target_data):
|
281 |
+
logger.info('Calculating CRPS values')
|
282 |
|
283 |
def measure_crps(cdf, t):
|
284 |
zgrid = np.linspace(0,4,1000)
|
|
|
292 |
|
293 |
crps_list = []
|
294 |
|
295 |
+
self.model_f = self.model_f.eval()
|
296 |
+
self.model_f = self.model_f.to(self.device)
|
297 |
+
self.model_z = self.model_z.eval()
|
298 |
+
self.model_z = self.model_z.to(self.device)
|
299 |
|
300 |
input_data = input_data.to(self.device)
|
301 |
|
302 |
|
303 |
+
features = self.model_f(input_data)
|
304 |
+
mu, logsig, logmix_coeff = self.model_z(features)
|
305 |
logsig = torch.clamp(logsig,-6,2)
|
306 |
sig = torch.exp(logsig)
|
307 |
|
|
|
313 |
z = (mix_coeff * mu).sum(1)
|
314 |
|
315 |
x = np.linspace(0, 4, 1000)
|
316 |
+
pz = np.zeros(shape=(len(target_data), len(x)))
|
317 |
for ii in range(len(input_data)):
|
318 |
for i in range(6):
|
319 |
+
pz[ii] += mix_coeff[ii,i] * norm.pdf(x, mu[ii,i], sig[ii,i])
|
320 |
|
321 |
+
pz = pz / pz.sum(1)[:,None]
|
322 |
|
323 |
|
324 |
+
cdf_z = np.cumsum(pz,1)
|
325 |
|
326 |
+
crps_value = measure_crps(cdf_z, target_data)
|
327 |
|
328 |
|
329 |
|
330 |
return crps_value
|
331 |
|
|
|
|
temps/temps_arch.py
CHANGED
@@ -20,52 +20,46 @@ class EncoderPhotometry(nn.Module):
|
|
20 |
nn.Linear(50, 20),
|
21 |
nn.Dropout(dropout_prob),
|
22 |
nn.ReLU(),
|
23 |
-
nn.Linear(20, 10)
|
24 |
)
|
25 |
-
|
26 |
def forward(self, x):
|
27 |
f = self.features(x)
|
28 |
-
f =
|
29 |
return f
|
30 |
|
31 |
-
|
32 |
|
33 |
class MeasureZ(nn.Module):
|
34 |
def __init__(self, num_gauss=10, dropout_prob=0):
|
35 |
super(MeasureZ, self).__init__()
|
36 |
-
|
37 |
-
self.ngaussians=num_gauss
|
38 |
self.measure_mu = nn.Sequential(
|
39 |
nn.Linear(10, 20),
|
40 |
nn.Dropout(dropout_prob),
|
41 |
nn.ReLU(),
|
42 |
-
nn.Linear(20, num_gauss)
|
43 |
)
|
44 |
|
45 |
self.measure_coeffs = nn.Sequential(
|
46 |
nn.Linear(10, 20),
|
47 |
nn.Dropout(dropout_prob),
|
48 |
nn.ReLU(),
|
49 |
-
nn.Linear(20, num_gauss)
|
50 |
)
|
51 |
|
52 |
self.measure_sigma = nn.Sequential(
|
53 |
nn.Linear(10, 20),
|
54 |
nn.Dropout(dropout_prob),
|
55 |
nn.ReLU(),
|
56 |
-
nn.Linear(20, num_gauss)
|
57 |
)
|
58 |
-
|
59 |
-
|
60 |
def forward(self, f):
|
61 |
mu = self.measure_mu(f)
|
62 |
sigma = self.measure_sigma(f)
|
63 |
logmix_coeff = self.measure_coeffs(f)
|
64 |
-
|
65 |
-
logmix_coeff = logmix_coeff - torch.logsumexp(logmix_coeff, 1)[:,None]
|
66 |
-
|
67 |
-
return mu, sigma, logmix_coeff
|
68 |
|
69 |
-
|
70 |
-
|
71 |
|
|
|
|
20 |
nn.Linear(50, 20),
|
21 |
nn.Dropout(dropout_prob),
|
22 |
nn.ReLU(),
|
23 |
+
nn.Linear(20, 10),
|
24 |
)
|
25 |
+
|
26 |
def forward(self, x):
|
27 |
f = self.features(x)
|
28 |
+
f = F.log_softmax(f, dim=1)
|
29 |
return f
|
30 |
|
|
|
31 |
|
32 |
class MeasureZ(nn.Module):
|
33 |
def __init__(self, num_gauss=10, dropout_prob=0):
|
34 |
super(MeasureZ, self).__init__()
|
35 |
+
|
36 |
+
self.ngaussians = num_gauss
|
37 |
self.measure_mu = nn.Sequential(
|
38 |
nn.Linear(10, 20),
|
39 |
nn.Dropout(dropout_prob),
|
40 |
nn.ReLU(),
|
41 |
+
nn.Linear(20, num_gauss),
|
42 |
)
|
43 |
|
44 |
self.measure_coeffs = nn.Sequential(
|
45 |
nn.Linear(10, 20),
|
46 |
nn.Dropout(dropout_prob),
|
47 |
nn.ReLU(),
|
48 |
+
nn.Linear(20, num_gauss),
|
49 |
)
|
50 |
|
51 |
self.measure_sigma = nn.Sequential(
|
52 |
nn.Linear(10, 20),
|
53 |
nn.Dropout(dropout_prob),
|
54 |
nn.ReLU(),
|
55 |
+
nn.Linear(20, num_gauss),
|
56 |
)
|
57 |
+
|
|
|
58 |
def forward(self, f):
|
59 |
mu = self.measure_mu(f)
|
60 |
sigma = self.measure_sigma(f)
|
61 |
logmix_coeff = self.measure_coeffs(f)
|
|
|
|
|
|
|
|
|
62 |
|
63 |
+
logmix_coeff = logmix_coeff - torch.logsumexp(logmix_coeff, 1)[:, None]
|
|
|
64 |
|
65 |
+
return mu, sigma, logmix_coeff
|
temps/utils.py
CHANGED
@@ -3,113 +3,22 @@ import pandas as pd
|
|
3 |
import matplotlib.pyplot as plt
|
4 |
from scipy import stats
|
5 |
import torch
|
6 |
-
from
|
7 |
|
8 |
-
def nmad(data):
|
9 |
-
return 1.4826 * np.median(np.abs(data - np.median(data)))
|
10 |
-
|
11 |
-
def sigma68(data): return 0.5*(pd.Series(data).quantile(q = 0.84) - pd.Series(data).quantile(q = 0.16))
|
12 |
-
|
13 |
-
def plot_photoz(df_list, nbins, xvariable, metric, type_bin='bin',label_list=None, samp='zs', save=False):
|
14 |
-
#plot properties
|
15 |
-
plt.rcParams['font.family'] = 'serif'
|
16 |
-
plt.rcParams['font.size'] = 12
|
17 |
-
|
18 |
|
|
|
|
|
19 |
|
20 |
-
|
21 |
-
bin_edges = stats.mstats.mquantiles(df_list[0][xvariable].values, np.linspace(0.05, 1, nbins))
|
22 |
-
print(bin_edges)
|
23 |
-
cmap = plt.get_cmap('Dark2') # Choose a colormap for coloring lines
|
24 |
-
plt.figure(figsize=(6, 5))
|
25 |
-
|
26 |
-
for i, df in enumerate(df_list):
|
27 |
-
ydata, xlab = [], []
|
28 |
-
|
29 |
-
for k in range(len(bin_edges)-1):
|
30 |
-
edge_min = bin_edges[k]
|
31 |
-
edge_max = bin_edges[k+1]
|
32 |
-
|
33 |
-
mean_mag = (edge_max + edge_min) / 2
|
34 |
-
|
35 |
-
if type_bin == 'bin':
|
36 |
-
df_plot = df[(df[xvariable] > edge_min) & (df[xvariable] < edge_max)]
|
37 |
-
elif type_bin == 'cum':
|
38 |
-
df_plot = df[(df[xvariable] < edge_max)]
|
39 |
-
else:
|
40 |
-
raise ValueError("Only type_bin=='bin' for binned and 'cum' for cumulative are supported")
|
41 |
-
|
42 |
-
xlab.append(mean_mag)
|
43 |
-
if metric == 'sig68':
|
44 |
-
ydata.append(sigma68(df_plot.zwerr))
|
45 |
-
elif metric == 'bias':
|
46 |
-
ydata.append(np.mean(df_plot.zwerr))
|
47 |
-
elif metric == 'nmad':
|
48 |
-
ydata.append(nmad(df_plot.zwerr))
|
49 |
-
elif metric == 'outliers':
|
50 |
-
ydata.append(len(df_plot[np.abs(df_plot.zwerr) > 0.15]) / len(df_plot)*100)
|
51 |
-
|
52 |
-
print(ydata)
|
53 |
-
color = cmap(i) # Get a different color for each dataframe
|
54 |
-
plt.plot(xlab, ydata, ls='-', marker='.', lw=1, label=f'{label_list[i]}', color=color)
|
55 |
-
|
56 |
-
if xvariable == 'VISmag':
|
57 |
-
xvariable_lab = 'VIS'
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
plt.ylabel(f'{metric} $[\\Delta z]$', fontsize=18)
|
62 |
-
plt.xlabel(f'{xvariable_lab}', fontsize=16)
|
63 |
-
plt.grid(False)
|
64 |
-
plt.legend()
|
65 |
-
|
66 |
-
if save==True:
|
67 |
-
plt.savefig(f'{metric}_{xvariable}_{samp}.pdf', dpi=300, bbox_inches='tight')
|
68 |
-
plt.show()
|
69 |
-
|
70 |
-
|
71 |
-
def plot_nz(df, bins=np.arange(0,5,0.2)):
|
72 |
-
kwargs=dict( bins=bins,alpha=0.5)
|
73 |
-
plt.hist(df.zs.values, color='grey', ls='-' ,**kwargs)
|
74 |
-
counts, _, =np.histogram(df.z.values, bins=bins)
|
75 |
-
|
76 |
-
plt.plot((bins[:-1]+bins[1:])*0.5,counts, color ='purple')
|
77 |
-
|
78 |
-
#plt.legend(fontsize=14)
|
79 |
-
plt.xlabel(r'Redshift', fontsize=14)
|
80 |
-
plt.ylabel(r'Counts', fontsize=14)
|
81 |
-
plt.yscale('log')
|
82 |
-
|
83 |
-
plt.show()
|
84 |
-
|
85 |
-
return
|
86 |
-
|
87 |
-
|
88 |
-
def plot_scatter(df, sample='specz', save=True):
|
89 |
-
# Calculate the point density
|
90 |
-
xy = np.vstack([df.zs.values,df.z.values])
|
91 |
-
zd = gaussian_kde(xy)(xy)
|
92 |
-
|
93 |
-
fig, ax = plt.subplots()
|
94 |
-
plt.scatter(df.zs.values, df.z.values,c=zd, s=1)
|
95 |
-
plt.xlim(0,5)
|
96 |
-
plt.ylim(0,5)
|
97 |
|
98 |
-
|
99 |
-
|
100 |
|
101 |
-
plt.xticks(fontsize = 12)
|
102 |
-
plt.yticks(fontsize = 12)
|
103 |
|
104 |
-
|
105 |
-
|
106 |
|
107 |
-
plt.show()
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
|
112 |
-
def maximum_mean_discrepancy(x, y, kernel_type=
|
113 |
"""
|
114 |
Compute the Maximum Mean Discrepancy (MMD) between two sets of samples.
|
115 |
|
@@ -130,7 +39,8 @@ def maximum_mean_discrepancy(x, y, kernel_type='rbf', kernel_mul=2.0, kernel_num
|
|
130 |
mmd_loss = torch.mean(x_kernel) + torch.mean(y_kernel) - 2 * torch.mean(xy_kernel)
|
131 |
return mmd_loss
|
132 |
|
133 |
-
|
|
|
134 |
"""
|
135 |
Compute the kernel matrix based on the chosen kernel type.
|
136 |
|
@@ -151,73 +61,77 @@ def compute_kernel(x, y, kernel_type='rbf', kernel_mul=2.0, kernel_num=5):
|
|
151 |
x = x.unsqueeze(1).expand(x_size, y_size, dim)
|
152 |
y = y.unsqueeze(0).expand(x_size, y_size, dim)
|
153 |
|
154 |
-
kernel_input = (x - y).pow(2).mean(2)
|
155 |
|
156 |
-
if kernel_type ==
|
157 |
kernel_matrix = kernel_input
|
158 |
-
elif kernel_type ==
|
159 |
kernel_matrix = (1 + kernel_input / kernel_mul).pow(kernel_num)
|
160 |
-
elif kernel_type ==
|
161 |
kernel_matrix = torch.exp(-kernel_input / (2 * kernel_mul**2))
|
162 |
-
elif kernel_type ==
|
163 |
kernel_matrix = torch.tanh(kernel_mul * kernel_input)
|
164 |
else:
|
165 |
-
raise ValueError(
|
|
|
|
|
166 |
|
167 |
return kernel_matrix
|
168 |
|
169 |
|
170 |
-
def select_cut(
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
if (completenss_lim is None)&(nmad_lim is None)&(outliers_lim is None):
|
178 |
-
raise(ValueError("Select at least one cut"))
|
179 |
elif sum(c is not None for c in [completenss_lim, nmad_lim, outliers_lim]) > 1:
|
180 |
raise ValueError("Select only one cut at a time")
|
181 |
-
|
182 |
else:
|
183 |
-
bin_edges = stats.mstats.mquantiles(df.
|
184 |
-
scatter, eta, cmptnss, nobj = [],[],[], []
|
185 |
|
186 |
-
for k in range(len(bin_edges)-1):
|
187 |
edge_min = bin_edges[k]
|
188 |
-
edge_max = bin_edges[k+1]
|
189 |
|
190 |
-
df_bin = df[(df.
|
191 |
-
|
192 |
|
193 |
-
cmptnss.append(np.round(len(df_bin)/len(df),2)*100)
|
194 |
scatter.append(nmad(df_bin.zwerr))
|
195 |
-
eta.append(len(df_bin[np.abs(df_bin.zwerr)>0.15])/len(df_bin)*100)
|
196 |
nobj.append(len(df_bin))
|
197 |
-
|
198 |
-
dfcuts = pd.DataFrame(
|
199 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
200 |
if completenss_lim is not None:
|
201 |
-
|
202 |
-
selected_cut = dfcuts[dfcuts[
|
203 |
-
|
204 |
-
|
205 |
elif nmad_lim is not None:
|
206 |
-
|
207 |
-
selected_cut = dfcuts[dfcuts[
|
208 |
|
209 |
-
|
210 |
elif outliers_lim is not None:
|
211 |
-
|
212 |
-
selected_cut = dfcuts[dfcuts[
|
213 |
|
|
|
|
|
|
|
214 |
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
if return_df==True:
|
219 |
-
return df_cut, selected_cut['flagcut'], dfcuts
|
220 |
else:
|
221 |
-
return selected_cut[
|
222 |
-
|
223 |
-
|
|
|
3 |
import matplotlib.pyplot as plt
|
4 |
from scipy import stats
|
5 |
import torch
|
6 |
+
from loguru import logger
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
+
def caluclate_eta(df):
|
10 |
+
return len(df[np.abs(df.zwerr)>0.15])/len(df) *100
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
+
def nmad(data):
|
14 |
+
return 1.4826 * np.median(np.abs(data - np.median(data)))
|
15 |
|
|
|
|
|
16 |
|
17 |
+
def sigma68(data):
|
18 |
+
return 0.5 * (pd.Series(data).quantile(q=0.84) - pd.Series(data).quantile(q=0.16))
|
19 |
|
|
|
|
|
|
|
|
|
20 |
|
21 |
+
def maximum_mean_discrepancy(x, y, kernel_type="rbf", kernel_mul=2.0, kernel_num=5):
|
22 |
"""
|
23 |
Compute the Maximum Mean Discrepancy (MMD) between two sets of samples.
|
24 |
|
|
|
39 |
mmd_loss = torch.mean(x_kernel) + torch.mean(y_kernel) - 2 * torch.mean(xy_kernel)
|
40 |
return mmd_loss
|
41 |
|
42 |
+
|
43 |
+
def compute_kernel(x, y, kernel_type="rbf", kernel_mul=2.0, kernel_num=5):
|
44 |
"""
|
45 |
Compute the kernel matrix based on the chosen kernel type.
|
46 |
|
|
|
61 |
x = x.unsqueeze(1).expand(x_size, y_size, dim)
|
62 |
y = y.unsqueeze(0).expand(x_size, y_size, dim)
|
63 |
|
64 |
+
kernel_input = (x - y).pow(2).mean(2)
|
65 |
|
66 |
+
if kernel_type == "linear":
|
67 |
kernel_matrix = kernel_input
|
68 |
+
elif kernel_type == "poly":
|
69 |
kernel_matrix = (1 + kernel_input / kernel_mul).pow(kernel_num)
|
70 |
+
elif kernel_type == "rbf":
|
71 |
kernel_matrix = torch.exp(-kernel_input / (2 * kernel_mul**2))
|
72 |
+
elif kernel_type == "sigmoid":
|
73 |
kernel_matrix = torch.tanh(kernel_mul * kernel_input)
|
74 |
else:
|
75 |
+
raise ValueError(
|
76 |
+
"Invalid kernel type. Supported types are 'linear', 'poly', 'rbf', and 'sigmoid'."
|
77 |
+
)
|
78 |
|
79 |
return kernel_matrix
|
80 |
|
81 |
|
82 |
+
def select_cut(
|
83 |
+
df, completenss_lim=None, nmad_lim=None, outliers_lim=None, return_df=False
|
84 |
+
):
|
85 |
+
|
86 |
+
if (completenss_lim is None) & (nmad_lim is None) & (outliers_lim is None):
|
87 |
+
raise (ValueError("Select at least one cut"))
|
|
|
|
|
|
|
88 |
elif sum(c is not None for c in [completenss_lim, nmad_lim, outliers_lim]) > 1:
|
89 |
raise ValueError("Select only one cut at a time")
|
90 |
+
|
91 |
else:
|
92 |
+
bin_edges = stats.mstats.mquantiles(df.odds, np.arange(0, 1.01, 0.1))
|
93 |
+
scatter, eta, cmptnss, nobj = [], [], [], []
|
94 |
|
95 |
+
for k in range(len(bin_edges) - 1):
|
96 |
edge_min = bin_edges[k]
|
97 |
+
edge_max = bin_edges[k + 1]
|
98 |
|
99 |
+
df_bin = df[(df.odds > edge_min)]
|
|
|
100 |
|
101 |
+
cmptnss.append(np.round(len(df_bin) / len(df), 2) * 100)
|
102 |
scatter.append(nmad(df_bin.zwerr))
|
103 |
+
eta.append(len(df_bin[np.abs(df_bin.zwerr) > 0.15]) / len(df_bin) * 100)
|
104 |
nobj.append(len(df_bin))
|
105 |
+
|
106 |
+
dfcuts = pd.DataFrame(
|
107 |
+
data=np.c_[
|
108 |
+
np.round(bin_edges[:-1], 5),
|
109 |
+
np.round(nobj, 1),
|
110 |
+
np.round(cmptnss, 1),
|
111 |
+
np.round(scatter, 3),
|
112 |
+
np.round(eta, 2),
|
113 |
+
],
|
114 |
+
columns=["flagcut", "Nobj", "completeness", "nmad", "eta"],
|
115 |
+
)
|
116 |
+
|
117 |
if completenss_lim is not None:
|
118 |
+
logger.info("Selecting cut based on completeness")
|
119 |
+
selected_cut = dfcuts[dfcuts["completeness"] <= completenss_lim].iloc[0]
|
120 |
+
|
|
|
121 |
elif nmad_lim is not None:
|
122 |
+
logger.info("Selecting cut based on nmad")
|
123 |
+
selected_cut = dfcuts[dfcuts["nmad"] <= nmad_lim].iloc[0]
|
124 |
|
|
|
125 |
elif outliers_lim is not None:
|
126 |
+
logger.info("Selecting cut based on outliers")
|
127 |
+
selected_cut = dfcuts[dfcuts["eta"] <= outliers_lim].iloc[0]
|
128 |
|
129 |
+
logger.info(
|
130 |
+
f"This cut provides completeness of {selected_cut['completeness']}, nmad={selected_cut['nmad']} and eta={selected_cut['eta']}"
|
131 |
+
)
|
132 |
|
133 |
+
df_cut = df[(df.odds > selected_cut["flagcut"])]
|
134 |
+
if return_df == True:
|
135 |
+
return df_cut, selected_cut["flagcut"], dfcuts
|
|
|
|
|
136 |
else:
|
137 |
+
return selected_cut["flagcut"], dfcuts
|
|
|
|