lauracabayol commited on
Commit
00f4790
Β·
unverified Β·
2 Parent(s): d1e8fb6 b25063d

Merge pull request #1 from lauracabayol/clean_code

Browse files
.gitignore CHANGED
@@ -1,6 +1,7 @@
1
- /temps/__pycache__/*
2
- /notebooks/.ipynb_checkpoints/
3
- /notebooks/*.ipynb
 
4
  /notebooks/developer_notebooks
5
  temps/.ipynb_checkpoints/
6
  *.ipynb
 
1
+ temps/__pycache__/*
2
+ notebooks/.ipynb_checkpoints/
3
+ notebooks/cache
4
+ notebooks/*.ipynb
5
  /notebooks/developer_notebooks
6
  temps/.ipynb_checkpoints/
7
  *.ipynb
notebooks/{Fig7_colourspace.py β†’ Colourspace.py} RENAMED
@@ -5,11 +5,11 @@
5
  # extension: .py
6
  # format_name: light
7
  # format_version: '1.5'
8
- # jupytext_version: 1.14.5
9
  # kernelspec:
10
- # display_name: insight
11
  # language: python
12
- # name: insight
13
  # ---
14
 
15
  # # FIGURE COLOURSPACE IN THE PAPER
@@ -23,6 +23,7 @@ import os
23
  from astropy.io import fits
24
  from astropy.table import Table
25
  import torch
 
26
 
27
  #matplotlib settings
28
  from matplotlib import rcParams
@@ -30,19 +31,11 @@ import matplotlib.pyplot as plt
30
  rcParams["mathtext.fontset"] = "stix"
31
  rcParams["font.family"] = "STIXGeneral"
32
 
33
- # +
34
- #insight modules
35
- import sys
36
- sys.path.append('../temps')
37
-
38
- from archive import archive
39
- from utils import nmad
40
- from temps_arch import EncoderPhotometry, MeasureZ
41
- from temps import Temps_module
42
- from plots import plot_nz
43
-
44
 
45
- # -
46
 
47
  def estimate_som_map(df, plot_arg='z', nx=40, ny=40):
48
  """
@@ -98,10 +91,14 @@ def plot_som_map(som_data, plot_arg = 'z', vmin=0, vmax=1):
98
 
99
  # ### LOAD DATA
100
 
 
 
 
 
101
  # +
102
  filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
103
 
104
- hdu_list = fits.open(os.path.join(parent_dir,filename_valid))
105
  cat = Table(hdu_list[1].data).to_pandas()
106
  cat = cat[cat['FLAG_PHOT']==0]
107
  cat = cat[cat['mu_class_L07']==1]
@@ -116,27 +113,29 @@ ID = cat['ID']
116
  VISmag = cat['MAG_VIS']
117
  zsflag = cat['reliable_S15']
118
 
119
- photoz_archive = archive(path = parent_dir,only_zspec=False)
120
  f, ferr = photoz_archive._extract_fluxes(catalogue= cat)
121
  col, colerr = photoz_archive._to_colors(f, ferr)
122
 
 
 
123
  # +
124
  dfs = {}
125
 
126
  for il, lab in enumerate(['z','L15','DA']):
127
 
128
  nn_features = EncoderPhotometry()
129
- nn_features.load_state_dict(torch.load(os.path.join(modules_dir,f'modelF_{lab}.pt')))
130
  nn_z = MeasureZ(num_gauss=6)
131
- nn_z.load_state_dict(torch.load(os.path.join(modules_dir,f'modelZ_{lab}.pt')))
132
-
133
- temps = Temps_module(nn_features, nn_z)
134
-
135
- z,zerr ,pz, flag, odds = temps.get_pz(input_data=torch.Tensor(col),
136
  return_pz=True)
137
  # Create a DataFrame with the desired columns
138
- df = pd.DataFrame(np.c_[ID, VISmag,z, flag, ztarget,zsflag,zerr, specz_or_photo],
139
- columns=['ID','VISmag','z','zflag', 'ztarget','zsflag','zuncert','S15_L15_flag'])
140
 
141
  # Calculate additional columns or operations if needed
142
  df['zwerr'] = (df.z - df.ztarget) / (1 + df.ztarget)
@@ -152,15 +151,15 @@ for il, lab in enumerate(['z','L15','DA']):
152
  # ### LOAD TRAINED MODELS AND EVALUATE PDFs AND REDSHIFT
153
 
154
  #define here the directory containing the photometric catalogues
155
- parent_dir = '/data/astro/scratch2/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5'
156
- modules_dir = '../data/models/'
157
 
158
  df_z = dfs['z']
159
  df_z_DA = dfs['DA']
160
 
161
  # ##### LOAD TRAIN SOM ON TRAINING DATA
162
 
163
- df_som = pd.read_csv(os.path.join(parent_dir,'som_dataframe.csv'), header = 0, sep =',')
164
  df_z = df_z.merge(df_som, on = 'ID')
165
  df_z_DA = df_z_DA.merge(df_som, on = 'ID')
166
 
@@ -171,10 +170,10 @@ df_l15 = df_z[(df_z.ztarget>0)]
171
  df_l15_DA = df_z_DA[(df_z_DA.ztarget>0)]
172
 
173
  df_l15_euclid = df_z[(df_z.VISmag <24.5) & (df_z.z > 0.2) & (df_z.z < 2.6)]
174
- df_l15_euclid_cut= df_l15_euclid[df_l15_euclid.zflag>0.033]
175
 
176
  df_l15_euclid_da = df_z_DA[(df_z_DA.VISmag <24.5) & (df_z_DA.z > 0.2) & (df_z_DA.z < 2.6)]
177
- df_l15_euclid_cut_da= df_l15_euclid_da[df_l15_euclid_da.zflag>0.018]
178
 
179
  # ## MAKE SOM PLOT
180
 
@@ -186,7 +185,7 @@ fig, axs = plt.subplots(6, 4, figsize=(13, 15), sharex=True, sharey=True, gridsp
186
  # Plot in the top row (axs[0, i])
187
  #top row, spectroscopic sample
188
  columns = ['ztarget','z','zwerr','count']
189
- titles = [r'$z_{true}$',r'$z$',r'$z_{\rm error}$','Counts']
190
  limits = [[0,4],[0,4],[-0.5,0.5],[0,50]]
191
  for ii in range(4):
192
  som_data = estimate_som_map(df_zspec, plot_arg=columns[ii], nx=40, ny=40)
@@ -245,13 +244,13 @@ axs[4, 0].set_ylabel(r'$y$', fontsize=14)
245
  axs[5, 0].set_ylabel(r'$y$', fontsize=14)
246
 
247
 
248
- fig.text(0.09, 0.815, r'$z_{\rm s}$ sample', va='center', rotation='vertical', fontsize=16)
249
- fig.text(0.09, 0.69, r'L15 sample', va='center', rotation='vertical', fontsize=16)
250
- fig.text(0.09, 0.56, r'L15 sample + DA', va='center', rotation='vertical', fontsize=14)
251
- fig.text(0.09, 0.44, r'$Euclid$ sample + DA', va='center', rotation='vertical', fontsize=14)
252
- fig.text(0.09, 0.3, r'$Euclid$ sample + QC', va='center', rotation='vertical', fontsize=14)
253
 
254
- fig.text(0.09, 0.17, r'$Euclid$ sample + DA + QC', va='center', rotation='vertical', fontsize=13)
255
 
256
 
257
  plt.savefig('SOM_colourspace.pdf', format='pdf', bbox_inches='tight', dpi=300)
 
5
  # extension: .py
6
  # format_name: light
7
  # format_version: '1.5'
8
+ # jupytext_version: 1.16.2
9
  # kernelspec:
10
+ # display_name: temps
11
  # language: python
12
+ # name: temps
13
  # ---
14
 
15
  # # FIGURE COLOURSPACE IN THE PAPER
 
23
  from astropy.io import fits
24
  from astropy.table import Table
25
  import torch
26
+ from pathlib import Path
27
 
28
  #matplotlib settings
29
  from matplotlib import rcParams
 
31
  rcParams["mathtext.fontset"] = "stix"
32
  rcParams["font.family"] = "STIXGeneral"
33
 
34
+ from temps.archive import Archive
35
+ from temps.utils import nmad
36
+ from temps.temps_arch import EncoderPhotometry, MeasureZ
37
+ from temps.temps import TempsModule
 
 
 
 
 
 
 
38
 
 
39
 
40
  def estimate_som_map(df, plot_arg='z', nx=40, ny=40):
41
  """
 
91
 
92
  # ### LOAD DATA
93
 
94
+ #define here the directory containing the photometric catalogues
95
+ parent_dir = Path('/data/astro/scratch/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5')
96
+ modules_dir = Path('../data/models/')
97
+
98
  # +
99
  filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
100
 
101
+ hdu_list = fits.open(parent_dir/filename_valid)
102
  cat = Table(hdu_list[1].data).to_pandas()
103
  cat = cat[cat['FLAG_PHOT']==0]
104
  cat = cat[cat['mu_class_L07']==1]
 
113
  VISmag = cat['MAG_VIS']
114
  zsflag = cat['reliable_S15']
115
 
116
+ photoz_archive = Archive(path = parent_dir,only_zspec=False)
117
  f, ferr = photoz_archive._extract_fluxes(catalogue= cat)
118
  col, colerr = photoz_archive._to_colors(f, ferr)
119
 
120
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
121
+
122
  # +
123
  dfs = {}
124
 
125
  for il, lab in enumerate(['z','L15','DA']):
126
 
127
  nn_features = EncoderPhotometry()
128
+ nn_features.load_state_dict(torch.load(modules_dir / f'modelF_{lab}.pt',map_location=torch.device('cpu')))
129
  nn_z = MeasureZ(num_gauss=6)
130
+ nn_z.load_state_dict(torch.load(modules_dir / f'modelZ_{lab}.pt',map_location=torch.device('cpu')))
131
+
132
+ temps_module = TempsModule(nn_features, nn_z)
133
+
134
+ z, pz, odds = temps_module.get_pz(input_data=torch.Tensor(col),
135
  return_pz=True)
136
  # Create a DataFrame with the desired columns
137
+ df = pd.DataFrame(np.c_[ID, VISmag,z,odds, ztarget,zsflag, specz_or_photo],
138
+ columns=['ID','VISmag','z', 'odds','ztarget','zsflag','S15_L15_flag'])
139
 
140
  # Calculate additional columns or operations if needed
141
  df['zwerr'] = (df.z - df.ztarget) / (1 + df.ztarget)
 
151
  # ### LOAD TRAINED MODELS AND EVALUATE PDFs AND REDSHIFT
152
 
153
  #define here the directory containing the photometric catalogues
154
+ parent_dir = Path('/data/astro/scratch/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5')
155
+ modules_dir = Path('../data/models/')
156
 
157
  df_z = dfs['z']
158
  df_z_DA = dfs['DA']
159
 
160
  # ##### LOAD TRAIN SOM ON TRAINING DATA
161
 
162
+ df_som = pd.read_csv(parent_dir/'som_dataframe.csv', header = 0, sep =',')
163
  df_z = df_z.merge(df_som, on = 'ID')
164
  df_z_DA = df_z_DA.merge(df_som, on = 'ID')
165
 
 
170
  df_l15_DA = df_z_DA[(df_z_DA.ztarget>0)]
171
 
172
  df_l15_euclid = df_z[(df_z.VISmag <24.5) & (df_z.z > 0.2) & (df_z.z < 2.6)]
173
+ df_l15_euclid_cut= df_l15_euclid[df_l15_euclid.odds>df_l15_euclid['odds'].quantile(0.2)]
174
 
175
  df_l15_euclid_da = df_z_DA[(df_z_DA.VISmag <24.5) & (df_z_DA.z > 0.2) & (df_z_DA.z < 2.6)]
176
+ df_l15_euclid_cut_da= df_l15_euclid_da[df_l15_euclid_da.odds>df_l15_euclid['odds'].quantile(0.2)]
177
 
178
  # ## MAKE SOM PLOT
179
 
 
185
  # Plot in the top row (axs[0, i])
186
  #top row, spectroscopic sample
187
  columns = ['ztarget','z','zwerr','count']
188
+ titles = [r'$z_{true}$ (A)',r'$z$ (B)',r'$z_{\rm error}$ (C)','Counts']
189
  limits = [[0,4],[0,4],[-0.5,0.5],[0,50]]
190
  for ii in range(4):
191
  som_data = estimate_som_map(df_zspec, plot_arg=columns[ii], nx=40, ny=40)
 
244
  axs[5, 0].set_ylabel(r'$y$', fontsize=14)
245
 
246
 
247
+ fig.text(0.09, 0.815, r'$z_{\rm s}$ samp. (1)', va='center', rotation='vertical', fontsize=16)
248
+ fig.text(0.09, 0.69, r'L15 samp. (2)', va='center', rotation='vertical', fontsize=16)
249
+ fig.text(0.09, 0.56, r'L15 samp. + DA (3)', va='center', rotation='vertical', fontsize=14)
250
+ fig.text(0.09, 0.44, r'$Euclid$ samp. + DA (4)', va='center', rotation='vertical', fontsize=14)
251
+ fig.text(0.09, 0.3, r'$Euclid$ samp. + QC (5)', va='center', rotation='vertical', fontsize=14)
252
 
253
+ fig.text(0.09, 0.17, r'(5) + DA ', va='center', rotation='vertical', fontsize=13)
254
 
255
 
256
  plt.savefig('SOM_colourspace.pdf', format='pdf', bbox_inches='tight', dpi=300)
notebooks/Comparison_methodology.py ADDED
@@ -0,0 +1,517 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ---
2
+ # jupyter:
3
+ # jupytext:
4
+ # text_representation:
5
+ # extension: .py
6
+ # format_name: light
7
+ # format_version: '1.5'
8
+ # jupytext_version: 1.16.2
9
+ # kernelspec:
10
+ # display_name: temps
11
+ # language: python
12
+ # name: temps
13
+ # ---
14
+
15
+ # +
16
+ import pandas as pd
17
+ import numpy as np
18
+ import matplotlib.pyplot as plt
19
+ from astropy.io import fits
20
+ import os
21
+ from astropy.table import Table
22
+
23
+ from temps.utils import nmad
24
+ from scipy import stats
25
+ from pathlib import Path
26
+ # -
27
+
28
+ #define here the directory containing the photometric catalogues
29
+ parent_dir = '/data/astro/scratch/lcabayol/EUCLID/DAz/DC2_results_to_share/'
30
+
31
+
32
+ # +
33
+ # List of FITS files to be processed
34
+ fits_files = [
35
+ 'GDE_RF_full.fits',
36
+ 'GDE_PHOSPHOROS_V2_full.fits',
37
+ 'OIL_LEPHARE_full.fits',
38
+ 'JDV_DNF_A_full.fits',
39
+ 'JSP_FRANKENZ_full.fits',
40
+ 'MBR_METAPHOR_full.fits',
41
+ 'GDE_ADABOOST_full.fits',
42
+ 'CSC_GPZ_best_full.fits',
43
+ 'SFO_CPZ_full.fits',
44
+ 'AAL_NNPZ_V3_full.fits'
45
+ ]
46
+
47
+ # Corresponding redshift column names
48
+ redshift_columns = [
49
+ 'REDSHIFT_RF',
50
+ 'REDSHIFT_PHOSPHOROS',
51
+ 'REDSHIFT_LEPHARE',
52
+ 'REDSHIFT_DNF',
53
+ 'REDSHIFT_FRANKENZ',
54
+ 'REDSHIFT_METAPHOR',
55
+ 'REDSHIFT_ADABOOST',
56
+ 'REDSHIFT_GPZ',
57
+ 'REDSHIFT_CPZ',
58
+ 'REDSHIFT_NNPZ'
59
+ ]
60
+
61
+ # Initialize an empty DataFrame for merging
62
+ merged_df = pd.DataFrame()
63
+
64
+ # Process each FITS file
65
+ for fits_file, redshift_col in zip(fits_files, redshift_columns):
66
+ print(fits_file)
67
+ # Open the FITS file
68
+ hdu_list = fits.open(os.path.join(parent_dir,fits_file))
69
+ df = Table(hdu_list[1].data).to_pandas()
70
+ df = df[df.REDSHIFT!=0]
71
+ df = df[['ID', 'VIS','SPECZ', 'REDSHIFT']].rename(columns={'REDSHIFT': redshift_col})
72
+ # Merge with the main DataFrame
73
+ if merged_df.empty:
74
+ merged_df = df
75
+ else:
76
+ merged_df = pd.merge(merged_df, df, on=['ID', 'VIS', 'SPECZ'], how='outer')
77
+
78
+
79
+ # -
80
+
81
+ # ## OPEN DATA
82
+
83
+ # +
84
+ modules_dir = Path('/data/astro/scratch/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5')
85
+ filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
86
+
87
+ hdu_list = fits.open(modules_dir/filename_valid)
88
+ cat_full = Table(hdu_list[1].data).to_pandas()
89
+
90
+ cat_full = cat_full[['ID','z_spec_S15','reliable_S15','mu_class_L07']]
91
+
92
+ merged_df['reliable_S15'] = cat_full.reliable_S15
93
+ merged_df['z_spec_S15'] = cat_full.z_spec_S15
94
+ merged_df['mu_class_L07'] = cat_full.mu_class_L07
95
+ merged_df['ID_catfull'] = cat_full.ID
96
+ # -
97
+
98
+ merged_df_specz = merged_df[(merged_df.z_spec_S15>0)&(merged_df.SPECZ>0)&(merged_df.reliable_S15==1)&(merged_df.mu_class_L07==1)&(merged_df.VIS!=np.inf)]
99
+
100
+ # ##Β ONLY SPECZ SAMPLE
101
+
102
+ scatter, outliers =[],[]
103
+ for im, method in enumerate(redshift_columns):
104
+ print(method)
105
+ df_method = merged_df_specz.dropna(subset=method)
106
+ zerr = (df_method.SPECZ - df_method[method] ) / (1 + df_method.SPECZ)
107
+ print(len(zerr[np.abs(zerr)>0.15]) /len(zerr))
108
+ scatter.append(nmad(zerr))
109
+ outliers.append(len(zerr[np.abs(zerr)>0.15]) / len(df_method))
110
+
111
+
112
+ # +
113
+ labs = [
114
+ 'RF',
115
+ 'PHOSPHOROS',
116
+ 'LEPHARE',
117
+ 'DNF',
118
+ 'FRANKENZ',
119
+ 'METAPHOR',
120
+ 'ADABOOST',
121
+ 'GPZ',
122
+ 'CPZ',
123
+ 'NNPZ',
124
+ ]
125
+
126
+ # Colors from colormap
127
+ cmap = plt.get_cmap('tab20')
128
+ colors = [cmap(i / len(labs)) for i in range(len(labs))]
129
+
130
+ # Plotting
131
+ plt.figure(figsize=(10, 6))
132
+ for i in range(len(labs)):
133
+ plt.scatter(outliers[i]*100, scatter[i], color=colors[i], label=labs[i], marker = '^')
134
+
135
+ # Adding legend
136
+ plt.legend(fontsize=12)
137
+ plt.ylabel(r'NMAD $[\Delta z]$', fontsize=14)
138
+ plt.xlabel('Outlier fraction [%]', fontsize=14)
139
+ plt.xticks(fontsize=14)
140
+ plt.yticks(fontsize=14)
141
+
142
+ plt.xlim(5,35)
143
+ plt.ylim(0,0.14)
144
+
145
+ # Display plot
146
+ plt.show()
147
+ # -
148
+
149
+ # ### ADD TEMPS PREDICTIONS
150
+
151
+ import torch
152
+ from temps.archive import Archive
153
+ from temps.temps_arch import EncoderPhotometry, MeasureZ
154
+ from temps.temps import TempsModule
155
+
156
+ # +
157
+ data_dir = Path('/data/astro/scratch/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5')
158
+ filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
159
+
160
+ hdu_list = fits.open(data_dir/filename_valid)
161
+ cat_phot = Table(hdu_list[1].data).to_pandas()
162
+ # -
163
+
164
+ cat_phot = cat_phot[cat_phot.ID.isin(merged_df_specz.ID_catfull)]
165
+
166
+ # +
167
+ photoz_archive = Archive(path = '/data/astro/scratch/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5',
168
+ only_zspec=True)
169
+ f, ferr = photoz_archive._extract_fluxes(catalogue= cat_phot)
170
+ col, colerr = photoz_archive._to_colors(f, ferr)
171
+
172
+ ID = cat_phot.ID
173
+
174
+ # +
175
+ modules_dir = Path('/nfs/pic.es/user/l/lcabayol/EUCLID/TEMPS/data/models')
176
+
177
+ nn_features = EncoderPhotometry()
178
+ nn_features.load_state_dict(torch.load(modules_dir / f'modelF_DA.pt',map_location=torch.device('cpu')))
179
+ nn_z = MeasureZ(num_gauss=6)
180
+ nn_z.load_state_dict(torch.load(modules_dir / f'modelZ_DA.pt', map_location=torch.device('cpu')))
181
+
182
+ temps_module = TempsModule(nn_features, nn_z)
183
+
184
+ z, pz, odds = temps_module.get_pz(input_data=torch.Tensor(col),
185
+ return_pz=True)
186
+ df = pd.DataFrame(np.c_[ID, z],
187
+ columns=['ID','TEMPS'])
188
+
189
+ df = df.dropna()
190
+ # -
191
+
192
+ merged_df_specz= merged_df_specz.merge(df, left_on='ID_catfull', right_on='ID')
193
+
194
+ # Corresponding redshift column names
195
+ redshift_columns = redshift_columns + ['TEMPS']
196
+
197
+ scatter, outliers =[],[]
198
+ for im, method in enumerate(redshift_columns):
199
+ print(method)
200
+ df_method = merged_df_specz.dropna(subset=method)
201
+ zerr = (df_method.SPECZ - df_method[method] ) / (1 + df_method.SPECZ)
202
+ print(len(zerr[np.abs(zerr)>0.15]) /len(zerr))
203
+ scatter.append(nmad(zerr))
204
+ outliers.append(len(zerr[np.abs(zerr)>0.15]) / len(df_method))
205
+
206
+
207
+ # +
208
+ labs = [
209
+ 'RF',
210
+ 'PHOSPHOROS',
211
+ 'LEPHARE',
212
+ 'DNF',
213
+ 'FRANKENZ',
214
+ 'METAPHOR',
215
+ 'ADABOOST',
216
+ 'GPZ',
217
+ 'CPZ',
218
+ 'NNPZ',
219
+ 'TEMPS'
220
+ ]
221
+
222
+ # Colors from colormap
223
+ cmap = plt.get_cmap('tab20')
224
+ colors = [cmap(i / len(labs)) for i in range(len(labs))]
225
+
226
+ # Plotting
227
+ plt.figure(figsize=(10, 6))
228
+ for i in range(len(labs)):
229
+ plt.scatter(outliers[i]*100, scatter[i], color=colors[i], label=labs[i], marker = '^')
230
+
231
+ # Adding legend
232
+ plt.legend(fontsize=12)
233
+ plt.ylabel(r'NMAD $[\Delta z]$', fontsize=14)
234
+ plt.xlabel('Outlier fraction [%]', fontsize=14)
235
+ plt.xticks(fontsize=14)
236
+ plt.yticks(fontsize=14)
237
+
238
+ plt.xlim(5,35)
239
+ plt.ylim(0,0.14)
240
+
241
+ # Display plot
242
+ plt.show()
243
+ # -
244
+
245
+ # ## ANOTHER SELECTION
246
+
247
+ # +
248
+ # List of FITS files to be processed
249
+ fits_files = [
250
+ 'GDE_RF_full.fits',
251
+ 'GDE_PHOSPHOROS_V2_full.fits',
252
+ 'OIL_LEPHARE_full.fits',
253
+ 'JDV_DNF_A_full.fits',
254
+ 'JSP_FRANKENZ_full.fits',
255
+ 'MBR_METAPHOR_full.fits',
256
+ 'GDE_ADABOOST_full.fits',
257
+ 'CSC_GPZ_best_full.fits',
258
+ 'SFO_CPZ_full.fits',
259
+ 'AAL_NNPZ_V3_full.fits'
260
+ ]
261
+
262
+ # Corresponding redshift column names
263
+ redshift_columns = [
264
+ 'REDSHIFT_RF',
265
+ 'REDSHIFT_PHOSPHOROS',
266
+ 'REDSHIFT_LEPHARE',
267
+ 'REDSHIFT_DNF',
268
+ 'REDSHIFT_FRANKENZ',
269
+ 'REDSHIFT_METAPHOR',
270
+ 'REDSHIFT_ADABOOST',
271
+ 'REDSHIFT_GPZ',
272
+ 'REDSHIFT_CPZ',
273
+ 'REDSHIFT_NNPZ'
274
+ ]
275
+
276
+ use_columns = [
277
+ 'USE_RF',
278
+ 'USE_PHOSPHOROS',
279
+ 'USE_LEPHARE',
280
+ 'USE_DNF',
281
+ 'USE_FRANKENZ',
282
+ 'USE_METAPHOR',
283
+ 'USE_ADABOOST',
284
+ 'USE_GPZ',
285
+ 'USE_CPZ',
286
+ 'USE_NNPZ'
287
+ ]
288
+
289
+ # Initialize an empty DataFrame for merging
290
+ merged_df = pd.DataFrame()
291
+
292
+ # Process each FITS file
293
+ for fits_file, redshift_col,use_col in zip(fits_files, redshift_columns,use_columns):
294
+ print(fits_file)
295
+ # Open the FITS file
296
+ hdu_list = fits.open(os.path.join(parent_dir,fits_file))
297
+ df = Table(hdu_list[1].data).to_pandas()
298
+ df = df[df.REDSHIFT!=0]
299
+ df = df[['ID', 'VIS', 'SPECZ', 'REDSHIFT', 'L15PHZ', 'USE']].rename(columns={'REDSHIFT': redshift_col, 'USE': use_col})
300
+ # Merge with the main DataFrame
301
+ if merged_df.empty:
302
+ merged_df = df
303
+ else:
304
+ merged_df = pd.merge(merged_df, df, on=['ID', 'VIS', 'SPECZ','L15PHZ'], how='outer')
305
+
306
+
307
+ # -
308
+
309
+ merged_df['comp_z'] = np.where(merged_df['SPECZ'] > 0, merged_df['SPECZ'], merged_df['L15PHZ'])
310
+ #merged_df = merged_df[(merged_df.comp_z>0)&(merged_df.comp_z<4)&(merged_df.VIS>23.5)]
311
+ merged_df = merged_df[(merged_df.comp_z>0)&(merged_df.comp_z<4)&(merged_df.VIS<25)]
312
+
313
+ # +
314
+ modules_dir = Path('/data/astro/scratch/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5')
315
+ filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
316
+
317
+ hdu_list = fits.open(modules_dir/filename_valid)
318
+ cat_full = Table(hdu_list[1].data).to_pandas()
319
+
320
+ merged_df['ID_catfull'] = cat_full.ID
321
+
322
+ # +
323
+ data_dir = Path('/data/astro/scratch/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5')
324
+ filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
325
+
326
+ hdu_list = fits.open(data_dir/filename_valid)
327
+ cat_phot = Table(hdu_list[1].data).to_pandas()
328
+ # -
329
+
330
+ cat_phot = cat_phot[cat_phot.ID.isin(merged_df.ID_catfull)]
331
+
332
+ # +
333
+ photoz_archive = Archive(path = '/data/astro/scratch/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5',
334
+ only_zspec=False)
335
+ f, ferr = photoz_archive._extract_fluxes(catalogue= cat_phot)
336
+ col, colerr = photoz_archive._to_colors(f, ferr)
337
+
338
+ ID = cat_phot.ID
339
+
340
+ # +
341
+ modules_dir = Path('/nfs/pic.es/user/l/lcabayol/EUCLID/TEMPS/data/models')
342
+
343
+ nn_features = EncoderPhotometry()
344
+ nn_features.load_state_dict(torch.load(modules_dir/f'modelF_DA.pt',map_location=torch.device('cpu')))
345
+ nn_z = MeasureZ(num_gauss=6)
346
+ nn_z.load_state_dict(torch.load(modules_dir/f'modelZ_DA.pt',map_location=torch.device('cpu')))
347
+
348
+ temps_module = TempsModule(nn_features, nn_z)
349
+
350
+ z, pz, odds = temps_module.get_pz(input_data=torch.Tensor(col),
351
+ return_pz=True)
352
+
353
+ nn_features = EncoderPhotometry()
354
+ nn_features.load_state_dict(torch.load(modules_dir/f'modelF_z.pt',map_location=torch.device('cpu')))
355
+ nn_z = MeasureZ(num_gauss=6)
356
+ nn_z.load_state_dict(torch.load(modules_dir/f'modelZ_z.pt',map_location=torch.device('cpu')))
357
+
358
+ temps_module = TempsModule(nn_features, nn_z)
359
+ znoda, pz, odds_noda = temps_module.get_pz(input_data=torch.Tensor(col),
360
+ return_pz=True)
361
+
362
+ nn_features = EncoderPhotometry()
363
+ nn_features.load_state_dict(torch.load(modules_dir/f'modelF_L15.pt',map_location=torch.device('cpu')))
364
+ nn_z = MeasureZ(num_gauss=6)
365
+ nn_z.load_state_dict(torch.load(modules_dir/f'modelZ_L15.pt',map_location=torch.device('cpu')))
366
+
367
+ temps_module = TempsModule(nn_features, nn_z)
368
+ z_L15, pz, odds_L15 = temps_module.get_pz(input_data=torch.Tensor(col),
369
+ return_pz=True)
370
+
371
+ df = pd.DataFrame(np.c_[ID, z, odds, znoda, odds_noda,z_L15, odds_L15],
372
+ columns=['ID','TEMPS', 'flag_TEMPS', 'TEMPS_noda', 'flag_TEMPSnoda', 'TEMPS_L15', 'flag_L15'])
373
+
374
+ df = df.dropna()
375
+
376
+ # +
377
+ percent=0.3
378
+ df['USE_TEMPS'] = np.zeros(shape=len(df))
379
+ # Calculate the 50th percentile (median) value of 'Flag_temps'
380
+ threshold = df['flag_TEMPS'].quantile(percent)
381
+
382
+ # Set 'USE_TEMPS' to 1 if 'Flag_temps' is in the top 50% (greater than or equal to the threshold)
383
+ df['USE_TEMPS'] = np.where(df['flag_TEMPS'] >= threshold, 1, 0)
384
+
385
+ # +
386
+ percent=0.3
387
+ df['USE_TEMPS_noda'] = np.zeros(shape=len(df))
388
+ # Calculate the 50th percentile (median) value of 'Flag_temps'
389
+ threshold = df['flag_TEMPSnoda'].quantile(percent)
390
+
391
+ # Set 'USE_TEMPS' to 1 if 'Flag_temps' is in the top 50% (greater than or equal to the threshold)
392
+ df['USE_TEMPS_noda'] = np.where(df['flag_TEMPSnoda'] >= threshold, 1, 0)
393
+
394
+ # +
395
+ percent=0.3
396
+ df['USE_TEMPS_L15'] = np.zeros(shape=len(df))
397
+ # Calculate the 50th percentile (median) value of 'Flag_temps'
398
+ threshold = df['flag_L15'].quantile(percent)
399
+
400
+ # Set 'USE_TEMPS' to 1 if 'Flag_temps' is in the top 50% (greater than or equal to the threshold)
401
+ df['USE_TEMPS_L15'] = np.where(df['flag_L15'] >= threshold, 1, 0)
402
+ # -
403
+
404
+ merged_df_temps = merged_df.merge(df, left_on='ID_catfull', right_on='ID')
405
+
406
+ # Corresponding redshift column names
407
+ redshift_columns = [
408
+ 'REDSHIFT_RF',
409
+ 'REDSHIFT_PHOSPHOROS',
410
+ 'REDSHIFT_LEPHARE',
411
+ 'REDSHIFT_DNF',
412
+ 'REDSHIFT_FRANKENZ',
413
+ 'REDSHIFT_METAPHOR',
414
+ 'REDSHIFT_ADABOOST',
415
+ 'REDSHIFT_GPZ',
416
+ 'REDSHIFT_CPZ',
417
+ 'REDSHIFT_NNPZ'
418
+ ]
419
+
420
+ redshift_columns = redshift_columns + ['TEMPS', 'TEMPS_noda', 'TEMPS_L15']
421
+ use_columns = use_columns + ['USE_TEMPS','USE_TEMPS_noda', 'USE_TEMPS_L15']
422
+
423
+ merged_df_temps = merged_df_temps[merged_df_temps.VIS <25]
424
+
425
+
426
+ scatter, outliers, size =[],[], []
427
+ for method, use in(zip(redshift_columns, use_columns)):
428
+ print(method)
429
+ #df_method = merged_df_temps.dropna(subset=method)
430
+ df_method = merged_df_temps[(merged_df_temps.loc[:, method]>0.2)&(merged_df_temps.loc[:, method]<2.6)]
431
+ df_method = df_method[df_method.VIS<24.5]
432
+ norm_size = len(df_method)
433
+ df_method = df_method[df_method.loc[:, use]==1]
434
+ zerr = (df_method.comp_z - df_method[method] ) / (1 + df_method.comp_z)
435
+ scatter.append(nmad(zerr))
436
+ outliers.append(len(zerr[np.abs(zerr)>0.15]) / len(df_method))
437
+ size.append(len(df_method)/norm_size)
438
+ print(nmad(zerr),len(zerr[np.abs(zerr)>0.15]) / len(df_method), len(df_method) /norm_size )
439
+
440
+
441
+ scatter_faint, outliers_faint, size_faint =[],[], []
442
+ for method, use in(zip(redshift_columns, use_columns)):
443
+ print(method)
444
+ #df_method = merged_df_temps.dropna(subset=method)
445
+ df_method = merged_df_temps[(merged_df_temps.loc[:,'VIS']>23.5)&(merged_df_temps.loc[:,'VIS']<25)]
446
+ #df_method = df_method[df_method.loc[:, use]==1]
447
+ #df_method = merged_df_temps[(merged_df_temps.loc[:,'VIS']>23.5)&(merged_df_temps.loc[:,'VIS']<24.5)]
448
+ zerr = (df_method.comp_z - df_method[method] ) / (1 + df_method.comp_z)
449
+ scatter_faint.append(nmad(zerr))
450
+ outliers_faint.append(len(zerr[np.abs(zerr)>0.15]) / len(df_method))
451
+ size_faint.append(len(df_method))
452
+ print(nmad(zerr),len(zerr[np.abs(zerr)>0.15]) / len(df_method), len(df_method))
453
+
454
+
455
+ # +
456
+ import matplotlib.pyplot as plt
457
+ import numpy as np
458
+ from pastamarkers import markers
459
+
460
+ # Define labels for the models
461
+ labs = [
462
+ 'RF', 'PHOSPHOROS', 'LEPHARE', 'DNF', 'FRANKENZ', 'METAPHOR',
463
+ 'ADABOOST', 'GPZ', 'CPZ', 'NNPZ', 'TEMPS', 'TEMPS - no DA', 'TEMPS - L15'
464
+ ]
465
+
466
+ markers_pasta = [markers.penne, markers.conchiglie, markers.tortellini, markers.creste, markers.spaghetti, markers.ravioli, markers.tagliatelle, markers.mezzelune,markers.puntine, markers.stelline , 's', 'o', '^']
467
+
468
+ labs_faint = [f"{lab}_faint" for lab in labs] # Labels for the faint data
469
+
470
+
471
+ # Colors from colormap
472
+ cmap = plt.get_cmap('tab20')
473
+ colors = [cmap(i / len(labs)) for i in range(len(labs))]
474
+
475
+ # Create subplots with 2 panels stacked vertically
476
+ fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 12), sharex=False)
477
+
478
+ # Plotting for the top panel
479
+ for i in range(len(labs)):
480
+ if labs[i] == 'TEMPS - no DA' or labs[i] == 'TEMPS - L15':
481
+ ax1.scatter(np.nan, np.nan, color=colors[i], label=labs[i], marker=markers_pasta[i], s=300)
482
+ elif labs[i]=='CPZ':
483
+ ax1.scatter(outliers[i] * 100, scatter[i], color=colors[i], label=labs[i], marker=markers_pasta[i], s=300)
484
+ ax1.text(outliers[i] * 100 -0.2, scatter[i] + 0.001, f'{int(np.around(size[i] * 100))}', fontsize=12, verticalalignment='bottom')
485
+
486
+ elif labs[i]=='ADABOOST':
487
+ ax1.scatter(outliers[i] * 100, scatter[i], color=colors[i], label=labs[i], marker=markers_pasta[i], s=300)
488
+ ax1.text(outliers[i] * 100 - 0.5, scatter[i] - 0.004, f'{int(np.around(size[i] * 100))}', fontsize=12, verticalalignment='bottom')
489
+
490
+ else:
491
+ ax1.scatter(outliers[i] * 100, scatter[i], color=colors[i], label=labs[i], marker=markers_pasta[i], s=300)
492
+ ax1.text(outliers[i] * 100 - 0.5, scatter[i] + 0.001, f'{int(np.around(size[i] * 100))}', fontsize=12, verticalalignment='bottom')
493
+
494
+ # Customizations for the top plot
495
+ ax1.set_ylabel(r'NMAD $[\Delta z]$', fontsize=24)
496
+ ax1.legend(fontsize=14)
497
+ ax1.tick_params(axis='both', which='major', labelsize=20)
498
+
499
+ # Plotting for the bottom panel (faint data)
500
+ for i in range(len(labs)):
501
+ ax2.scatter(outliers_faint[i] * 100, scatter_faint[i], color=colors[i], label=labs[i], marker=markers_pasta[i], s=300)
502
+
503
+ # Customizations for the bottom plot
504
+ ax2.set_ylabel(r'NMAD $[\Delta z]$', fontsize=24)
505
+ ax2.set_xlabel('Outlier fraction [%]', fontsize=24)
506
+ ax2.tick_params(axis='both', which='major', labelsize=20)
507
+
508
+ # Display the plot
509
+ plt.tight_layout()
510
+ #plt.savefig('Comparison_paper.pdf', bbox_inches='tight')
511
+ plt.show()
512
+
513
+ # -
514
+
515
+ cat_val_z = cat_val[['RA','DEC']].merge(cat_all[['RA','DEC','z_spec_S15','photo_z_L15','reliable_S15','mu_class_L07']], on = ['RA','DEC'])
516
+
517
+ merged_df = merged_df.merge(cat_val_z, on = ['RA','DEC'])
notebooks/Feature_space.py CHANGED
@@ -5,11 +5,11 @@
5
  # extension: .py
6
  # format_name: light
7
  # format_version: '1.5'
8
- # jupytext_version: 1.14.5
9
  # kernelspec:
10
- # display_name: insight
11
  # language: python
12
- # name: insight
13
  # ---
14
 
15
  # # DOMAIN ADAPTATION INTUITION
@@ -23,6 +23,9 @@ import os
23
  from astropy.io import fits
24
  from astropy.table import Table
25
  import torch
 
 
 
26
 
27
  #matplotlib settings
28
  from matplotlib import rcParams
@@ -30,28 +33,22 @@ import matplotlib.pyplot as plt
30
  rcParams["mathtext.fontset"] = "stix"
31
  rcParams["font.family"] = "STIXGeneral"
32
 
33
- # +
34
- #insight modules
35
- import sys
36
- sys.path.append('../temps')
37
-
38
- from archive import archive
39
- from utils import nmad
40
- from temps_arch import EncoderPhotometry, MeasureZ
41
- from temps import Temps_module
42
- from plots import plot_nz
43
- # -
44
 
45
  # ## LOAD DATA
46
 
47
  #define here the directory containing the photometric catalogues
48
- parent_dir = '/data/astro/scratch2/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5'
49
- modules_dir = '../data/models/'
50
 
51
  # +
52
  filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
53
 
54
- hdu_list = fits.open(os.path.join(parent_dir,filename_valid))
55
  cat = Table(hdu_list[1].data).to_pandas()
56
  cat = cat[cat['FLAG_PHOT']==0]
57
  cat = cat[cat['mu_class_L07']==1]
@@ -70,7 +67,7 @@ cat['specz_or_photo']=specz_or_photo
70
 
71
  # ### EXTRACT PHOTOMETRY
72
 
73
- photoz_archive = archive(path = parent_dir,only_zspec=False)
74
  f, ferr = photoz_archive._extract_fluxes(catalogue= cat)
75
  col, colerr = photoz_archive._to_colors(f, ferr)
76
 
@@ -80,7 +77,7 @@ features_all = np.zeros((3,len(cat),10))
80
  for il, lab in enumerate(['z','L15','DA']):
81
 
82
  nn_features = EncoderPhotometry()
83
- nn_features.load_state_dict(torch.load(os.path.join(modules_dir,f'modelF_{lab}.pt')))
84
 
85
  features = nn_features(torch.Tensor(col))
86
  features = features.detach().cpu().numpy()
@@ -132,7 +129,7 @@ autoencoder = Autoencoder(input_dim=10,
132
  criterion = nn.L1Loss()
133
  optimizer = optim.Adam(autoencoder.parameters(), lr=0.0001)
134
 
135
- # +
136
  # Define the number of epochs
137
  num_epochs = 100
138
  for epoch in range(num_epochs):
@@ -162,9 +159,7 @@ print('Training finished')
162
 
163
  # #### EVALUTATE AUTOENCODER
164
 
165
- # + [markdown] jupyter={"source_hidden": true}
166
  # cat.to_csv('features_cat.csv', header=True, sep=',')
167
- # -
168
 
169
  indexes_specz = cat[(cat.specz_or_photo==0)&(cat.reliable_S15>0)].reset_index().index
170
 
@@ -173,6 +168,8 @@ for i in range(3):
173
  _, features = autoencoder(torch.Tensor(features_all[i]))
174
  features_all_reduced[i] = features.detach().cpu().numpy()
175
 
 
 
176
  # ### Plot the features
177
 
178
  start = 0
@@ -182,7 +179,6 @@ values_not_in_indexes_specz = all_values - set(indexes_specz)
182
  indexes_nospecz = sorted(values_not_in_indexes_specz)
183
 
184
  # +
185
- import seaborn as sns
186
 
187
  # Create subplots with three panels
188
  fig, axs = plt.subplots(1, 3, figsize=(15, 5))
@@ -223,14 +219,14 @@ axs[1].set_title('Trained on L15')
223
 
224
  # Third subplot
225
  features_all_reduced_nospecz = pd.DataFrame(features_all_reduced[2, indexes_nospecz, :]).drop_duplicates().values
226
- sns.kdeplot(x=features_all_reduced_nospecz[:, 0],
227
- y=features_all_reduced_nospecz[:, 1],
228
  clip=(-1, 5),
229
  ax=axs[2],
230
  color='salmon',
231
  label='Wide-field sample')
232
- sns.kdeplot(x=features_all_reduced_specz[:, 0],
233
- y=features_all_reduced_specz[:, 1],
234
  clip=(-1, 5),
235
  ax=axs[2],
236
  color='lightskyblue',
@@ -252,200 +248,7 @@ axs[2].legend(legend_handles, legend_labels, loc='upper right', fontsize=16)
252
  # Adjust layout
253
  plt.tight_layout()
254
 
255
- plt.savefig('Contourplot.pdf', bbox_inches='tight')
256
- plt.show()
257
-
258
- # -
259
-
260
-
261
-
262
-
263
-
264
-
265
-
266
- np.savetxt('features.txt',features_all_reduced.reshape(3*164816, 2))
267
-
268
-
269
-
270
-
271
-
272
-
273
-
274
-
275
-
276
-
277
-
278
- # +
279
- photoz_archive = archive(path = parent_dir,only_zspec=False)
280
-
281
- fig, ax = plt.subplots(ncols = 3, figsize=(15,4), sharex=True, sharey=True)
282
- colors = ['navy', 'goldenrod']
283
- titles = [r'Training: $z_s$', r'Training: L15',r'Training: $z_s$ + DA']
284
- x_min, x_max = -5,5
285
- y_min, y_max = -5,5
286
- x_grid, y_grid = np.meshgrid(np.linspace(x_min, x_max, 10), np.linspace(y_min, y_max, 10))
287
- xy_grid = np.vstack([x_grid.ravel(), y_grid.ravel()])
288
- density_grid = density_estimation(xy_grid).reshape(x_grid.shape)
289
- for il, lab in enumerate(['z','L15','DA']):
290
-
291
-
292
- nn_features = EncoderPhotometry()
293
- nn_features.load_state_dict(torch.load(os.path.join(modules_dir,f'modelF_{lab}.pt')))
294
-
295
- for it, target_type in enumerate(['L15','zs']):
296
- if target_type=='zs':
297
- cat_sub = photoz_archive._select_only_zspec(cat)
298
- cat_sub = photoz_archive._clean_zspec_sample(cat_sub)
299
-
300
- elif target_type=='L15':
301
- cat_sub = photoz_archive._exclude_only_zspec(cat)
302
- else:
303
- assert False
304
-
305
- cat_sub = photoz_archive._clean_photometry(cat_sub)
306
- print(cat_sub.shape)
307
-
308
-
309
-
310
- f, ferr = photoz_archive._extract_fluxes(cat_sub)
311
- col, colerr = photoz_archive._to_colors(f, ferr)
312
-
313
- features = nn_features(torch.Tensor(col))
314
- features = features.detach().cpu().numpy()
315
-
316
-
317
- #xy = np.vstack([features[:1000,0], features[:1000,1]])
318
- #zd = gaussian_kde(xy)(xy)
319
- #ax[il].scatter(features[:1000,0], features[:1000,1],c=zd, s=3)
320
-
321
- xy = np.vstack([features[:,0], features[:,1]])
322
- density_estimation = gaussian_kde(xy)
323
-
324
- # Define grid for plotting density lines
325
-
326
- xy_grid = np.vstack([x_grid.ravel(), y_grid.ravel()])
327
- density_grid = density_estimation(xy_grid).reshape(x_grid.shape)
328
-
329
- # Plot contour lines representing density
330
- ax[il].contour(x_grid, y_grid, density_grid, colors=colors[it], label = f'{target_type}')
331
-
332
-
333
-
334
- ax[il].set_title(titles[il])
335
- ax[il].set_xlim(-5,5)
336
- ax[il].set_ylim(-5,5)
337
-
338
-
339
- ax[0].set_ylabel('Feature 1', fontsize=14)
340
- #plt.ylabel('Feature 2', fontsize=14)
341
-
342
- #assert False
343
-
344
-
345
-
346
-
347
-
348
-
349
- # -
350
-
351
- H
352
-
353
- H
354
-
355
- xedges
356
-
357
- yedges
358
-
359
- # +
360
- import matplotlib.colors as colors
361
- from matplotlib import path
362
- import numpy as np
363
- from matplotlib import pyplot as plt
364
- try:
365
- from astropy.convolution import Gaussian2DKernel, convolve
366
- astro_smooth = True
367
- except ImportError as IE:
368
- astro_smooth = False
369
-
370
- np.random.seed(123)
371
- #t = np.linspace(-5,1.2,1000)
372
- x = features[:1000,0]
373
- y = features[:1000,1]
374
-
375
- H, xedges, yedges = np.histogram2d(x,y, bins=(10,10))
376
- xmesh, ymesh = np.meshgrid(xedges[:-1], yedges[:-1])
377
-
378
- # Smooth the contours (if astropy is installed)
379
- if astro_smooth:
380
- kernel = Gaussian2DKernel(x_stddev=1.)
381
- H=convolve(H,kernel)
382
-
383
- fig,ax = plt.subplots(1, figsize=(7,6))
384
- clevels = ax.contour(xmesh,ymesh,H.T,lw=.9,cmap='winter')#,zorder=90)
385
- ax.scatter(x,y,s=3)
386
- #ax.set_xlim(-20,5)
387
- #ax.set_ylim(-20,5)
388
-
389
- # Identify points within contours
390
- #p = clevels.collections[0].get_paths()
391
- #inside = np.full_like(x,False,dtype=bool)
392
- #for level in p:
393
- # inside |= level.contains_points(zip(*(x,y)))
394
-
395
- #ax.plot(x[~inside],y[~inside],'kx')
396
- #plt.show(block=False)
397
- # -
398
-
399
- density_grid
400
-
401
- features.shape, zd.shape
402
-
403
- # + jupyter={"outputs_hidden": true}
404
- xy = np.vstack([features[:,0], features[:,1]])
405
- zd = gaussian_kde(xy)(xy)
406
- plt.scatter(features[:,0], features[:,1],c=zd)
407
-
408
-
409
- # +
410
- # Make the base corner plot
411
- figure = corner.corner(features[:,:2], quantiles=[0.16, 0.84], show_titles=False, color ='crimson')
412
- corner.corner(samples2, fig=fig)
413
- ndim=2
414
- # Extract the axes
415
- axes = np.array(figure.axes).reshape((ndim, ndim))
416
-
417
-
418
- for a in axes[np.triu_indices(ndim)]:
419
- a.remove()
420
-
421
- # +
422
- import numpy as np
423
- import matplotlib.pyplot as plt
424
- from scipy.stats import gaussian_kde
425
-
426
- # Assuming 'features' is your data array with shape (n_samples, 2)
427
-
428
- # Calculate the density estimate
429
- xy = np.vstack([features[:,0], features[:,1]])
430
- density_estimation = gaussian_kde(xy)
431
-
432
- # Define grid for plotting density lines
433
-
434
- xy_grid = np.vstack([x_grid.ravel(), y_grid.ravel()])
435
- density_grid = density_estimation(xy_grid).reshape(x_grid.shape)
436
-
437
- # Plot contour lines representing density
438
- plt.contour(x_grid, y_grid, density_grid, colors='black')
439
-
440
- # Optionally, you can add a scatter plot on top of the density lines for better visualization
441
- #plt.scatter(features[:,0], features[:,1], color='blue', alpha=0.5)
442
-
443
- # Set labels and title
444
- plt.xlabel('Feature 1')
445
- plt.ylabel('Feature 2')
446
- plt.title('Density Lines Plot')
447
-
448
- # Show plot
449
  plt.show()
450
 
451
  # -
@@ -454,41 +257,6 @@ plt.show()
454
 
455
 
456
 
457
- corner_plot = corner.corner(Arinyo_preds,
458
- labels=[r'$b$', r'$\beta$', '$q_1$', '$k_{vav}$','$a_v$','$b_v$','$k_p$','$q_2$'],
459
- truths=Arinyo_coeffs_central[test_snap],
460
- truth_color='crimson')
461
-
462
- import corner
463
- figure = corner.corner(features, quantiles=[0.16, 0.5, 0.84], show_titles=False)
464
- axes = np.array(fig.axes).reshape((ndim, ndim))
465
- for a in axes[np.triu_indices(ndim)]:
466
- a.remove()
467
 
468
 
469
 
470
- # +
471
- # My data
472
- x = features[:,0]
473
- y = features[:,1]
474
-
475
- # Peform the kernel density estimate
476
- k = stats.gaussian_kde(np.vstack([x, y]))
477
- xi, yi = np.mgrid[-5:5,-5:5]
478
- zi = k(np.vstack([xi.flatten(), yi.flatten()]))
479
-
480
-
481
-
482
- fig = plt.figure()
483
- ax = fig.gca()
484
-
485
-
486
- CS = ax.contour(xi, yi, zi.reshape(xi.shape), colors='crimson')
487
-
488
- ax.set_xlim(-5, 5)
489
- ax.set_ylim(-5, 5)
490
-
491
- plt.show()
492
- # -
493
-
494
-
 
5
  # extension: .py
6
  # format_name: light
7
  # format_version: '1.5'
8
+ # jupytext_version: 1.16.2
9
  # kernelspec:
10
+ # display_name: temps
11
  # language: python
12
+ # name: temps
13
  # ---
14
 
15
  # # DOMAIN ADAPTATION INTUITION
 
23
  from astropy.io import fits
24
  from astropy.table import Table
25
  import torch
26
+ from pathlib import Path
27
+ import seaborn as sns
28
+
29
 
30
  #matplotlib settings
31
  from matplotlib import rcParams
 
33
  rcParams["mathtext.fontset"] = "stix"
34
  rcParams["font.family"] = "STIXGeneral"
35
 
36
+ from temps.archive import Archive
37
+ from temps.utils import nmad
38
+ from temps.temps_arch import EncoderPhotometry, MeasureZ
39
+ from temps.temps import TempsModule
40
+ from temps.plots import plot_nz
 
 
 
 
 
 
41
 
42
  # ## LOAD DATA
43
 
44
  #define here the directory containing the photometric catalogues
45
+ parent_dir = Path('/data/astro/scratch/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5')
46
+ modules_dir = Path('../data/models/')
47
 
48
  # +
49
  filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
50
 
51
+ hdu_list = fits.open(parent_dir/filename_valid)
52
  cat = Table(hdu_list[1].data).to_pandas()
53
  cat = cat[cat['FLAG_PHOT']==0]
54
  cat = cat[cat['mu_class_L07']==1]
 
67
 
68
  # ### EXTRACT PHOTOMETRY
69
 
70
+ photoz_archive = Archive(path = parent_dir,only_zspec=False)
71
  f, ferr = photoz_archive._extract_fluxes(catalogue= cat)
72
  col, colerr = photoz_archive._to_colors(f, ferr)
73
 
 
77
  for il, lab in enumerate(['z','L15','DA']):
78
 
79
  nn_features = EncoderPhotometry()
80
+ nn_features.load_state_dict(torch.load(modules_dir/f'modelF_{lab}.pt',map_location=torch.device('cpu')))
81
 
82
  features = nn_features(torch.Tensor(col))
83
  features = features.detach().cpu().numpy()
 
129
  criterion = nn.L1Loss()
130
  optimizer = optim.Adam(autoencoder.parameters(), lr=0.0001)
131
 
132
+ # + jupyter={"outputs_hidden": true}
133
  # Define the number of epochs
134
  num_epochs = 100
135
  for epoch in range(num_epochs):
 
159
 
160
  # #### EVALUTATE AUTOENCODER
161
 
 
162
  # cat.to_csv('features_cat.csv', header=True, sep=',')
 
163
 
164
  indexes_specz = cat[(cat.specz_or_photo==0)&(cat.reliable_S15>0)].reset_index().index
165
 
 
168
  _, features = autoencoder(torch.Tensor(features_all[i]))
169
  features_all_reduced[i] = features.detach().cpu().numpy()
170
 
171
+ features_all.shape
172
+
173
  # ### Plot the features
174
 
175
  start = 0
 
179
  indexes_nospecz = sorted(values_not_in_indexes_specz)
180
 
181
  # +
 
182
 
183
  # Create subplots with three panels
184
  fig, axs = plt.subplots(1, 3, figsize=(15, 5))
 
219
 
220
  # Third subplot
221
  features_all_reduced_nospecz = pd.DataFrame(features_all_reduced[2, indexes_nospecz, :]).drop_duplicates().values
222
+ sns.kdeplot(x=features_all_reduced[2, indexes_nospecz, 0],
223
+ y=features_all_reduced[2, indexes_nospecz, 1],
224
  clip=(-1, 5),
225
  ax=axs[2],
226
  color='salmon',
227
  label='Wide-field sample')
228
+ sns.kdeplot(x=features_all_reduced[2, indexes_specz, 0],
229
+ y=features_all_reduced[2, indexes_specz,1],
230
  clip=(-1, 5),
231
  ax=axs[2],
232
  color='lightskyblue',
 
248
  # Adjust layout
249
  plt.tight_layout()
250
 
251
+ #plt.savefig('Contourplot.pdf', bbox_inches='tight')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  plt.show()
253
 
254
  # -
 
257
 
258
 
259
 
 
 
 
 
 
 
 
 
 
 
260
 
261
 
262
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
notebooks/Fig6_qualitycut.py DELETED
@@ -1,164 +0,0 @@
1
- # ---
2
- # jupyter:
3
- # jupytext:
4
- # text_representation:
5
- # extension: .py
6
- # format_name: light
7
- # format_version: '1.5'
8
- # jupytext_version: 1.14.5
9
- # kernelspec:
10
- # display_name: insight
11
- # language: python
12
- # name: insight
13
- # ---
14
-
15
- # # FIGURE 6 IN THE PAPER
16
-
17
- # ## QUALITY CUTS
18
-
19
- # %load_ext autoreload
20
- # %autoreload 2
21
-
22
- import pandas as pd
23
- import numpy as np
24
- import os
25
- import torch
26
- from scipy import stats
27
-
28
- #matplotlib settings
29
- from matplotlib import rcParams
30
- import matplotlib.pyplot as plt
31
- rcParams["mathtext.fontset"] = "stix"
32
- rcParams["font.family"] = "STIXGeneral"
33
-
34
- #insight modules
35
- import sys
36
- sys.path.append('../temps')
37
- #from insight_arch import EncoderPhotometry, MeasureZ
38
- #from insight import Insight_module
39
- from archive import archive
40
- from utils import nmad
41
- from temps_arch import EncoderPhotometry, MeasureZ
42
- from temps import Temps_module
43
-
44
-
45
- # ### LOAD DATA (ONLY SPECZ)
46
-
47
- #define here the directory containing the photometric catalogues
48
- parent_dir = '/data/astro/scratch2/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5'
49
- modules_dir = '../data/models/'
50
-
51
- photoz_archive = archive(path = parent_dir,only_zspec=True,flags_kept=[1. , 1.1, 1.4, 1.5, 2,2.1,2.4,2.5,3., 3.1, 3.4, 3.5, 4., 9. , 9.1, 9.3, 9.4, 9.5,11.1, 11.5, 12.1, 12.5, 13. , 13.1, 13.5, 14, ])
52
- f_test_specz, ferr_test_specz, specz_test ,VIS_mag_test = photoz_archive.get_testing_data()
53
-
54
-
55
- # ### LOAD TRAINED MODELS AND EVALUATE PDF OF RANDOM EXAMPLES
56
-
57
- # +
58
- # Initialize an empty dictionary to store DataFrames
59
- dfs = {}
60
-
61
- for il, lab in enumerate(['z','L15','DA']):
62
-
63
- nn_features = EncoderPhotometry()
64
- nn_features.load_state_dict(torch.load(os.path.join(modules_dir,f'modelF_{lab}.pt')))
65
- nn_z = MeasureZ(num_gauss=6)
66
- nn_z.load_state_dict(torch.load(os.path.join(modules_dir,f'modelZ_{lab}.pt')))
67
-
68
- temps = Temps_module(nn_features, nn_z)
69
-
70
- z,zerr, pz, flag, odds = temps.get_pz(input_data=torch.Tensor(f_test_specz),
71
- return_pz=True)
72
-
73
-
74
- # Create a DataFrame with the desired columns
75
- df = pd.DataFrame(np.c_[z, flag, odds, specz_test],
76
- columns=['z','zflag', 'odds' ,'ztarget'])
77
-
78
- # Calculate additional columns or operations if needed
79
- df['zwerr'] = (df.z - df.ztarget) / (1 + df.ztarget)
80
-
81
- # Drop any rows with NaN values
82
- df = df.dropna()
83
-
84
- # Assign the DataFrame to a key in the dictionary
85
- dfs[lab] = df
86
-
87
- # -
88
-
89
- # ### STATISTICS BASED ON OUR QUALITY CUT
90
-
91
- # +
92
- bin_edges = stats.mstats.mquantiles(df.zflag, np.arange(0,1.01,0.05))
93
- scatter, eta, xlab, xmag, xzs, flagmean = [],[],[], [], [], []
94
-
95
- for k in range(len(bin_edges)-1):
96
- edge_min = bin_edges[k]
97
- edge_max = bin_edges[k+1]
98
-
99
- df_bin = df[(df.zflag > edge_min)]
100
-
101
-
102
- xlab.append(np.round(len(df_bin)/len(df),2)*100)
103
- xzs.append(0.5*(df_bin.ztarget.min()+df_bin.ztarget.max()))
104
- flagmean.append(np.mean(df_bin.zflag))
105
- scatter.append(nmad(df_bin.zwerr))
106
- eta.append(len(df_bin[np.abs(df_bin.zwerr)>0.15])/len(df)*100)
107
-
108
-
109
- # -
110
-
111
- # ### STATISTICS BASED ON ODDS
112
-
113
- # +
114
- bin_edges = stats.mstats.mquantiles(df.odds, np.arange(0,1.01,0.05))
115
- scatter_odds, eta_odds,xlab_odds, oddsmean = [],[],[], []
116
-
117
- for k in range(len(bin_edges)-1):
118
- edge_min = bin_edges[k]
119
- edge_max = bin_edges[k+1]
120
-
121
- df_bin = df[(df.odds > edge_min)]
122
-
123
-
124
- xlab_odds.append(np.round(len(df_bin)/len(df),2)*100)
125
- oddsmean.append(np.mean(df_bin.zflag))
126
- scatter_odds.append(nmad(df_bin.zwerr))
127
- eta_odds.append(len(df_bin[np.abs(df_bin.zwerr)>0.15])/len(df)*100)
128
-
129
-
130
- # -
131
-
132
- # ### PLOTS
133
-
134
- # +
135
- plt.plot(xlab_odds,scatter_odds, marker = '.', color ='crimson', label=r'$\theta(\Delta z)$', ls='--', alpha=0.5)
136
- plt.plot(xlab,scatter, marker = '.', color ='navy',label=r'$\xi = \theta(\Delta z)$')
137
-
138
-
139
- plt.ylabel(r'NMAD [$\Delta z\ /\ (1 + z_{\rm s})$]', fontsize=16)
140
- plt.xlabel('Completeness', fontsize=16)
141
-
142
- plt.yticks(fontsize=12)
143
- plt.xticks(np.arange(5,101,10), fontsize=12)
144
- plt.legend(fontsize=14)
145
-
146
- plt.savefig('Flag_nmad_zspec.pdf', bbox_inches='tight')
147
- plt.show()
148
-
149
- # +
150
- plt.plot(xlab_odds,eta_odds, marker='.', color ='crimson', label=r'$\theta(\Delta z)$', ls='--', alpha=0.5)
151
- plt.plot(xlab,eta, marker='.', color ='navy',label=r'$\xi = \theta(\Delta z)$')
152
-
153
- plt.yticks(fontsize=12)
154
- plt.xticks(np.arange(5,101,10), fontsize=12)
155
- plt.ylabel(r'$\eta$ [%]', fontsize=16)
156
- plt.xlabel('Completeness', fontsize=16)
157
- plt.legend()
158
-
159
- plt.savefig('Flag_eta_zspec.pdf', bbox_inches='tight')
160
-
161
- plt.show()
162
- # -
163
-
164
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
notebooks/{Fig2_NMAD.py β†’ NMAD.py} RENAMED
@@ -6,15 +6,15 @@
6
  # extension: .py
7
  # format_name: percent
8
  # format_version: '1.3'
9
- # jupytext_version: 1.14.5
10
  # kernelspec:
11
- # display_name: insight
12
  # language: python
13
- # name: insight
14
  # ---
15
 
16
  # %% [markdown]
17
- # # FIGURE 2 IN THE PAPER
18
 
19
  # %% [markdown]
20
  # ## METRICS FOR THE DIFFERENT METHODS ON THE WIDE FIELD SAMPLE
@@ -43,15 +43,14 @@ rcParams["font.family"] = "STIXGeneral"
43
 
44
 
45
  # %%
46
- #insight modules
47
- import sys
48
- sys.path.append('../temps')
49
-
50
- from archive import archive
51
- from utils import nmad
52
- from temps_arch import EncoderPhotometry, MeasureZ
53
- from temps import Temps_module
54
 
 
 
 
 
 
 
55
 
56
 
57
  # %%
@@ -62,15 +61,13 @@ eval_methods=True
62
 
63
  # %%
64
  #define here the directory containing the photometric catalogues
65
- parent_dir = '/data/astro/scratch2/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5'
66
- modules_dir = '../data/models/'
67
 
68
  # %%
69
- #load catalogue and apply cuts
70
-
71
  filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
72
-
73
- hdu_list = fits.open(os.path.join(parent_dir,filename_valid))
74
  cat = Table(hdu_list[1].data).to_pandas()
75
  cat = cat[cat['FLAG_PHOT']==0]
76
  cat = cat[cat['mu_class_L07']==1]
@@ -78,7 +75,6 @@ cat = cat[(cat['z_spec_S15'] > 0) | (cat['photo_z_L15'] > 0)]
78
  cat = cat[cat['MAG_VIS']<25]
79
 
80
 
81
-
82
  # %%
83
  ztarget = [cat['z_spec_S15'].values[ii] if cat['z_spec_S15'].values[ii]> 0 else cat['photo_z_L15'].values[ii] for ii in range(len(cat))]
84
  specz_or_photo = [0 if cat['z_spec_S15'].values[ii]> 0 else 1 for ii in range(len(cat))]
@@ -87,7 +83,7 @@ VISmag = cat['MAG_VIS']
87
  zsflag = cat['reliable_S15']
88
 
89
  # %%
90
- photoz_archive = archive(path = parent_dir,only_zspec=False)
91
  f, ferr = photoz_archive._extract_fluxes(catalogue= cat)
92
  col, colerr = photoz_archive._to_colors(f, ferr)
93
 
@@ -101,20 +97,21 @@ if eval_methods:
101
  for il, lab in enumerate(['z','L15','DA']):
102
 
103
  nn_features = EncoderPhotometry()
104
- nn_features.load_state_dict(torch.load(os.path.join(modules_dir,f'modelF_{lab}.pt')))
105
  nn_z = MeasureZ(num_gauss=6)
106
- nn_z.load_state_dict(torch.load(os.path.join(modules_dir,f'modelZ_{lab}.pt')))
107
 
108
- temps = Temps_module(nn_features, nn_z)
109
 
110
- z,zerr, zmode,pz, flag, odds = temps.get_pz(input_data=torch.Tensor(col),
111
- return_pz=True)
 
112
  # Create a DataFrame with the desired columns
113
- df = pd.DataFrame(np.c_[ID, VISmag,z, zmode, flag, ztarget,zsflag,zerr, specz_or_photo],
114
- columns=['ID','VISmag','z', 'zmode','zflag', 'ztarget','zsflag','zuncert','S15_L15_flag'])
115
 
116
  # Calculate additional columns or operations if needed
117
- df['zwerr'] = (df.zmode - df.ztarget) / (1 + df.ztarget)
118
 
119
  # Drop any rows with NaN values
120
  df = df.dropna()
@@ -135,36 +132,24 @@ dfs['DA']['zwerr'] = (dfs['DA'].z - dfs['DA'].ztarget) / (1 + dfs['DA'].ztarget)
135
  # %%
136
  if not eval_methods:
137
  dfs = {}
138
- dfs['z'] = pd.read_csv(os.path.join(parent_dir, 'predictions_specztraining.csv'), header=0)
139
- dfs['L15'] = pd.read_csv(os.path.join(parent_dir, 'predictions_speczL15training.csv'), header=0)
140
- dfs['DA'] = pd.read_csv(os.path.join(parent_dir, 'predictions_speczDAtraining.csv'), header=0)
141
 
142
 
143
  # %% [markdown]
144
  # ### MAKE PLOT
145
 
146
  # %%
147
- plot_photoz(df_list,
148
- nbins=8,
149
- xvariable='VISmag',
150
- metric='nmad',
151
- type_bin='bin',
152
- label_list = ['zs','zs+L15',r'TEMPS'],
153
- save=False,
154
- samp='L15'
155
- )
156
 
157
  # %%
158
  plot_photoz(df_list,
159
  nbins=8,
160
  xvariable='VISmag',
161
- metric='outliers',
162
  type_bin='bin',
163
  label_list = ['zs','zs+L15',r'TEMPS'],
164
  save=False,
165
  samp='L15'
166
  )
167
-
168
- # %%
169
-
170
- # %%
 
6
  # extension: .py
7
  # format_name: percent
8
  # format_version: '1.3'
9
+ # jupytext_version: 1.16.2
10
  # kernelspec:
11
+ # display_name: temps
12
  # language: python
13
+ # name: temps
14
  # ---
15
 
16
  # %% [markdown]
17
+ # # FIGURE METRICS
18
 
19
  # %% [markdown]
20
  # ## METRICS FOR THE DIFFERENT METHODS ON THE WIDE FIELD SAMPLE
 
43
 
44
 
45
  # %%
46
+ import temps
 
 
 
 
 
 
 
47
 
48
+ # %%
49
+ from temps.archive import Archive
50
+ from temps.utils import nmad
51
+ from temps.temps_arch import EncoderPhotometry, MeasureZ
52
+ from temps.temps import TempsModule
53
+ from temps.plots import plot_photoz
54
 
55
 
56
  # %%
 
61
 
62
  # %%
63
  #define here the directory containing the photometric catalogues
64
+ parent_dir = Path('/data/astro/scratch/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5')
65
+ modules_dir = Path('../data/models/')
66
 
67
  # %%
 
 
68
  filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
69
+ path_file = parent_dir / filename_valid # Creating the path to the file
70
+ hdu_list = fits.open(path_file)
71
  cat = Table(hdu_list[1].data).to_pandas()
72
  cat = cat[cat['FLAG_PHOT']==0]
73
  cat = cat[cat['mu_class_L07']==1]
 
75
  cat = cat[cat['MAG_VIS']<25]
76
 
77
 
 
78
  # %%
79
  ztarget = [cat['z_spec_S15'].values[ii] if cat['z_spec_S15'].values[ii]> 0 else cat['photo_z_L15'].values[ii] for ii in range(len(cat))]
80
  specz_or_photo = [0 if cat['z_spec_S15'].values[ii]> 0 else 1 for ii in range(len(cat))]
 
83
  zsflag = cat['reliable_S15']
84
 
85
  # %%
86
+ photoz_archive = Archive(path = parent_dir,only_zspec=False)
87
  f, ferr = photoz_archive._extract_fluxes(catalogue= cat)
88
  col, colerr = photoz_archive._to_colors(f, ferr)
89
 
 
97
  for il, lab in enumerate(['z','L15','DA']):
98
 
99
  nn_features = EncoderPhotometry()
100
+ nn_features.load_state_dict(modules_dir / f'modelF_{lab}.pt',map_location=torch.device('cpu')))
101
  nn_z = MeasureZ(num_gauss=6)
102
+ nn_z.load_state_dict(modules_dir / f'modelZ_{lab}.pt',map_location=torch.device('cpu')))
103
 
104
+ temps_module = TempsModule(nn_features, nn_z)
105
 
106
+ z, pz, odds = temps_module.get_pz(input_data=torch.Tensor(col),
107
+ return_pz=True,
108
+ return_flag=True)
109
  # Create a DataFrame with the desired columns
110
+ df = pd.DataFrame(np.c_[ID, VISmag,z, odds, ztarget,zsflag, specz_or_photo],
111
+ columns=['ID','VISmag','z','odds', 'ztarget','zsflag','S15_L15_flag'])
112
 
113
  # Calculate additional columns or operations if needed
114
+ df['zwerr'] = (df.z - df.ztarget) / (1 + df.ztarget)
115
 
116
  # Drop any rows with NaN values
117
  df = df.dropna()
 
132
  # %%
133
  if not eval_methods:
134
  dfs = {}
135
+ dfs['z'] = pd.read_csv(parent_dir / 'predictions_specztraining.csv', header=0)
136
+ dfs['L15'] = pd.read_csv(parent_dir / 'predictions_speczL15training.csv', header=0)
137
+ dfs['DA'] = pd.read_csv(parent_dir / 'predictions_speczDAtraining.csv', header=0)
138
 
139
 
140
  # %% [markdown]
141
  # ### MAKE PLOT
142
 
143
  # %%
144
+ df_list = [dfs['z'], dfs['L15'], dfs['DA']]
 
 
 
 
 
 
 
 
145
 
146
  # %%
147
  plot_photoz(df_list,
148
  nbins=8,
149
  xvariable='VISmag',
150
+ metric='nmad',
151
  type_bin='bin',
152
  label_list = ['zs','zs+L15',r'TEMPS'],
153
  save=False,
154
  samp='L15'
155
  )
 
 
 
 
notebooks/{Fig3_PIT_CRPS.py β†’ PIT_CRPS.py} RENAMED
@@ -1,93 +1,84 @@
1
  # ---
2
  # jupyter:
3
  # jupytext:
4
- # formats: ipynb,py:percent
5
  # text_representation:
6
  # extension: .py
7
- # format_name: percent
8
- # format_version: '1.3'
9
- # jupytext_version: 1.14.5
10
  # kernelspec:
11
- # display_name: insight
12
  # language: python
13
- # name: insight
14
  # ---
15
 
16
- # %% [markdown]
17
- # # FIGURE 3 IN THE PAPER
18
 
19
- # %% [markdown]
20
  # ## PIT AND CRPS FOR THE THREE METHODS
21
 
22
- # %% [markdown]
23
  # ### LOAD PYTHON MODULES
24
 
25
- # %%
26
  # %load_ext autoreload
27
  # %autoreload 2
28
 
29
- # %%
 
30
  import pandas as pd
31
  import numpy as np
32
  import os
33
  from astropy.io import fits
34
  from astropy.table import Table
35
  import torch
 
36
 
37
-
38
- # %%
39
  #matplotlib settings
40
  from matplotlib import rcParams
41
  import matplotlib.pyplot as plt
42
  rcParams["mathtext.fontset"] = "stix"
43
  rcParams["font.family"] = "STIXGeneral"
44
 
45
- # %%
46
- #insight modules
47
- import sys
48
- sys.path.append('../temps')
49
- #from insight_arch import EncoderPhotometry, MeasureZ
50
- #from insight import Insight_module
51
- from archive import archive
52
- from utils import nmad
53
- from plots import plot_PIT, plot_crps
54
- from temps_arch import EncoderPhotometry, MeasureZ
55
- from temps import Temps_module
56
 
 
57
 
58
- # %% [markdown]
59
  # ### LOAD DATA
60
 
61
- # %%
62
- photoz_archive = archive(path = parent_dir,
 
 
 
63
  only_zspec=False,
64
  flags_kept=[1. , 1.1, 1.4, 1.5, 2,2.1,2.4,2.5,3., 3.1, 3.4, 3.5, 4., 9. , 9.1, 9.3, 9.4, 9.5,11.1, 11.5, 12.1, 12.5, 13. , 13.1, 13.5, 14, ],
65
  target_test='L15')
66
  f_test, ferr_test, specz_test ,VIS_mag_test = photoz_archive.get_testing_data()
67
 
68
 
69
- # %% [markdown]
70
  # ## CREATE PIT; CRPS; SPECTROSCOPIC SAMPLE
71
 
72
- # %% [markdown]
73
  # This loads pre-trained models (for the sake of time). You can learn how to train the models in the Tutorial notebook.
74
 
75
- # %%
76
  # Initialize an empty dictionary to store DataFrames
77
  crps_dict = {}
78
  pit_dict = {}
79
  for il, lab in enumerate(['z','L15','DA']):
80
 
81
  nn_features = EncoderPhotometry()
82
- nn_features.load_state_dict(torch.load(os.path.join(modules_dir,f'modelF_{lab}.pt')))
83
  nn_z = MeasureZ(num_gauss=6)
84
- nn_z.load_state_dict(torch.load(os.path.join(modules_dir,f'modelZ_{lab}.pt')))
85
 
86
- temps = Temps_module(nn_features, nn_z)
87
 
88
 
89
- pit_list = temps.pit(input_data=torch.Tensor(f_test), target_data=torch.Tensor(specz_test))
90
- crps_list = temps.crps(input_data=torch.Tensor(f_test), target_data=specz_test)
91
 
92
 
93
  # Assign the DataFrame to a key in the dictionary
@@ -95,7 +86,7 @@ for il, lab in enumerate(['z','L15','DA']):
95
  pit_dict[lab] = pit_list
96
 
97
 
98
- # %%
99
  plot_PIT(pit_dict['z'],
100
  pit_dict['L15'],
101
  pit_dict['DA'],
@@ -106,7 +97,7 @@ plot_PIT(pit_dict['z'],
106
 
107
 
108
 
109
- # %%
110
  plot_crps(crps_dict['z'],
111
  crps_dict['L15'],
112
  crps_dict['DA'],
@@ -116,5 +107,6 @@ plot_crps(crps_dict['z'],
116
 
117
 
118
 
 
 
119
 
120
- # %%
 
1
  # ---
2
  # jupyter:
3
  # jupytext:
 
4
  # text_representation:
5
  # extension: .py
6
+ # format_name: light
7
+ # format_version: '1.5'
8
+ # jupytext_version: 1.16.2
9
  # kernelspec:
10
+ # display_name: temps
11
  # language: python
12
+ # name: temps
13
  # ---
14
 
15
+ # # $p(z)$ DISTRIBUTIONS
 
16
 
 
17
  # ## PIT AND CRPS FOR THE THREE METHODS
18
 
 
19
  # ### LOAD PYTHON MODULES
20
 
 
21
  # %load_ext autoreload
22
  # %autoreload 2
23
 
24
+ import temps
25
+
26
  import pandas as pd
27
  import numpy as np
28
  import os
29
  from astropy.io import fits
30
  from astropy.table import Table
31
  import torch
32
+ from pathlib import Path
33
 
 
 
34
  #matplotlib settings
35
  from matplotlib import rcParams
36
  import matplotlib.pyplot as plt
37
  rcParams["mathtext.fontset"] = "stix"
38
  rcParams["font.family"] = "STIXGeneral"
39
 
40
+ # +
41
+ from temps.temps import TempsModule
42
+ from temps.archive import Archive
43
+ from temps.utils import nmad
44
+ from temps.temps_arch import EncoderPhotometry, MeasureZ
45
+ from temps.plots import plot_photoz, plot_PIT, plot_crps
46
+
 
 
 
 
47
 
48
+ # -
49
 
 
50
  # ### LOAD DATA
51
 
52
+ #define here the directory containing the photometric catalogues
53
+ parent_dir = Path('/data/astro/scratch/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5')
54
+ modules_dir = Path('../data/models/')
55
+
56
+ photoz_archive = Archive(path = parent_dir,
57
  only_zspec=False,
58
  flags_kept=[1. , 1.1, 1.4, 1.5, 2,2.1,2.4,2.5,3., 3.1, 3.4, 3.5, 4., 9. , 9.1, 9.3, 9.4, 9.5,11.1, 11.5, 12.1, 12.5, 13. , 13.1, 13.5, 14, ],
59
  target_test='L15')
60
  f_test, ferr_test, specz_test ,VIS_mag_test = photoz_archive.get_testing_data()
61
 
62
 
 
63
  # ## CREATE PIT; CRPS; SPECTROSCOPIC SAMPLE
64
 
 
65
  # This loads pre-trained models (for the sake of time). You can learn how to train the models in the Tutorial notebook.
66
 
 
67
  # Initialize an empty dictionary to store DataFrames
68
  crps_dict = {}
69
  pit_dict = {}
70
  for il, lab in enumerate(['z','L15','DA']):
71
 
72
  nn_features = EncoderPhotometry()
73
+ nn_features.load_state_dict(torch.load(modules_dir / f'modelF_{lab}.pt',map_location=torch.device('cpu')))
74
  nn_z = MeasureZ(num_gauss=6)
75
+ nn_z.load_state_dict(torch.load(modules_dir / f'modelZ_{lab}.pt',map_location=torch.device('cpu')))
76
 
77
+ temps_module = TempsModule(nn_features, nn_z)
78
 
79
 
80
+ pit_list = temps_module.calculate_pit(input_data=torch.Tensor(f_test), target_data=torch.Tensor(specz_test))
81
+ crps_list = temps_module.calculate_crps(input_data=torch.Tensor(f_test), target_data=specz_test)
82
 
83
 
84
  # Assign the DataFrame to a key in the dictionary
 
86
  pit_dict[lab] = pit_list
87
 
88
 
89
+ # +
90
  plot_PIT(pit_dict['z'],
91
  pit_dict['L15'],
92
  pit_dict['DA'],
 
97
 
98
 
99
 
100
+ # +
101
  plot_crps(crps_dict['z'],
102
  crps_dict['L15'],
103
  crps_dict['DA'],
 
107
 
108
 
109
 
110
+ # -
111
+
112
 
 
notebooks/Qualitycut.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ---
2
+ # jupyter:
3
+ # jupytext:
4
+ # text_representation:
5
+ # extension: .py
6
+ # format_name: light
7
+ # format_version: '1.5'
8
+ # jupytext_version: 1.16.2
9
+ # kernelspec:
10
+ # display_name: temps
11
+ # language: python
12
+ # name: temps
13
+ # ---
14
+
15
+ # # QUALITY CUTS
16
+
17
+ # %load_ext autoreload
18
+ # %autoreload 2
19
+
20
+ import pandas as pd
21
+ import numpy as np
22
+ import os
23
+ import torch
24
+ from scipy import stats
25
+ from pathlib import Path
26
+
27
+ #matplotlib settings
28
+ from matplotlib import rcParams
29
+ import matplotlib.pyplot as plt
30
+ rcParams["mathtext.fontset"] = "stix"
31
+ rcParams["font.family"] = "STIXGeneral"
32
+
33
+ from temps.archive import Archive
34
+ from temps.utils import nmad, caluclate_eta
35
+ from temps.temps_arch import EncoderPhotometry, MeasureZ
36
+ from temps.temps import TempsModule
37
+
38
+
39
+ # ### LOAD DATA (ONLY SPECZ)
40
+
41
+ #define here the directory containing the photometric catalogues
42
+ parent_dir = Path('/data/astro/scratch/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5')
43
+ modules_dir = Path('../data/models/')
44
+
45
+ photoz_archive = Archive(path = parent_dir,only_zspec=True,flags_kept=[1. , 1.1, 1.4, 1.5, 2,2.1,2.4,2.5,3., 3.1, 3.4, 3.5, 4., 9. , 9.1, 9.3, 9.4, 9.5,11.1, 11.5, 12.1, 12.5, 13. , 13.1, 13.5, 14, ])
46
+ f_test_specz, ferr_test_specz, specz_test ,VIS_mag_test = photoz_archive.get_testing_data()
47
+
48
+
49
+ # ### LOAD TRAINED MODELS AND EVALUATE PDF OF RANDOM EXAMPLES
50
+
51
+ # Initialize an empty dictionary to store DataFrames
52
+ dfs = {}
53
+ pzs = np.zeros(shape = (3,11016,1000))
54
+ for il, lab in enumerate(['z','L15','DA']):
55
+
56
+ nn_features = EncoderPhotometry()
57
+ nn_features.load_state_dict(torch.load(modules_dir / f'modelF_{lab}.pt',map_location=torch.device('cpu')))
58
+ nn_z = MeasureZ(num_gauss=6)
59
+ nn_z.load_state_dict(torch.load(modules_dir / f'modelZ_{lab}.pt', map_location=torch.device('cpu')))
60
+
61
+ temps_module = TempsModule(nn_features, nn_z)
62
+
63
+ z, pz, odds = temps_module.get_pz(input_data=torch.Tensor(f_test_specz),
64
+ return_pz=True)
65
+
66
+ pzs[il] = pz
67
+
68
+ # Create a DataFrame with the desired columns
69
+ df = pd.DataFrame(np.c_[z, odds, specz_test],
70
+ columns=['z', 'odds' ,'ztarget'])
71
+
72
+ # Calculate additional columns or operations if needed
73
+ df['zwerr'] = (df.z - df.ztarget) / (1 + df.ztarget)
74
+
75
+ # Drop any rows with NaN values
76
+ df = df.dropna()
77
+
78
+ # Assign the DataFrame to a key in the dictionary
79
+ dfs[lab] = df
80
+
81
+
82
+ # ### STATS
83
+
84
+ # +
85
+ #odds_test = [0, 0.01, 0.03, 0.05, 0.07, 0.1, 0.13, 0.15]
86
+ odds_test = np.arange(0,0.15,0.01)
87
+
88
+ df = dfs['DA'].copy()
89
+ zgrid = np.linspace(0, 5, 1000)
90
+ pz = pzs[2]
91
+ # -
92
+
93
+ diff_matrix = np.abs(df.z.values[:,None] - zgrid[None,:])
94
+ idx_peak = np.argmax(pz,1)
95
+ idx = np.argmin(diff_matrix,1)
96
+
97
+ odds_cat = np.zeros(shape = (len(odds_test),len(df)))
98
+ for ii, odds_ in enumerate(odds_test):
99
+ diff_matrix_upper = np.abs((df.z.values+odds_)[:,None] - zgrid[None,:])
100
+ diff_matrix_lower = np.abs((df.z.values-odds_)[:,None] - zgrid[None,:])
101
+
102
+ idx = np.argmin(diff_matrix,1)
103
+ idx_upper = np.argmin(diff_matrix_upper,1)
104
+ idx_lower = np.argmin(diff_matrix_lower,1)
105
+
106
+ odds = []
107
+ for jj in range(len(pz)):
108
+ odds.append(pz[jj,idx_lower[jj]:(idx_upper[jj]+1)].sum())
109
+
110
+ odds_cat[ii] = np.array(odds)
111
+
112
+
113
+ odds_df = pd.DataFrame(odds_cat.T, columns=[f'odds_{x}' for x in odds_test])
114
+ df = pd.concat([df, odds_df], axis=1)
115
+
116
+
117
+ # ## statistics on ODDS
118
+
119
+ # +
120
+ scatter_odds, eta_odds,xlab_odds, oddsmean = [],[],[], []
121
+
122
+ for c in complenteness:
123
+ percentile_cutoff = df['odds'].quantile(c)
124
+
125
+ df_bin = df[(df.odds > percentile_cutoff)]
126
+
127
+ xlab_odds.append((1-c)*100)
128
+ oddsmean.append(np.mean(df_bin.odds))
129
+ scatter_odds.append(nmad(df_bin.zwerr))
130
+ eta_odds.append(caluclate_eta(df_bin))
131
+ if np.round(c,1) ==0.3:
132
+ percentiles_cutoff = [df[f'odds_{col}'].quantile(c) for col in odds_test]
133
+ scatters_odds = [nmad(df[df[f'odds_{col}'] > percentile_cutoff].zwerr) for (col, percentile_cutoff) in zip(odds_test,percentiles_cutoff)]
134
+ etas_odds = [caluclate_eta(df[df[f'odds_{col}'] > percentile_cutoff]) for (col, percentile_cutoff) in zip(odds_test,percentiles_cutoff)]
135
+
136
+
137
+
138
+
139
+ # -
140
+
141
+ df_completeness = pd.DataFrame(np.c_[xlab_odds,scatter_odds, eta_odds],
142
+ columns = ['completeness', 'sigma_odds', 'eta_odds'])
143
+
144
+ # ## PLOTS
145
+
146
+ # +
147
+ # Initialize the figure and axis
148
+ fig, ax1 = plt.subplots(figsize=(7, 5))
149
+
150
+ # First plot (Sigma) - using the left y-axis
151
+ color = 'crimson'
152
+ ax1.plot(df_completeness.completeness,
153
+ df_completeness.sigma_odds,
154
+ marker='.',
155
+ color=color,
156
+ label=r'NMAD',
157
+ ls='-',
158
+ alpha=0.5,
159
+ )
160
+
161
+
162
+ ax1.set_xlabel('Completeness', fontsize=16)
163
+ ax1.set_ylabel(r'NMAD [$\Delta z$]', color=color, fontsize=16)
164
+ ax1.tick_params(axis='x', labelsize=14)
165
+ ax1.tick_params(axis='y', which='major', labelsize = 14, width=2.5, length=3, labelcolor=color)
166
+ ax1.set_xticks(np.arange(5, 101, 10))
167
+
168
+ ax2 = ax1.twinx() # Create another y-axis that shares the same x-axis
169
+ color = 'navy'
170
+ ax2.plot(df_completeness.completeness,
171
+ df_completeness.eta_odds,
172
+ marker='.',
173
+ color=color,
174
+ label=r'$\eta$ [%]',
175
+ ls='--',
176
+ alpha=0.5)
177
+
178
+ ax2.set_ylabel(r'$\eta$ [%]', color=color, fontsize=16)
179
+
180
+ # Adjust notation to allow comparison
181
+ ax1.yaxis.get_major_formatter().set_powerlimits((0, 0)) # Adjust scientific notation for Sigma
182
+ ax2.yaxis.get_major_formatter().set_powerlimits((0, 0)) # Adjust scientific notation for Eta
183
+ ax2.tick_params(axis='x', labelsize=14)
184
+ ax2.tick_params(axis='y', which='major', labelsize = 14, width=2.5, length=3, labelcolor=color)
185
+
186
+ # Final adjustments
187
+ fig.tight_layout()
188
+ fig.legend(bbox_to_anchor = [-0.18,0.75,0.5,0.2], fontsize = 14)
189
+ #plt.savefig('Flag_nmad_eta_sigma_comparison.pdf', bbox_inches='tight')
190
+ plt.show()
191
+
192
+
193
+ # +
194
+ # Initialize the figure and axis
195
+ fig, ax1 = plt.subplots(figsize=(7, 5))
196
+
197
+ # First plot (Sigma) - using the left y-axis
198
+ color = 'crimson'
199
+ ax1.plot(odds_test,
200
+ scatters_odds,
201
+ marker='.',
202
+ color=color,
203
+ label=r'NMAD',
204
+ ls='-',
205
+ alpha=0.5,
206
+ )
207
+
208
+
209
+ ax1.set_xlabel(r'$\delta z$ (ODDS)', fontsize=16)
210
+ ax1.set_ylabel(r'NMAD [$\Delta z$]', color=color, fontsize=16)
211
+ ax1.tick_params(axis='x', labelsize=14)
212
+ ax1.tick_params(axis='y', which='major', labelsize = 14, width=2.5, length=3, labelcolor=color)
213
+ ax1.set_xticks(np.arange(0,0.16,0.02))
214
+
215
+ ax2 = ax1.twinx() # Create another y-axis that shares the same x-axis
216
+ color = 'navy'
217
+ ax2.plot(odds_test,
218
+ etas_odds,
219
+ marker='.',
220
+ color=color,
221
+ label=r'$\eta$ [%]',
222
+ ls='--',
223
+ alpha=0.5)
224
+
225
+ ax2.set_ylabel(r'$\eta$ [%]', color=color, fontsize=16)
226
+
227
+ # Adjust notation to allow comparison
228
+ ax1.yaxis.get_major_formatter().set_powerlimits((0, 0)) # Adjust scientific notation for Sigma
229
+ ax2.yaxis.get_major_formatter().set_powerlimits((0, 0)) # Adjust scientific notation for Eta
230
+ ax2.tick_params(axis='x', labelsize=14)
231
+ ax2.tick_params(axis='y', which='major', labelsize = 14, width=2.5, length=3, labelcolor=color)
232
+
233
+ # Final adjustments
234
+ fig.tight_layout()
235
+ fig.legend(bbox_to_anchor = [0.10,0.75,0.5,0.2], fontsize = 14)
236
+ #plt.savefig('ODDS_study.pdf', bbox_inches='tight')
237
+ plt.show()
238
+
239
+ # -
240
+
241
+
notebooks/Table_metrics.py CHANGED
@@ -5,11 +5,11 @@
5
  # extension: .py
6
  # format_name: light
7
  # format_version: '1.5'
8
- # jupytext_version: 1.14.5
9
  # kernelspec:
10
- # display_name: insight
11
  # language: python
12
- # name: insight
13
  # ---
14
 
15
  # # TABLE METRICS
@@ -24,6 +24,7 @@ import torch
24
  from scipy import stats
25
  from astropy.io import fits
26
  from astropy.table import Table
 
27
 
28
  #matplotlib settings
29
  from matplotlib import rcParams
@@ -31,27 +32,22 @@ import matplotlib.pyplot as plt
31
  rcParams["mathtext.fontset"] = "stix"
32
  rcParams["font.family"] = "STIXGeneral"
33
 
34
- #insight modules
35
- import sys
36
- sys.path.append('../temps')
37
- #from insight_arch import EncoderPhotometry, MeasureZ
38
- #from insight import Insight_module
39
- from archive import archive
40
- from utils import nmad, select_cut
41
- from temps_arch import EncoderPhotometry, MeasureZ
42
- from temps import Temps_module
43
 
44
 
45
  # ## LOAD DATA
46
 
47
  #define here the directory containing the photometric catalogues
48
- parent_dir = '/data/astro/scratch2/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5'
49
- modules_dir = '../data/models/'
50
 
51
  # +
52
  filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
53
 
54
- hdu_list = fits.open(os.path.join(parent_dir,filename_valid))
55
  cat = Table(hdu_list[1].data).to_pandas()
56
  cat = cat[cat['FLAG_PHOT']==0]
57
  cat = cat[cat['mu_class_L07']==1]
@@ -74,7 +70,7 @@ cat = cat[cat.ztarget>0]
74
 
75
  # ### EXTRACT PHOTOMETRY
76
 
77
- photoz_archive = archive(path = parent_dir,only_zspec=False)
78
  f, ferr = photoz_archive._extract_fluxes(catalogue= cat)
79
  col, colerr = photoz_archive._to_colors(f, ferr)
80
 
@@ -84,19 +80,19 @@ col, colerr = photoz_archive._to_colors(f, ferr)
84
  # Initialize an empty dictionary to store DataFrames
85
  lab='DA'
86
  nn_features = EncoderPhotometry()
87
- nn_features.load_state_dict(torch.load(os.path.join(modules_dir,f'modelF_{lab}.pt')))
88
  nn_z = MeasureZ(num_gauss=6)
89
- nn_z.load_state_dict(torch.load(os.path.join(modules_dir,f'modelZ_{lab}.pt')))
90
 
91
- temps = Temps_module(nn_features, nn_z)
92
 
93
- z,zerr, pz, flag, odds = temps.get_pz(input_data=torch.Tensor(col),
94
  return_pz=True)
95
 
96
 
97
  # Create a DataFrame with the desired columns
98
- df = pd.DataFrame(np.c_[z, flag, odds, cat.ztarget, cat.reliable_S15, cat.specz_or_photo],
99
- columns=['z','zflag', 'odds' ,'ztarget','reliable_S15', 'specz_or_photo'])
100
 
101
  # Calculate additional columns or operations if needed
102
  df['zwerr'] = (df.z - df.ztarget) / (1 + df.ztarget)
@@ -130,10 +126,12 @@ print(dfcuts.to_latex(float_format="%.3f",
130
 
131
  df_euclid = df[(df.z >0.2)&(df.z < 2.6)]
132
 
 
 
133
  # +
134
  df_selected, cut, dfcuts = select_cut(df_euclid,
135
  completenss_lim=None,
136
- nmad_lim=0.055,
137
  outliers_lim=None,
138
  return_df=True)
139
 
 
5
  # extension: .py
6
  # format_name: light
7
  # format_version: '1.5'
8
+ # jupytext_version: 1.16.2
9
  # kernelspec:
10
+ # display_name: temps
11
  # language: python
12
+ # name: temps
13
  # ---
14
 
15
  # # TABLE METRICS
 
24
  from scipy import stats
25
  from astropy.io import fits
26
  from astropy.table import Table
27
+ from pathlib import Path
28
 
29
  #matplotlib settings
30
  from matplotlib import rcParams
 
32
  rcParams["mathtext.fontset"] = "stix"
33
  rcParams["font.family"] = "STIXGeneral"
34
 
35
+ from temps.archive import Archive
36
+ from temps.utils import nmad, select_cut
37
+ from temps.temps_arch import EncoderPhotometry, MeasureZ
38
+ from temps.temps import TempsModule
 
 
 
 
 
39
 
40
 
41
  # ## LOAD DATA
42
 
43
  #define here the directory containing the photometric catalogues
44
+ parent_dir = Path('/data/astro/scratch/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5')
45
+ modules_dir = Path('../data/models/')
46
 
47
  # +
48
  filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
49
 
50
+ hdu_list = fits.open(parent_dir / filename_valid)
51
  cat = Table(hdu_list[1].data).to_pandas()
52
  cat = cat[cat['FLAG_PHOT']==0]
53
  cat = cat[cat['mu_class_L07']==1]
 
70
 
71
  # ### EXTRACT PHOTOMETRY
72
 
73
+ photoz_archive = Archive(path = parent_dir,only_zspec=False)
74
  f, ferr = photoz_archive._extract_fluxes(catalogue= cat)
75
  col, colerr = photoz_archive._to_colors(f, ferr)
76
 
 
80
  # Initialize an empty dictionary to store DataFrames
81
  lab='DA'
82
  nn_features = EncoderPhotometry()
83
+ nn_features.load_state_dict(torch.load(modules_dir / f'modelF_{lab}.pt', map_location=torch.device('cpu')))
84
  nn_z = MeasureZ(num_gauss=6)
85
+ nn_z.load_state_dict(torch.load(modules_dir / f'modelZ_{lab}.pt', map_location=torch.device('cpu')))
86
 
87
+ temps_module = TempsModule(nn_features, nn_z)
88
 
89
+ z, pz, odds = temps_module.get_pz(input_data=torch.Tensor(col),
90
  return_pz=True)
91
 
92
 
93
  # Create a DataFrame with the desired columns
94
+ df = pd.DataFrame(np.c_[z, odds, cat.ztarget, cat.reliable_S15, cat.specz_or_photo],
95
+ columns=['z', 'odds' ,'ztarget','reliable_S15', 'specz_or_photo'])
96
 
97
  # Calculate additional columns or operations if needed
98
  df['zwerr'] = (df.z - df.ztarget) / (1 + df.ztarget)
 
126
 
127
  df_euclid = df[(df.z >0.2)&(df.z < 2.6)]
128
 
129
+ df_euclid
130
+
131
  # +
132
  df_selected, cut, dfcuts = select_cut(df_euclid,
133
  completenss_lim=None,
134
+ nmad_lim= 0.05,
135
  outliers_lim=None,
136
  return_df=True)
137
 
notebooks/nz.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ---
2
+ # jupyter:
3
+ # jupytext:
4
+ # text_representation:
5
+ # extension: .py
6
+ # format_name: light
7
+ # format_version: '1.5'
8
+ # jupytext_version: 1.16.2
9
+ # kernelspec:
10
+ # display_name: temps
11
+ # language: python
12
+ # name: temps
13
+ # ---
14
+
15
+ # # FIGURE 5 IN THE PAPER
16
+
17
+ # ## n(z) distributions
18
+
19
+ # %load_ext autoreload
20
+ # %autoreload 2
21
+
22
+ import pandas as pd
23
+ import numpy as np
24
+ from astropy.io import fits
25
+ from astropy.table import Table
26
+ import torch
27
+ from pathlib import Path
28
+
29
+ #matplotlib settings
30
+ from matplotlib import rcParams
31
+ import matplotlib.pyplot as plt
32
+ rcParams["mathtext.fontset"] = "stix"
33
+ rcParams["font.family"] = "STIXGeneral"
34
+
35
+ from temps.archive import Archive
36
+ from temps.utils import nmad
37
+ from temps.temps_arch import EncoderPhotometry, MeasureZ
38
+ from temps.temps import TempsModule
39
+ from temps.plots import plot_nz
40
+
41
+ eval_methods=False
42
+
43
+ # ### LOAD DATA
44
+
45
+ #define here the directory containing the photometric catalogues
46
+ parent_dir = Path('/data/astro/scratch/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5')
47
+ modules_dir = Path('../data/models/')
48
+
49
+ # +
50
+ filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
51
+
52
+ hdu_list = fits.open(parent_dir / filename_valid)
53
+ cat = Table(hdu_list[1].data).to_pandas()
54
+ cat = cat[cat['FLAG_PHOT']==0]
55
+ cat = cat[cat['mu_class_L07']==1]
56
+ cat = cat[(cat['z_spec_S15'] > 0) | (cat['photo_z_L15'] > 0)]
57
+ cat = cat[cat['MAG_VIS']<25]
58
+
59
+ # -
60
+
61
+ ztarget = [cat['z_spec_S15'].values[ii] if cat['z_spec_S15'].values[ii]> 0 else cat['photo_z_L15'].values[ii] for ii in range(len(cat))]
62
+ specz_or_photo = [0 if cat['z_spec_S15'].values[ii]> 0 else 1 for ii in range(len(cat))]
63
+ ID = cat['ID']
64
+ VISmag = cat['MAG_VIS']
65
+ zsflag = cat['reliable_S15']
66
+
67
+ photoz_archive = Archive(path = parent_dir,only_zspec=False)
68
+ f, ferr = photoz_archive._extract_fluxes(catalogue= cat)
69
+ col, colerr = photoz_archive._to_colors(f, ferr)
70
+
71
+ # ### LOAD TRAINED MODELS AND EVALUATE PDFs AND REDSHIFT
72
+
73
+ if eval_methods:
74
+ dfs = {}
75
+
76
+ for il, lab in enumerate(['z','L15','DA']):
77
+
78
+ nn_features = EncoderPhotometry()
79
+ nn_features.load_state_dict(torch.load(modules_dir / f'modelF_{lab}.pt',map_location=torch.device('cpu')))
80
+ nn_z = MeasureZ(num_gauss=6)
81
+ nn_z.load_state_dict(torch.load(modules_dir / f'modelZ_{lab}.pt',map_location=torch.device('cpu')))
82
+
83
+ temps_module = TempsModule(nn_features, nn_z)
84
+
85
+ z, pz, odds = temps_module.get_pz(input_data=torch.Tensor(col),
86
+ return_pz=True)
87
+ # Create a DataFrame with the desired columns
88
+ df = pd.DataFrame(np.c_[ID, VISmag,z, odds, ztarget,zsflag, specz_or_photo],
89
+ columns=['ID','VISmag','z','odds', 'ztarget','zsflag','S15_L15_flag'])
90
+
91
+ # Calculate additional columns or operations if needed
92
+ df['zwerr'] = (df.z - df.ztarget) / (1 + df.ztarget)
93
+
94
+ # Drop any rows with NaN values
95
+ df = df.dropna()
96
+
97
+ # Assign the DataFrame to a key in the dictionary
98
+ dfs[lab] = df
99
+
100
+
101
+ # ### LOAD CATALOGUES IF AVAILABLE
102
+
103
+ if not eval_methods:
104
+
105
+ df_zs = pd.read_csv(parent_dir / 'predictions_specztraining.csv', header=0)
106
+ df_zsL15 = pd.read_csv(parent_dir / 'predictions_speczL15training.csv', header=0)
107
+ df_DA = pd.read_csv(parent_dir / 'predictions_speczDAtraining.csv', header=0)
108
+
109
+
110
+ dfs = {}
111
+ dfs['z'] = df_zs
112
+ dfs['L15'] = df_zsL15
113
+ dfs['DA'] = df_DA
114
+
115
+ # +
116
+ import matplotlib.pyplot as plt
117
+ from matplotlib import gridspec
118
+
119
+ # Create figure and grid specification
120
+ fig = plt.figure(figsize=(8, 10))
121
+ gs = gridspec.GridSpec(5, 1, height_ratios=[0.1, 1, 1,1,1])
122
+
123
+ # Upper panel (very thin) with shaded areas
124
+ ax1 = plt.subplot(gs[0])
125
+ ax1.set_yticks([])
126
+
127
+ ax1.set_ylabel('Bins', fontsize=10)
128
+
129
+ # Define the ranges for shaded areas
130
+ #z_ranges = [[0.15, 0.35], [0.35, 0.55], [0.55, 0.85], [0.85, 1.05], [1.05, 1.35],
131
+ # [1.35, 1.55],# [1.55, 1.85], [1.85, 2], [2, 2.5], [2.5, 3], [3, 4]]
132
+
133
+ z_ranges = [[0.15, 0.5], [0.5, 1], [1, 1.5], [1.5,2]]#, [2, 3], [3,4]]#,
134
+ #[1.35, 1.55],# [1.55, 1.85], [1.85, 2], [2, 2.5], [2.5, 3], [3, 4]]
135
+
136
+ colors = ['deepskyblue', 'forestgreen', 'coral', 'grey', 'pink', 'goldenrod',
137
+ 'cyan', 'seagreen', 'salmon', 'steelblue', 'orange']
138
+
139
+ # Plot shaded areas
140
+ x_values = [0, 1, 2] # Example x values, adjust as needed
141
+ for i, (start, end) in enumerate(z_ranges):
142
+ ax1.fill_betweenx(x_values, start, end, color=colors[i], alpha=0.5)
143
+
144
+ # Middle panel (equally thick)
145
+ ax2 = plt.subplot(gs[1])
146
+ for i, (start, end) in enumerate(z_ranges):
147
+ dfplot_z = dfs['z'][(dfs['z']['ztarget'] > start) & (dfs['z']['ztarget'] < end)]
148
+ ax2.hist(dfplot_z.ztarget, bins=50, color=colors[i], histtype='step', linestyle='-', density=True, range=(0, 4))
149
+
150
+ # Bottom panel (equally thick)
151
+ ax3 = plt.subplot(gs[2])
152
+ for i, (start, end) in enumerate(z_ranges):
153
+ dfplot_z = dfs['z'][(dfs['z']['z'] > start) & (dfs['z']['z'] < end)]
154
+ ax3.hist(dfplot_z.ztarget, bins=50, color=colors[i], histtype='step', linestyle='-', density=True, range=(0, 4))
155
+
156
+ # Bottom panel (equally thick)
157
+ ax4 = plt.subplot(gs[3])
158
+ for i, (start, end) in enumerate(z_ranges):
159
+ dfplot_z = dfs['L15'][(dfs['L15']['z'] > start) & (dfs['L15']['z'] < end)]
160
+ print(len(dfplot_z))
161
+ ax4.hist(dfplot_z.ztarget, bins=50, color=colors[i], histtype='step', linestyle='-', density=True, range=(0, 4))
162
+
163
+ ax5 = plt.subplot(gs[4])
164
+ for i, (start, end) in enumerate(z_ranges):
165
+ dfplot_z = dfs['DA'][(dfs['DA']['z'] > start) & (dfs['DA']['z'] < end)]
166
+ ax5.hist(dfplot_z.ztarget, bins=50, color=colors[i], histtype='step', linestyle='-', density=True, range=(0, 4))
167
+
168
+ plt.tight_layout()
169
+ plt.show()
170
+
171
+ # -
172
+
173
+ def plot_nz(df_list,
174
+ zcuts = [0.1, 0.5, 1, 1.5, 2, 3, 4],
175
+ save=False):
176
+ # Plot properties
177
+ plt.rcParams['font.family'] = 'serif'
178
+ plt.rcParams['font.size'] = 16
179
+
180
+ cmap = plt.get_cmap('Dark2') # Choose a colormap for coloring lines
181
+
182
+ # Create subplots
183
+ fig, axs = plt.subplots(3, 1, figsize=(20, 8), sharex=True)
184
+
185
+ for i, df in enumerate(df_list):
186
+ dfplot = df_list[i].copy() # Assuming df_list contains dataframes
187
+ ax = axs[i] # Selecting the appropriate subplot
188
+
189
+ for iz in range(len(zcuts)-1):
190
+ dfplot_z = dfplot[(dfplot['ztarget'] > zcuts[iz]) & (dfplot['ztarget'] < zcuts[iz + 1])]
191
+ color = cmap(iz) # Get a different color for each redshift
192
+
193
+ zt_mean = np.median(dfplot_z.ztarget.values)
194
+ zp_mean = np.median(dfplot_z.z.values)
195
+
196
+
197
+ # Plot histogram on the selected subplot
198
+ ax.hist(dfplot_z.z, bins=50, color=color, histtype='step', linestyle='-', density=True, range=(0, 4))
199
+ ax.axvline(zt_mean, color=color, linestyle='-', lw=2)
200
+ ax.axvline(zp_mean, color=color, linestyle='--', lw=2)
201
+
202
+ ax.set_ylabel(f'Frequency', fontsize=14)
203
+ ax.grid(False)
204
+ ax.set_xlim(0, 3.5)
205
+
206
+ axs[-1].set_xlabel(f'$z$', fontsize=18)
207
+
208
+ if save:
209
+ plt.savefig(f'nz_hist.pdf', dpi=300, bbox_inches='tight')
210
+
211
+ plt.show()
212
+
213
+ plot_nz(df_list)
214
+
215
+
notebooks/{Fig4_pz_examples.py β†’ pz_examples.py} RENAMED
@@ -1,70 +1,55 @@
1
  # ---
2
  # jupyter:
3
  # jupytext:
4
- # formats: ipynb,py:percent
5
  # text_representation:
6
  # extension: .py
7
- # format_name: percent
8
- # format_version: '1.3'
9
- # jupytext_version: 1.14.5
10
  # kernelspec:
11
- # display_name: insight
12
  # language: python
13
- # name: insight
14
  # ---
15
 
16
- # %% [markdown]
17
- # # FIGURE 4 IN THE PAPER
18
 
19
- # %% [markdown]
20
  # ## IMPACT OF TEMPS ON CONCRETE P(Z) EXAMPLES
21
 
22
- # %% [markdown]
23
  # ### LOAD PYTHON MODULES
24
 
25
- # %%
26
  # %load_ext autoreload
27
  # %autoreload 2
28
 
29
- # %%
30
  import pandas as pd
31
  import numpy as np
32
  import os
33
  from astropy.io import fits
34
  from astropy.table import Table
35
  import torch
 
36
 
37
- # %%
38
  #matplotlib settings
39
  from matplotlib import rcParams
40
  import matplotlib.pyplot as plt
41
  rcParams["mathtext.fontset"] = "stix"
42
  rcParams["font.family"] = "STIXGeneral"
43
 
44
- # %%
45
- #insight modules
46
- import sys
47
- sys.path.append('../temps')
48
- #from insight_arch import EncoderPhotometry, MeasureZ
49
- #from insight import Insight_module
50
- from archive import archive
51
- from utils import nmad
52
- from temps_arch import EncoderPhotometry, MeasureZ
53
- from temps import Temps_module
54
 
55
 
56
- # %% [markdown]
57
  # ### LOAD DATA
58
 
59
- # %%
60
  #define here the directory containing the photometric catalogues
61
- parent_dir = '/data/astro/scratch2/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5'
62
- modules_dir = '../data/models/'
63
 
64
- # %%
65
  filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
66
-
67
- hdu_list = fits.open(os.path.join(parent_dir,filename_valid))
68
  cat = Table(hdu_list[1].data).to_pandas()
69
  cat = cat[cat['FLAG_PHOT']==0]
70
  cat = cat[cat['mu_class_L07']==1]
@@ -72,46 +57,41 @@ cat = cat[(cat['z_spec_S15'] > 0) | (cat['photo_z_L15'] > 0)]
72
  cat = cat[cat['MAG_VIS']<25]
73
 
74
 
75
- # %%
76
  ztarget = [cat['z_spec_S15'].values[ii] if cat['z_spec_S15'].values[ii]> 0 else cat['photo_z_L15'].values[ii] for ii in range(len(cat))]
77
  specz_or_photo = [0 if cat['z_spec_S15'].values[ii]> 0 else 1 for ii in range(len(cat))]
78
  ID = cat['ID']
79
  VISmag = cat['MAG_VIS']
80
  zsflag = cat['reliable_S15']
81
 
82
- # %%
83
- photoz_archive = archive(path = parent_dir,only_zspec=False)
84
  f, ferr = photoz_archive._extract_fluxes(catalogue= cat)
85
  col, colerr = photoz_archive._to_colors(f, ferr)
86
 
87
- # %% [markdown]
88
  # ### LOAD TRAINED MODELS AND EVALUATE PDF OF RANDOM EXAMPLES
89
 
90
- # %% [markdown]
91
  # The notebook 'Tutorial_temps' gives an example of how to train and save models.
92
 
93
- # %%
94
  # Initialize an empty dictionary to store DataFrames
95
  ii = np.random.randint(0,len(col),1)
96
  pz_dict = {}
97
  for il, lab in enumerate(['z','L15','DA']):
98
 
99
  nn_features = EncoderPhotometry()
100
- nn_features.load_state_dict(torch.load(os.path.join(modules_dir,f'modelF_{lab}.pt')))
101
  nn_z = MeasureZ(num_gauss=6)
102
- nn_z.load_state_dict(torch.load(os.path.join(modules_dir,f'modelZ_{lab}.pt')))
103
 
104
- temps = Temps_module(nn_features, nn_z)
105
 
106
 
107
- z,zerr, pz, flag,_ = temps.get_pz(input_data=torch.Tensor(col[ii]),return_pz=True)
108
 
109
 
110
  # Assign the DataFrame to a key in the dictionary
111
  pz_dict[lab] = pz
112
 
113
 
114
- # %%
115
  cmap = plt.get_cmap('Dark2')
116
 
117
  plt.plot(np.linspace(0,5,1000),pz_dict['z'][0],label='z', color = cmap(0), ls ='--')
@@ -124,5 +104,6 @@ plt.legend()
124
  plt.xlabel(r'$z$', fontsize=14)
125
  plt.ylabel('Probability', fontsize=14)
126
  #plt.savefig(f'pz_{ii[0]}.pdf', bbox_inches='tight')
 
 
127
 
128
- # %%
 
1
  # ---
2
  # jupyter:
3
  # jupytext:
 
4
  # text_representation:
5
  # extension: .py
6
+ # format_name: light
7
+ # format_version: '1.5'
8
+ # jupytext_version: 1.16.2
9
  # kernelspec:
10
+ # display_name: temps
11
  # language: python
12
+ # name: temps
13
  # ---
14
 
15
+ # # $p(z)$ examples
 
16
 
 
17
  # ## IMPACT OF TEMPS ON CONCRETE P(Z) EXAMPLES
18
 
 
19
  # ### LOAD PYTHON MODULES
20
 
 
21
  # %load_ext autoreload
22
  # %autoreload 2
23
 
 
24
  import pandas as pd
25
  import numpy as np
26
  import os
27
  from astropy.io import fits
28
  from astropy.table import Table
29
  import torch
30
+ from pathlib import Path
31
 
 
32
  #matplotlib settings
33
  from matplotlib import rcParams
34
  import matplotlib.pyplot as plt
35
  rcParams["mathtext.fontset"] = "stix"
36
  rcParams["font.family"] = "STIXGeneral"
37
 
38
+ from temps.archive import Archive
39
+ from temps.utils import nmad
40
+ from temps.temps_arch import EncoderPhotometry, MeasureZ
41
+ from temps.temps import TempsModule
 
 
 
 
 
 
42
 
43
 
 
44
  # ### LOAD DATA
45
 
 
46
  #define here the directory containing the photometric catalogues
47
+ parent_dir = Path('/data/astro/scratch/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5')
48
+ modules_dir = Path('../data/models/')
49
 
 
50
  filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
51
+ path_file = parent_dir / filename_valid # Creating the path to the file
52
+ hdu_list = fits.open(path_file)
53
  cat = Table(hdu_list[1].data).to_pandas()
54
  cat = cat[cat['FLAG_PHOT']==0]
55
  cat = cat[cat['mu_class_L07']==1]
 
57
  cat = cat[cat['MAG_VIS']<25]
58
 
59
 
 
60
  ztarget = [cat['z_spec_S15'].values[ii] if cat['z_spec_S15'].values[ii]> 0 else cat['photo_z_L15'].values[ii] for ii in range(len(cat))]
61
  specz_or_photo = [0 if cat['z_spec_S15'].values[ii]> 0 else 1 for ii in range(len(cat))]
62
  ID = cat['ID']
63
  VISmag = cat['MAG_VIS']
64
  zsflag = cat['reliable_S15']
65
 
66
+ photoz_archive = Archive(path = parent_dir,only_zspec=False)
 
67
  f, ferr = photoz_archive._extract_fluxes(catalogue= cat)
68
  col, colerr = photoz_archive._to_colors(f, ferr)
69
 
 
70
  # ### LOAD TRAINED MODELS AND EVALUATE PDF OF RANDOM EXAMPLES
71
 
 
72
  # The notebook 'Tutorial_temps' gives an example of how to train and save models.
73
 
 
74
  # Initialize an empty dictionary to store DataFrames
75
  ii = np.random.randint(0,len(col),1)
76
  pz_dict = {}
77
  for il, lab in enumerate(['z','L15','DA']):
78
 
79
  nn_features = EncoderPhotometry()
80
+ nn_features.load_state_dict(torch.load(modules_dir / f'modelF_{lab}.pt',map_location=torch.device('cpu')))
81
  nn_z = MeasureZ(num_gauss=6)
82
+ nn_z.load_state_dict(torch.load(modules_dir / f'modelZ_{lab}.pt',map_location=torch.device('cpu')))
83
 
84
+ temps_module = TempsModule(nn_features, nn_z)
85
 
86
 
87
+ z, pz, fodds = temps_module.get_pz(input_data=torch.Tensor(col[ii]),return_pz=True)
88
 
89
 
90
  # Assign the DataFrame to a key in the dictionary
91
  pz_dict[lab] = pz
92
 
93
 
94
+ # +
95
  cmap = plt.get_cmap('Dark2')
96
 
97
  plt.plot(np.linspace(0,5,1000),pz_dict['z'][0],label='z', color = cmap(0), ls ='--')
 
104
  plt.xlabel(r'$z$', fontsize=14)
105
  plt.ylabel('Probability', fontsize=14)
106
  #plt.savefig(f'pz_{ii[0]}.pdf', bbox_inches='tight')
107
+ # -
108
+
109
 
 
temps/archive.py CHANGED
@@ -1,42 +1,62 @@
1
  import numpy as np
2
  import pandas as pd
3
  from astropy.io import fits
4
- import os
5
  from astropy.table import Table
6
  from scipy.spatial import KDTree
 
 
 
 
7
 
8
- import matplotlib.pyplot as plt
9
 
10
- from matplotlib import rcParams
11
  rcParams["mathtext.fontset"] = "stix"
12
  rcParams["font.family"] = "STIXGeneral"
13
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- class archive():
16
- def __init__(self, path, aperture=2, drop_stars=True, clean_photometry=True, convert_colors=True, extinction_corr=True, only_zspec=True, target_test='specz', flags_kept=[3,3.1,3.4,3.5,4]):
17
 
 
18
  self.aperture = aperture
19
- self.flags_kept=flags_kept
 
20
 
 
 
21
 
 
 
 
22
 
23
- filename_calib='euclid_cosmos_DC2_S1_v2.1_calib_clean.fits'
24
- filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
 
 
25
 
26
- hdu_list = fits.open(os.path.join(path,filename_calib))
27
- cat = Table(hdu_list[1].data).to_pandas()
28
- cat = cat[(cat['z_spec_S15'] > 0) | (cat['photo_z_L15'] > 0)]
29
-
30
 
31
- hdu_list = fits.open(os.path.join(path,filename_valid))
32
- cat_test = Table(hdu_list[1].data).to_pandas()
 
33
 
34
 
35
  if drop_stars==True:
 
36
  cat = cat[cat.mu_class_L07==1]
37
  cat_test = cat_test[cat_test.mu_class_L07==1]
38
 
39
  if clean_photometry==True:
 
40
  cat = self._clean_photometry(cat)
41
  cat_test = self._clean_photometry(cat_test)
42
 
@@ -55,6 +75,7 @@ class archive():
55
 
56
 
57
  self._set_training_data(cat,
 
58
  only_zspec=only_zspec,
59
  extinction_corr=extinction_corr,
60
  convert_colors=convert_colors)
@@ -65,17 +86,51 @@ class archive():
65
 
66
 
67
  def _extract_fluxes(self,catalogue):
68
- columns_f = [f'FLUX_{x}_{self.aperture}' for x in ['G','R','I','Z','Y','J','H']]
69
- columns_ferr = [f'FLUXERR_{x}_{self.aperture}' for x in ['G','R','I','Z','Y','J','H']]
 
 
 
 
70
 
71
  f = catalogue[columns_f].values
72
  ferr = catalogue[columns_ferr].values
73
  return f, ferr
74
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  def _to_colors(self, flux, fluxerr):
76
  """ Convert fluxes to colors"""
77
- color = flux[:,:-1] / flux[:,1:]
78
- color_err = fluxerr[:,:-1]**2 / flux[:,1:]**2 + flux[:,:-1]**2 / flux[:,1:]**4 * fluxerr[:,:-1]**2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  return color,color_err
80
 
81
  def _set_combiend_target(self, catalogue):
@@ -92,13 +147,20 @@ class archive():
92
 
93
  return catalogue
94
 
95
- def _correct_extinction(self,catalogue, f):
96
  """Corrects for extinction"""
97
  ext_correction_cols = [f'EB_V_corr_FLUX_{x}' for x in ['G','R','I','Z','Y','J','H']]
98
- ext_correction = catalogue[ext_correction_cols].values
 
 
 
 
99
 
100
  f = f * ext_correction
101
- return f
 
 
 
102
 
103
  def _select_only_zspec(self,catalogue,cat_flag=None):
104
  """Selects only galaxies with spectroscopic redshift"""
@@ -158,22 +220,24 @@ class archive():
158
  return catalogue_valid
159
 
160
 
161
- def _set_training_data(self,catalogue, only_zspec=True, extinction_corr=True, convert_colors=True):
162
 
163
- cat_da = self._exclude_only_zspec(catalogue)
164
  target_z_train_DA = cat_da['photo_z_L15'].values
165
 
166
 
167
  if only_zspec:
 
168
  catalogue = self._select_only_zspec(catalogue, cat_flag='Calib')
169
  catalogue = self._clean_zspec_sample(catalogue, flags_kept=self.flags_kept)
170
  else:
 
171
  catalogue = self._take_zspec_and_photoz(catalogue, cat_flag='Calib')
172
 
173
 
174
  self.cat_train=catalogue
175
  f, ferr = self._extract_fluxes(catalogue)
176
-
177
  f_DA, ferr_DA = self._extract_fluxes(cat_da)
178
  idx = np.random.randint(0, len(f_DA), len(f))
179
  f_DA, ferr_DA = f_DA[idx], ferr_DA[idx]
@@ -182,9 +246,11 @@ class archive():
182
 
183
 
184
  if extinction_corr==True:
 
185
  f = self._correct_extinction(catalogue,f)
186
-
187
  if convert_colors==True:
 
188
  col, colerr = self._to_colors(f, ferr)
189
  col_DA, colerr_DA = self._to_colors(f_DA, ferr_DA)
190
 
 
1
  import numpy as np
2
  import pandas as pd
3
  from astropy.io import fits
 
4
  from astropy.table import Table
5
  from scipy.spatial import KDTree
6
+ from matplotlib import pyplot as plt
7
+ from matplotlib import rcParams
8
+ from pathlib import Path
9
+ from loguru import logger
10
 
 
11
 
 
12
  rcParams["mathtext.fontset"] = "stix"
13
  rcParams["font.family"] = "STIXGeneral"
14
 
15
+ class Archive:
16
+ def __init__(self, path,
17
+ aperture=2,
18
+ drop_stars=True,
19
+ clean_photometry=True,
20
+ convert_colors=True,
21
+ extinction_corr=True,
22
+ only_zspec=True,
23
+ all_apertures=False,
24
+ target_test='specz', flags_kept=[3, 3.1, 3.4, 3.5, 4]):
25
 
 
 
26
 
27
+ logger.info("Starting archive")
28
  self.aperture = aperture
29
+ self.all_apertures = all_apertures
30
+ self.flags_kept = flags_kept
31
 
32
+ filename_calib = 'euclid_cosmos_DC2_S1_v2.1_calib_clean.fits'
33
+ filename_valid = 'euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
34
 
35
+ # Use Path for file handling
36
+ path_calib = Path(path) / filename_calib
37
+ path_valid = Path(path) / filename_valid
38
 
39
+ # Open the calibration FITS file
40
+ with fits.open(path_calib) as hdu_list:
41
+ cat = Table(hdu_list[1].data).to_pandas()
42
+ cat = cat[(cat['z_spec_S15'] > 0) | (cat['photo_z_L15'] > 0)]
43
 
44
+ # Open the validation FITS file
45
+ with fits.open(path_valid) as hdu_list:
46
+ cat_test = Table(hdu_list[1].data).to_pandas()
 
47
 
48
+ # Store the catalogs for later use
49
+ self.cat = cat
50
+ self.cat_test = cat_test
51
 
52
 
53
  if drop_stars==True:
54
+ logger.info("dropping stars...")
55
  cat = cat[cat.mu_class_L07==1]
56
  cat_test = cat_test[cat_test.mu_class_L07==1]
57
 
58
  if clean_photometry==True:
59
+ logger.info("cleaning stars...")
60
  cat = self._clean_photometry(cat)
61
  cat_test = self._clean_photometry(cat_test)
62
 
 
75
 
76
 
77
  self._set_training_data(cat,
78
+ cat_test,
79
  only_zspec=only_zspec,
80
  extinction_corr=extinction_corr,
81
  convert_colors=convert_colors)
 
86
 
87
 
88
  def _extract_fluxes(self,catalogue):
89
+ if self.all_apertures:
90
+ columns_f = [f'FLUX_{x}_{a}' for a in [1,2,3] for x in ['G','R','I','Z','Y','J','H']]
91
+ columns_ferr = [f'FLUXERR_{x}_{a}' for a in [1,2,3] for x in ['G','R','I','Z','Y','J','H'] ]
92
+ else:
93
+ columns_f = [f'FLUX_{x}_{self.aperture}' for x in ['G','R','I','Z','Y','J','H']]
94
+ columns_ferr = [f'FLUXERR_{x}_{self.aperture}' for x in ['G','R','I','Z','Y','J','H']]
95
 
96
  f = catalogue[columns_f].values
97
  ferr = catalogue[columns_ferr].values
98
  return f, ferr
99
 
100
+ def _extract_magnitudes(self,catalogue):
101
+ if self.all_apertures:
102
+ columns_m = [f'MAG_{x}_{a}' for a in [1,2,3] for x in ['G','R','I','Z','Y','J','H']]
103
+ columns_merr = [f'MAGERR_{x}_{a}' for a in [1,2,3] for x in ['G','R','I','Z','Y','J','H'] ]
104
+ else:
105
+ columns_m = [f'MAG_{x}_{self.aperture}' for x in ['G','R','I','Z','Y','J','H']]
106
+ columns_merr = [f'MAGERR_{x}_{self.aperture}' for x in ['G','R','I','Z','Y','J','H']]
107
+
108
+ m = catalogue[columns_m].values
109
+ merr = catalogue[columns_merr].values
110
+ return m, merr
111
+
112
  def _to_colors(self, flux, fluxerr):
113
  """ Convert fluxes to colors"""
114
+
115
+ if self.all_apertures:
116
+
117
+ for a in range(3):
118
+ lim1 = 7*a
119
+ lim2 = 7*(a+1)
120
+ c = flux[:,lim1:(lim2-1)] / flux[:,(lim1+1):lim2]
121
+ cerr = np.sqrt((fluxerr[:,lim1:(lim2-1)]/ flux[:,(lim1+1):lim2])**2 + (flux[:,lim1:(lim2-1)] / flux[:,(lim1+1):lim2]**2)**2 * fluxerr[:,(lim1+1):lim2]**2)
122
+
123
+ if a==0:
124
+ color = c
125
+ color_err = cerr
126
+ else:
127
+ color = np.concatenate((color,c),axis=1)
128
+ color_err = np.concatenate((color_err,cerr),axis=1)
129
+
130
+ else:
131
+ color = flux[:,:-1] / flux[:,1:]
132
+
133
+ color_err = np.sqrt((fluxerr[:,:-1]/ flux[:,1:])**2 + (flux[:,:-1] / flux[:,1:]**2)**2 * fluxerr[:,1:]**2)
134
  return color,color_err
135
 
136
  def _set_combiend_target(self, catalogue):
 
147
 
148
  return catalogue
149
 
150
+ def _correct_extinction(self,catalogue, f, return_ext_corr=False):
151
  """Corrects for extinction"""
152
  ext_correction_cols = [f'EB_V_corr_FLUX_{x}' for x in ['G','R','I','Z','Y','J','H']]
153
+ if self.all_apertures:
154
+ ext_correction = catalogue[ext_correction_cols].values
155
+ ext_correction = np.concatenate((ext_correction,ext_correction,ext_correction),axis=1)
156
+ else:
157
+ ext_correction = catalogue[ext_correction_cols].values
158
 
159
  f = f * ext_correction
160
+ if return_ext_corr:
161
+ return f, ext_correction
162
+ else:
163
+ return f
164
 
165
  def _select_only_zspec(self,catalogue,cat_flag=None):
166
  """Selects only galaxies with spectroscopic redshift"""
 
220
  return catalogue_valid
221
 
222
 
223
+ def _set_training_data(self,catalogue, catalogue_da, only_zspec=True, extinction_corr=True, convert_colors=True):
224
 
225
+ cat_da = self._exclude_only_zspec(catalogue_da)
226
  target_z_train_DA = cat_da['photo_z_L15'].values
227
 
228
 
229
  if only_zspec:
230
+ logger.info("Selecting only galaxies with spectroscopic redshift")
231
  catalogue = self._select_only_zspec(catalogue, cat_flag='Calib')
232
  catalogue = self._clean_zspec_sample(catalogue, flags_kept=self.flags_kept)
233
  else:
234
+ logger.info("Selecting galaxies with spectroscopic redshift and high-precision photo-z")
235
  catalogue = self._take_zspec_and_photoz(catalogue, cat_flag='Calib')
236
 
237
 
238
  self.cat_train=catalogue
239
  f, ferr = self._extract_fluxes(catalogue)
240
+
241
  f_DA, ferr_DA = self._extract_fluxes(cat_da)
242
  idx = np.random.randint(0, len(f_DA), len(f))
243
  f_DA, ferr_DA = f_DA[idx], ferr_DA[idx]
 
246
 
247
 
248
  if extinction_corr==True:
249
+ logger.info("Correcting MW extinction")
250
  f = self._correct_extinction(catalogue,f)
251
+
252
  if convert_colors==True:
253
+ logger.info("Converting to colors")
254
  col, colerr = self._to_colors(f, ferr)
255
  col_DA, colerr_DA = self._to_colors(f_DA, ferr_DA)
256
 
temps/plots.py CHANGED
@@ -1,7 +1,7 @@
1
  import numpy as np
2
  import pandas as pd
3
  import matplotlib.pyplot as plt
4
- from utils import nmad
5
 
6
  import numpy as np
7
  import matplotlib.pyplot as plt
@@ -181,68 +181,7 @@ def plot_PIT(pit_list_1, pit_list_2 = None, pit_list_3=None, sample='specz', lab
181
  # Show the plot
182
  plt.show()
183
 
184
-
185
- import numpy as np
186
- import matplotlib.pyplot as plt
187
- from scipy import stats
188
-
189
- def plot_photoz(df_list, nbins, xvariable, metric, type_bin='bin',label_list=None, samp='zs', save=False):
190
- #plot properties
191
- plt.rcParams['font.family'] = 'serif'
192
- plt.rcParams['font.size'] = 12
193
-
194
-
195
-
196
-
197
- bin_edges = stats.mstats.mquantiles(df_list[0][xvariable].values, np.linspace(0.05, 1, nbins))
198
- print(bin_edges)
199
- cmap = plt.get_cmap('Dark2') # Choose a colormap for coloring lines
200
- plt.figure(figsize=(6, 5))
201
- ls = ['--',':','-']
202
-
203
- for i, df in enumerate(df_list):
204
- ydata, xlab = [], []
205
-
206
- for k in range(len(bin_edges)-1):
207
- edge_min = bin_edges[k]
208
- edge_max = bin_edges[k+1]
209
-
210
- mean_mag = (edge_max + edge_min) / 2
211
-
212
- if type_bin == 'bin':
213
- df_plot = df[(df[xvariable] > edge_min) & (df[xvariable] < edge_max)]
214
- elif type_bin == 'cum':
215
- df_plot = df[(df[xvariable] < edge_max)]
216
- else:
217
- raise ValueError("Only type_bin=='bin' for binned and 'cum' for cumulative are supported")
218
-
219
- xlab.append(mean_mag)
220
- if metric == 'sig68':
221
- ydata.append(sigma68(df_plot.zwerr))
222
- elif metric == 'bias':
223
- ydata.append(np.mean(df_plot.zwerr))
224
- elif metric == 'nmad':
225
- ydata.append(nmad(df_plot.zwerr))
226
- elif metric == 'outliers':
227
- ydata.append(len(df_plot[np.abs(df_plot.zwerr) > 0.15]) / len(df_plot)*100)
228
-
229
- print(ydata)
230
- color = cmap(i) # Get a different color for each dataframe
231
- plt.plot(xlab, ydata,marker='.', lw=1, label=f'{label_list[i]}', color=color, ls=ls[i])
232
-
233
- if xvariable == 'VISmag':
234
- xvariable_lab = 'VIS'
235
-
236
-
237
 
238
- plt.ylabel(f'{metric} $[\\Delta z]$', fontsize=18)
239
- plt.xlabel(f'{xvariable_lab}', fontsize=16)
240
- plt.grid(False)
241
- plt.legend()
242
-
243
- if save==True:
244
- plt.savefig(f'{metric}_{xvariable}_{samp}.pdf', dpi=300, bbox_inches='tight')
245
- plt.show()
246
 
247
 
248
  def plot_nz(df_list,
@@ -336,3 +275,43 @@ def plot_crps(crps_list_1, crps_list_2 = None, crps_list_3=None, labels=None, s
336
  # Show the plot
337
  plt.show()
338
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import numpy as np
2
  import pandas as pd
3
  import matplotlib.pyplot as plt
4
+ from temps.utils import nmad
5
 
6
  import numpy as np
7
  import matplotlib.pyplot as plt
 
181
  # Show the plot
182
  plt.show()
183
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
 
 
 
 
 
 
 
 
185
 
186
 
187
  def plot_nz(df_list,
 
275
  # Show the plot
276
  plt.show()
277
 
278
+
279
+
280
+ def plot_nz(df, bins=np.arange(0,5,0.2)):
281
+ kwargs=dict( bins=bins,alpha=0.5)
282
+ plt.hist(df.zs.values, color='grey', ls='-' ,**kwargs)
283
+ counts, _, =np.histogram(df.z.values, bins=bins)
284
+
285
+ plt.plot((bins[:-1]+bins[1:])*0.5,counts, color ='purple')
286
+
287
+ #plt.legend(fontsize=14)
288
+ plt.xlabel(r'Redshift', fontsize=14)
289
+ plt.ylabel(r'Counts', fontsize=14)
290
+ plt.yscale('log')
291
+
292
+ plt.show()
293
+
294
+ return
295
+
296
+
297
+ def plot_scatter(df, sample='specz', save=True):
298
+ # Calculate the point density
299
+ xy = np.vstack([df.zs.values,df.z.values])
300
+ zd = gaussian_kde(xy)(xy)
301
+
302
+ fig, ax = plt.subplots()
303
+ plt.scatter(df.zs.values, df.z.values,c=zd, s=1)
304
+ plt.xlim(0,5)
305
+ plt.ylim(0,5)
306
+
307
+ plt.xlabel(r'$z_{\rm s}$', fontsize = 14)
308
+ plt.ylabel('$z$', fontsize = 14)
309
+
310
+ plt.xticks(fontsize = 12)
311
+ plt.yticks(fontsize = 12)
312
+
313
+ if save==True:
314
+ plt.savefig(f'{sample}_scatter.pdf', dpi = 300, bbox_inches='tight')
315
+
316
+ plt.show()
317
+
temps/temps.py CHANGED
@@ -1,257 +1,267 @@
 
 
1
  import torch
2
- from torch.utils.data import DataLoader, dataset, TensorDataset
3
  from torch import nn, optim
 
4
  from torch.optim import lr_scheduler
5
- import numpy as np
6
- import pandas as pd
7
- from astropy.io import fits
8
- import os
9
- from astropy.table import Table
10
- from scipy.spatial import KDTree
11
- from scipy.special import erf
12
  from scipy.stats import norm
13
- import sys
14
-
15
- sys.path.append('/.')
16
- from utils import maximum_mean_discrepancy, compute_kernel
17
-
18
- class Temps_module():
19
- """ Define class"""
20
-
21
- def __init__(self, modelF, modelZ, batch_size=100,rejection_param=1, da=True, verbose=False):
22
- self.modelZ=modelZ
23
- self.modelF=modelF
24
- self.da=da
25
- self.verbose=verbose
26
- self.ngaussians=modelZ.ngaussians
27
-
28
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
29
- self.batch_size=batch_size
30
- self.rejection_parameter=rejection_param
31
-
32
-
33
-
34
- def _get_dataloaders(self, input_data, target_data, input_data_DA, target_data_DA, val_fraction=0.1):
 
 
 
 
 
 
 
 
 
 
 
35
  input_data = torch.Tensor(input_data)
36
  target_data = torch.Tensor(target_data)
37
- if input_data_DA is not None:
38
- input_data_DA = torch.Tensor(input_data_DA)
39
- target_data_DA = torch.Tensor(target_data_DA)
40
- else:
41
- input_data_DA = input_data.clone()
42
- target_data_DA = target_data.clone()
43
-
44
- dataset = TensorDataset(input_data, input_data_DA, target_data, target_data_DA)
45
- trainig_dataset, val_dataset = torch.utils.data.random_split(dataset, [int(len(dataset)*(1-val_fraction)), int(len(dataset)*val_fraction)+1])
46
- loader_train = DataLoader(trainig_dataset, batch_size=self.batch_size, shuffle = True)
47
- loader_val = DataLoader(val_dataset, batch_size=64, shuffle = True)
 
 
 
 
48
 
49
  return loader_train, loader_val
50
 
51
-
52
-
53
-
54
- def _loss_function(self,mean, std, logmix, true):
55
-
56
- log_prob = logmix - 0.5*(mean - true[:,None]).pow(2) / std.pow(2) - torch.log(std)
57
- log_prob = torch.logsumexp(log_prob, 1)
58
  loss = -log_prob.mean()
59
-
60
- return loss
61
-
62
- def _loss_function_DA(self,f1, f2):
63
- kl_loss = nn.KLDivLoss(reduction="batchmean",log_target=True)
64
  loss = kl_loss(f1, f2)
65
- loss = torch.log(loss)
66
- #print('f1',f1)
67
- #print('f2',f2)
68
-
69
- return loss
70
 
71
- def _to_numpy(self,x):
 
72
  return x.detach().cpu().numpy()
73
-
74
-
75
-
76
- def train(self,input_data,
77
- input_data_DA,
78
- target_data,
79
- target_data_DA,
80
- nepochs=10,
81
- step_size = 100,
82
- val_fraction=0.1,
83
- lr=1e-3,
84
- weight_decay=0):
85
- self.modelZ = self.modelZ.train()
86
- self.modelF = self.modelF.train()
87
-
88
- loader_train, loader_val = self._get_dataloaders(input_data, target_data, input_data_DA, target_data_DA, val_fraction=0.1)
89
- optimizerZ = optim.Adam(self.modelZ.parameters(), lr=lr, weight_decay=weight_decay)
90
- optimizerF = optim.Adam(self.modelF.parameters(), lr=lr, weight_decay=weight_decay)
91
-
92
- schedulerZ = torch.optim.lr_scheduler.StepLR(optimizerZ, step_size=step_size, gamma =0.1)
93
- schedulerF = torch.optim.lr_scheduler.StepLR(optimizerF, step_size=step_size, gamma =0.1)
94
-
95
- self.modelZ = self.modelZ.to(self.device)
96
- self.modelF = self.modelF.to(self.device)
97
 
98
- self.loss_train, self.loss_validation = [],[]
99
-
100
- for epoch in range(nepochs):
101
- for input_data, input_data_da, target_data, target_data_DA in loader_train:
102
- _loss_train, _loss_validation = [],[]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
- input_data = input_data.to(self.device)
105
- target_data = target_data.to(self.device)
106
-
 
 
 
 
 
 
107
  if self.da:
108
  input_data_da = input_data_da.to(self.device)
109
- target_data_DA = target_data_DA.to(self.device)
110
 
111
- optimizerF.zero_grad()
112
- optimizerZ.zero_grad()
113
 
114
- features = self.modelF(input_data)
115
- if self.da:
116
- features_DA = self.modelF(input_data_da)
117
 
118
- mu, logsig, logmix_coeff = self.modelZ(features)
119
- logsig = torch.clamp(logsig,-6,2)
120
  sig = torch.exp(logsig)
121
 
122
- lossZ = self._loss_function(mu, sig, logmix_coeff, target_data)
123
-
124
- #mu, logsig, logmix_coeff = self.modelZ(features_DA)
125
- #logsig = torch.clamp(logsig,-6,2)
126
- #sig = torch.exp(logsig)
127
-
128
- #lossZ_DA = self._loss_function(mu, sig, logmix_coeff, target_data_DA)
129
-
130
- if self.da:
131
- lossDA = maximum_mean_discrepancy(features, features_DA, kernel_type='rbf')
132
- lossDA = lossDA.sum()
133
- loss = lossZ +1e3*lossDA
134
- else:
135
- loss = lossZ
136
-
137
- _loss_train.append(lossZ.item())
138
-
139
  loss.backward()
140
- optimizerF.step()
141
- optimizerZ.step()
142
-
143
- schedulerF.step()
144
- schedulerZ.step()
145
-
146
- self.loss_train.append(np.mean(_loss_train))
 
 
 
 
 
147
 
148
- for input_data, _, target_data, _ in loader_val:
 
 
 
 
149
 
 
 
 
 
150
  input_data = input_data.to(self.device)
151
  target_data = target_data.to(self.device)
152
 
153
-
154
- features = self.modelF(input_data)
155
- mu, logsig, logmix_coeff = self.modelZ(features)
156
-
157
- logsig = torch.clamp(logsig,-6,2)
158
  sig = torch.exp(logsig)
159
 
160
  loss_val = self._loss_function(mu, sig, logmix_coeff, target_data)
161
  _loss_validation.append(loss_val.item())
162
 
163
- self.loss_validation.append(np.mean(_loss_validation))
164
-
165
-
166
- if self.verbose:
167
-
168
- print(f'training_loss:{loss}',f'testing_loss:{loss_val}')
169
-
170
 
171
  def get_features(self, input_data):
172
- self.modelF = self.modelF.eval()
173
- self.modelF = self.modelF.to(self.device)
174
-
175
  input_data = input_data.to(self.device)
176
-
177
- features = self.modelF(input_data)
178
-
179
- return features.detach().cpu().numpy()
180
-
181
 
182
- def get_pz(self,input_data, return_pz=True, return_flag=True, retrun_odds=False):
183
- self.modelZ = self.modelZ.eval()
184
- self.modelZ = self.modelZ.to(self.device)
185
- self.modelF = self.modelF.eval()
186
- self.modelF = self.modelF.to(self.device)
187
 
188
  input_data = input_data.to(self.device)
189
-
190
-
191
- features = self.modelF(input_data)
192
- mu, logsig, logmix_coeff = self.modelZ(features)
193
- logsig = torch.clamp(logsig,-6,2)
194
  sig = torch.exp(logsig)
195
 
196
  mix_coeff = torch.exp(logmix_coeff)
 
 
 
 
 
197
 
198
- z = (mix_coeff * mu).sum(1)
199
- zerr = torch.sqrt( (mix_coeff * sig**2).sum(1) + (mix_coeff * (mu - mu.mean(1)[:,None])**2).sum(1))
200
-
201
- mu, mix_coeff, sig = mu.detach().cpu().numpy(), mix_coeff.detach().cpu().numpy(), sig.detach().cpu().numpy()
202
-
203
-
204
- if return_pz==True:
205
- zgrid = np.linspace(0, 5, 1000)
206
- pdf_mixture = np.zeros(shape=(len(input_data), len(zgrid)))
207
- for ii in range(len(input_data)):
208
- for i in range(self.ngaussians):
209
- pdf_mixture[ii] += mix_coeff[ii,i] * norm.pdf(zgrid, mu[ii,i], sig[ii,i])
210
- if return_flag==True:
211
- #narrow peak
212
- pdf_mixture = pdf_mixture / pdf_mixture.sum(1)[:,None]
213
- diff_matrix = np.abs(self._to_numpy(z)[:,None] - zgrid[None,:])
214
- #odds
215
- idx_peak = np.argmax(pdf_mixture,1)
216
- zpeak = zgrid[idx_peak]
217
- diff_matrix_upper = np.abs((zpeak+0.05)[:,None] - zgrid[None,:])
218
- diff_matrix_lower = np.abs((zpeak-0.05)[:,None] - zgrid[None,:])
219
-
220
- idx = np.argmin(diff_matrix,1)
221
- idx_upper = np.argmin(diff_matrix_upper,1)
222
- idx_lower = np.argmin(diff_matrix_lower,1)
223
-
224
- p_z_x = np.zeros(shape=(len(z)))
225
- odds = np.zeros(shape=(len(z)))
226
-
227
- for ii in range(len(z)):
228
- p_z_x[ii] = pdf_mixture[ii,idx[ii]]
229
- odds[ii] = pdf_mixture[ii,:idx_upper[ii]].sum() - pdf_mixture[ii,:idx_lower[ii]].sum()
230
-
231
 
232
-
233
- return self._to_numpy(z),self._to_numpy(zerr), pdf_mixture, p_z_x, odds
234
- else:
235
-
236
- return self._to_numpy(z),self._to_numpy(zerr), pdf_mixture
237
-
238
  else:
239
- return self._to_numpy(z),self._to_numpy(zerr)
240
-
241
- def pit(self, input_data, target_data):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
 
243
  pit_list = []
244
 
245
- self.modelF = self.modelF.eval()
246
- self.modelF = self.modelF.to(self.device)
247
- self.modelZ = self.modelZ.eval()
248
- self.modelZ = self.modelZ.to(self.device)
249
 
250
  input_data = input_data.to(self.device)
251
 
252
 
253
- features = self.modelF(input_data)
254
- mu, logsig, logmix_coeff = self.modelZ(features)
255
 
256
  logsig = torch.clamp(logsig,-6,2)
257
  sig = torch.exp(logsig)
@@ -267,7 +277,8 @@ class Temps_module():
267
 
268
  return pit_list
269
 
270
- def crps(self, input_data, target_data):
 
271
 
272
  def measure_crps(cdf, t):
273
  zgrid = np.linspace(0,4,1000)
@@ -281,16 +292,16 @@ class Temps_module():
281
 
282
  crps_list = []
283
 
284
- self.modelF = self.modelF.eval()
285
- self.modelF = self.modelF.to(self.device)
286
- self.modelZ = self.modelZ.eval()
287
- self.modelZ = self.modelZ.to(self.device)
288
 
289
  input_data = input_data.to(self.device)
290
 
291
 
292
- features = self.modelF(input_data)
293
- mu, logsig, logmix_coeff = self.modelZ(features)
294
  logsig = torch.clamp(logsig,-6,2)
295
  sig = torch.exp(logsig)
296
 
@@ -302,21 +313,19 @@ class Temps_module():
302
  z = (mix_coeff * mu).sum(1)
303
 
304
  x = np.linspace(0, 4, 1000)
305
- pdf_mixture = np.zeros(shape=(len(target_data), len(x)))
306
  for ii in range(len(input_data)):
307
  for i in range(6):
308
- pdf_mixture[ii] += mix_coeff[ii,i] * norm.pdf(x, mu[ii,i], sig[ii,i])
309
 
310
- pdf_mixture = pdf_mixture / pdf_mixture.sum(1)[:,None]
311
 
312
 
313
- cdf_mixture = np.cumsum(pdf_mixture,1)
314
 
315
- crps_value = measure_crps(cdf_mixture, target_data)
316
 
317
 
318
 
319
  return crps_value
320
 
321
-
322
-
 
1
+ import numpy as np
2
+ import pandas as pd
3
  import torch
 
4
  from torch import nn, optim
5
+ from torch.utils.data import DataLoader, TensorDataset
6
  from torch.optim import lr_scheduler
 
 
 
 
 
 
 
7
  from scipy.stats import norm
8
+ from loguru import logger
9
+ from tqdm import tqdm # Import tqdm for progress bars
10
+
11
+ # Local imports
12
+ from temps.utils import maximum_mean_discrepancy
13
+
14
+
15
+ class TempsModule:
16
+ """Class for managing temperature-related models and training."""
17
+
18
+ def __init__(
19
+ self,
20
+ model_f,
21
+ model_z,
22
+ batch_size=100,
23
+ rejection_param=1,
24
+ da=True,
25
+ verbose=False,
26
+ ):
27
+ self.model_z = model_z
28
+ self.model_f = model_f
29
+ self.da = da
30
+ self.verbose = verbose
31
+ self.ngaussians = model_z.ngaussians
32
+
33
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
+ self.batch_size = batch_size
35
+ self.rejection_parameter = rejection_param
36
+
37
+ def _get_dataloaders(
38
+ self, input_data, target_data, input_data_da=None, val_fraction=0.1
39
+ ):
40
+ """Create training and validation dataloaders."""
41
  input_data = torch.Tensor(input_data)
42
  target_data = torch.Tensor(target_data)
43
+ input_data_da = (
44
+ torch.Tensor(input_data_da)
45
+ if input_data_da is not None
46
+ else input_data.clone()
47
+ )
48
+
49
+ dataset = TensorDataset(input_data, input_data_da, target_data)
50
+ train_dataset, val_dataset = torch.utils.data.random_split(
51
+ dataset,
52
+ [int(len(dataset) * (1 - val_fraction)), int(len(dataset) * val_fraction)],
53
+ )
54
+ loader_train = DataLoader(
55
+ train_dataset, batch_size=self.batch_size, shuffle=True
56
+ )
57
+ loader_val = DataLoader(val_dataset, batch_size=64, shuffle=True)
58
 
59
  return loader_train, loader_val
60
 
61
+ def _loss_function(self, mean, std, logmix, true):
62
+ """Compute the loss function."""
63
+ log_prob = (
64
+ logmix - 0.5 * (mean - true[:, None]).pow(2) / std.pow(2) - torch.log(std)
65
+ )
66
+ log_prob = torch.logsumexp(log_prob, dim=1)
 
67
  loss = -log_prob.mean()
68
+ return loss
69
+
70
+ def _loss_function_da(self, f1, f2):
71
+ """Compute the KL divergence loss for domain adaptation."""
72
+ kl_loss = nn.KLDivLoss(reduction="batchmean", log_target=True)
73
  loss = kl_loss(f1, f2)
74
+ return torch.log(loss)
 
 
 
 
75
 
76
+ def _to_numpy(self, x):
77
+ """Convert a tensor to a NumPy array."""
78
  return x.detach().cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
+ def train(
81
+ self,
82
+ input_data,
83
+ input_data_da,
84
+ target_data,
85
+ nepochs=10,
86
+ step_size=100,
87
+ val_fraction=0.1,
88
+ lr=1e-3,
89
+ weight_decay=0,
90
+ ):
91
+ """Train the models using provided data."""
92
+ self.model_z.train()
93
+ self.model_f.train()
94
+
95
+ loader_train, loader_val = self._get_dataloaders(
96
+ input_data, target_data, input_data_da, val_fraction
97
+ )
98
+ optimizer_z = optim.Adam(
99
+ self.model_z.parameters(), lr=lr, weight_decay=weight_decay
100
+ )
101
+ optimizer_f = optim.Adam(
102
+ self.model_f.parameters(), lr=lr, weight_decay=weight_decay
103
+ )
104
+
105
+ scheduler_z = lr_scheduler.StepLR(optimizer_z, step_size=step_size, gamma=0.1)
106
+ scheduler_f = lr_scheduler.StepLR(optimizer_f, step_size=step_size, gamma=0.1)
107
+
108
+ self.model_z.to(self.device)
109
+ self.model_f.to(self.device)
110
+
111
+ loss_train, loss_validation = [], []
112
 
113
+ for epoch in range(nepochs):
114
+ _loss_train, _loss_validation = [], []
115
+ logger.info(f"Epoch {epoch + 1}/{nepochs} starting...")
116
+ for input_data, input_data_da, target_data in tqdm(
117
+ loader_train, desc="Training", unit="batch"
118
+ ):
119
+ input_data, target_data = input_data.to(self.device), target_data.to(
120
+ self.device
121
+ )
122
  if self.da:
123
  input_data_da = input_data_da.to(self.device)
 
124
 
125
+ optimizer_f.zero_grad()
126
+ optimizer_z.zero_grad()
127
 
128
+ features = self.model_f(input_data)
129
+ features_da = self.model_f(input_data_da) if self.da else None
 
130
 
131
+ mu, logsig, logmix_coeff = self.model_z(features)
132
+ logsig = torch.clamp(logsig, -6, 2)
133
  sig = torch.exp(logsig)
134
 
135
+ loss_z = self._loss_function(mu, sig, logmix_coeff, target_data)
136
+ loss = loss_z + (
137
+ 1e3
138
+ * maximum_mean_discrepancy(
139
+ features, features_da, kernel_type="rbf"
140
+ ).sum()
141
+ if self.da
142
+ else 0
143
+ )
144
+
145
+ _loss_train.append(loss_z.item())
 
 
 
 
 
 
146
  loss.backward()
147
+ optimizer_f.step()
148
+ optimizer_z.step()
149
+
150
+ scheduler_f.step()
151
+ scheduler_z.step()
152
+
153
+ loss_train.append(np.mean(_loss_train))
154
+ _loss_validation = self._validate(loader_val, target_data)
155
+
156
+ logger.info(
157
+ f"Epoch {epoch + 1}: Training Loss: {np.mean(_loss_train):.4f}, Validation Loss: {np.mean(_loss_validation):.4f}"
158
+ )
159
 
160
+ def _validate(self, loader_val, target_data):
161
+ """Validate the model on the validation dataset."""
162
+ self.model_z.eval()
163
+ self.model_f.eval()
164
+ _loss_validation = []
165
 
166
+ with torch.no_grad():
167
+ for input_data, _, target_data in tqdm(
168
+ loader_val, desc="Validating", unit="batch"
169
+ ):
170
  input_data = input_data.to(self.device)
171
  target_data = target_data.to(self.device)
172
 
173
+ features = self.model_f(input_data)
174
+ mu, logsig, logmix_coeff = self.model_z(features)
175
+ logsig = torch.clamp(logsig, -6, 2)
 
 
176
  sig = torch.exp(logsig)
177
 
178
  loss_val = self._loss_function(mu, sig, logmix_coeff, target_data)
179
  _loss_validation.append(loss_val.item())
180
 
181
+ return _loss_validation
 
 
 
 
 
 
182
 
183
  def get_features(self, input_data):
184
+ """Get features from the model."""
185
+ self.model_f.eval()
 
186
  input_data = input_data.to(self.device)
187
+ features = self.model_f(input_data)
188
+ return self._to_numpy(features)
 
 
 
189
 
190
+ def get_pz(self, input_data, return_pz=True, return_flag=True, return_odds=False):
191
+ """Get the predicted z values and their uncertainties."""
192
+ logger.info("Predicting photo-z for the input galaxies...")
193
+ self.model_z.eval()
194
+ self.model_f.eval()
195
 
196
  input_data = input_data.to(self.device)
197
+ features = self.model_f(input_data)
198
+ mu, logsig, logmix_coeff = self.model_z(features)
199
+ logsig = torch.clamp(logsig, -6, 2)
 
 
200
  sig = torch.exp(logsig)
201
 
202
  mix_coeff = torch.exp(logmix_coeff)
203
+ z = (mix_coeff * mu).sum(dim=1)
204
+ zerr = torch.sqrt(
205
+ (mix_coeff * sig**2).sum(dim=1)
206
+ + (mix_coeff * (mu - mu.mean(dim=1, keepdim=True)) ** 2).sum(dim=1)
207
+ )
208
 
209
+ mu, mix_coeff, sig = map(self._to_numpy, (mu, mix_coeff, sig))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
 
211
+ if return_pz:
212
+ logger.info("Returning p(z)")
213
+ return self._calculate_pdf(z, mu, sig, mix_coeff, return_flag)
 
 
 
214
  else:
215
+ return self._to_numpy(z), self._to_numpy(zerr)
216
+
217
+ def _calculate_pdf(self, z, mu, sig, mix_coeff, return_flag):
218
+ """Calculate the probability density function."""
219
+ zgrid = np.linspace(0, 5, 1000)
220
+ pz = np.zeros((len(z), len(zgrid)))
221
+
222
+ for ii in range(len(z)):
223
+ for i in range(self.ngaussians):
224
+ pz[ii] += mix_coeff[ii, i] * norm.pdf(
225
+ zgrid, mu[ii, i], sig[ii, i]
226
+ )
227
+
228
+ if return_flag:
229
+ logger.info("Calculating and returning ODDS")
230
+ pz /= pz.sum(axis=1, keepdims=True)
231
+ return self._calculate_odds(z, pz, zgrid)
232
+ return self._to_numpy(z), pz
233
+
234
+ def _calculate_odds(self, z, pz, zgrid):
235
+ """Calculate odds based on the PDF."""
236
+ logger.info('Calculating ODDS values')
237
+ diff_matrix = np.abs(self._to_numpy(z)[:, None] - zgrid[None, :])
238
+ idx_peak = np.argmax(pz, axis=1)
239
+ zpeak = zgrid[idx_peak]
240
+ idx_upper = np.argmin(np.abs((zpeak + 0.05)[:, None] - zgrid[None, :]), axis=1)
241
+ idx_lower = np.argmin(np.abs((zpeak - 0.05)[:, None] - zgrid[None, :]), axis=1)
242
+
243
+ odds = []
244
+ for jj in range(len(pz)):
245
+ odds.append(pz[jj,idx_lower[jj]:(idx_upper[jj]+1)].sum())
246
+
247
+ odds = np.array(odds)
248
+ return self._to_numpy(z), pz, odds
249
+
250
+ def calculate_pit(self, input_data, target_data):
251
+ logger.info('Calculating PIT values')
252
 
253
  pit_list = []
254
 
255
+ self.model_f = self.model_f.eval()
256
+ self.model_f = self.model_f.to(self.device)
257
+ self.model_z = self.model_z.eval()
258
+ self.model_z = self.model_z.to(self.device)
259
 
260
  input_data = input_data.to(self.device)
261
 
262
 
263
+ features = self.model_f(input_data)
264
+ mu, logsig, logmix_coeff = self.model_z(features)
265
 
266
  logsig = torch.clamp(logsig,-6,2)
267
  sig = torch.exp(logsig)
 
277
 
278
  return pit_list
279
 
280
+ def calculate_crps(self, input_data, target_data):
281
+ logger.info('Calculating CRPS values')
282
 
283
  def measure_crps(cdf, t):
284
  zgrid = np.linspace(0,4,1000)
 
292
 
293
  crps_list = []
294
 
295
+ self.model_f = self.model_f.eval()
296
+ self.model_f = self.model_f.to(self.device)
297
+ self.model_z = self.model_z.eval()
298
+ self.model_z = self.model_z.to(self.device)
299
 
300
  input_data = input_data.to(self.device)
301
 
302
 
303
+ features = self.model_f(input_data)
304
+ mu, logsig, logmix_coeff = self.model_z(features)
305
  logsig = torch.clamp(logsig,-6,2)
306
  sig = torch.exp(logsig)
307
 
 
313
  z = (mix_coeff * mu).sum(1)
314
 
315
  x = np.linspace(0, 4, 1000)
316
+ pz = np.zeros(shape=(len(target_data), len(x)))
317
  for ii in range(len(input_data)):
318
  for i in range(6):
319
+ pz[ii] += mix_coeff[ii,i] * norm.pdf(x, mu[ii,i], sig[ii,i])
320
 
321
+ pz = pz / pz.sum(1)[:,None]
322
 
323
 
324
+ cdf_z = np.cumsum(pz,1)
325
 
326
+ crps_value = measure_crps(cdf_z, target_data)
327
 
328
 
329
 
330
  return crps_value
331
 
 
 
temps/temps_arch.py CHANGED
@@ -20,52 +20,46 @@ class EncoderPhotometry(nn.Module):
20
  nn.Linear(50, 20),
21
  nn.Dropout(dropout_prob),
22
  nn.ReLU(),
23
- nn.Linear(20, 10)
24
  )
25
-
26
  def forward(self, x):
27
  f = self.features(x)
28
- f = F.log_softmax(f, dim=1)
29
  return f
30
 
31
-
32
 
33
  class MeasureZ(nn.Module):
34
  def __init__(self, num_gauss=10, dropout_prob=0):
35
  super(MeasureZ, self).__init__()
36
-
37
- self.ngaussians=num_gauss
38
  self.measure_mu = nn.Sequential(
39
  nn.Linear(10, 20),
40
  nn.Dropout(dropout_prob),
41
  nn.ReLU(),
42
- nn.Linear(20, num_gauss)
43
  )
44
 
45
  self.measure_coeffs = nn.Sequential(
46
  nn.Linear(10, 20),
47
  nn.Dropout(dropout_prob),
48
  nn.ReLU(),
49
- nn.Linear(20, num_gauss)
50
  )
51
 
52
  self.measure_sigma = nn.Sequential(
53
  nn.Linear(10, 20),
54
  nn.Dropout(dropout_prob),
55
  nn.ReLU(),
56
- nn.Linear(20, num_gauss)
57
  )
58
-
59
-
60
  def forward(self, f):
61
  mu = self.measure_mu(f)
62
  sigma = self.measure_sigma(f)
63
  logmix_coeff = self.measure_coeffs(f)
64
-
65
- logmix_coeff = logmix_coeff - torch.logsumexp(logmix_coeff, 1)[:,None]
66
-
67
- return mu, sigma, logmix_coeff
68
 
69
-
70
-
71
 
 
 
20
  nn.Linear(50, 20),
21
  nn.Dropout(dropout_prob),
22
  nn.ReLU(),
23
+ nn.Linear(20, 10),
24
  )
25
+
26
  def forward(self, x):
27
  f = self.features(x)
28
+ f = F.log_softmax(f, dim=1)
29
  return f
30
 
 
31
 
32
  class MeasureZ(nn.Module):
33
  def __init__(self, num_gauss=10, dropout_prob=0):
34
  super(MeasureZ, self).__init__()
35
+
36
+ self.ngaussians = num_gauss
37
  self.measure_mu = nn.Sequential(
38
  nn.Linear(10, 20),
39
  nn.Dropout(dropout_prob),
40
  nn.ReLU(),
41
+ nn.Linear(20, num_gauss),
42
  )
43
 
44
  self.measure_coeffs = nn.Sequential(
45
  nn.Linear(10, 20),
46
  nn.Dropout(dropout_prob),
47
  nn.ReLU(),
48
+ nn.Linear(20, num_gauss),
49
  )
50
 
51
  self.measure_sigma = nn.Sequential(
52
  nn.Linear(10, 20),
53
  nn.Dropout(dropout_prob),
54
  nn.ReLU(),
55
+ nn.Linear(20, num_gauss),
56
  )
57
+
 
58
  def forward(self, f):
59
  mu = self.measure_mu(f)
60
  sigma = self.measure_sigma(f)
61
  logmix_coeff = self.measure_coeffs(f)
 
 
 
 
62
 
63
+ logmix_coeff = logmix_coeff - torch.logsumexp(logmix_coeff, 1)[:, None]
 
64
 
65
+ return mu, sigma, logmix_coeff
temps/utils.py CHANGED
@@ -3,113 +3,22 @@ import pandas as pd
3
  import matplotlib.pyplot as plt
4
  from scipy import stats
5
  import torch
6
- from scipy.stats import gaussian_kde
7
 
8
- def nmad(data):
9
- return 1.4826 * np.median(np.abs(data - np.median(data)))
10
-
11
- def sigma68(data): return 0.5*(pd.Series(data).quantile(q = 0.84) - pd.Series(data).quantile(q = 0.16))
12
-
13
- def plot_photoz(df_list, nbins, xvariable, metric, type_bin='bin',label_list=None, samp='zs', save=False):
14
- #plot properties
15
- plt.rcParams['font.family'] = 'serif'
16
- plt.rcParams['font.size'] = 12
17
-
18
 
 
 
19
 
20
-
21
- bin_edges = stats.mstats.mquantiles(df_list[0][xvariable].values, np.linspace(0.05, 1, nbins))
22
- print(bin_edges)
23
- cmap = plt.get_cmap('Dark2') # Choose a colormap for coloring lines
24
- plt.figure(figsize=(6, 5))
25
-
26
- for i, df in enumerate(df_list):
27
- ydata, xlab = [], []
28
-
29
- for k in range(len(bin_edges)-1):
30
- edge_min = bin_edges[k]
31
- edge_max = bin_edges[k+1]
32
-
33
- mean_mag = (edge_max + edge_min) / 2
34
-
35
- if type_bin == 'bin':
36
- df_plot = df[(df[xvariable] > edge_min) & (df[xvariable] < edge_max)]
37
- elif type_bin == 'cum':
38
- df_plot = df[(df[xvariable] < edge_max)]
39
- else:
40
- raise ValueError("Only type_bin=='bin' for binned and 'cum' for cumulative are supported")
41
-
42
- xlab.append(mean_mag)
43
- if metric == 'sig68':
44
- ydata.append(sigma68(df_plot.zwerr))
45
- elif metric == 'bias':
46
- ydata.append(np.mean(df_plot.zwerr))
47
- elif metric == 'nmad':
48
- ydata.append(nmad(df_plot.zwerr))
49
- elif metric == 'outliers':
50
- ydata.append(len(df_plot[np.abs(df_plot.zwerr) > 0.15]) / len(df_plot)*100)
51
-
52
- print(ydata)
53
- color = cmap(i) # Get a different color for each dataframe
54
- plt.plot(xlab, ydata, ls='-', marker='.', lw=1, label=f'{label_list[i]}', color=color)
55
-
56
- if xvariable == 'VISmag':
57
- xvariable_lab = 'VIS'
58
-
59
-
60
-
61
- plt.ylabel(f'{metric} $[\\Delta z]$', fontsize=18)
62
- plt.xlabel(f'{xvariable_lab}', fontsize=16)
63
- plt.grid(False)
64
- plt.legend()
65
-
66
- if save==True:
67
- plt.savefig(f'{metric}_{xvariable}_{samp}.pdf', dpi=300, bbox_inches='tight')
68
- plt.show()
69
-
70
-
71
- def plot_nz(df, bins=np.arange(0,5,0.2)):
72
- kwargs=dict( bins=bins,alpha=0.5)
73
- plt.hist(df.zs.values, color='grey', ls='-' ,**kwargs)
74
- counts, _, =np.histogram(df.z.values, bins=bins)
75
-
76
- plt.plot((bins[:-1]+bins[1:])*0.5,counts, color ='purple')
77
-
78
- #plt.legend(fontsize=14)
79
- plt.xlabel(r'Redshift', fontsize=14)
80
- plt.ylabel(r'Counts', fontsize=14)
81
- plt.yscale('log')
82
-
83
- plt.show()
84
-
85
- return
86
-
87
-
88
- def plot_scatter(df, sample='specz', save=True):
89
- # Calculate the point density
90
- xy = np.vstack([df.zs.values,df.z.values])
91
- zd = gaussian_kde(xy)(xy)
92
-
93
- fig, ax = plt.subplots()
94
- plt.scatter(df.zs.values, df.z.values,c=zd, s=1)
95
- plt.xlim(0,5)
96
- plt.ylim(0,5)
97
 
98
- plt.xlabel(r'$z_{\rm s}$', fontsize = 14)
99
- plt.ylabel('$z$', fontsize = 14)
100
 
101
- plt.xticks(fontsize = 12)
102
- plt.yticks(fontsize = 12)
103
 
104
- if save==True:
105
- plt.savefig(f'{sample}_scatter.pdf', dpi = 300, bbox_inches='tight')
106
 
107
- plt.show()
108
-
109
-
110
-
111
 
112
- def maximum_mean_discrepancy(x, y, kernel_type='rbf', kernel_mul=2.0, kernel_num=5):
113
  """
114
  Compute the Maximum Mean Discrepancy (MMD) between two sets of samples.
115
 
@@ -130,7 +39,8 @@ def maximum_mean_discrepancy(x, y, kernel_type='rbf', kernel_mul=2.0, kernel_num
130
  mmd_loss = torch.mean(x_kernel) + torch.mean(y_kernel) - 2 * torch.mean(xy_kernel)
131
  return mmd_loss
132
 
133
- def compute_kernel(x, y, kernel_type='rbf', kernel_mul=2.0, kernel_num=5):
 
134
  """
135
  Compute the kernel matrix based on the chosen kernel type.
136
 
@@ -151,73 +61,77 @@ def compute_kernel(x, y, kernel_type='rbf', kernel_mul=2.0, kernel_num=5):
151
  x = x.unsqueeze(1).expand(x_size, y_size, dim)
152
  y = y.unsqueeze(0).expand(x_size, y_size, dim)
153
 
154
- kernel_input = (x - y).pow(2).mean(2) # Pairwise squared Euclidean distances
155
 
156
- if kernel_type == 'linear':
157
  kernel_matrix = kernel_input
158
- elif kernel_type == 'poly':
159
  kernel_matrix = (1 + kernel_input / kernel_mul).pow(kernel_num)
160
- elif kernel_type == 'rbf':
161
  kernel_matrix = torch.exp(-kernel_input / (2 * kernel_mul**2))
162
- elif kernel_type == 'sigmoid':
163
  kernel_matrix = torch.tanh(kernel_mul * kernel_input)
164
  else:
165
- raise ValueError("Invalid kernel type. Supported types are 'linear', 'poly', 'rbf', and 'sigmoid'.")
 
 
166
 
167
  return kernel_matrix
168
 
169
 
170
- def select_cut(df,
171
- completenss_lim=None,
172
- nmad_lim = None,
173
- outliers_lim=None,
174
- return_df=False):
175
-
176
-
177
- if (completenss_lim is None)&(nmad_lim is None)&(outliers_lim is None):
178
- raise(ValueError("Select at least one cut"))
179
  elif sum(c is not None for c in [completenss_lim, nmad_lim, outliers_lim]) > 1:
180
  raise ValueError("Select only one cut at a time")
181
-
182
  else:
183
- bin_edges = stats.mstats.mquantiles(df.zflag, np.arange(0,1.01,0.1))
184
- scatter, eta, cmptnss, nobj = [],[],[], []
185
 
186
- for k in range(len(bin_edges)-1):
187
  edge_min = bin_edges[k]
188
- edge_max = bin_edges[k+1]
189
 
190
- df_bin = df[(df.zflag > edge_min)]
191
-
192
 
193
- cmptnss.append(np.round(len(df_bin)/len(df),2)*100)
194
  scatter.append(nmad(df_bin.zwerr))
195
- eta.append(len(df_bin[np.abs(df_bin.zwerr)>0.15])/len(df_bin)*100)
196
  nobj.append(len(df_bin))
197
-
198
- dfcuts = pd.DataFrame(data=np.c_[np.round(bin_edges[:-1],5), np.round(nobj,1), np.round(cmptnss,1), np.round(scatter,3), np.round(eta,2)], columns=['flagcut', 'Nobj','completeness', 'nmad', 'eta'])
199
-
 
 
 
 
 
 
 
 
 
200
  if completenss_lim is not None:
201
- print('Selecting cut based on completeness')
202
- selected_cut = dfcuts[dfcuts['completeness'] <= completenss_lim].iloc[0]
203
-
204
-
205
  elif nmad_lim is not None:
206
- print('Selecting cut based on nmad')
207
- selected_cut = dfcuts[dfcuts['nmad'] <= nmad_lim].iloc[0]
208
 
209
-
210
  elif outliers_lim is not None:
211
- print('Selecting cut based on outliers')
212
- selected_cut = dfcuts[dfcuts['eta'] <= outliers_lim].iloc[0]
213
 
 
 
 
214
 
215
- print(f"This cut provides completeness of {selected_cut['completeness']}, nmad={selected_cut['nmad']} and eta={selected_cut['eta']}")
216
-
217
- df_cut = df[(df.zflag > selected_cut['flagcut'])]
218
- if return_df==True:
219
- return df_cut, selected_cut['flagcut'], dfcuts
220
  else:
221
- return selected_cut['flagcut'], dfcuts
222
-
223
-
 
3
  import matplotlib.pyplot as plt
4
  from scipy import stats
5
  import torch
6
+ from loguru import logger
7
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ def caluclate_eta(df):
10
+ return len(df[np.abs(df.zwerr)>0.15])/len(df) *100
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ def nmad(data):
14
+ return 1.4826 * np.median(np.abs(data - np.median(data)))
15
 
 
 
16
 
17
+ def sigma68(data):
18
+ return 0.5 * (pd.Series(data).quantile(q=0.84) - pd.Series(data).quantile(q=0.16))
19
 
 
 
 
 
20
 
21
+ def maximum_mean_discrepancy(x, y, kernel_type="rbf", kernel_mul=2.0, kernel_num=5):
22
  """
23
  Compute the Maximum Mean Discrepancy (MMD) between two sets of samples.
24
 
 
39
  mmd_loss = torch.mean(x_kernel) + torch.mean(y_kernel) - 2 * torch.mean(xy_kernel)
40
  return mmd_loss
41
 
42
+
43
+ def compute_kernel(x, y, kernel_type="rbf", kernel_mul=2.0, kernel_num=5):
44
  """
45
  Compute the kernel matrix based on the chosen kernel type.
46
 
 
61
  x = x.unsqueeze(1).expand(x_size, y_size, dim)
62
  y = y.unsqueeze(0).expand(x_size, y_size, dim)
63
 
64
+ kernel_input = (x - y).pow(2).mean(2)
65
 
66
+ if kernel_type == "linear":
67
  kernel_matrix = kernel_input
68
+ elif kernel_type == "poly":
69
  kernel_matrix = (1 + kernel_input / kernel_mul).pow(kernel_num)
70
+ elif kernel_type == "rbf":
71
  kernel_matrix = torch.exp(-kernel_input / (2 * kernel_mul**2))
72
+ elif kernel_type == "sigmoid":
73
  kernel_matrix = torch.tanh(kernel_mul * kernel_input)
74
  else:
75
+ raise ValueError(
76
+ "Invalid kernel type. Supported types are 'linear', 'poly', 'rbf', and 'sigmoid'."
77
+ )
78
 
79
  return kernel_matrix
80
 
81
 
82
+ def select_cut(
83
+ df, completenss_lim=None, nmad_lim=None, outliers_lim=None, return_df=False
84
+ ):
85
+
86
+ if (completenss_lim is None) & (nmad_lim is None) & (outliers_lim is None):
87
+ raise (ValueError("Select at least one cut"))
 
 
 
88
  elif sum(c is not None for c in [completenss_lim, nmad_lim, outliers_lim]) > 1:
89
  raise ValueError("Select only one cut at a time")
90
+
91
  else:
92
+ bin_edges = stats.mstats.mquantiles(df.odds, np.arange(0, 1.01, 0.1))
93
+ scatter, eta, cmptnss, nobj = [], [], [], []
94
 
95
+ for k in range(len(bin_edges) - 1):
96
  edge_min = bin_edges[k]
97
+ edge_max = bin_edges[k + 1]
98
 
99
+ df_bin = df[(df.odds > edge_min)]
 
100
 
101
+ cmptnss.append(np.round(len(df_bin) / len(df), 2) * 100)
102
  scatter.append(nmad(df_bin.zwerr))
103
+ eta.append(len(df_bin[np.abs(df_bin.zwerr) > 0.15]) / len(df_bin) * 100)
104
  nobj.append(len(df_bin))
105
+
106
+ dfcuts = pd.DataFrame(
107
+ data=np.c_[
108
+ np.round(bin_edges[:-1], 5),
109
+ np.round(nobj, 1),
110
+ np.round(cmptnss, 1),
111
+ np.round(scatter, 3),
112
+ np.round(eta, 2),
113
+ ],
114
+ columns=["flagcut", "Nobj", "completeness", "nmad", "eta"],
115
+ )
116
+
117
  if completenss_lim is not None:
118
+ logger.info("Selecting cut based on completeness")
119
+ selected_cut = dfcuts[dfcuts["completeness"] <= completenss_lim].iloc[0]
120
+
 
121
  elif nmad_lim is not None:
122
+ logger.info("Selecting cut based on nmad")
123
+ selected_cut = dfcuts[dfcuts["nmad"] <= nmad_lim].iloc[0]
124
 
 
125
  elif outliers_lim is not None:
126
+ logger.info("Selecting cut based on outliers")
127
+ selected_cut = dfcuts[dfcuts["eta"] <= outliers_lim].iloc[0]
128
 
129
+ logger.info(
130
+ f"This cut provides completeness of {selected_cut['completeness']}, nmad={selected_cut['nmad']} and eta={selected_cut['eta']}"
131
+ )
132
 
133
+ df_cut = df[(df.odds > selected_cut["flagcut"])]
134
+ if return_df == True:
135
+ return df_cut, selected_cut["flagcut"], dfcuts
 
 
136
  else:
137
+ return selected_cut["flagcut"], dfcuts