dbuscombe commited on
Commit
d86998c
1 Parent(s): 9eafdd1
.gitattributes CHANGED
@@ -32,3 +32,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ examples/*.* filter=lfs diff=lfs merge=lfs -text
36
+ weights/*.* filter=lfs diff=lfs merge=lfs -text
37
+ app_files/*.* filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Daniel Buscombe, Marda Science LLC 2023
2
+ # This file contains many functions originally from Doodleverse https://github.com/Doodleverse programs
3
+
4
+ import gradio as gr
5
+ import numpy as np
6
+ import sys, json, os
7
+ sys.path.insert(1, 'app_files'+os.sep+'src')
8
+ from sedinet_eval import *
9
+
10
+ ###===================================================
11
+ def estimate_siso_simo_1image(vars, im, greyscale,
12
+ dropout, weights_path): # numclass, name, mode, res_folder,
13
+ # batch_size, ):#, scale): #
14
+ """
15
+ This function uses a sedinet model for continuous prediction on 1 image
16
+ """
17
+ SM = make_sedinet_siso_simo(vars, greyscale, dropout)
18
+ SM.load_weights(weights_path)
19
+
20
+ # im = Image.open(image).convert('LA')
21
+ #im = im.resize((IM_HEIGHT, IM_HEIGHT))
22
+ im = Image.fromarray(im)
23
+ im = np.array(im)[:,:,0]
24
+ nx,ny = np.shape(im)
25
+ if (nx!=IM_HEIGHT) or (ny!=IM_HEIGHT):
26
+ im = im[int(nx/2)-int(IM_HEIGHT/2):int(nx/2)+int(IM_HEIGHT/2), int(ny/2)-int(IM_HEIGHT/2):int(ny/2)+int(IM_HEIGHT/2)]
27
+
28
+ if DO_STANDARDIZE==True:
29
+ im = do_standardize(im)
30
+ else:
31
+ im = np.array(im) / 255.0
32
+
33
+ result = SM.predict(np.expand_dims(np.expand_dims(im, axis=2), axis=0))
34
+ result = [float(r[0]) for r in result]
35
+
36
+ return result
37
+
38
+ ###===================================================
39
+ def grainsize(input_img, dims=(1024, 1024)):
40
+
41
+ configfile = 'weights/config_usace_combined2021_2022_v12.json'
42
+ weights_path = 'weights/sandsnap_merged_1024_modelrevOct2022_v12_simo_batch10_im1024_9vars_mse_noaug.hdf5'
43
+
44
+ # load the user configs
45
+ with open(os.getcwd()+os.sep+configfile) as f:
46
+ config = json.load(f)
47
+
48
+ ###===================================================
49
+ dropout = config["dropout"]
50
+ greyscale = config['greyscale']
51
+
52
+ try:
53
+ greyscale = config['greyscale']
54
+ except:
55
+ greyscale = 'true'
56
+
57
+ #output variables
58
+ vars = [k for k in config.keys() if not np.any([k.startswith('base'), k.startswith('MAX_LR'),
59
+ k.startswith('MIN_LR'), k.startswith('DO_AUG'), k.startswith('SHALLOW'),
60
+ k.startswith('res_folder'), k.startswith('train_csvfile'), k.startswith('csvfile'),
61
+ k.startswith('test_csvfile'), k.startswith('name'), k.startswith('val_csvfile'),
62
+ k.startswith('greyscale'), k.startswith('aux_in'),
63
+ k.startswith('dropout'), k.startswith('N'),k.startswith('scale'),
64
+ k.startswith('numclass')])]
65
+ vars = sorted(vars)
66
+
67
+ #this relates to 'mimo' and 'miso' modes that are planned for the future but not currently implemented
68
+ auxin = [k for k in config.keys() if k.startswith('aux_in')]
69
+
70
+ if len(auxin) > 0:
71
+ auxin = config[auxin[0]] ##at least for now, just one 'auxilliary' (numerical/categorical) input in addition to imagery
72
+ if len(vars) ==1:
73
+ mode = 'miso'
74
+ elif len(vars) >1:
75
+ mode = 'mimo'
76
+ else:
77
+ if len(vars) ==1:
78
+ mode = 'siso'
79
+ elif len(vars) >1:
80
+ mode = 'simo'
81
+
82
+ print("Mode: %s" % (mode))
83
+
84
+ result = estimate_siso_simo_1image(vars, input_img, greyscale,
85
+ dropout, weights_path)
86
+
87
+ result = np.array(result)
88
+ print(result)
89
+
90
+ plt.plot(np.hstack((result[:3], result[4:])),[10,16,25,50,65,75,84,90], 'k-o')
91
+ plt.xlabel('Grain size (pixels)')
92
+ plt.ylabel('Percent finer')
93
+ plt.savefig("psd.png", dpi=300, bbox_inches="tight")
94
+
95
+ return 'mean grain size = %f pixels' % (result[4]), '90th percentile grain size = %f pixels' % (result[-1]), plt
96
+
97
+ title = "SandSnap/SediNet Model Demo- Measure grain size from image of sand!"
98
+ description = "Allows upload of imagery and download of grain size statistics. Statistics are unscaled (i.e. in pixels)"
99
+
100
+ examples = [
101
+ ['examples/IMG_20210922_170908944_cropped.jpg'],
102
+ ['examples/20210208_172834_cropped.jpg'],
103
+ ['examples/20220101_165359_cropped.jpg']
104
+ ]
105
+
106
+ inp = gr.Image()
107
+ out2 = gr.Plot(type='matplotlib')
108
+
109
+ Segapp = gr.Interface(grainsize, inp, ["text", "text", out2], title = title, description = description, examples=examples)
110
+ #, allow_flagging='manual', flagging_options=["bad", "ok", "good", "perfect"], flagging_dir="flagged")
111
+
112
+ Segapp.launch(enable_queue=True)
app_files/src/__pycache__/defaults.cpython-311.pyc ADDED
Binary file (608 Bytes). View file
 
app_files/src/__pycache__/imports.cpython-311.pyc ADDED
Binary file (3.15 kB). View file
 
app_files/src/__pycache__/sedinet_eval.cpython-311.pyc ADDED
Binary file (13.1 kB). View file
 
app_files/src/__pycache__/sedinet_models.cpython-311.pyc ADDED
Binary file (6.8 kB). View file
 
app_files/src/__pycache__/sedinet_utils.cpython-311.pyc ADDED
Binary file (111 kB). View file
 
app_files/src/defaults.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Written by Dr Daniel Buscombe, Marda Science LLC
2
+ # for the SandSnap Program
3
+ #
4
+ # MIT License
5
+ #
6
+ # Copyright (c) 2020-2021, Marda Science LLC
7
+ #
8
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ # of this software and associated documentation files (the "Software"), to deal
10
+ # in the Software without restriction, including without limitation the rights
11
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ # copies of the Software, and to permit persons to whom the Software is
13
+ # furnished to do so, subject to the following conditions:
14
+ #
15
+ # The above copyright notice and this permission notice shall be included in all
16
+ # copies or substantial portions of the Software.
17
+ #
18
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
+ # SOFTWARE.
25
+
26
+
27
+ ##> Release v1.4 (Aug 2021)
28
+
29
+ ## Contains values for defaults that you may change.
30
+ ## They are listed in order of likelihood that you might change them:
31
+
32
+ # size of image in pixels. keep this consistent in training and application
33
+ # suggestd: 512 -- 1024 (larger = larger GPU required)
34
+ # integer
35
+ IM_HEIGHT = 1024
36
+ IM_WIDTH = IM_HEIGHT
37
+
38
+ # number of images to feed the network per step in epoch #suggested: as many as you have gpu memory for, probably
39
+ # integer
40
+
41
+ # BATCH_SIZE =8
42
+ # BATCH_SIZE =10
43
+ BATCH_SIZE =12
44
+
45
+ #use an ensemble of batch sizes like this
46
+ #BATCH_SIZE = [7,12,14]
47
+
48
+ # if True, use a smaller (shallower) network architecture
49
+ ##True or False ##False=larger network
50
+ SHALLOW = False #True
51
+
52
+ ## if True, carry out data augmentation. 2 x number of images used in training
53
+ ##True or False
54
+ DO_AUG = False #True
55
+
56
+ # maximum learning rate ##1e-1 -- 1e-5
57
+ MAX_LR = 1e-4
58
+ # MAX_LR = 1e-5
59
+ # MAX_LR = 5e-3
60
+ # MAX_LR = 5e-4
61
+
62
+ # max. number of training epics (20 -1000)
63
+ # integer
64
+ NUM_EPOCHS = 300
65
+
66
+ ## loss function for continuous models (2 choices)
67
+ #CONT_LOSS = 'pinball'
68
+ CONT_LOSS = 'mse'
69
+
70
+ ## loss function for categorical (disrete) models (2 choices)
71
+ CAT_LOSS = 'focal'
72
+ #CAT_LOSS = 'categorical_crossentropy'
73
+
74
+ # optimizer (gradient descent solver) good alternative == 'rmsprop'
75
+ OPT = 'adam'
76
+
77
+ # base number of conv2d filters in categorical models
78
+ # integer
79
+ BASE_CAT = 30
80
+
81
+ # base number of conv2d filters in continuous models
82
+ # integer
83
+ # BASE_CONT = 30
84
+ BASE_CONT = 10
85
+
86
+ # number of Dense units for continuous prediction
87
+ # integer
88
+ # CONT_DENSE_UNITS = 3072
89
+ CONT_DENSE_UNITS = 2048
90
+ # CONT_DENSE_UNITS = 1024
91
+
92
+ # number of Dense units for categorical prediction
93
+ # integer
94
+ CAT_DENSE_UNITS = 128
95
+
96
+ # set to False if you wish to use cpu (not recommended)
97
+ ##True or False
98
+ USE_GPU = True
99
+
100
+ ## standardize imagery (recommended)
101
+ DO_STANDARDIZE = True
102
+
103
+
104
+ # STOP_PATIENCE = 10
105
+
106
+ # FACTOR = 0.2
107
+
108
+ # MIN_DELTA = 0.0001
109
+
110
+ # MIN_LR = 1e-4
app_files/src/imports.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Written by Dr Daniel Buscombe, Marda Science LLC
2
+ # for the SandSnap Program
3
+ #
4
+ # MIT License
5
+ #
6
+ # Copyright (c) 2020-2021, Marda Science LLC
7
+ #
8
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ # of this software and associated documentation files (the "Software"), to deal
10
+ # in the Software without restriction, including without limitation the rights
11
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ # copies of the Software, and to permit persons to whom the Software is
13
+ # furnished to do so, subject to the following conditions:
14
+ #
15
+ # The above copyright notice and this permission notice shall be included in all
16
+ # copies or substantial portions of the Software.
17
+ #
18
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
+ # SOFTWARE.
25
+
26
+
27
+ ##> Release v1.4 (Aug 2021)
28
+
29
+ ###===================================================
30
+ # import libraries
31
+ import gc, os, sys, shutil
32
+
33
+ ###===================================================
34
+ # import and set global variables from defaults.py
35
+ from defaults import *
36
+
37
+ global IM_HEIGHT, IM_WIDTH
38
+
39
+ global NUM_EPOCHS, SHALLOW
40
+
41
+ global VALID_BATCH_SIZE, BATCH_SIZE
42
+
43
+ VALID_BATCH_SIZE = BATCH_SIZE
44
+
45
+ global MAX_LR, OPT, USE_GPU, DO_AUG, DO_STANDARDIZE
46
+
47
+
48
+ # global STOP_PATIENCE, FACTOR, MIN_DELTA, MIN_LR
49
+
50
+ # global MIN_DELTA, FACTOR, STOP_PATIENCE
51
+ ##====================================================
52
+
53
+ # import tensorflow.compat.v1 as tf1
54
+ # config = tf1.ConfigProto()
55
+ # config.gpu_options.allow_growth = True # dynamically grow the memory used on the GPU
56
+ # config.log_device_placement = True # to log device placement (on which device the operation ran)
57
+ # sess = tf1.Session(config=config)
58
+ # tf1.keras.backend.set_session(sess)
59
+
60
+ # PREDICT = False
61
+ #
62
+ # ##OS
63
+ # if PREDICT == True:
64
+ # os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
65
+
66
+ ##TF/keras
67
+ if USE_GPU == True:
68
+ ##use the first available GPU
69
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
70
+ else:
71
+ ## to use the CPU (not recommended):
72
+ os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
73
+
74
+ import numpy as np
75
+ import tensorflow as tf
76
+
77
+ # from tensorflow.keras import mixed_precision
78
+ # mixed_precision.set_global_policy('mixed_float16')
79
+
80
+ SEED=42
81
+ np.random.seed(SEED)
82
+ AUTO = tf.data.experimental.AUTOTUNE # used in tf.data.Dataset API
83
+
84
+ tf.random.set_seed(SEED)
85
+
86
+ print("Version: ", tf.__version__)
87
+ print("Eager mode: ", tf.executing_eagerly())
88
+ print('GPU name: ', tf.config.experimental.list_physical_devices('GPU'))
89
+ print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))
90
+
91
+ from tensorflow.keras.layers import Input, Dense, MaxPool2D, GlobalMaxPool2D
92
+ from tensorflow.keras.layers import Dropout, MaxPooling2D, GlobalAveragePooling2D
93
+ from tensorflow.keras.models import Model, Sequential
94
+ from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau, LearningRateScheduler
95
+ from tensorflow.keras.layers import DepthwiseConv2D, Conv2D, SeparableConv2D
96
+ from tensorflow.keras.layers import BatchNormalization, Activation, concatenate
97
+
98
+ try:
99
+ from tensorflow.keras.utils import plot_model
100
+ except:
101
+ pass
102
+
103
+ import tensorflow.keras.backend as K
104
+ from tensorflow.keras.utils import to_categorical
105
+ import tensorflow_addons as tfa
106
+
107
+ ##SKLEARN
108
+ from sklearn.preprocessing import RobustScaler #MinMaxScaler
109
+ from sklearn.metrics import confusion_matrix, classification_report
110
+
111
+ ##OTHER
112
+ from PIL import Image
113
+ from glob import glob
114
+ import matplotlib.pyplot as plt
115
+ import pandas as pd
116
+ import itertools
117
+ import joblib
118
+ import random
119
+ from tempfile import TemporaryFile
120
+ import tensorflow_addons as tfa
121
+ import tqdm
122
+
123
+ from skimage.transform import AffineTransform, warp #rotate,
app_files/src/sedinet_eval.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Written by Dr Daniel Buscombe, Marda Science LLC
2
+ # for the SandSnap Program
3
+ #
4
+ # MIT License
5
+ #
6
+ # Copyright (c) 2020-2021, Marda Science LLC
7
+ #
8
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ # of this software and associated documentation files (the "Software"), to deal
10
+ # in the Software without restriction, including without limitation the rights
11
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ # copies of the Software, and to permit persons to whom the Software is
13
+ # furnished to do so, subject to the following conditions:
14
+ #
15
+ # The above copyright notice and this permission notice shall be included in all
16
+ # copies or substantial portions of the Software.
17
+ #
18
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
+ # SOFTWARE.
25
+
26
+
27
+ ##> Release v1.4 (Aug 2021)
28
+
29
+ ###===================================================
30
+ # import libraries
31
+ from sedinet_models import *
32
+
33
+ ###===================================================
34
+ def get_data_generator(df, indices, greyscale, tilesize,batch_size=16):
35
+ """
36
+ This function generates data for a batch of images and no metric, for # "unseen" samples
37
+ """
38
+
39
+ for_training = False
40
+ images = []
41
+ while True:
42
+ for i in indices:
43
+ r = df.iloc[i]
44
+ file = r['files']
45
+
46
+ # if greyscale==True:
47
+ # im = Image.open(file).convert('LA')
48
+ # else:
49
+ # im = Image.open(file)
50
+ # im = im.resize((IM_HEIGHT, IM_HEIGHT))
51
+ # im = np.array(im) / 255.0
52
+
53
+ if greyscale==True:
54
+ im = Image.open(file).convert('LA')
55
+ #im = im.resize((IM_HEIGHT, IM_HEIGHT))
56
+ im = np.array(im)[:,:,0]
57
+ nx,ny = np.shape(im)
58
+ if (nx!=tilesize) or (ny!=tilesize):
59
+ im = im[int(nx/2)-int(tilesize/2):int(nx/2)+int(tilesize/2), int(ny/2)-int(tilesize/2):int(ny/2)+int(tilesize/2)]
60
+
61
+ else:
62
+ im = Image.open(file)
63
+ #im = im.resize((IM_HEIGHT, IM_HEIGHT))
64
+ im = np.array(im)
65
+ nx,ny,nz = np.shape(im)
66
+ if (nx!=tilesize) or (ny!=tilesize):
67
+ im = im[int(nx/2)-int(tilesize/2):int(nx/2)+int(tilesize/2), int(ny/2)-int(tilesize/2):int(ny/2)+int(tilesize/2)]
68
+
69
+ if greyscale==True:
70
+ images.append(np.expand_dims(im, axis=2)) #[:,:,0]
71
+ else:
72
+ images.append(im)
73
+
74
+ if len(images) >= batch_size:
75
+ yield np.array(images)
76
+ images = []
77
+ if not for_training:
78
+ break
79
+
80
+ ###===================================================
81
+ def get_data_generator_1vars(df, indices, for_training, vars, greyscale,
82
+ tilesize, batch_size=16):
83
+ """
84
+ This function generates data for a batch of images and 1 associated metric
85
+ """
86
+ images, p1s = [], []
87
+ while True:
88
+ for i in indices:
89
+ r = df.iloc[i]
90
+ file, p1 = r['files'], r[vars[0]]
91
+ #im = Image.open(file).convert('LA')
92
+ #im = im.resize((IM_HEIGHT, IM_HEIGHT))
93
+ #im = np.array(im) / 255.0
94
+ #im2 = np.rot90(im)
95
+
96
+ # if greyscale==True:
97
+ # im = Image.open(file).convert('LA')
98
+ # else:
99
+ # im = Image.open(file)
100
+ # im = im.resize((IM_HEIGHT, IM_HEIGHT))
101
+ # im = np.array(im) / 255.0
102
+
103
+ if greyscale==True:
104
+ im = Image.open(file).convert('LA')
105
+ #im = im.resize((IM_HEIGHT, IM_HEIGHT))
106
+ im = np.array(im)[:,:,0]
107
+ nx,ny = np.shape(im)
108
+ if (nx!=tilesize) or (ny!=tilesize):
109
+ im = im[int(nx/2)-int(tilesize/2):int(nx/2)+int(tilesize/2), int(ny/2)-int(tilesize/2):int(ny/2)+int(tilesize/2)]
110
+
111
+ else:
112
+ im = Image.open(file)
113
+ #im = im.resize((IM_HEIGHT, IM_HEIGHT))
114
+ im = np.array(im)
115
+ nx,ny,nz = np.shape(im)
116
+ if (nx!=tilesize) or (ny!=tilesize):
117
+ im = im[int(nx/2)-int(tilesize/2):int(nx/2)+int(tilesize/2), int(ny/2)-int(tilesize/2):int(ny/2)+int(tilesize/2)]
118
+
119
+
120
+ if greyscale==True:
121
+ images.append(np.expand_dims(im, axis=2))
122
+ else:
123
+ images.append(im)
124
+
125
+ p1s.append(p1)
126
+ if len(images) >= batch_size:
127
+ yield np.array(images), [np.array(p1s)]
128
+ images, p1s = [], []
129
+ if not for_training:
130
+ break
131
+
132
+ ###===================================================
133
+ def estimate_categorical(vars, csvfile, res_folder, dropout,
134
+ numclass, greyscale, name, mode):
135
+ """
136
+ This function uses a SediNet model for categorical prediction
137
+ """
138
+
139
+ ID_MAP = dict(zip(np.arange(numclass), [str(k) for k in range(numclass)]))
140
+
141
+ ##======================================
142
+ ## this randomly selects imagery for training and testing imagery sets
143
+ ## while also making sure that both training and tetsing sets have
144
+ ## at least 3 examples of each category
145
+ test_idx, test_df. _ = get_df(csvfile,fortrain=True)
146
+
147
+ # for 16GB RAM, used maximum of 200 samples to test on
148
+ # need to change batch gnerator into a better keras one
149
+
150
+ valid_gen = get_data_generator_1image(test_df, test_idx, True, ID_MAP,
151
+ vars[0], len(train_idx), greyscale, False, IM_HEIGHT) #np.min((200, len(train_idx))),
152
+
153
+ if SHALLOW is True:
154
+ if DO_AUG is True:
155
+ weights_path = name+"_"+mode+"_batch"+str(BATCH_SIZE)+"_im"+str(IM_HEIGHT)+\
156
+ "_"+str(IM_WIDTH)+"_shallow_"+vars[0]+"_"+CAT_LOSS+"_aug.hdf5"
157
+ else:
158
+ weights_path = name+"_"+mode+"_batch"+str(BATCH_SIZE)+"_im"+str(IM_HEIGHT)+\
159
+ "_"+str(IM_WIDTH)+"_shallow_"+vars[0]+"_"+CAT_LOSS+"_noaug.hdf5"
160
+ else:
161
+ if DO_AUG is True:
162
+ weights_path = name+"_"+mode+"_batch"+str(BATCH_SIZE)+"_im"+str(IM_HEIGHT)+\
163
+ "_"+str(IM_WIDTH)+"_"+vars[0]+"_"+CAT_LOSS+"_aug.hdf5"
164
+ else:
165
+ weights_path = name+"_"+mode+"_batch"+str(BATCH_SIZE)+"_im"+str(IM_HEIGHT)+\
166
+ "_"+str(IM_WIDTH)+"_"+vars[0]+"_"+CAT_LOSS+"_noaug.hdf5"
167
+
168
+
169
+ if not os.path.exists(weights_path):
170
+ weights_path = res_folder + os.sep+ weights_path
171
+ print("Using %s" % (weights_path))
172
+
173
+ if numclass>0:
174
+ ID_MAP = dict(zip(np.arange(numclass), [str(k) for k in range(numclass)]))
175
+
176
+ SM = make_cat_sedinet(ID_MAP, dropout)
177
+
178
+ if type(BATCH_SIZE)==list:
179
+ predict_test_train_cat(test_df, None, test_idx, None, vars[0],
180
+ SMs, [i for i in ID_MAP.keys()], weights_path, greyscale,
181
+ name, DO_AUG, IM_HEIGHT)
182
+ else:
183
+ predict_test_train_cat(test_df, None, test_idx, None, vars[0],
184
+ SM, [i for i in ID_MAP.keys()], weights_path, greyscale,
185
+ name, DO_AUG, IM_HEIGHT)
186
+
187
+ K.clear_session()
188
+
189
+ ##===================================
190
+ ## move model files and plots to the results folder
191
+ tidy(name, res_folder)
192
+
193
+ ###===================================================
194
+ def estimate_siso_simo(vars, csvfile, greyscale,
195
+ dropout, numclass, name, mode, res_folder,#scale,
196
+ batch_size, weights_path):
197
+ """
198
+ This function uses a sedinet model for continuous prediction
199
+ """
200
+
201
+ if not os.path.exists(weights_path):
202
+ weights_path = res_folder + os.sep+ weights_path
203
+ print("Using %s" % (weights_path))
204
+
205
+ ##======================================
206
+ ## this randomly selects imagery for training and testing imagery sets
207
+ ## while also making sure that both training and tetsing sets have
208
+ ## at least 3 examples of each category
209
+ #train_idx, train_df = get_df(train_csvfile)
210
+ train_idx, train_df,split = get_df(csvfile)
211
+
212
+ ##==============================================
213
+ ## create a sedinet model to estimate category
214
+ SM = make_sedinet_siso_simo(vars, greyscale, dropout)
215
+
216
+ # if scale==True:
217
+ # CS = []
218
+ # for var in vars:
219
+ # cs = RobustScaler() #MinMaxScaler()
220
+ # if split:
221
+ # cs.fit_transform(
222
+ # np.r_[train_df[0][var].values].reshape(-1,1)
223
+ # )
224
+ # else:
225
+ # cs.fit_transform(
226
+ # np.r_[train_df[var].values].reshape(-1,1)
227
+ # )
228
+ # CS.append(cs)
229
+ # del cs
230
+ # else:
231
+ # CS = []
232
+
233
+
234
+ do_aug = False
235
+ for_training = False
236
+ if type(train_df)==list:
237
+ print('Reading in all files and memory mapping in batches ... takes a while')
238
+ train_gen = []
239
+ for df,id in zip(train_df,train_idx):
240
+ train_gen.append(get_data_generator_Nvars_siso_simo(df, id, for_training,
241
+ vars, len(id), greyscale, do_aug, DO_STANDARDIZE, IM_HEIGHT))#CS,
242
+
243
+ x_train = []; vals = []; files = []
244
+ for gen in train_gen:
245
+ a, b = next(gen)
246
+ outfile = TemporaryFile()
247
+ files.append(outfile)
248
+ dt = a.dtype; sh = a.shape
249
+ fp = np.memmap(outfile, dtype=dt, mode='w+', shape=sh)
250
+ fp[:] = a[:]
251
+ fp.flush()
252
+ del a
253
+ del fp
254
+ a = np.memmap(outfile, dtype=dt, mode='r', shape=sh)
255
+ x_train.append(a)
256
+ vals.append(b)
257
+
258
+ else:
259
+ train_gen = get_data_generator_Nvars_siso_simo(train_df, train_idx, for_training,
260
+ vars, len(train_idx), greyscale,do_aug, DO_STANDARDIZE, IM_HEIGHT)# CS,
261
+
262
+ x_train, vals = next(train_gen)
263
+
264
+ # test model
265
+ # if numclass==0:
266
+ x_test=None
267
+ test_vals = None
268
+ if type(BATCH_SIZE)==list:
269
+ predict_test_train_siso_simo(x_train, vals, x_test, test_vals, vars, #train_df, None, train_idx, None,
270
+ SMs, weights_path, name, mode, greyscale, #CS,
271
+ dropout, DO_AUG, DO_STANDARDIZE,counter)#scale,
272
+ else:
273
+ if type(x_train)==list:
274
+ for counter, x in enumerate(x_train):
275
+ #print(counter)
276
+ predict_test_train_siso_simo(x, vals[counter], x_test, test_vals, vars,
277
+ SM, weights_path, name, mode, greyscale, #CS,
278
+ dropout, DO_AUG, DO_STANDARDIZE,counter)#scale,
279
+ else:
280
+ predict_test_train_siso_simo(x_train,vals, x_test, test_vals, vars,
281
+ SM, weights_path, name, mode, greyscale,# CS,
282
+ dropout,DO_AUG, DO_STANDARDIZE,counter)# scale
283
+ K.clear_session()
284
+
285
+ ##===================================
286
+ ## move model files and plots to the results folder
287
+ tidy(name, res_folder)
app_files/src/sedinet_infer.py ADDED
@@ -0,0 +1,544 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Written by Dr Daniel Buscombe, Marda Science LLC
2
+ # for the SandSnap Program
3
+ #
4
+ # MIT License
5
+ #
6
+ # Copyright (c) 2020-2021, Marda Science LLC
7
+ #
8
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ # of this software and associated documentation files (the "Software"), to deal
10
+ # in the Software without restriction, including without limitation the rights
11
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ # copies of the Software, and to permit persons to whom the Software is
13
+ # furnished to do so, subject to the following conditions:
14
+ #
15
+ # The above copyright notice and this permission notice shall be included in all
16
+ # copies or substantial portions of the Software.
17
+ #
18
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
+ # SOFTWARE.
25
+
26
+
27
+ ##> Release v1.4 (Aug 2021)
28
+
29
+ from sedinet_models import *
30
+
31
+ ###===================================================
32
+ def run_training_siso_simo(vars, train_csvfile, test_csvfile, val_csvfile, name, res_folder,
33
+ mode, greyscale, dropout, numclass): #scale
34
+ """
35
+ This function generates, trains and evaluates a sedinet model for
36
+ continuous prediction
37
+ """
38
+
39
+ if numclass>0:
40
+ ID_MAP = dict(zip(np.arange(numclass), [str(k) for k in range(numclass)]))
41
+
42
+ # ##======================================
43
+ # ## this randomly selects imagery for training and testing imagery sets
44
+ # ## while also making sure that both training and tetsing sets have
45
+ # ## at least 3 examples of each category
46
+ # train_idx, train_df, _ = get_df(train_csvfile,fortrain=True)
47
+ # test_idx, test_df, _ = get_df(test_csvfile,fortrain=True)
48
+
49
+ ##==============================================
50
+ ## create a sedinet model to estimate category
51
+ if numclass>0:
52
+ SM = make_cat_sedinet(ID_MAP, dropout)
53
+ else:
54
+ SM = make_sedinet_siso_simo(vars, greyscale, dropout)
55
+
56
+ # if scale==True:
57
+ # CS = []
58
+ # for var in vars:
59
+ # cs = RobustScaler() ##alternative = MinMaxScaler()
60
+ # cs.fit_transform(
61
+ # np.r_[train_df[var].values, test_df[var].values].reshape(-1,1)
62
+ # )
63
+ # CS.append(cs)
64
+ # del cs
65
+ # else:
66
+ # CS = []
67
+
68
+ ##==============================================
69
+ ## train model
70
+ if numclass==0:
71
+ if type(BATCH_SIZE)==list:
72
+ SMs = []; weights_path = []
73
+ for batch_size, valid_batch_size in zip(BATCH_SIZE, VALID_BATCH_SIZE):
74
+ sm, wp,train_df, test_df, val_df, train_idx, test_idx, val_idx = train_sedinet_siso_simo(SM, name,
75
+ train_csvfile, test_csvfile, val_csvfile, vars, mode, greyscale, #CS,
76
+ dropout, batch_size, valid_batch_size,
77
+ res_folder)#, scale)
78
+ SMs.append(sm)
79
+ weights_path.append(wp)
80
+ gc.collect()
81
+
82
+ else:
83
+ SM, weights_path,train_df, test_df, val_df, train_idx, test_idx, val_idx = train_sedinet_siso_simo(SM, name,
84
+ train_csvfile, test_csvfile, val_csvfile, vars, mode, greyscale, #CS,
85
+ dropout, BATCH_SIZE, VALID_BATCH_SIZE,
86
+ res_folder)#, scale)
87
+ else:
88
+ if type(BATCH_SIZE)==list:
89
+ SMs = []; weights_path = []
90
+ for batch_size, valid_batch_size in zip(BATCH_SIZE, VALID_BATCH_SIZE):
91
+ sm, wp = train_sedinet_cat(SM, train_df, test_df, train_idx,
92
+ test_idx, ID_MAP, vars, greyscale, name, mode,
93
+ batch_size, valid_batch_size, res_folder)
94
+ SMs.append(sm)
95
+ weights_path.append(wp)
96
+ gc.collect()
97
+
98
+ else:
99
+ SM, weights_path = train_sedinet_cat(SM, train_df, test_df, train_idx,
100
+ test_idx, ID_MAP, vars, greyscale, name, mode,
101
+ BATCH_SIZE, VALID_BATCH_SIZE, res_folder)
102
+
103
+
104
+ classes = np.arange(len(ID_MAP))
105
+
106
+ K.clear_session()
107
+
108
+ ##==============================================
109
+ # test model
110
+ do_aug = False
111
+ for_training = False
112
+ if type(test_df)==list:
113
+ print('Reading in all train files and memory mapping in batches ... takes a while')
114
+ test_gen = []
115
+ for df,id in zip(test_df,test_idx):
116
+ test_gen.append(get_data_generator_Nvars_siso_simo(df, id, for_training,
117
+ vars, len(id), greyscale, do_aug, DO_STANDARDIZE, IM_HEIGHT)) #CS,
118
+
119
+ x_test = []; test_vals = []; files = []
120
+ for gen in test_gen:
121
+ a, b = next(gen)
122
+ outfile = TemporaryFile()
123
+ files.append(outfile)
124
+ dt = a.dtype; sh = a.shape
125
+ fp = np.memmap(outfile, dtype=dt, mode='w+', shape=sh)
126
+ fp[:] = a[:]
127
+ fp.flush()
128
+ del a
129
+ del fp
130
+ a = np.memmap(outfile, dtype=dt, mode='r', shape=sh)
131
+ x_test.append(a)
132
+ test_vals.append(b)
133
+
134
+
135
+ else:
136
+ # train_gen = get_data_generator_Nvars_siso_simo(train_df, train_idx, for_training,
137
+ # vars, len(train_idx), greyscale, do_aug, DO_STANDARDIZE, IM_HEIGHT)#CS,
138
+
139
+ # x_train, train_vals = next(train_gen)
140
+
141
+ test_gen = get_data_generator_Nvars_siso_simo(test_df, test_idx, for_training,
142
+ vars, len(test_idx), greyscale, do_aug, DO_STANDARDIZE, IM_HEIGHT)
143
+
144
+ x_test, test_vals = next(test_gen)
145
+
146
+ # if numclass==0:
147
+ # # suffix = 'train'
148
+ # if type(BATCH_SIZE)==list:
149
+ # count_in = 0
150
+ # predict_train_siso_simo(x_train, train_vals, vars, #train_df, test_df, train_idx, test_idx, vars, x_test, test_vals,
151
+ # SMs, weights_path, name, mode, greyscale,# CS,
152
+ # dropout, DO_AUG,DO_STANDARDIZE, count_in)#scale,
153
+ # else:
154
+ # if type(x_train)==list:
155
+ # for count_in, (a, b) in enumerate(zip(x_train, train_vals)): #x_test, test_vals
156
+ # predict_train_siso_simo(a, b, vars, #train_df, test_df, train_idx, test_idx, vars, c, d,
157
+ # SM, weights_path, name, mode, greyscale,# CS,
158
+ # dropout, DO_AUG,DO_STANDARDIZE, count_in)#scale,
159
+ # plot_all_save_all(weights_path, vars)
160
+
161
+ # else:
162
+ # count_in = 0; consolidate = False
163
+ # predict_train_siso_simo(x_train, train_vals, vars, #train_df, test_df, train_idx, test_idx, vars, x_test, test_vals,
164
+ # SM, weights_path, name, mode, greyscale,# CS,
165
+ # dropout, DO_AUG,DO_STANDARDIZE, count_in)#scale,
166
+
167
+
168
+ if numclass==0:
169
+ if type(BATCH_SIZE)==list:
170
+ count_in = 0
171
+ predict_train_siso_simo(x_test, test_vals, vars,
172
+ SMs, weights_path, name, mode, greyscale,
173
+ dropout, DO_AUG,DO_STANDARDIZE, count_in)
174
+ else:
175
+ if type(x_test)==list:
176
+ for count_in, (a, b) in enumerate(zip(x_test, test_vals)):
177
+ predict_train_siso_simo(a, b, vars,
178
+ SM, weights_path, name, mode, greyscale,
179
+ dropout, DO_AUG,DO_STANDARDIZE, count_in)
180
+ plot_all_save_all(weights_path, vars)
181
+
182
+ else:
183
+ count_in = 0; #consolidate = False
184
+ predict_train_siso_simo(x_test, test_vals, vars,
185
+ SM, weights_path, name, mode, greyscale,
186
+ dropout, DO_AUG,DO_STANDARDIZE, count_in)
187
+
188
+ else:
189
+ if type(BATCH_SIZE)==list:
190
+ predict_test_train_cat(train_df, test_df, train_idx, test_idx, vars[0],
191
+ SMs, [i for i in ID_MAP.keys()], weights_path, greyscale,
192
+ name, DO_AUG,DO_STANDARDIZE)
193
+ else:
194
+ predict_test_train_cat(train_df, test_df, train_idx, test_idx, vars[0],
195
+ SM, [i for i in ID_MAP.keys()], weights_path, greyscale,
196
+ name, DO_AUG,DO_STANDARDIZE)
197
+
198
+ K.clear_session()
199
+
200
+ #
201
+
202
+ ##===================================
203
+ ## move model files and plots to the results folder
204
+ tidy(name, res_folder)
205
+
206
+
207
+ ###==================================
208
+ def train_sedinet_cat(SM, train_csvfile, test_csvfile, #train_df, test_df, train_idx, test_idx,
209
+ ID_MAP, vars, greyscale, name, mode, batch_size, valid_batch_size,
210
+ res_folder):
211
+ """
212
+ This function trains an implementation of SediNet
213
+ """
214
+ ##================================
215
+ ## create training and testing file generators, set the weights path,
216
+ ## plot the model, and create a callback list for model training
217
+ for_training=True
218
+ train_gen = get_data_generator_1image(train_df, train_idx, for_training, ID_MAP,
219
+ vars[0], batch_size, greyscale, DO_AUG, DO_STANDARDIZE, IM_HEIGHT) ##BATCH_SIZE
220
+ do_aug = False
221
+ valid_gen = get_data_generator_1image(test_df, test_idx, for_training, ID_MAP,
222
+ vars[0], valid_batch_size, greyscale, do_aug, DO_STANDARDIZE, IM_HEIGHT) ##VALID_BATCH_SIZE
223
+
224
+ if SHALLOW is True:
225
+ if DO_AUG is True:
226
+ weights_path = name+"_"+mode+"_batch"+str(batch_size)+"_im"+str(IM_HEIGHT)+\
227
+ "_"+str(IM_WIDTH)+"_shallow_"+vars[0]+"_"+CAT_LOSS+"_aug.hdf5"
228
+ else:
229
+ weights_path = name+"_"+mode+"_batch"+str(batch_size)+"_im"+str(IM_HEIGHT)+\
230
+ "_"+str(IM_WIDTH)+"_shallow_"+vars[0]+"_"+CAT_LOSS+"_noaug.hdf5"
231
+ else:
232
+ if DO_AUG is True:
233
+ weights_path = name+"_"+mode+"_batch"+str(batch_size)+"_im"+str(IM_HEIGHT)+\
234
+ "_"+str(IM_WIDTH)+"_"+vars[0]+"_"+CAT_LOSS+"_aug.hdf5"
235
+ else:
236
+ weights_path = name+"_"+mode+"_batch"+str(batch_size)+"_im"+str(IM_HEIGHT)+\
237
+ "_"+str(IM_WIDTH)+"_"+vars[0]+"_"+CAT_LOSS+"_noaug.hdf5"
238
+
239
+ if os.path.exists(weights_path):
240
+ SM.load_weights(weights_path)
241
+ print("==========================================")
242
+ print("Loading weights that already exist: %s" % (weights_path) )
243
+ print("Skipping model training")
244
+
245
+ elif os.path.exists(res_folder+os.sep+weights_path):
246
+ weights_path = res_folder+os.sep+weights_path
247
+ SM.load_weights(weights_path)
248
+ print("==========================================")
249
+ print("Loading weights that already exist: %s" % (weights_path) )
250
+ print("Skipping model training")
251
+
252
+ else:
253
+
254
+ try:
255
+ plot_model(SM, weights_path.replace('.hdf5', '_model.png'),
256
+ show_shapes=True, show_layer_names=True)
257
+ except:
258
+ pass
259
+
260
+ callbacks_list = [
261
+ ModelCheckpoint(weights_path, monitor='val_loss', verbose=1,
262
+ save_best_only=True, mode='min',
263
+ save_weights_only = True)
264
+ ]
265
+
266
+ print("=========================================")
267
+ print("[INFORMATION] schematic of the model has been written out to: "+\
268
+ weights_path.replace('.hdf5', '_model.png'))
269
+ print("[INFORMATION] weights will be written out to: "+weights_path)
270
+
271
+ ##==============================================
272
+ ## set checkpoint file and parameters that control early stopping,
273
+ ## and reduction of learning rate if and when validation
274
+ ## scores plateau upon successive epochs
275
+ # reduceloss_plat = ReduceLROnPlateau(monitor='val_loss', factor=FACTOR,
276
+ # patience=STOP_PATIENCE, verbose=1, mode='auto', min_delta=MIN_DELTA,
277
+ # cooldown=STOP_PATIENCE, min_lr=MIN_LR)
278
+ #
279
+ earlystop = EarlyStopping(monitor="val_loss", mode="min", patience=10)
280
+
281
+ model_checkpoint = ModelCheckpoint(weights_path, monitor='val_loss',
282
+ verbose=1, save_best_only=True, mode='min',
283
+ save_weights_only = True)
284
+
285
+ ##==============================================
286
+ ## train the model
287
+
288
+ ## with non-adaptive exponentially decreasing learning rate
289
+ #exponential_decay_fn = exponential_decay(MAX_LR, NUM_EPOCHS)
290
+
291
+ #lr_scheduler = LearningRateScheduler(exponential_decay_fn)
292
+
293
+ callbacks_list = [model_checkpoint, earlystop] #lr_scheduler
294
+
295
+ ## train the model
296
+ history = SM.fit(train_gen,
297
+ steps_per_epoch=len(train_idx)//batch_size, ##BATCH_SIZE
298
+ epochs=NUM_EPOCHS,
299
+ callbacks=callbacks_list,
300
+ validation_data=valid_gen, #use_multiprocessing=True,
301
+ validation_steps=len(test_idx)//valid_batch_size) #max_queue_size=10 ##VALID_BATCH_SIZE
302
+
303
+ ###===================================================
304
+ ## Plot the loss and accuracy as a function of epoch
305
+ plot_train_history_1var(history)
306
+ # plt.savefig(vars+'_'+str(IM_HEIGHT)+'_batch'+str(batch_size)+'_history.png', ##BATCH_SIZE
307
+ # dpi=300, bbox_inches='tight')
308
+ plt.savefig(weights_path.replace('.hdf5','_history.png'),dpi=300, bbox_inches='tight')
309
+ plt.close('all')
310
+
311
+ # serialize model to JSON to use later to predict
312
+ model_json = SM.to_json()
313
+ with open(weights_path.replace('.hdf5','.json'), "w") as json_file:
314
+ json_file.write(model_json)
315
+
316
+ return SM, weights_path
317
+
318
+
319
+ ###===================================================
320
+ def train_sedinet_siso_simo(SM, name, train_csvfile, test_csvfile, val_csvfile, #train_df, test_df, train_idx, test_idx,
321
+ vars, mode, greyscale, dropout, batch_size, valid_batch_size,#CS,
322
+ res_folder):#, scale):
323
+ """
324
+ This function trains an implementation of sedinet
325
+ """
326
+
327
+ ##==============================================
328
+ ## create training and testing file generators, set the weights path,
329
+ ## plot the model, and create a callback list for model training
330
+
331
+ # get a string saying how many variables, fr the output files
332
+ varstring = str(len(vars))+'vars' #''.join([str(k)+'_' for k in vars])
333
+
334
+ # mae the appropriate weights file
335
+ if SHALLOW is True:
336
+ if DO_AUG is True:
337
+ # if len(CS)>0:#scale is True:
338
+ # weights_path = name+"_"+mode+"_batch"+str(batch_size)+"_im"+str(IM_HEIGHT)+\
339
+ # "_"+str(IM_WIDTH)+"_shallow_"+varstring+"_"+CONT_LOSS+"_aug_scale.hdf5"
340
+ # else:
341
+ weights_path = name+"_"+mode+"_batch"+str(batch_size)+"_im"+str(IM_HEIGHT)+\
342
+ "_"+str(IM_WIDTH)+"_shallow_"+varstring+"_"+CONT_LOSS+"_aug.hdf5"
343
+ else:
344
+ # if len(CS)>0:#scale is True:
345
+ # weights_path = name+"_"+mode+"_batch"+str(batch_size)+"_im"+str(IM_HEIGHT)+\
346
+ # "_"+str(IM_WIDTH)+"_shallow_"+varstring+"_"+CONT_LOSS+"_noaug_scale.hdf5"
347
+ # else:
348
+ weights_path = name+"_"+mode+"_batch"+str(batch_size)+"_im"+str(IM_HEIGHT)+\
349
+ "_"+str(IM_WIDTH)+"_shallow_"+varstring+"_"+CONT_LOSS+"_noaug.hdf5"
350
+ else:
351
+ if DO_AUG is True:
352
+ # if len(CS)>0:#scale is True:
353
+ # weights_path = name+"_"+mode+"_batch"+str(batch_size)+"_im"+str(IM_HEIGHT)+\
354
+ # "_"+str(IM_WIDTH)+"_"+varstring+"_"+CONT_LOSS+"_aug_scale.hdf5"
355
+ # else:
356
+ weights_path = name+"_"+mode+"_batch"+str(batch_size)+"_im"+str(IM_HEIGHT)+\
357
+ "_"+str(IM_WIDTH)+"_"+varstring+"_"+CONT_LOSS+"_aug.hdf5"
358
+ else:
359
+ # if len(CS)>0:#scale is True:
360
+ # weights_path = name+"_"+mode+"_batch"+str(batch_size)+"_im"+str(IM_HEIGHT)+\
361
+ # "_"+str(IM_WIDTH)+"_"+varstring+"_"+CONT_LOSS+"_noaug_scale.hdf5"
362
+ # else:
363
+ weights_path = name+"_"+mode+"_batch"+str(batch_size)+"_im"+str(IM_HEIGHT)+\
364
+ "_"+varstring+"_"+CONT_LOSS+"_noaug.hdf5"
365
+
366
+
367
+ # if it already exists, skip training
368
+ if os.path.exists(weights_path):
369
+ SM.load_weights(weights_path)
370
+ print("==========================================")
371
+ print("Loading weights that already exist: %s" % (weights_path) )
372
+ print("Skipping model training")
373
+
374
+ ##======================================
375
+ ## this randomly selects imagery for training and testing imagery sets
376
+ ## while also making sure that both training and tetsing sets have
377
+ ## at least 3 examples of each category
378
+ train_idx, train_df, _ = get_df(train_csvfile,fortrain=False)
379
+ test_idx, test_df, _ = get_df(test_csvfile,fortrain=False)
380
+ val_idx, test_df, _ = get_df(val_csvfile,fortrain=False)
381
+
382
+
383
+ for_training = False
384
+ train_gen = get_data_generator_Nvars_siso_simo(train_df, train_idx, for_training,
385
+ vars, batch_size, greyscale,
386
+ DO_AUG, DO_STANDARDIZE, IM_HEIGHT) # CS,
387
+ do_aug = False
388
+ valid_gen = get_data_generator_Nvars_siso_simo(val_df, val_idx, for_training,
389
+ vars, valid_batch_size, greyscale,
390
+ do_aug, DO_STANDARDIZE, IM_HEIGHT) ##only augment training # CS,
391
+
392
+ # do_aug = False
393
+ # test_gen = get_data_generator_Nvars_siso_simo(test_df, test_idx, for_training,
394
+ # vars, valid_batch_size, greyscale,
395
+ # do_aug, DO_STANDARDIZE, IM_HEIGHT) ##only augment training # CS,
396
+
397
+ # if it already exists in res_folder, skip training
398
+ elif os.path.exists(res_folder+os.sep+weights_path):
399
+ weights_path = res_folder+os.sep+weights_path
400
+ SM.load_weights(weights_path)
401
+ print("==========================================")
402
+ print("Loading weights that already exist: %s" % (weights_path) )
403
+ print("Skipping model training")
404
+
405
+ ##======================================
406
+ ## this randomly selects imagery for training and testing imagery sets
407
+ ## while also making sure that both training and tetsing sets have
408
+ ## at least 3 examples of each category
409
+ train_idx, train_df, _ = get_df(train_csvfile,fortrain=False)
410
+ test_idx, test_df, _ = get_df(test_csvfile,fortrain=False)
411
+ val_idx, val_df, _ = get_df(val_csvfile,fortrain=False)
412
+
413
+ for_training = False
414
+ train_gen = get_data_generator_Nvars_siso_simo(train_df, train_idx, for_training,
415
+ vars, batch_size, greyscale,
416
+ DO_AUG, DO_STANDARDIZE, IM_HEIGHT) # CS,
417
+ do_aug = False
418
+ valid_gen = get_data_generator_Nvars_siso_simo(val_df, val_idx, for_training,
419
+ vars, valid_batch_size, greyscale,
420
+ do_aug, DO_STANDARDIZE, IM_HEIGHT) ##only augment training # CS,
421
+
422
+ # do_aug = False
423
+ # test_gen = get_data_generator_Nvars_siso_simo(test_df, test_idx, for_training,
424
+ # vars, valid_batch_size, greyscale,
425
+ # do_aug, DO_STANDARDIZE, IM_HEIGHT) ##only augment training # CS,
426
+
427
+ else: #train
428
+
429
+ ##======================================
430
+ ## this randomly selects imagery for training and testing imagery sets
431
+ ## while also making sure that both training and tetsing sets have
432
+ ## at least 3 examples of each category
433
+ train_idx, train_df, _ = get_df(train_csvfile,fortrain=True)
434
+ test_idx, test_df, _ = get_df(test_csvfile,fortrain=True)
435
+ val_idx, val_df, _ = get_df(val_csvfile,fortrain=True)
436
+
437
+ for_training = True
438
+ train_gen = get_data_generator_Nvars_siso_simo(train_df, train_idx, for_training,
439
+ vars, batch_size, greyscale,
440
+ DO_AUG, DO_STANDARDIZE, IM_HEIGHT) # CS,
441
+ # do_aug = False
442
+ # test_gen = get_data_generator_Nvars_siso_simo(test_df, test_idx, for_training,
443
+ # vars, valid_batch_size, greyscale,
444
+ # do_aug, DO_STANDARDIZE, IM_HEIGHT) ##only augment training # CS,
445
+
446
+ do_aug = False
447
+ valid_gen = get_data_generator_Nvars_siso_simo(val_df, val_idx, for_training,
448
+ vars, valid_batch_size, greyscale,
449
+ do_aug, DO_STANDARDIZE, IM_HEIGHT) ##only augment training # CS,
450
+
451
+ # if scaler=true (CS=[]), dump out scalers to pickle file
452
+ # if len(CS)==0:
453
+ # pass
454
+ # else:
455
+ # joblib.dump(CS, weights_path.replace('.hdf5','_scaler.pkl'))
456
+ # print('Wrote scaler to pkl file')
457
+
458
+ try: # plot the model if pydot/graphviz installed
459
+ plot_model(SM, weights_path.replace('.hdf5', '_model.png'),
460
+ show_shapes=True, show_layer_names=True)
461
+ print("model schematic written to: "+\
462
+ weights_path.replace('.hdf5', '_model.png'))
463
+ except:
464
+ pass
465
+
466
+ print("==========================================")
467
+ print("weights will be written out to: "+weights_path)
468
+
469
+ ##==============================================
470
+ ## set checkpoint file and parameters that control early stopping,
471
+ ## and reduction of learning rate if and when validation scores plateau upon successive epochs
472
+ # reduceloss_plat = ReduceLROnPlateau(monitor='val_loss', factor=FACTOR,
473
+ # patience=STOP_PATIENCE, verbose=1, mode='auto',
474
+ # min_delta=MIN_DELTA, cooldown=5,
475
+ # min_lr=MIN_LR)
476
+
477
+ earlystop = EarlyStopping(monitor="val_loss", mode="min",
478
+ patience=10)
479
+
480
+ # set model checkpoint. only save best weights, based on min validation loss
481
+ model_checkpoint = ModelCheckpoint(weights_path, monitor='val_loss', verbose=1,
482
+ save_best_only=True, mode='min',
483
+ save_weights_only = True)
484
+
485
+
486
+ #tqdm_callback = tfa.callbacks.TQDMProgressBar()
487
+ # callbacks_list = [model_checkpoint, reduceloss_plat, earlystop] #, tqdm_callback]
488
+
489
+ try: #write summary of the model to txt file
490
+ with open(weights_path.replace('.hdf5','') + '_report.txt','w') as fh:
491
+ # Pass the file handle in as a lambda function to make it callable
492
+ SM.summary(print_fn=lambda x: fh.write(x + '\n'))
493
+ fh.close()
494
+ print("model summary written to: "+ \
495
+ weights_path.replace('.hdf5','') + '_report.txt')
496
+ with open(weights_path.replace('.hdf5','') + '_report.txt','r') as fh:
497
+ tmp = fh.readlines()
498
+ print("===============================================")
499
+ print("Total parameters: %s" %\
500
+ (''.join(tmp).split('Total params:')[-1].split('\n')[0]))
501
+ fh.close()
502
+ print("===============================================")
503
+ except:
504
+ pass
505
+
506
+ ##==============================================
507
+ ## train the model
508
+
509
+ ## non-adaptive exponentially decreasing learning rate
510
+ # exponential_decay_fn = exponential_decay(MAX_LR, NUM_EPOCHS)
511
+
512
+ #lr_scheduler = LearningRateScheduler(exponential_decay_fn)
513
+
514
+ callbacks_list = [model_checkpoint, earlystop] #lr_scheduler
515
+
516
+ ## train the model
517
+ history = SM.fit(train_gen,
518
+ steps_per_epoch=len(train_idx)//batch_size, ##BATCH_SIZE
519
+ epochs=NUM_EPOCHS,
520
+ callbacks=callbacks_list,
521
+ validation_data=valid_gen, #use_multiprocessing=True,
522
+ validation_steps=len(val_idx)//valid_batch_size) #max_queue_size=10 ##VALID_BATCH_SIZE
523
+
524
+
525
+ ###===================================================
526
+ ## Plot the loss and accuracy as a function of epoch
527
+ if len(vars)==1:
528
+ plot_train_history_1var_mae(history)
529
+ else:
530
+ plot_train_history_Nvar(history, vars, len(vars))
531
+
532
+ varstring = ''.join([str(k)+'_' for k in vars])
533
+ plt.savefig(weights_path.replace('.hdf5', '_history.png'), dpi=300,
534
+ bbox_inches='tight')
535
+ plt.close('all')
536
+
537
+ # serialize model to JSON to use later to predict
538
+ model_json = SM.to_json()
539
+ with open(weights_path.replace('.hdf5','.json'), "w") as json_file:
540
+ json_file.write(model_json)
541
+
542
+ return SM, weights_path,train_df, test_df, val_df, train_idx, test_idx, val_idx
543
+
544
+ #
app_files/src/sedinet_models.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Written by Dr Daniel Buscombe, Marda Science LLC
2
+ # for the SandSnap Program
3
+ #
4
+ # MIT License
5
+ #
6
+ # Copyright (c) 2020-2021, Marda Science LLC
7
+ #
8
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ # of this software and associated documentation files (the "Software"), to deal
10
+ # in the Software without restriction, including without limitation the rights
11
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ # copies of the Software, and to permit persons to whom the Software is
13
+ # furnished to do so, subject to the following conditions:
14
+ #
15
+ # The above copyright notice and this permission notice shall be included in all
16
+ # copies or substantial portions of the Software.
17
+ #
18
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
+ # SOFTWARE.
25
+
26
+
27
+ ##> Release v1.4 (Aug 2021)
28
+
29
+ ###===================================================
30
+ # import libraries
31
+ from sedinet_utils import *
32
+
33
+ ###===================================================
34
+ def conv_block2(inp, filters=32, bn=True, pool=True, drop=True):
35
+ """
36
+ This function generates a SediNet convolutional block
37
+ """
38
+ # _ = Conv2D(filters=filters, kernel_size=3, activation='relu',
39
+ # kernel_initializer='he_uniform')(inp)
40
+
41
+ #relu creating dead neurons?
42
+ _ = SeparableConv2D(filters=filters, kernel_size=3, activation='relu')(inp) #'relu' #kernel_initializer='he_uniform'
43
+ if bn:
44
+ _ = BatchNormalization()(_)
45
+ if pool:
46
+ _ = MaxPool2D()(_)
47
+ if drop:
48
+ _ = Dropout(0.2)(_)
49
+ return _
50
+
51
+ ###===================================================
52
+ def make_cat_sedinet(ID_MAP, dropout):
53
+ """
54
+ This function creates an implementation of SediNet for estimating
55
+ sediment category
56
+ """
57
+
58
+ base = BASE_CAT ##30
59
+
60
+ input_layer = Input(shape=(IM_HEIGHT, IM_WIDTH, 3))
61
+ _ = conv_block2(input_layer, filters=base, bn=False, pool=False, drop=False) #x #
62
+ _ = conv_block2(_, filters=base*2, bn=False, pool=True,drop=False)
63
+ _ = conv_block2(_, filters=base*3, bn=False, pool=True,drop=False)
64
+ _ = conv_block2(_, filters=base*4, bn=False, pool=True,drop=False)
65
+
66
+ bottleneck = GlobalMaxPool2D()(_)
67
+ bottleneck = Dropout(dropout)(bottleneck)
68
+
69
+ # for class prediction
70
+ _ = Dense(units=CAT_DENSE_UNITS, activation='relu')(bottleneck) ##128
71
+ output = Dense(units=len(ID_MAP), activation='softmax', name='output')(_)
72
+
73
+ model = Model(inputs=input_layer, outputs=[output])
74
+
75
+ OPT = tf.keras.optimizers.Adam(learning_rate=MAX_LR)
76
+
77
+ if CAT_LOSS == 'focal':
78
+ model.compile(optimizer=OPT,
79
+ loss={'output': tfa.losses.SigmoidFocalCrossEntropy() },
80
+ metrics={'output': 'accuracy'})
81
+ else:
82
+ model.compile(optimizer=OPT, #'adam',
83
+ loss={'output': CAT_LOSS}, #'categorical_crossentropy'
84
+ metrics={'output': 'accuracy'})
85
+
86
+
87
+ print("==========================================")
88
+ print('[INFORMATION] Model summary:')
89
+ model.summary()
90
+ return model
91
+
92
+
93
+ ###===================================================
94
+ def make_sedinet_siso_simo(vars, greyscale, dropout):
95
+ """
96
+ This function creates an implementation of SediNet for estimating
97
+ sediment metric on a continuous scale
98
+ """
99
+
100
+ base = BASE_CONT ##30 ## suggested range = 20 -- 40
101
+ if greyscale==True:
102
+ input_layer = Input(shape=(IM_HEIGHT, IM_WIDTH, 1))
103
+ else:
104
+ input_layer = Input(shape=(IM_HEIGHT, IM_WIDTH, 3))
105
+
106
+ _ = conv_block2(input_layer, filters=base, bn=False, pool=False, drop=False) #x #
107
+ _ = conv_block2(_, filters=base*2, bn=False, pool=True,drop=False)
108
+ _ = conv_block2(_, filters=base*3, bn=False, pool=True,drop=False)
109
+ _ = conv_block2(_, filters=base*4, bn=False, pool=True,drop=False)
110
+ _ = conv_block2(_, filters=base*5, bn=False, pool=True,drop=False)
111
+
112
+ if not SHALLOW:
113
+ _ = conv_block2(_, filters=base*6, bn=False, pool=True,drop=False)
114
+ _ = conv_block2(_, filters=base*7, bn=False, pool=True,drop=False)
115
+ _ = conv_block2(_, filters=base*8, bn=False, pool=True,drop=False)
116
+ _ = conv_block2(_, filters=base*9, bn=False, pool=True,drop=False)
117
+
118
+ _ = BatchNormalization(axis=-1)(_)
119
+ bottleneck = GlobalMaxPool2D()(_)
120
+ bottleneck = Dropout(dropout)(bottleneck)
121
+
122
+ units = CONT_DENSE_UNITS ## suggested range 512 -- 1024
123
+ _ = Dense(units=units, activation='relu')(bottleneck) #'relu'
124
+
125
+ ##would it be better to predict the full vector directly instread of one by one?
126
+ outputs = []
127
+ for var in vars:
128
+ outputs.append(Dense(units=1, activation='linear', name=var+'_output')(_) ) #relu
129
+
130
+ if CONT_LOSS == 'pinball':
131
+ loss = dict(zip([k+"_output" for k in vars], [tfa.losses.PinballLoss(tau=.5) for k in vars]))
132
+ else: ## 'mse'
133
+ loss = dict(zip([k+"_output" for k in vars], ['mse' for k in vars])) #loss = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE) # Sum of squared error
134
+
135
+ metrics = dict(zip([k+"_output" for k in vars], ['mae' for k in vars]))
136
+
137
+ OPT = tf.keras.optimizers.Adam(learning_rate=MAX_LR)
138
+
139
+ model = Model(inputs=input_layer, outputs=outputs)
140
+ model.compile(optimizer=OPT,loss=loss, metrics=metrics)
141
+ #print("==========================================")
142
+ #print('[INFORMATION] Model summary:')
143
+ #model.summary()
144
+ return model
app_files/src/sedinet_utils.py ADDED
@@ -0,0 +1,2117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Written by Dr Daniel Buscombe, Marda Science LLC
2
+ # for the SandSnap Program
3
+ #
4
+ # MIT License
5
+ #
6
+ # Copyright (c) 2020-2021, Marda Science LLC
7
+ #
8
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ # of this software and associated documentation files (the "Software"), to deal
10
+ # in the Software without restriction, including without limitation the rights
11
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ # copies of the Software, and to permit persons to whom the Software is
13
+ # furnished to do so, subject to the following conditions:
14
+ #
15
+ # The above copyright notice and this permission notice shall be included in all
16
+ # copies or substantial portions of the Software.
17
+ #
18
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
+ # SOFTWARE.
25
+
26
+
27
+ ##> Release v1.4 (Aug 2021)
28
+
29
+ from imports import *
30
+ from matplotlib import MatplotlibDeprecationWarning
31
+
32
+ import warnings
33
+ warnings.filterwarnings(action="ignore",category=MatplotlibDeprecationWarning)
34
+
35
+ ###===================================================
36
+ ## FUNCTIONS FOR LEARNING RATE SCHEDULER
37
+
38
+ def exponential_decay(lr0, s):
39
+ def exponential_decay_fn(epoch):
40
+ return lr0 * 0.1 **(epoch / s)
41
+ return exponential_decay_fn
42
+
43
+ ###===================================================
44
+ ## IMAGE AUGMENTATION FUNCTIONS (for DO_AUG=True)
45
+
46
+
47
+ # def h_flip(image):
48
+ # return np.fliplr(image)
49
+
50
+ def v_flip(image):
51
+ return np.flipud(image)
52
+
53
+ def warp_shift(image):
54
+ shift= random.randint(25,200)
55
+ transform = AffineTransform(translation=(0,shift))
56
+ warp_image = warp(image, transform, mode="wrap")
57
+ return warp_image
58
+
59
+ def apply_aug(im):
60
+ return [im,v_flip(warp_shift(im))] #, clockwise_rotation(im), h_flip(im)]
61
+
62
+
63
+ ##========================================================
64
+ def rescale(dat,
65
+ mn,
66
+ mx):
67
+ '''
68
+ rescales an input dat between mn and mx
69
+ '''
70
+ m = min(dat.flatten())
71
+ M = max(dat.flatten())
72
+ return (mx-mn)*(dat-m)/(M-m)+mn
73
+
74
+ def do_standardize(img):
75
+ #standardization using adjusted standard deviation
76
+ N = np.shape(img)[0] * np.shape(img)[1]
77
+ s = np.maximum(np.std(img), 1.0/np.sqrt(N))
78
+ m = np.mean(img)
79
+ img = (img - m) / s
80
+ img = rescale(img, 0, 1)
81
+ del m, s, N
82
+
83
+ return img
84
+
85
+ ###===================================================
86
+ ### IMAGE BATCH GENERATOR FUNCTIONS
87
+
88
+ def get_data_generator_Nvars_siso_simo(df, indices, for_training, vars,
89
+ batch_size, greyscale, do_aug,#CS,
90
+ standardize, tilesize):
91
+ """
92
+ This function generates data for a batch of images and N associated metrics
93
+ """
94
+
95
+ ##print(do_aug)
96
+
97
+ if len(vars)==1:
98
+ images, p1s = [], []
99
+ elif len(vars)==2:
100
+ images, p1s, p2s = [], [], []
101
+ elif len(vars)==3:
102
+ images, p1s, p2s, p3s = [], [], [], []
103
+ elif len(vars)==4:
104
+ images, p1s, p2s, p3s, p4s = [], [], [], [], []
105
+ elif len(vars)==5:
106
+ images, p1s, p2s, p3s, p4s, p5s = [], [], [], [], [], []
107
+ elif len(vars)==6:
108
+ images, p1s, p2s, p3s, p4s, p5s, p6s =\
109
+ [], [], [], [], [], [], []
110
+ elif len(vars)==7:
111
+ images, p1s, p2s, p3s, p4s, p5s, p6s, p7s =\
112
+ [], [], [], [], [], [], [], []
113
+ elif len(vars)==8:
114
+ images, p1s, p2s, p3s, p4s, p5s, p6s, p7s, p8s =\
115
+ [], [], [], [], [], [], [], [], []
116
+ elif len(vars)==9:
117
+ images, p1s, p2s, p3s, p4s, p5s, p6s, p7s, p8s, p9s =\
118
+ [], [], [], [], [], [], [], [], [], []
119
+
120
+ while True:
121
+ for i in indices:
122
+ r = df.iloc[i]
123
+ if len(vars)==1:
124
+ file, p1 = r['filenames'], r[vars[0]]
125
+ if len(vars)==2:
126
+ file, p1, p2 = r['filenames'], r[vars[0]], r[vars[1]]
127
+ if len(vars)==3:
128
+ file, p1, p2, p3 = \
129
+ r['filenames'], r[vars[0]], r[vars[1]], r[vars[2]]
130
+ if len(vars)==4:
131
+ file, p1, p2, p3, p4 = \
132
+ r['filenames'], r[vars[0]], r[vars[1]], r[vars[2]], r[vars[3]]
133
+ if len(vars)==5:
134
+ file, p1, p2, p3, p4, p5 = \
135
+ r['filenames'], r[vars[0]], r[vars[1]], r[vars[2]], r[vars[3]], r[vars[4]]
136
+ if len(vars)==6:
137
+ file, p1, p2, p3, p4, p5, p6 = \
138
+ r['filenames'], r[vars[0]], r[vars[1]], r[vars[2]], r[vars[3]], r[vars[4]], r[vars[5]]
139
+ if len(vars)==7:
140
+ file, p1, p2, p3, p4, p5, p6, p7 = \
141
+ r['filenames'], r[vars[0]], r[vars[1]], r[vars[2]], r[vars[3]], r[vars[4]], r[vars[5]], r[vars[6]]
142
+ if len(vars)==8:
143
+ file, p1, p2, p3, p4, p5, p6, p7, p8 = \
144
+ r['filenames'], r[vars[0]], r[vars[1]], r[vars[2]], r[vars[3]], r[vars[4]], r[vars[5]], r[vars[6]], r[vars[7]]
145
+ elif len(vars)==9:
146
+ file, p1, p2, p3, p4, p5, p6, p7, p8, p9 = \
147
+ r['filenames'], r[vars[0]], r[vars[1]], r[vars[2]], r[vars[3]], r[vars[4]], r[vars[5]], r[vars[6]], r[vars[7]], r[vars[8]]
148
+
149
+ if greyscale==True:
150
+ im = Image.open(file).convert('LA')
151
+ #im = im.resize((IM_HEIGHT, IM_HEIGHT))
152
+ im = np.array(im)[:,:,0]
153
+ nx,ny = np.shape(im)
154
+ if (nx!=tilesize) or (ny!=tilesize):
155
+ im = im[int(nx/2)-int(tilesize/2):int(nx/2)+int(tilesize/2), int(ny/2)-int(tilesize/2):int(ny/2)+int(tilesize/2)]
156
+
157
+ else:
158
+ im = Image.open(file)
159
+ #im = im.resize((IM_HEIGHT, IM_HEIGHT))
160
+ im = np.array(im)
161
+ nx,ny,nz = np.shape(im)
162
+ if (nx!=tilesize) or (ny!=tilesize):
163
+ im = im[int(nx/2)-int(tilesize/2):int(nx/2)+int(tilesize/2), int(ny/2)-int(tilesize/2):int(ny/2)+int(tilesize/2)]
164
+
165
+ if standardize==True:
166
+ im = do_standardize(im)
167
+ else:
168
+ im = np.array(im) / 255.0
169
+
170
+ #if np.ndim(im)==2:
171
+ # im = np.dstack((im, im , im)) ##np.expand_dims(im[:,:,0], axis=2)
172
+
173
+ #im = im[:,:,:3]
174
+
175
+ if greyscale==True:
176
+ if do_aug==True:
177
+ aug = apply_aug(im)
178
+ images.append(aug)
179
+ if len(vars)==1:
180
+ p1s.append([p1 for k in range(2)])
181
+ elif len(vars)==2:
182
+ p1s.append([p1 for k in range(2)]); p2s.append([p2 for k in range(2)])
183
+ elif len(vars)==3:
184
+ p1s.append([p1 for k in range(2)]); p2s.append([p2 for k in range(2)])
185
+ p3s.append([p3 for k in range(2)]);
186
+ elif len(vars)==4:
187
+ p1s.append([p1 for k in range(2)]); p2s.append([p2 for k in range(2)])
188
+ p3s.append([p3 for k in range(2)]); p4s.append([p4 for k in range(2)])
189
+ elif len(vars)==5:
190
+ p1s.append([p1 for k in range(2)]); p2s.append([p2 for k in range(2)])
191
+ p3s.append([p3 for k in range(2)]); p4s.append([p4 for k in range(2)])
192
+ p5s.append([p5 for k in range(2)]);
193
+ elif len(vars)==6:
194
+ p1s.append([p1 for k in range(2)]); p2s.append([p2 for k in range(2)])
195
+ p3s.append([p3 for k in range(2)]); p4s.append([p4 for k in range(2)])
196
+ p5s.append([p5 for k in range(2)]); p6s.append([p6 for k in range(2)])
197
+ elif len(vars)==7:
198
+ p1s.append([p1 for k in range(2)]); p2s.append([p2 for k in range(2)])
199
+ p3s.append([p3 for k in range(2)]); p4s.append([p4 for k in range(2)])
200
+ p5s.append([p5 for k in range(2)]); p6s.append([p6 for k in range(2)])
201
+ p7s.append([p7 for k in range(2)]);
202
+ elif len(vars)==8:
203
+ p1s.append([p1 for k in range(2)]); p2s.append([p2 for k in range(2)])
204
+ p3s.append([p3 for k in range(2)]); p4s.append([p4 for k in range(2)])
205
+ p5s.append([p5 for k in range(2)]); p6s.append([p6 for k in range(2)])
206
+ p7s.append([p7 for k in range(2)]); p8s.append([p8 for k in range(2)])
207
+ elif len(vars)==9:
208
+ p1s.append([p1 for k in range(2)]); p2s.append([p2 for k in range(2)])
209
+ p3s.append([p3 for k in range(2)]); p4s.append([p4 for k in range(2)])
210
+ p5s.append([p5 for k in range(2)]); p6s.append([p6 for k in range(2)])
211
+ p7s.append([p7 for k in range(2)]); p8s.append([p8 for k in range(2)])
212
+ p9s.append([p9 for k in range(2)])
213
+
214
+ else:
215
+ images.append(np.expand_dims(im, axis=2))
216
+ if len(vars)==1:
217
+ p1s.append(p1)
218
+ elif len(vars)==2:
219
+ p1s.append(p1); p2s.append(p2)
220
+ elif len(vars)==3:
221
+ p1s.append(p1); p2s.append(p2)
222
+ p3s.append(p3);
223
+ elif len(vars)==4:
224
+ p1s.append(p1); p2s.append(p2)
225
+ p3s.append(p3); p4s.append(p4)
226
+ elif len(vars)==5:
227
+ p1s.append(p1); p2s.append(p2)
228
+ p3s.append(p3); p4s.append(p4)
229
+ p5s.append(p5);
230
+ elif len(vars)==6:
231
+ p1s.append(p1); p2s.append(p2)
232
+ p3s.append(p3); p4s.append(p4)
233
+ p5s.append(p5); p6s.append(p6)
234
+ elif len(vars)==7:
235
+ p1s.append(p1); p2s.append(p2)
236
+ p3s.append(p3); p4s.append(p4)
237
+ p5s.append(p5); p6s.append(p6)
238
+ p7s.append(p7);
239
+ elif len(vars)==8:
240
+ p1s.append(p1); p2s.append(p2)
241
+ p3s.append(p3); p4s.append(p4)
242
+ p5s.append(p5); p6s.append(p6)
243
+ p7s.append(p7); p8s.append(p8)
244
+ elif len(vars)==9:
245
+ p1s.append(p1); p2s.append(p2)
246
+ p3s.append(p3); p4s.append(p4)
247
+ p5s.append(p5); p6s.append(p6)
248
+ p7s.append(p7); p8s.append(p8)
249
+ p9s.append(p9)
250
+
251
+ else:
252
+ if do_aug==True:
253
+ aug = apply_aug(im)
254
+ images.append(aug)
255
+ if len(vars)==1:
256
+ p1s.append([p1 for k in range(2)])
257
+ elif len(vars)==2:
258
+ p1s.append([p1 for k in range(2)]); p2s.append([p2 for k in range(2)])
259
+ elif len(vars)==3:
260
+ p1s.append([p1 for k in range(2)]); p2s.append([p2 for k in range(2)])
261
+ p3s.append([p3 for k in range(2)]);
262
+ elif len(vars)==4:
263
+ p1s.append([p1 for k in range(2)]); p2s.append([p2 for k in range(2)])
264
+ p3s.append([p3 for k in range(2)]); p4s.append([p4 for k in range(2)])
265
+ elif len(vars)==5:
266
+ p1s.append([p1 for k in range(2)]); p2s.append([p2 for k in range(2)])
267
+ p3s.append([p3 for k in range(2)]); p4s.append([p4 for k in range(2)])
268
+ p5s.append([p5 for k in range(2)]);
269
+ elif len(vars)==6:
270
+ p1s.append([p1 for k in range(2)]); p2s.append([p2 for k in range(2)])
271
+ p3s.append([p3 for k in range(2)]); p4s.append([p4 for k in range(2)])
272
+ p5s.append([p5 for k in range(2)]); p6s.append([p6 for k in range(2)])
273
+ elif len(vars)==7:
274
+ p1s.append([p1 for k in range(2)]); p2s.append([p2 for k in range(2)])
275
+ p3s.append([p3 for k in range(2)]); p4s.append([p4 for k in range(2)])
276
+ p5s.append([p5 for k in range(2)]); p6s.append([p6 for k in range(2)])
277
+ p7s.append([p7 for k in range(2)]);
278
+ elif len(vars)==8:
279
+ p1s.append([p1 for k in range(2)]); p2s.append([p2 for k in range(2)])
280
+ p3s.append([p3 for k in range(2)]); p4s.append([p4 for k in range(2)])
281
+ p5s.append([p5 for k in range(2)]); p6s.append([p6 for k in range(2)])
282
+ p7s.append([p7 for k in range(2)]); p8s.append([p8 for k in range(2)])
283
+ elif len(vars)==9:
284
+ p1s.append([p1 for k in range(2)]); p2s.append([p2 for k in range(2)])
285
+ p3s.append([p3 for k in range(2)]); p4s.append([p4 for k in range(2)])
286
+ p5s.append([p5 for k in range(2)]); p6s.append([p6 for k in range(2)])
287
+ p7s.append([p7 for k in range(2)]); p8s.append([p8 for k in range(2)])
288
+ p9s.append([p9 for k in range(2)])
289
+
290
+ else:
291
+ images.append(im)
292
+ if len(vars)==1:
293
+ p1s.append(p1)
294
+ elif len(vars)==2:
295
+ p1s.append(p1); p2s.append(p2)
296
+ elif len(vars)==3:
297
+ p1s.append(p1); p2s.append(p2)
298
+ p3s.append(p3);
299
+ elif len(vars)==4:
300
+ p1s.append(p1); p2s.append(p2)
301
+ p3s.append(p3); p4s.append(p4)
302
+ elif len(vars)==5:
303
+ p1s.append(p1); p2s.append(p2)
304
+ p3s.append(p3); p4s.append(p4)
305
+ p5s.append(p5);
306
+ elif len(vars)==6:
307
+ p1s.append(p1); p2s.append(p2)
308
+ p3s.append(p3); p4s.append(p4)
309
+ p5s.append(p5); p6s.append(p6)
310
+ elif len(vars)==7:
311
+ p1s.append(p1); p2s.append(p2)
312
+ p3s.append(p3); p4s.append(p4)
313
+ p5s.append(p5); p6s.append(p6)
314
+ p7s.append(p7);
315
+ elif len(vars)==8:
316
+ p1s.append(p1); p2s.append(p2)
317
+ p3s.append(p3); p4s.append(p4)
318
+ p5s.append(p5); p6s.append(p6)
319
+ p7s.append(p7); p8s.append(p8)
320
+ elif len(vars)==9:
321
+ p1s.append(p1); p2s.append(p2)
322
+ p3s.append(p3); p4s.append(p4)
323
+ p5s.append(p5); p6s.append(p6)
324
+ p7s.append(p7); p8s.append(p8)
325
+ p9s.append(p9)
326
+
327
+ if len(images) >= batch_size:
328
+ if len(vars)==1:
329
+ # if len(CS)==0:
330
+ p1s = np.squeeze(np.array(p1s))
331
+ # else:
332
+ # p1s = np.squeeze(CS[0].transform(np.array(p1s).reshape(-1, 1)))
333
+ if do_aug==True:
334
+ if len(images) >= batch_size:
335
+ if greyscale==False:
336
+ images = np.array(np.vstack(images))
337
+ else:
338
+ images = np.expand_dims(np.array(np.vstack(images)), axis=-1)
339
+ p1s = np.expand_dims(np.vstack(p1s).flatten(),axis=-1)
340
+ yield images,[p1s]
341
+ else:
342
+ if len(images) >= batch_size:
343
+ #p1s = np.expand_dims(np.vstack(p1s).flatten(),axis=-1)
344
+ yield np.array(images),[np.array(p1s)]
345
+ images, p1s = [], []
346
+
347
+ elif len(vars)==2:
348
+ # if len(CS)==0:
349
+ p1s = np.squeeze(np.array(p1s))
350
+ p2s = np.squeeze(np.array(p2s))
351
+ # else:
352
+ # p1s = np.squeeze(CS[0].transform(np.array(p1s).reshape(-1, 1)))
353
+ # p2s = np.squeeze(CS[1].transform(np.array(p2s).reshape(-1, 1)))
354
+ if do_aug==True:
355
+ if len(images) >= batch_size:
356
+ if greyscale==False:
357
+ images = np.array(np.vstack(images))
358
+ else:
359
+ images = np.expand_dims(np.array(np.vstack(images)), axis=-1)
360
+ p1s = np.expand_dims(np.vstack(p1s).flatten(),axis=-1)
361
+ p2s = np.expand_dims(np.vstack(p2s).flatten(),axis=-1)
362
+ yield images,[p1s, p2s]
363
+ else:
364
+ if len(images) >= batch_size:
365
+ yield np.array(images),[np.array(p1s), np.array(p2s)]
366
+ images, p1s, p2s = [], [], []
367
+
368
+ elif len(vars)==3:
369
+ # if len(CS)==0:
370
+ p1s = np.squeeze(np.array(p1s))
371
+ p2s = np.squeeze(np.array(p2s))
372
+ p3s = np.squeeze(np.array(p3s))
373
+ # else:
374
+ # p1s = np.squeeze(CS[0].transform(np.array(p1s).reshape(-1, 1)))
375
+ # p2s = np.squeeze(CS[1].transform(np.array(p2s).reshape(-1, 1)))
376
+ # p3s = np.squeeze(CS[2].transform(np.array(p3s).reshape(-1, 1)))
377
+ if do_aug==True:
378
+ if len(images) >= batch_size:
379
+ if greyscale==False:
380
+ images = np.array(np.vstack(images))
381
+ else:
382
+ images = np.expand_dims(np.array(np.vstack(images)), axis=-1)
383
+ p1s = np.expand_dims(np.vstack(p1s).flatten(),axis=-1)
384
+ p2s = np.expand_dims(np.vstack(p2s).flatten(),axis=-1)
385
+ p3s = np.expand_dims(np.vstack(p3s).flatten(),axis=-1)
386
+ yield images,[p1s, p2s, p3s]
387
+ else:
388
+ if len(images) >= batch_size:
389
+ yield np.array(images),[np.array(p1s), np.array(p2s), np.array(p3s)]
390
+ images, p1s, p2s, p3s = [], [], [], []
391
+
392
+ elif len(vars)==4:
393
+ # if len(CS)==0:
394
+ p1s = np.squeeze(np.array(p1s))
395
+ p2s = np.squeeze(np.array(p2s))
396
+ p3s = np.squeeze(np.array(p3s))
397
+ p4s = np.squeeze(np.array(p4s))
398
+ # else:
399
+ # p1s = np.squeeze(CS[0].transform(np.array(p1s).reshape(-1, 1)))
400
+ # p2s = np.squeeze(CS[1].transform(np.array(p2s).reshape(-1, 1)))
401
+ # p3s = np.squeeze(CS[2].transform(np.array(p3s).reshape(-1, 1)))
402
+ # p4s = np.squeeze(CS[3].transform(np.array(p4s).reshape(-1, 1)))
403
+ if do_aug==True:
404
+ if len(images) >= batch_size:
405
+ if greyscale==False:
406
+ images = np.array(np.vstack(images))
407
+ else:
408
+ images = np.expand_dims(np.array(np.vstack(images)), axis=-1)
409
+ p1s = np.expand_dims(np.vstack(p1s).flatten(),axis=-1)
410
+ p2s = np.expand_dims(np.vstack(p2s).flatten(),axis=-1)
411
+ p3s = np.expand_dims(np.vstack(p3s).flatten(),axis=-1)
412
+ p4s = np.expand_dims(np.vstack(p4s).flatten(),axis=-1)
413
+ yield images,[p1s, p2s, p3s, p4s]
414
+ else:
415
+ if len(images) >= batch_size:
416
+ yield np.array(images),[np.array(p1s), np.array(p2s), np.array(p3s),
417
+ np.array(p4s)]
418
+ images, p1s, p2s, p3s, p4s = [], [], [], [], []
419
+
420
+ elif len(vars)==5:
421
+ # if len(CS)==0:
422
+ p1s = np.squeeze(np.array(p1s))
423
+ p2s = np.squeeze(np.array(p2s))
424
+ p3s = np.squeeze(np.array(p3s))
425
+ p4s = np.squeeze(np.array(p4s))
426
+ p5s = np.squeeze(np.array(p5s))
427
+ # else:
428
+ # p1s = np.squeeze(CS[0].transform(np.array(p1s).reshape(-1, 1)))
429
+ # p2s = np.squeeze(CS[1].transform(np.array(p2s).reshape(-1, 1)))
430
+ # p3s = np.squeeze(CS[2].transform(np.array(p3s).reshape(-1, 1)))
431
+ # p4s = np.squeeze(CS[3].transform(np.array(p4s).reshape(-1, 1)))
432
+ # p5s = np.squeeze(CS[4].transform(np.array(p5s).reshape(-1, 1)))
433
+ if do_aug==True:
434
+ if len(images) >= batch_size:
435
+ if greyscale==False:
436
+ images = np.array(np.vstack(images))
437
+ else:
438
+ images = np.expand_dims(np.array(np.vstack(images)), axis=-1)
439
+ p1s = np.expand_dims(np.vstack(p1s).flatten(),axis=-1)
440
+ p2s = np.expand_dims(np.vstack(p2s).flatten(),axis=-1)
441
+ p3s = np.expand_dims(np.vstack(p3s).flatten(),axis=-1)
442
+ p4s = np.expand_dims(np.vstack(p4s).flatten(),axis=-1)
443
+ p5s = np.expand_dims(np.vstack(p5s).flatten(),axis=-1)
444
+ yield images,[p1s, p2s, p3s, p4s, p5s]
445
+ else:
446
+ if len(images) >= batch_size:
447
+ yield np.array(images),[np.array(p1s), np.array(p2s), np.array(p3s),
448
+ np.array(p4s), np.array(p5s)]
449
+ images, p1s, p2s, p3s, p4s, p5s = [], [], [], [], [], []
450
+
451
+ elif len(vars)==6:
452
+ # if len(CS)==0:
453
+ p1s = np.squeeze(np.array(p1s))
454
+ p2s = np.squeeze(np.array(p2s))
455
+ p3s = np.squeeze(np.array(p3s))
456
+ p4s = np.squeeze(np.array(p4s))
457
+ p5s = np.squeeze(np.array(p5s))
458
+ p6s = np.squeeze(np.array(p6s))
459
+ # else:
460
+ # p1s = np.squeeze(CS[0].transform(np.array(p1s).reshape(-1, 1)))
461
+ # p2s = np.squeeze(CS[1].transform(np.array(p2s).reshape(-1, 1)))
462
+ # p3s = np.squeeze(CS[2].transform(np.array(p3s).reshape(-1, 1)))
463
+ # p4s = np.squeeze(CS[3].transform(np.array(p4s).reshape(-1, 1)))
464
+ # p5s = np.squeeze(CS[4].transform(np.array(p5s).reshape(-1, 1)))
465
+ # p6s = np.squeeze(CS[5].transform(np.array(p6s).reshape(-1, 1)))
466
+ if do_aug==True:
467
+ if len(images) >= batch_size:
468
+ if greyscale==False:
469
+ images = np.array(np.vstack(images))
470
+ else:
471
+ images = np.expand_dims(np.array(np.vstack(images)), axis=-1)
472
+ p1s = np.expand_dims(np.vstack(p1s).flatten(),axis=-1)
473
+ p2s = np.expand_dims(np.vstack(p2s).flatten(),axis=-1)
474
+ p3s = np.expand_dims(np.vstack(p3s).flatten(),axis=-1)
475
+ p4s = np.expand_dims(np.vstack(p4s).flatten(),axis=-1)
476
+ p5s = np.expand_dims(np.vstack(p5s).flatten(),axis=-1)
477
+ p6s = np.expand_dims(np.vstack(p6s).flatten(),axis=-1)
478
+ yield images,[p1s, p2s, p3s, p4s, p5s, p6s]
479
+ else:
480
+ if len(images) >= batch_size:
481
+ yield np.array(images),[np.array(p1s), np.array(p2s), np.array(p3s),
482
+ np.array(p4s), np.array(p5s), np.array(p6s)]
483
+ images, p1s, p2s, p3s, p4s, p5s, p6s = \
484
+ [], [], [], [], [], [], []
485
+
486
+ elif len(vars)==7:
487
+ # if len(CS)==0:
488
+ p1s = np.squeeze(np.array(p1s))
489
+ p2s = np.squeeze(np.array(p2s))
490
+ p3s = np.squeeze(np.array(p3s))
491
+ p4s = np.squeeze(np.array(p4s))
492
+ p5s = np.squeeze(np.array(p5s))
493
+ p6s = np.squeeze(np.array(p6s))
494
+ p7s = np.squeeze(np.array(p7s))
495
+ # else:
496
+ # p1s = np.squeeze(CS[0].transform(np.array(p1s).reshape(-1, 1)))
497
+ # p2s = np.squeeze(CS[1].transform(np.array(p2s).reshape(-1, 1)))
498
+ # p3s = np.squeeze(CS[2].transform(np.array(p3s).reshape(-1, 1)))
499
+ # p4s = np.squeeze(CS[3].transform(np.array(p4s).reshape(-1, 1)))
500
+ # p5s = np.squeeze(CS[4].transform(np.array(p5s).reshape(-1, 1)))
501
+ # p6s = np.squeeze(CS[5].transform(np.array(p6s).reshape(-1, 1)))
502
+ # p7s = np.squeeze(CS[6].transform(np.array(p7s).reshape(-1, 1)))
503
+ if do_aug==True:
504
+ if len(images) >= batch_size:
505
+ if greyscale==False:
506
+ images = np.array(np.vstack(images))
507
+ else:
508
+ images = np.expand_dims(np.array(np.vstack(images)), axis=-1)
509
+ p1s = np.expand_dims(np.vstack(p1s).flatten(),axis=-1)
510
+ p2s = np.expand_dims(np.vstack(p2s).flatten(),axis=-1)
511
+ p3s = np.expand_dims(np.vstack(p3s).flatten(),axis=-1)
512
+ p4s = np.expand_dims(np.vstack(p4s).flatten(),axis=-1)
513
+ p5s = np.expand_dims(np.vstack(p5s).flatten(),axis=-1)
514
+ p6s = np.expand_dims(np.vstack(p6s).flatten(),axis=-1)
515
+ p7s = np.expand_dims(np.vstack(p7s).flatten(),axis=-1)
516
+ yield images,[p1s, p2s, p3s, p4s, p5s, p6s, p7s]
517
+ else:
518
+ if len(images) >= batch_size:
519
+ yield np.array(images),[np.array(p1s), np.array(p2s), np.array(p3s),
520
+ np.array(p4s), np.array(p5s), np.array(p6s),
521
+ np.array(p7s)]
522
+ images, p1s, p2s, p3s, p4s, p5s, p6s, p7s = \
523
+ [], [], [], [], [], [], [], []
524
+
525
+ elif len(vars)==8:
526
+ # if len(CS)==0:
527
+ p1s = np.squeeze(np.array(p1s))
528
+ p2s = np.squeeze(np.array(p2s))
529
+ p3s = np.squeeze(np.array(p3s))
530
+ p4s = np.squeeze(np.array(p4s))
531
+ p5s = np.squeeze(np.array(p5s))
532
+ p6s = np.squeeze(np.array(p6s))
533
+ p7s = np.squeeze(np.array(p7s))
534
+ p8s = np.squeeze(np.array(p8s))
535
+ # else:
536
+ # p1s = np.squeeze(CS[0].transform(np.array(p1s).reshape(-1, 1)))
537
+ # p2s = np.squeeze(CS[1].transform(np.array(p2s).reshape(-1, 1)))
538
+ # p3s = np.squeeze(CS[2].transform(np.array(p3s).reshape(-1, 1)))
539
+ # p4s = np.squeeze(CS[3].transform(np.array(p4s).reshape(-1, 1)))
540
+ # p5s = np.squeeze(CS[4].transform(np.array(p5s).reshape(-1, 1)))
541
+ # p6s = np.squeeze(CS[5].transform(np.array(p6s).reshape(-1, 1)))
542
+ # p7s = np.squeeze(CS[6].transform(np.array(p7s).reshape(-1, 1)))
543
+ # p8s = np.squeeze(CS[7].transform(np.array(p8s).reshape(-1, 1)))
544
+ if do_aug==True:
545
+ if len(images) >= batch_size:
546
+ if greyscale==False:
547
+ images = np.array(np.vstack(images))
548
+ else:
549
+ images = np.expand_dims(np.array(np.vstack(images)), axis=-1)
550
+ p1s = np.expand_dims(np.vstack(p1s).flatten(),axis=-1)
551
+ p2s = np.expand_dims(np.vstack(p2s).flatten(),axis=-1)
552
+ p3s = np.expand_dims(np.vstack(p3s).flatten(),axis=-1)
553
+ p4s = np.expand_dims(np.vstack(p4s).flatten(),axis=-1)
554
+ p5s = np.expand_dims(np.vstack(p5s).flatten(),axis=-1)
555
+ p6s = np.expand_dims(np.vstack(p6s).flatten(),axis=-1)
556
+ p7s = np.expand_dims(np.vstack(p7s).flatten(),axis=-1)
557
+ p8s = np.expand_dims(np.vstack(p8s).flatten(),axis=-1)
558
+ yield images,[p1s, p2s, p3s, p4s, p5s, p6s, p7s, p8s]
559
+
560
+ else:
561
+ if len(images) >= batch_size:
562
+ yield np.array(images),[np.array(p1s), np.array(p2s), np.array(p3s),
563
+ np.array(p4s), np.array(p5s), np.array(p6s),
564
+ np.array(p7s), np.array(p8s)]
565
+ images, p1s, p2s, p3s, p4s, p5s, p6s, p7s, p8s = \
566
+ [], [], [], [], [], [], [], [], []
567
+
568
+ elif len(vars)==9:
569
+ # if len(CS)==0:
570
+ p1s = np.squeeze(np.array(p1s))
571
+ p2s = np.squeeze(np.array(p2s))
572
+ p3s = np.squeeze(np.array(p3s))
573
+ p4s = np.squeeze(np.array(p4s))
574
+ p5s = np.squeeze(np.array(p5s))
575
+ p6s = np.squeeze(np.array(p6s))
576
+ p7s = np.squeeze(np.array(p7s))
577
+ p8s = np.squeeze(np.array(p8s))
578
+ p9s = np.squeeze(np.array(p9s))
579
+ # else:
580
+ # p1s = np.squeeze(CS[0].transform(np.array(p1s).reshape(-1, 1)))
581
+ # p2s = np.squeeze(CS[1].transform(np.array(p2s).reshape(-1, 1)))
582
+ # p3s = np.squeeze(CS[2].transform(np.array(p3s).reshape(-1, 1)))
583
+ # p4s = np.squeeze(CS[3].transform(np.array(p4s).reshape(-1, 1)))
584
+ # p5s = np.squeeze(CS[4].transform(np.array(p5s).reshape(-1, 1)))
585
+ # p6s = np.squeeze(CS[5].transform(np.array(p6s).reshape(-1, 1)))
586
+ # p7s = np.squeeze(CS[6].transform(np.array(p7s).reshape(-1, 1)))
587
+ # p8s = np.squeeze(CS[7].transform(np.array(p8s).reshape(-1, 1)))
588
+ # p9s = np.squeeze(CS[8].transform(np.array(p9s).reshape(-1, 1)))
589
+
590
+ try:
591
+ if do_aug==True:
592
+ if len(images) >= batch_size:
593
+ if greyscale==False:
594
+ images = np.array(np.vstack(images))
595
+ else:
596
+ images = np.expand_dims(np.array(np.vstack(images)), axis=-1)
597
+ p1s = np.expand_dims(np.vstack(p1s).flatten(),axis=-1)
598
+ p2s = np.expand_dims(np.vstack(p2s).flatten(),axis=-1)
599
+ p3s = np.expand_dims(np.vstack(p3s).flatten(),axis=-1)
600
+ p4s = np.expand_dims(np.vstack(p4s).flatten(),axis=-1)
601
+ p5s = np.expand_dims(np.vstack(p5s).flatten(),axis=-1)
602
+ p6s = np.expand_dims(np.vstack(p6s).flatten(),axis=-1)
603
+ p7s = np.expand_dims(np.vstack(p7s).flatten(),axis=-1)
604
+ p8s = np.expand_dims(np.vstack(p8s).flatten(),axis=-1)
605
+ p9s = np.expand_dims(np.vstack(p9s).flatten(),axis=-1)
606
+ yield images,[p1s, p2s, p3s, p4s, p5s, p6s, p7s, p8s, p9s]
607
+ else:
608
+ if len(images) >= batch_size:
609
+ yield np.array(images),[np.array(p1s), np.array(p2s), np.array(p3s),
610
+ np.array(p4s), np.array(p5s), np.array(p6s),
611
+ np.array(p7s), np.array(p8s), np.array(p9s)]
612
+ except GeneratorExit:
613
+ print("")
614
+ images, p1s, p2s, p3s, p4s, p5s, p6s, p7s, p8s, p9s = \
615
+ [], [], [], [], [], [], [], [], [], []
616
+
617
+ if not for_training:
618
+ break
619
+
620
+
621
+ ###===================================================
622
+ def get_data_generator_1image(df, indices, for_training, ID_MAP,
623
+ var, batch_size, greyscale, do_aug,
624
+ standardize, tilesize):
625
+ """
626
+ This function creates a dataset generator consisting of batches of images
627
+ and corresponding one-hot-encoded labels describing the sediment in each image
628
+ """
629
+ try:
630
+ ID_MAP2 = dict((g, i) for i, g in ID_MAP.items())
631
+ except:
632
+ ID_MAP = dict(zip(np.arange(ID_MAP), [str(k) for k in range(ID_MAP)]))
633
+ ID_MAP2 = dict((g, i) for i, g in ID_MAP.items())
634
+
635
+ images, pops = [], []
636
+ while True:
637
+ for i in indices:
638
+ r = df.iloc[i]
639
+ file, pop = r['filenames'], r[var]
640
+
641
+ # if greyscale==True:
642
+ # im = Image.open(file).convert('LA')
643
+ # else:
644
+ # im = Image.open(file)
645
+ # im = im.resize((IM_HEIGHT, IM_HEIGHT))
646
+ # im = np.array(im) / 255.0
647
+ if greyscale==True:
648
+ im = Image.open(file).convert('LA')
649
+ #im = im.resize((IM_HEIGHT, IM_HEIGHT))
650
+ im = np.array(im)[:,:,0]
651
+ nx,ny = np.shape(im)
652
+ if (nx!=tilesize) or (ny!=tilesize):
653
+ im = im[int(nx/2)-int(tilesize/2):int(nx/2)+int(tilesize/2), int(ny/2)-int(tilesize/2):int(ny/2)+int(tilesize/2)]
654
+
655
+ else:
656
+ im = Image.open(file)
657
+ #im = im.resize((IM_HEIGHT, IM_HEIGHT))
658
+ im = np.array(im)
659
+ nx,ny,nz = np.shape(im)
660
+ if (nx!=tilesize) or (ny!=tilesize):
661
+ im = im[int(nx/2)-int(tilesize/2):int(nx/2)+int(tilesize/2), int(ny/2)-int(tilesize/2):int(ny/2)+int(tilesize/2)]
662
+
663
+ if standardize==True:
664
+ im = do_standardize(im)
665
+
666
+ # if np.ndim(im)==2:
667
+ # im = np.dstack((im, im , im)) ##np.expand_dims(im[:,:,0], axis=2)
668
+ # im = im[:,:,:3]
669
+
670
+ if greyscale==True:
671
+ if do_aug==True:
672
+ aug = apply_aug(im[:,:,0])
673
+ images.append(aug)
674
+ pops.append([to_categorical(pop, len(ID_MAP2)) for k in range(2)]) #3
675
+ else:
676
+ images.append(np.expand_dims(im[:,:,0], axis=2))
677
+ else:
678
+ if do_aug==True:
679
+ aug = apply_aug(im)
680
+ images.append(aug)
681
+ pops.append([to_categorical(pop, len(ID_MAP2)) for k in range(2)])
682
+ else:
683
+ images.append(im)
684
+ pops.append(to_categorical(pop, len(ID_MAP2)))
685
+
686
+ try:
687
+ if do_aug==True:
688
+ if len(images) >= batch_size:
689
+ if greyscale==False:
690
+ images = np.array(np.vstack(images))
691
+ pops = np.array(np.vstack(pops))
692
+ else:
693
+ images = np.expand_dims(np.array(np.vstack(images)), axis=-1)
694
+ pops = np.array(np.vstack(pops))
695
+ yield images, pops
696
+ images, pops = [], []
697
+ else:
698
+ if len(images) >= batch_size:
699
+ yield np.squeeze(np.array(images)),np.array(pops) #[np.array(pops)]
700
+ images, pops = [], []
701
+ except GeneratorExit:
702
+ print("") #pass
703
+
704
+ if not for_training:
705
+ break
706
+
707
+
708
+ ###===================================================
709
+ ### PLOT TRAINING HISTORY FUNCTIONS
710
+
711
+ def plot_train_history_1var(history):
712
+ """
713
+ This function plots loss and accuracy curves from the model training
714
+ """
715
+ fig, axes = plt.subplots(1, 2, figsize=(10, 10))
716
+
717
+ print(history.history.keys())
718
+
719
+ axes[0].plot(history.history['loss'], label='Training loss')
720
+ axes[0].plot(history.history['val_loss'], label='Validation loss')
721
+ axes[0].set_xlabel('Epochs')
722
+ axes[0].legend()
723
+ try:
724
+ axes[1].plot(history.history['acc'], label='pop train accuracy')
725
+ axes[1].plot(history.history['val_acc'], label='pop test accuracy')
726
+ except:
727
+ axes[1].plot(history.history['accuracy'], label='pop train accuracy')
728
+ axes[1].plot(history.history['val_accuracy'], label='pop test accuracy')
729
+ axes[1].set_xlabel('Epochs')
730
+ axes[1].legend()
731
+
732
+
733
+ ###===================================================
734
+ def plot_train_history_Nvar(history, varuse, N):
735
+ """
736
+ This function makes a plot of error train/validation history for 9 variables,
737
+ plus overall loss functions
738
+ """
739
+ fig, axes = plt.subplots(1, N+1, figsize=(20, 5))
740
+ for k in range(N):
741
+ try:
742
+ axes[k].plot(history.history[varuse[k]+'_output_mean_absolute_error'],
743
+ label=varuse[k]+' Train MAE')
744
+ axes[k].plot(history.history['val_'+varuse[k]+'_output_mean_absolute_error'],
745
+ label=varuse[k]+' Val MAE')
746
+ except:
747
+ axes[k].plot(history.history[varuse[k]+'_output_mae'],
748
+ label=varuse[k]+' Train MAE')
749
+ axes[k].plot(history.history['val_'+varuse[k]+'_output_mae'],
750
+ label=varuse[k]+' Val MAE')
751
+ axes[k].set_xlabel('Epochs')
752
+ axes[k].legend()
753
+
754
+ axes[N].plot(history.history['loss'], label='Training loss')
755
+ axes[N].plot(history.history['val_loss'], label='Validation loss')
756
+ axes[N].set_xlabel('Epochs')
757
+ axes[N].legend()
758
+
759
+
760
+ ###===================================================
761
+ def plot_train_history_1var_mae(history):
762
+ """
763
+ This function plots loss and accuracy curves from the model training
764
+ """
765
+ print(history.history.keys())
766
+
767
+ fig, axes = plt.subplots(1, 2, figsize=(10, 10))
768
+
769
+ axes[0].plot(history.history['loss'], label='Training loss')
770
+ axes[0].plot(history.history['val_loss'],
771
+ label='Validation loss')
772
+ axes[0].set_xlabel('Epochs')
773
+ axes[0].legend()
774
+
775
+ try:
776
+ axes[1].plot(history.history['mean_absolute_error'],
777
+ label='pop train MAE')
778
+ axes[1].plot(history.history['val_mean_absolute_error'],
779
+ label='pop test MAE')
780
+ except:
781
+ axes[1].plot(history.history['mae'], label='pop train MAE')
782
+ axes[1].plot(history.history['val_mae'], label='pop test MAE')
783
+
784
+ axes[1].set_xlabel('Epochs')
785
+ axes[1].legend()
786
+
787
+ ###===================================================
788
+ ### PLOT CONFUSION MATRIX FUNCTIONS
789
+
790
+ ###===================================================
791
+ def plot_confusion_matrix(cm, classes,
792
+ normalize=False,
793
+ cmap=plt.cm.Purples,
794
+ dolabels=True):
795
+ """
796
+ This function prints and plots the confusion matrix.
797
+ Normalization can be applied by setting `normalize=True`.
798
+ """
799
+ if normalize:
800
+ cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
801
+ cm[np.isnan(cm)] = 0
802
+
803
+ plt.imshow(cm, interpolation='nearest', cmap=cmap, vmax=1, vmin=0)
804
+ fmt = '.2f' if normalize else 'd'
805
+ thresh = cm.max() / 2.
806
+ if dolabels==True:
807
+ tick_marks = np.arange(len(classes))
808
+ plt.xticks(tick_marks, classes, fontsize=3) #, rotation=60
809
+ plt.yticks(tick_marks, classes, fontsize=3)
810
+
811
+ plt.ylabel('True label',fontsize=4)
812
+ plt.xlabel('Estimated label',fontsize=4)
813
+
814
+ else:
815
+ plt.axis('off')
816
+
817
+ for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
818
+ if cm[i, j]>0:
819
+ plt.text(j, i, str(cm[i, j])[:4],fontsize=5,
820
+ horizontalalignment="center",
821
+ color="white" if cm[i, j] > 0.6 else "black")
822
+ #plt.tight_layout()
823
+
824
+ plt.xlim(-0.5, len(classes))
825
+ plt.ylim(-0.5, len(classes))
826
+ return cm
827
+
828
+ ###===================================================
829
+ def plot_confmat(y_pred, y_true, prefix, classes):
830
+ """
831
+ This function generates and plots a confusion matrix
832
+ """
833
+ base = prefix+'_'
834
+
835
+ y = y_pred.copy()
836
+ del y_pred
837
+ l = y_true.copy()
838
+ del y_true
839
+
840
+ l = l.astype('float')
841
+ ytrue = l.flatten()
842
+ ypred = y.flatten()
843
+
844
+ ytrue = ytrue[~np.isnan(ytrue)]
845
+ ypred = ypred[~np.isnan(ypred)]
846
+
847
+ cm = confusion_matrix(ytrue, ypred)
848
+ cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
849
+ cm[np.isnan(cm)] = 0
850
+
851
+ fig=plt.figure()
852
+ plt.subplot(221)
853
+ plot_confusion_matrix(cm, classes=classes)
854
+
855
+
856
+
857
+ ###===================================================
858
+ ### PREDICTION FUNCTIONS
859
+
860
+ def predict_test_train_cat(train_df, test_df, train_idx, test_idx, var, SM,
861
+ classes, weights_path, greyscale, name, do_aug, tilesize):
862
+ """
863
+ This function creates makes predictions on test and train data,
864
+ prints a classification report, and prints confusion matrices
865
+ """
866
+ if type(SM) == list:
867
+ counter = 0
868
+ for s,wp in zip(SM, weights_path):
869
+ exec('SM[counter].load_weights(wp)')
870
+ counter += 1
871
+ else:
872
+ SM.load_weights(weights_path)
873
+
874
+ ##==============================================
875
+ ## make predictions on training data
876
+ for_training = False
877
+ train_gen = get_data_generator_1image(train_df, train_idx, for_training,
878
+ len(classes), var, len(train_idx), greyscale, #np.min((200, len(train_idx))),
879
+ do_aug, standardize, tilesize)
880
+ x_train, (trueT)= next(train_gen)
881
+
882
+ PT = []
883
+
884
+ if type(SM) == list:
885
+ #counter = 0
886
+ for s in SM:
887
+ tmp=s.predict(x_train, batch_size=8)
888
+ exec(
889
+ 'PT.append(np.asarray(np.squeeze(tmp)))'
890
+ )
891
+ del tmp
892
+
893
+ predT = np.median(PT, axis=0)
894
+ #predT = np.squeeze(np.asarray(PT))
895
+ del PT
896
+ K.clear_session()
897
+ gc.collect()
898
+
899
+ else:
900
+ predT = SM.predict(x_train, batch_size=8)
901
+ #predT = np.asarray(predT).argmax(axis=-1)
902
+
903
+ del train_gen, x_train
904
+
905
+ if test_df is not None:
906
+ ## make predictions on testing data
907
+ for_training = False
908
+ do_aug = False
909
+ test_gen = get_data_generator_1image(test_df, test_idx, for_training,
910
+ len(classes), var, len(test_idx), greyscale, #np.min((200, len(test_idx))),
911
+ do_aug, standardize, tilesize) #no augmentation on validation data
912
+ x_test, (true)= next(test_gen)
913
+
914
+ PT = []
915
+
916
+ if type(SM) == list:
917
+ #counter = 0
918
+ for s in SM:
919
+ tmp=s.predict(x_test, batch_size=8)
920
+ exec(
921
+ 'PT.append(np.asarray(np.squeeze(tmp)))'
922
+ )
923
+ del tmp
924
+
925
+ pred = np.median(PT, axis=0)
926
+ #pred = np.squeeze(np.asarray(PT))
927
+ del PT
928
+ K.clear_session()
929
+ gc.collect()
930
+
931
+ else:
932
+
933
+ pred = SM.predict(x_test, batch_size=8) #1)
934
+ #pred = np.asarray(pred).argmax(axis=-1)
935
+
936
+ del test_gen, x_test
937
+
938
+ trueT = np.squeeze(np.asarray(trueT).argmax(axis=-1) )
939
+ predT = np.squeeze(np.asarray(predT).argmax(axis=-1))#[0])
940
+
941
+ if test_df is not None:
942
+ pred = np.squeeze(np.asarray(pred).argmax(axis=-1))#[0])
943
+ true = np.squeeze(np.asarray(true).argmax(axis=-1) )
944
+
945
+ ##==============================================
946
+ ## print a classification report to screen, showing f1, precision, recall and accuracy
947
+ print("==========================================")
948
+ print("Classification report for "+var)
949
+ print(classification_report(true, pred))
950
+
951
+ fig = plt.figure()
952
+ ##==============================================
953
+ ## create figures showing confusion matrices for train and test data sets
954
+ if type(SM) == list:
955
+ if test_df is not None:
956
+ plot_confmat(pred, true, var, classes)
957
+ plt.savefig(weights_path[0].replace('.hdf5','_cm.png').\
958
+ replace('batch','_'.join(np.asarray(BATCH_SIZE, dtype='str'))),
959
+ dpi=300, bbox_inches='tight')
960
+ plt.close('all')
961
+
962
+ plot_confmat(predT, trueT, var+'T',classes)
963
+ plt.savefig(weights_path[0].replace('.hdf5','_cmT.png').\
964
+ replace('batch','_'.join(np.asarray(BATCH_SIZE, dtype='str'))),
965
+ dpi=300, bbox_inches='tight')
966
+ plt.close('all')
967
+
968
+ else:
969
+ if test_df is not None:
970
+ plot_confmat(pred, true, var, classes)
971
+ plt.savefig(weights_path.replace('.hdf5','_cm.png'),
972
+ dpi=300, bbox_inches='tight')
973
+ plt.close('all')
974
+
975
+ plot_confmat(predT, trueT, var+'T',classes)
976
+ plt.savefig(weights_path.replace('.hdf5','_cmT.png'),
977
+ dpi=300, bbox_inches='tight')
978
+ plt.close('all')
979
+
980
+ plt.close()
981
+ del fig
982
+
983
+
984
+ ###===================================================
985
+ def predict_train_siso_simo(a, b, vars,
986
+ SM, weights_path, name, mode, greyscale,
987
+ dropout,do_aug, standardize,#CS,# scale,
988
+ count_in):
989
+ """
990
+ This function creates makes predcitions on test and train data
991
+ """
992
+ ##==============================================
993
+ ## make predictions on training data
994
+ if type(SM) == list:
995
+ counter = 0
996
+ for s,wp in zip(SM, weights_path):
997
+ exec('SM[counter].load_weights(wp)')
998
+ counter += 1
999
+ else:
1000
+ SM.load_weights(weights_path)
1001
+
1002
+ # if scale == True:
1003
+ #
1004
+ # if len(vars)>1:
1005
+ # counter = 0
1006
+ # for v in vars:
1007
+ # exec(
1008
+ # v+\
1009
+ # '_trueT = np.squeeze(CS[counter].inverse_transform(b[counter].reshape(-1,1)))'
1010
+ # )
1011
+ # counter +=1
1012
+ # else:
1013
+ # exec(
1014
+ # vars[0]+\
1015
+ # '_trueT = np.squeeze(CS[0].inverse_transform(b[0].reshape(-1,1)))'
1016
+ # )
1017
+ #
1018
+ # else:
1019
+ if len(vars)>1:
1020
+ counter = 0
1021
+ for v in vars:
1022
+ exec(
1023
+ v+\
1024
+ '_trueT = np.squeeze(b[counter])'
1025
+ )
1026
+ counter +=1
1027
+ else:
1028
+ exec(
1029
+ vars[0]+\
1030
+ '_trueT = np.squeeze(b)'
1031
+ )
1032
+
1033
+ del b
1034
+
1035
+ for v in vars:
1036
+ exec(v+'_PT = []')
1037
+
1038
+ # if scale == True:
1039
+ #
1040
+ # if type(SM) == list:
1041
+ # counter = 0 #model iterator
1042
+ # for s in SM:
1043
+ # train_vals=s.predict(a, batch_size=8)
1044
+ #
1045
+ # if len(vars)>1:
1046
+ # counter2 = 0 #variable iterator
1047
+ # for v in vars:
1048
+ # exec(
1049
+ # v+\
1050
+ # '_PT.append(np.squeeze(CS[counter].inverse_transform(train_vals[counter2].reshape(-1,1))))'
1051
+ # )
1052
+ # counter2 +=1
1053
+ # else:
1054
+ # exec(
1055
+ # vars[0]+\
1056
+ # '_PT.append(np.asarray(np.squeeze(CS[0].inverse_transform(train_vals.reshape(-1,1)))))'
1057
+ # )
1058
+ #
1059
+ # del train_vals
1060
+ #
1061
+ # if len(vars)>1:
1062
+ # #counter = 0
1063
+ # for v in vars:
1064
+ # exec(
1065
+ # v+\
1066
+ # '_PT = np.median('+v+'_PT, axis=0)'
1067
+ # )
1068
+ # #counter +=1
1069
+ # else:
1070
+ # exec(
1071
+ # vars[0]+\
1072
+ # '_PT = np.median('+v+'_PT, axis=0)'
1073
+ # )
1074
+ #
1075
+ # else:
1076
+ # train_vals = SM.predict(a, batch_size=8) #128)
1077
+ #
1078
+ # if len(vars)>1:
1079
+ # counter = 0
1080
+ # for v in vars:
1081
+ # exec(
1082
+ # v+\
1083
+ # '_PT.append(np.squeeze(CS[counter].inverse_transform(train_vals[counter].reshape(-1,1))))'
1084
+ # )
1085
+ # counter +=1
1086
+ # else:
1087
+ # exec(
1088
+ # vars[0]+\
1089
+ # '_PT.append(np.asarray(np.squeeze(CS[0].inverse_transform(train_vals.reshape(-1,1)))))'
1090
+ # )
1091
+ #
1092
+ # del train_vals
1093
+ #
1094
+ # else:
1095
+
1096
+ if type(SM) == list:
1097
+ #counter = 0
1098
+ for s in SM:
1099
+ train_vals=s.predict(a, batch_size=8)
1100
+
1101
+ if len(vars)>1:
1102
+ counter2 = 0
1103
+ for v in vars:
1104
+ exec(
1105
+ v+\
1106
+ '_PT.append(np.squeeze(train_vals[counter2]))'
1107
+ )
1108
+ counter2 +=1
1109
+ else:
1110
+ exec(
1111
+ vars[0]+\
1112
+ '_PT.append(np.asarray(train_vals))'
1113
+ )
1114
+
1115
+ del train_vals
1116
+
1117
+ if len(vars)>1:
1118
+ #counter = 0
1119
+ for v in vars:
1120
+ exec(
1121
+ v+\
1122
+ '_PT = np.median('+v+'_PT, axis=0)'
1123
+ )
1124
+ #counter +=1
1125
+ else:
1126
+ exec(
1127
+ vars[0]+\
1128
+ '_PT = np.median('+v+'_PT, axis=0)'
1129
+ )
1130
+
1131
+ else:
1132
+ train_vals = SM.predict(a, batch_size=1)#8) #128)
1133
+
1134
+ if len(vars)>1:
1135
+ counter = 0
1136
+ for v in vars:
1137
+ exec(
1138
+ v+\
1139
+ '_PT.append(np.squeeze(train_vals[counter]))'
1140
+ )
1141
+ counter +=1
1142
+ else:
1143
+ exec(
1144
+ vars[0]+\
1145
+ '_PT.append(np.asarray(np.squeeze(train_vals)))'
1146
+ )
1147
+
1148
+ del train_vals
1149
+
1150
+
1151
+
1152
+ if len(vars)>1:
1153
+ for k in range(len(vars)):
1154
+ exec(vars[k]+'_predT = np.squeeze(np.asarray('+vars[k]+'_PT))')
1155
+ else:
1156
+ exec(vars[0]+'_predT = np.squeeze(np.asarray('+vars[0]+'_PT))')
1157
+
1158
+
1159
+ for v in vars:
1160
+ exec('del '+v+'_PT')
1161
+
1162
+ del a #train_gen,
1163
+
1164
+ if len(vars)==9:
1165
+ nrows = 3; ncols = 3
1166
+ elif len(vars)==8:
1167
+ nrows = 2; ncols = 4
1168
+ elif len(vars)==7:
1169
+ nrows = 3; ncols = 3
1170
+ elif len(vars)==6:
1171
+ nrows = 3; ncols = 2
1172
+ elif len(vars)==5:
1173
+ nrows = 3; ncols = 2
1174
+ elif len(vars)==4:
1175
+ nrows = 2; ncols = 2
1176
+ elif len(vars)==3:
1177
+ nrows = 2; ncols = 2
1178
+ elif len(vars)==2:
1179
+ nrows = 1; ncols = 2
1180
+ elif len(vars)==1:
1181
+ nrows = 1; ncols = 1
1182
+
1183
+ out = dict()
1184
+
1185
+ ## make a plot
1186
+ fig = plt.figure(figsize=(6*nrows,6*ncols))
1187
+ labs = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
1188
+ for k in range(1,1+(nrows*ncols)):
1189
+ # try:
1190
+ plt.subplot(nrows,ncols,k)
1191
+ x1 = eval(vars[k-1]+'_trueT')
1192
+ y1 = eval(vars[k-1]+'_predT')
1193
+ out[vars[k-1]+'_trueT'] = eval(vars[k-1]+'_trueT')
1194
+ out[vars[k-1]+'_predT'] = eval(vars[k-1]+'_predT')
1195
+
1196
+ plt.plot(x1, y1, 'ko', markersize=5)
1197
+ plt.plot([ np.min(np.hstack((x1,y1))), np.max(np.hstack((x1,y1)))],
1198
+ [ np.min(np.hstack((x1,y1))), np.max(np.hstack((x1,y1)))],
1199
+ 'k', lw=2)
1200
+ plt.text(np.nanmin(x1), 0.7*np.max(np.hstack((x1,y1))),'Train : '+\
1201
+ str(np.nanmean(100*(np.abs(y1-x1) / x1)))[:5]+\
1202
+ ' %', fontsize=10)
1203
+
1204
+ plt.title(r''+labs[k-1]+') '+vars[k-1], fontsize=8, loc='left')
1205
+
1206
+ #varstring = ''.join([str(k)+'_' for k in vars])
1207
+ varstring = str(len(vars))+'vars'
1208
+
1209
+ # except:
1210
+ # pass
1211
+ if type(SM) == list:
1212
+ plt.savefig(weights_path[0].replace('.hdf5', '_skill_ensemble'+str(count_in)+'.png').\
1213
+ replace('batch','_'.join(np.asarray(BATCH_SIZE, dtype='str'))),
1214
+ dpi=300, bbox_inches='tight')
1215
+
1216
+ else:
1217
+ plt.savefig(weights_path.replace('.hdf5', '_skill'+str(count_in)+'.png'),
1218
+ dpi=300, bbox_inches='tight')
1219
+
1220
+ plt.close()
1221
+ del fig
1222
+
1223
+ np.savez_compressed(weights_path.replace('.hdf5', '_out'+str(count_in)+'.npz'),**out)
1224
+ del out
1225
+
1226
+ if len(vars)==9:
1227
+ nrows = 3; ncols = 3
1228
+ elif len(vars)==8:
1229
+ nrows = 2; ncols = 4
1230
+ elif len(vars)==7:
1231
+ nrows = 3; ncols = 3
1232
+ elif len(vars)==6:
1233
+ nrows = 3; ncols = 2
1234
+ elif len(vars)==5:
1235
+ nrows = 3; ncols = 2
1236
+ elif len(vars)==4:
1237
+ nrows = 2; ncols = 2
1238
+ elif len(vars)==3:
1239
+ nrows = 2; ncols = 2
1240
+ elif len(vars)==2:
1241
+ nrows = 1; ncols = 2
1242
+ elif len(vars)==1:
1243
+ nrows = 1; ncols = 1
1244
+
1245
+ out = dict()
1246
+
1247
+ ## make a plot
1248
+ fig = plt.figure(figsize=(6*nrows,6*ncols))
1249
+ labs = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
1250
+ for k in range(1,1+(nrows*ncols)):
1251
+ # try:
1252
+ plt.subplot(nrows,ncols,k)
1253
+ x1 = eval(vars[k-1]+'_trueT')
1254
+ y1 = eval(vars[k-1]+'_predT')
1255
+ out[vars[k-1]+'_trueT'] = eval(vars[k-1]+'_trueT')
1256
+ out[vars[k-1]+'_predT'] = eval(vars[k-1]+'_predT')
1257
+
1258
+
1259
+ plt.plot(x1, y1, 'ko', markersize=5)
1260
+ plt.plot([ np.min(np.hstack((x1,y1))), np.max(np.hstack((x1,y1)))],
1261
+ [ np.min(np.hstack((x1,y1))), np.max(np.hstack((x1,y1)))],
1262
+ 'k', lw=2)
1263
+ plt.text(np.nanmin(x1), 0.7*np.max(np.hstack((x1,y1))),'Train : '+\
1264
+ str(np.nanmean(100*(np.abs(y1-x1) / x1)))[:5]+\
1265
+ ' %', fontsize=10)
1266
+
1267
+ plt.title(r''+labs[k-1]+') '+vars[k-1], fontsize=8, loc='left')
1268
+
1269
+ #varstring = ''.join([str(k)+'_' for k in vars])
1270
+ varstring = str(len(vars))+'vars'
1271
+
1272
+ # except:
1273
+ # pass
1274
+ if type(SM) == list:
1275
+ plt.savefig(weights_path[0].replace('.hdf5', '_skill_ensemble'+str(count_in)+'.png').\
1276
+ replace('batch','_'.join(np.asarray(BATCH_SIZE, dtype='str'))),
1277
+ dpi=300, bbox_inches='tight')
1278
+
1279
+ else:
1280
+ plt.savefig(weights_path.replace('.hdf5', '_skill'+str(count_in)+'.png'),
1281
+ dpi=300, bbox_inches='tight')
1282
+
1283
+ plt.close()
1284
+ del fig
1285
+
1286
+ np.savez_compressed(weights_path.replace('.hdf5', '_out'+str(count_in)+'.npz'),**out)
1287
+ del out
1288
+
1289
+
1290
+ ###============================================================
1291
+ def plot_all_save_all(weights_path, vars):
1292
+
1293
+ if type(weights_path) == list:
1294
+ npz_files = glob(weights_path[0].replace('.hdf5', '*.npz'))
1295
+ else:
1296
+ npz_files = glob(weights_path.replace('.hdf5', '*.npz'))
1297
+
1298
+ npz_files = [n for n in npz_files if '_all.npz' not in n]
1299
+
1300
+ print("Found %i npz files "%(len(npz_files)))
1301
+ if len(vars)==9:
1302
+ nrows = 3; ncols = 3
1303
+ elif len(vars)==8:
1304
+ nrows = 2; ncols = 4
1305
+ elif len(vars)==7:
1306
+ nrows = 3; ncols = 3
1307
+ elif len(vars)==6:
1308
+ nrows = 3; ncols = 2
1309
+ elif len(vars)==5:
1310
+ nrows = 3; ncols = 2
1311
+ elif len(vars)==4:
1312
+ nrows = 2; ncols = 2
1313
+ elif len(vars)==3:
1314
+ nrows = 2; ncols = 2
1315
+ elif len(vars)==2:
1316
+ nrows = 1; ncols = 2
1317
+ elif len(vars)==1:
1318
+ nrows = 1; ncols = 1
1319
+
1320
+ ## make a plot
1321
+ fig = plt.figure(figsize=(6*nrows,6*ncols))
1322
+
1323
+ for counter,file in enumerate(npz_files):
1324
+ out = dict()
1325
+ with np.load(file, allow_pickle=True) as dat:
1326
+ for k in dat.keys():
1327
+ out[k] = dat[k]
1328
+
1329
+ labs = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
1330
+ X = []; Y=[]
1331
+ Xt = []; Yt=[]
1332
+
1333
+ for k in range(1,1+(nrows*ncols)):
1334
+ # try:
1335
+ plt.subplot(nrows,ncols,k)
1336
+ x1 = out[vars[k-1]+'_trueT']
1337
+ y1 = out[vars[k-1]+'_predT']
1338
+
1339
+ X.append(x1.flatten()); Y.append(y1.flatten())
1340
+ del x1, y1
1341
+
1342
+ x1 = np.array(X)
1343
+ y1 = np.array(Y)
1344
+
1345
+ plt.plot(x1, y1, 'ko', markersize=5)
1346
+
1347
+ if counter==len(npz_files)-1:
1348
+ plt.plot([ np.min(np.hstack((x1,y1))), np.max(np.hstack((x1,y1)))],
1349
+ [ np.min(np.hstack((x1,y1))), np.max(np.hstack((x1,y1)))],
1350
+ 'k', lw=2)
1351
+
1352
+ plt.text(np.nanmin(x1), 0.8*np.max(np.hstack((x1,y1))),'Train : '+\
1353
+ str(np.mean(100*(np.abs(y1-x1) / x1)))[:5]+\
1354
+ ' %', fontsize=12, color='r')
1355
+
1356
+ plt.title(r''+labs[k-1]+') '+vars[k-1], fontsize=8, loc='left')
1357
+
1358
+ out[vars[k-1]+'_trueT'] = x1 #eval(vars[k-1]+'_trueT')
1359
+ out[vars[k-1]+'_predT'] = y1 #eval(vars[k-1]+'_predT')
1360
+
1361
+
1362
+ if counter==len(npz_files)-1:
1363
+ plt.plot([ np.min(np.hstack((x1,y1))), np.max(np.hstack((x1,y1)))],
1364
+ [ np.min(np.hstack((x1,y1))), np.max(np.hstack((x1,y1)))],
1365
+ 'k', lw=2)
1366
+
1367
+ plt.text(np.nanmin(x1), 0.8*np.max(np.hstack((x1,y1))),'Train : '+\
1368
+ str(np.mean(100*(np.abs(y1-x1) / x1)))[:5]+\
1369
+ ' %', fontsize=12, color='r')
1370
+
1371
+ try:
1372
+ plt.text(np.nanmin(x2), 0.8*np.max(np.hstack((x2,y2))),'Test : '+\
1373
+ str(np.mean(100*(np.abs(y2-x2) / x2)))[:5]+\
1374
+ ' %', fontsize=12, color='r')
1375
+ except:
1376
+ pass
1377
+
1378
+ plt.title(r''+labs[k-1]+') '+vars[k-1], fontsize=8, loc='left')
1379
+
1380
+
1381
+ if type(weights_path) == list:
1382
+ plt.savefig(weights_path[0].replace('.hdf5', '_skill_ensemble.png').\
1383
+ replace('batch','_'.join(np.asarray(BATCH_SIZE, dtype='str'))),
1384
+ dpi=300, bbox_inches='tight')
1385
+
1386
+ else:
1387
+ plt.savefig(weights_path.replace('.hdf5', '_skill.png'),
1388
+ dpi=300, bbox_inches='tight')
1389
+
1390
+ plt.close()
1391
+ del fig
1392
+ np.savez_compressed(weights_path.replace('.hdf5', '_out_all.npz'),**out)
1393
+ del out
1394
+
1395
+
1396
+ ###===================================================
1397
+ ### MISC. UTILITIES
1398
+
1399
+ def tidy(name,res_folder):
1400
+ """
1401
+ This function moves training outputs to a specific folder
1402
+ """
1403
+
1404
+ pngfiles = glob('*'+name+'*.png')
1405
+ jsonfiles = glob('*'+name+'*.json')
1406
+ hfiles = glob('*'+name+'*.hdf5')
1407
+ tfiles = glob('*'+name+'*.txt')
1408
+ #pfiles = glob('*'+name+'*.pkl')
1409
+ nfiles = glob('*'+name+'*.npz')
1410
+
1411
+ try:
1412
+ [shutil.move(k, res_folder) for k in pngfiles]
1413
+ [shutil.move(k, res_folder) for k in hfiles]
1414
+ [shutil.move(k, res_folder) for k in jsonfiles]
1415
+ [shutil.move(k, res_folder) for k in tfiles]
1416
+ #[shutil.move(k, res_folder) for k in pfiles]
1417
+ [shutil.move(k, res_folder) for k in nfiles]
1418
+ except:
1419
+ pass
1420
+
1421
+ ###===================================================
1422
+ def get_df(csvfile,fortrain=False):
1423
+ """
1424
+ This function reads a csvfile with image names and labels
1425
+ and returns random indices
1426
+ """
1427
+ ###===================================================
1428
+ ## read the data set in, clean and modify the pathnames so they are absolute
1429
+ df = pd.read_csv(csvfile)
1430
+
1431
+ num_split = 50
1432
+ if fortrain==False:
1433
+ if len(df)>num_split:
1434
+ #print('Spliting into chunks')
1435
+ df = np.array_split(df, int(np.round(len(df)/num_split)))
1436
+ split = True
1437
+ else:
1438
+ split = False
1439
+ else:
1440
+ split = False
1441
+
1442
+ if split:
1443
+ for k in range(len(df)):
1444
+ df[k]['filenames'] = [k.strip() for k in df[k]['filenames']]
1445
+ else:
1446
+ df['filenames'] = [k.strip() for k in df['filenames']]
1447
+
1448
+ if split:
1449
+ for k in range(len(df)):
1450
+ df[k]['filenames'] = [os.getcwd()+os.sep+f.replace('\\',os.sep) for f in df[k]['filenames']]
1451
+ else:
1452
+ df['filenames'] = [os.getcwd()+os.sep+f.replace('\\',os.sep) for f in df['filenames']]
1453
+
1454
+ np.random.seed(2021)
1455
+ if type(df)==list:
1456
+ idx = [np.random.permutation(len(d)) for d in df]
1457
+ else:
1458
+ idx = np.random.permutation(len(df))
1459
+
1460
+ return idx, df, split
1461
+
1462
+
1463
+
1464
+
1465
+ #
1466
+ # ###===================================================
1467
+ # def predict_test_siso_simo(a, b, vars,
1468
+ # SM, weights_path, name, mode, greyscale,
1469
+ # CS, dropout, scale, do_aug, standardize,
1470
+ # count_in):
1471
+ #
1472
+ # #
1473
+ # # ## make predictions on testing data
1474
+ # # if d is not None:
1475
+ # # do_aug = False
1476
+ # # for_training = False
1477
+ # # # test_gen = get_data_generator_Nvars_siso_simo(test_df, test_idx, for_training,
1478
+ # # # vars, len(test_idx), greyscale, CS, do_aug, standardize) #np.min((200, len(test_idx)))
1479
+ # # #
1480
+ # # # x_test, vals = next(test_gen)
1481
+ # #
1482
+ # # if scale == True:
1483
+ # #
1484
+ # # if len(vars)>1:
1485
+ # # counter = 0
1486
+ # # for v in vars:
1487
+ # # exec(
1488
+ # # v+\
1489
+ # # '_true = np.squeeze(CS[counter].inverse_transform(d[counter].reshape(-1,1)))'
1490
+ # # )
1491
+ # # counter +=1
1492
+ # # else:
1493
+ # # exec(
1494
+ # # vars[0]+\
1495
+ # # '_true = np.squeeze(CS[0].inverse_transform(d[0].reshape(-1,1)))'
1496
+ # # )
1497
+ # #
1498
+ # # else:
1499
+ # # if len(vars)>1:
1500
+ # # counter = 0
1501
+ # # for v in vars:
1502
+ # # exec(
1503
+ # # v+\
1504
+ # # '_true = np.squeeze(d[counter])'
1505
+ # # )
1506
+ # # counter +=1
1507
+ # # else:
1508
+ # # exec(
1509
+ # # vars[0]+\
1510
+ # # '_true = np.squeeze(d)'
1511
+ # # )
1512
+ # #
1513
+ # #
1514
+ # # del d
1515
+ # #
1516
+ # # for v in vars:
1517
+ # # exec(v+'_P = []')
1518
+ # #
1519
+ # # if scale == True:
1520
+ # #
1521
+ # # if type(SM) == list:
1522
+ # # #counter = 0
1523
+ # # for s in SM:
1524
+ # # test_vals=s.predict(c, batch_size=8)
1525
+ # #
1526
+ # # if len(vars)>1:
1527
+ # # counter = 0
1528
+ # # for v in vars:
1529
+ # # exec(
1530
+ # # v+\
1531
+ # # '_P.append(np.squeeze(CS[counter].inverse_transform(test_vals[counter].reshape(-1,1))))'
1532
+ # # )
1533
+ # # counter +=1
1534
+ # # else:
1535
+ # # exec(
1536
+ # # vars[0]+\
1537
+ # # '_P.append(np.asarray(np.squeeze(CS[0].inverse_transform(test_vals.reshape(-1,1)))))'
1538
+ # # )
1539
+ # #
1540
+ # # del test_vals
1541
+ # #
1542
+ # # if len(vars)>1:
1543
+ # # #counter = 0
1544
+ # # for v in vars:
1545
+ # # exec(
1546
+ # # v+\
1547
+ # # '_P = np.median('+v+'_P, axis=0)'
1548
+ # # )
1549
+ # # #counter +=1
1550
+ # # else:
1551
+ # # exec(
1552
+ # # vars[0]+\
1553
+ # # '_P = np.median('+v+'_P, axis=0)'
1554
+ # # )
1555
+ # #
1556
+ # # else:
1557
+ # #
1558
+ # # test_vals = SM.predict(c, batch_size=8) #128)
1559
+ # # if len(vars)>1:
1560
+ # # counter = 0
1561
+ # # for v in vars:
1562
+ # # exec(
1563
+ # # v+\
1564
+ # # '_P.append(np.squeeze(CS[counter].inverse_transform(test_vals[counter].reshape(-1,1))))'
1565
+ # # )
1566
+ # # counter +=1
1567
+ # # else:
1568
+ # # exec(
1569
+ # # vars[0]+\
1570
+ # # '_P.append(np.asarray(np.squeeze(CS[0].inverse_transform(test_vals.reshape(-1,1)))))'
1571
+ # # )
1572
+ # #
1573
+ # # del test_vals
1574
+ # #
1575
+ # #
1576
+ # # else: #no scale
1577
+ # #
1578
+ # # if type(SM) == list:
1579
+ # # counter = 0
1580
+ # # for s in SM:
1581
+ # # test_vals=s.predict(c, batch_size=8)
1582
+ # #
1583
+ # # if len(vars)>1:
1584
+ # # counter = 0
1585
+ # # for v in vars:
1586
+ # # exec(
1587
+ # # v+\
1588
+ # # '_P.append(np.squeeze(test_vals[counter]))'
1589
+ # # )
1590
+ # # counter +=1
1591
+ # # else:
1592
+ # # exec(
1593
+ # # vars[0]+\
1594
+ # # '_P.append(np.asarray(np.squeeze(test_vals)))'
1595
+ # # )
1596
+ # #
1597
+ # # del test_vals
1598
+ # #
1599
+ # # if len(vars)>1:
1600
+ # # #counter = 0
1601
+ # # for v in vars:
1602
+ # # exec(
1603
+ # # v+\
1604
+ # # '_P = np.median('+v+'_P, axis=0)'
1605
+ # # )
1606
+ # # #counter +=1
1607
+ # # else:
1608
+ # # exec(
1609
+ # # vars[0]+\
1610
+ # # '_P = np.median('+v+'_P, axis=0)'
1611
+ # # )
1612
+ # #
1613
+ # # else:
1614
+ # #
1615
+ # # test_vals = SM.predict(c, batch_size=8) #128)
1616
+ # # if len(vars)>1:
1617
+ # # counter = 0
1618
+ # # for v in vars:
1619
+ # # exec(
1620
+ # # v+\
1621
+ # # '_P.append(np.squeeze(test_vals[counter]))'
1622
+ # # )
1623
+ # # counter +=1
1624
+ # # else:
1625
+ # # exec(
1626
+ # # vars[0]+\
1627
+ # # '_P.append(np.asarray(np.squeeze(test_vals)))'
1628
+ # # )
1629
+ # #
1630
+ # # # del test_vals
1631
+ # #
1632
+ # #
1633
+ # # del c #test_gen,
1634
+ #
1635
+ # # if len(vars)>1:
1636
+ # # for k in range(len(vars)):
1637
+ # # exec(vars[k]+'_pred = np.squeeze(np.asarray('+vars[k]+'_P))')
1638
+ # # else:
1639
+ # # exec(vars[0]+'_pred = np.squeeze(np.asarray('+vars[0]+'_P))')
1640
+ # #
1641
+ # # for v in vars:
1642
+ # # exec('del '+v+'_P')
1643
+ #
1644
+ # # ## write out results to text files
1645
+ # # if len(vars)>1:
1646
+ # # for k in range(len(vars)):
1647
+ # # exec('np.savetxt("'+name+'_test'+vars[k]+'.txt", ('+vars[k]+'_pred))') #','+vars[k]+'_true))')
1648
+ # # exec('np.savetxt("'+name+'_train'+vars[k]+'.txt", ('+vars[k]+'_predT))') #,'+vars[k]+'_trueT))')
1649
+ # # np.savetxt(name+"_test_files.txt", np.asarray(test_df.files.values), fmt="%s")
1650
+ # # np.savetxt(name+"_train_files.txt", np.asarray(train_df.files.values), fmt="%s")
1651
+ # #
1652
+ # # else:
1653
+ # # exec('np.savetxt("'+name+'_test'+vars[0]+'.txt", ('+vars[0]+'_pred))') #','+vars[k]+'_true))')
1654
+ # # exec('np.savetxt("'+name+'_train'+vars[0]+'.txt", ('+vars[0]+'_predT))') #,'+vars[k]+'_trueT))')
1655
+ # # np.savetxt(name+"_test_files.txt", np.asarray(test_df.files.values), fmt="%s")
1656
+ # # np.savetxt(name+"_train_files.txt", np.asarray(train_df.files.values), fmt="%s")
1657
+ #
1658
+ # if len(vars)==9:
1659
+ # nrows = 3; ncols = 3
1660
+ # elif len(vars)==8:
1661
+ # nrows = 2; ncols = 4
1662
+ # elif len(vars)==7:
1663
+ # nrows = 3; ncols = 3
1664
+ # elif len(vars)==6:
1665
+ # nrows = 3; ncols = 2
1666
+ # elif len(vars)==5:
1667
+ # nrows = 3; ncols = 2
1668
+ # elif len(vars)==4:
1669
+ # nrows = 2; ncols = 2
1670
+ # elif len(vars)==3:
1671
+ # nrows = 2; ncols = 2
1672
+ # elif len(vars)==2:
1673
+ # nrows = 1; ncols = 2
1674
+ # elif len(vars)==1:
1675
+ # nrows = 1; ncols = 1
1676
+ #
1677
+ # out = dict()
1678
+ #
1679
+ # ## make a plot
1680
+ # fig = plt.figure(figsize=(6*nrows,6*ncols))
1681
+ # labs = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
1682
+ # for k in range(1,1+(nrows*ncols)):
1683
+ # # try:
1684
+ # plt.subplot(nrows,ncols,k)
1685
+ # x1 = eval(vars[k-1]+'_trueT')
1686
+ # y1 = eval(vars[k-1]+'_predT')
1687
+ # out[vars[k-1]+'_trueT'] = eval(vars[k-1]+'_trueT')
1688
+ # out[vars[k-1]+'_predT'] = eval(vars[k-1]+'_predT')
1689
+ #
1690
+ # # z = np.polyfit(y1,x1, 1)
1691
+ # # Z.append(z)
1692
+ # #
1693
+ # # y1 = np.polyval(z,y1)
1694
+ # # y1 = np.abs(y1)
1695
+ #
1696
+ # plt.plot(x1, y1, 'ko', markersize=5)
1697
+ # plt.plot([ np.min(np.hstack((x1,y1))), np.max(np.hstack((x1,y1)))],
1698
+ # [ np.min(np.hstack((x1,y1))), np.max(np.hstack((x1,y1)))],
1699
+ # 'k', lw=2)
1700
+ # plt.text(np.nanmin(x1), 0.7*np.max(np.hstack((x1,y1))),'Train : '+\
1701
+ # str(np.nanmean(100*(np.abs(y1-x1) / x1)))[:5]+\
1702
+ # ' %', fontsize=10)
1703
+ #
1704
+ # # if test_vals is not None:
1705
+ # # x2 = eval(vars[k-1]+'_true')
1706
+ # # y2 = eval(vars[k-1]+'_pred')
1707
+ # # # y2 = np.abs(np.polyval(z,y2))
1708
+ # #
1709
+ # # plt.plot(x2, y2, 'bx', markersize=5)
1710
+ # #
1711
+ # # if test_vals is not None:
1712
+ # # plt.text(np.nanmin(x2), 0.75*np.max(np.hstack((x2,y2))),'Test : '+\
1713
+ # # str(np.mean(100*(np.abs(y2-x2) / x2)))[:5]+\
1714
+ # # ' %', fontsize=10, color='b')
1715
+ # # else:
1716
+ # # plt.text(np.nanmin(x1), 0.7*np.max(np.hstack((x1,y1))),''+\
1717
+ # # str(np.mean(100*(np.abs(y1-x1) / x1)))[:5]+\
1718
+ # # ' %', fontsize=10)
1719
+ # plt.title(r''+labs[k-1]+') '+vars[k-1], fontsize=8, loc='left')
1720
+ #
1721
+ # #varstring = ''.join([str(k)+'_' for k in vars])
1722
+ # varstring = str(len(vars))+'vars'
1723
+ #
1724
+ # # except:
1725
+ # # pass
1726
+ # if type(SM) == list:
1727
+ # plt.savefig(weights_path[0].replace('.hdf5', '_skill_ensemble'+str(count_in)+'.png').\
1728
+ # replace('batch','_'.join(np.asarray(BATCH_SIZE, dtype='str'))),
1729
+ # dpi=300, bbox_inches='tight')
1730
+ #
1731
+ # else:
1732
+ # plt.savefig(weights_path.replace('.hdf5', '_skill'+str(count_in)+'.png'),
1733
+ # dpi=300, bbox_inches='tight')
1734
+ #
1735
+ # plt.close()
1736
+ # del fig
1737
+ #
1738
+ # np.savez_compressed(weights_path.replace('.hdf5', '_out'+str(count_in)+'.npz'),**out)
1739
+ # del out
1740
+
1741
+ #
1742
+ # ###===================================================
1743
+ # def predict_test_train_miso_mimo(train_df, test_df, train_idx, test_idx,
1744
+ # vars, auxin, SM, weights_path, name, mode,
1745
+ # greyscale, CS, CSaux):
1746
+ # """
1747
+ # This function creates makes predcitions on test and train data
1748
+ # """
1749
+ # ##==============================================
1750
+ # ## make predictions on training data
1751
+ #
1752
+ # SM.load_weights(weights_path)
1753
+ #
1754
+ # train_gen = get_data_generator_Nvars_miso_mimo(train_df, train_idx, False,
1755
+ # vars, auxin,aux_mean, aux_std, len(train_idx), greyscale)
1756
+ #
1757
+ # x_train, tmp = next(train_gen)
1758
+ #
1759
+ # if len(vars)>1:
1760
+ # counter = 0
1761
+ # for v in vars:
1762
+ # exec(
1763
+ # v+\
1764
+ # '_trueT = np.squeeze(CS[counter].inverse_transform(tmp[counter].reshape(-1,1)))'
1765
+ # )
1766
+ # counter +=1
1767
+ # else:
1768
+ # exec(
1769
+ # vars[0]+\
1770
+ # '_trueT = np.squeeze(CS[0].inverse_transform(tmp[0].reshape(-1,1)))'
1771
+ # )
1772
+ #
1773
+ # for v in vars:
1774
+ # exec(v+'_PT = []')
1775
+ #
1776
+ # del tmp
1777
+ # tmp = SM.predict(x_train, batch_size=8) #128)
1778
+ # if len(vars)>1:
1779
+ # counter = 0
1780
+ # for v in vars:
1781
+ # exec(
1782
+ # v+\
1783
+ # '_PT.append(np.squeeze(CS[counter].inverse_transform(tmp[counter].reshape(-1,1))))'
1784
+ # )
1785
+ # counter +=1
1786
+ # else:
1787
+ # exec(
1788
+ # vars[0]+\
1789
+ # '_PT.append(np.asarray(np.squeeze(CS[0].inverse_transform(tmp.reshape(-1,1)))))'
1790
+ # )
1791
+ #
1792
+ #
1793
+ # if len(vars)>1:
1794
+ # for k in range(len(vars)):
1795
+ # exec(
1796
+ # vars[k]+\
1797
+ # '_predT = np.squeeze(np.mean(np.asarray('+vars[k]+'_PT), axis=0))'
1798
+ # )
1799
+ # else:
1800
+ # exec(
1801
+ # vars[0]+\
1802
+ # '_predT = np.squeeze(np.mean(np.asarray('+vars[0]+'_PT), axis=0))'
1803
+ # )
1804
+ #
1805
+ # ## make predictions on testing data
1806
+ # test_gen = get_data_generator_Nvars_miso_mimo(test_df, test_idx, False,
1807
+ # vars, auxin, aux_mean, aux_std, len(test_idx), greyscale)
1808
+ #
1809
+ # del tmp
1810
+ # x_test, tmp = next(test_gen)
1811
+ # if len(vars)>1:
1812
+ # counter = 0
1813
+ # for v in vars:
1814
+ # exec(v+\
1815
+ # '_true = np.squeeze(CS[counter].inverse_transform(tmp[counter].reshape(-1,1)))'
1816
+ # )
1817
+ # counter +=1
1818
+ # else:
1819
+ # exec(vars[0]+\
1820
+ # '_true = np.squeeze(CS[0].inverse_transform(tmp[0].reshape(-1,1)))'
1821
+ # )
1822
+ #
1823
+ # for v in vars:
1824
+ # exec(v+'_P = []')
1825
+ #
1826
+ # del tmp
1827
+ # tmp = SM.predict(x_test, batch_size=8) #128)
1828
+ # if len(vars)>1:
1829
+ # counter = 0
1830
+ # for v in vars:
1831
+ # exec(
1832
+ # v+\
1833
+ # '_P.append(np.squeeze(CS[counter].inverse_transform(tmp[counter].reshape(-1,1))))'
1834
+ # )
1835
+ # counter +=1
1836
+ # else:
1837
+ # exec(
1838
+ # vars[0]+\
1839
+ # '_P.append(np.asarray(np.squeeze(CS[0].inverse_transform(tmp.reshape(-1,1)))))'
1840
+ # )
1841
+ #
1842
+ # if len(vars)>1:
1843
+ # for k in range(len(vars)):
1844
+ # exec(
1845
+ # vars[k]+\
1846
+ # '_pred = np.squeeze(np.mean(np.asarray('+vars[k]+'_P), axis=0))'
1847
+ # )
1848
+ # else:
1849
+ # exec(
1850
+ # vars[0]+\
1851
+ # '_pred = np.squeeze(np.mean(np.asarray('+vars[0]+'_P), axis=0))'
1852
+ # )
1853
+ #
1854
+ #
1855
+ # if len(vars)==9:
1856
+ # nrows = 3; ncols = 3
1857
+ # elif len(vars)==8:
1858
+ # nrows = 4; ncols = 2
1859
+ # elif len(vars)==7:
1860
+ # nrows = 4; ncols = 2
1861
+ # elif len(vars)==6:
1862
+ # nrows = 3; ncols = 2
1863
+ # elif len(vars)==5:
1864
+ # nrows = 3; ncols = 2
1865
+ # elif len(vars)==4:
1866
+ # nrows = 2; ncols = 2
1867
+ # elif len(vars)==3:
1868
+ # nrows = 3; ncols = 1
1869
+ # elif len(vars)==2:
1870
+ # nrows = 2; ncols = 1
1871
+ # elif len(vars)==1:
1872
+ # nrows = 1; ncols = 1
1873
+ #
1874
+ # ## make a plot
1875
+ # fig = plt.figure(figsize=(4*nrows,4*ncols))
1876
+ # labs = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
1877
+ # for k in range(1,1+(nrows*ncols)):
1878
+ # plt.subplot(nrows,ncols,k)
1879
+ # x = eval(vars[k-1]+'_trueT')
1880
+ # y = eval(vars[k-1]+'_predT')
1881
+ # plt.plot(x, y, 'ko', markersize=5)
1882
+ # plt.plot(eval(vars[k-1]+'_true'), eval(vars[k-1]+'_pred'),
1883
+ # 'bx', markersize=5)
1884
+ # plt.plot([ np.min(np.hstack((x,y))), np.max(np.hstack((x,y)))],
1885
+ # [ np.min(np.hstack((x,y))), np.max(np.hstack((x,y)))], 'k', lw=2)
1886
+ #
1887
+ # plt.text(np.nanmin(x), 0.96*np.max(np.hstack((x,y))),'Test : '+\
1888
+ # str(np.mean(100*(np.abs(eval(vars[k-1]+'_pred') -\
1889
+ # eval(vars[k-1]+'_true')) / eval(vars[k-1]+'_true'))))[:5]+\
1890
+ # ' %', fontsize=8, color='b')
1891
+ # plt.text(np.nanmin(x), np.max(np.hstack((x,y))),'Train : '+\
1892
+ # str(np.mean(100*(np.abs(eval(vars[k-1]+'_predT') -\
1893
+ # eval(vars[k-1]+'_trueT')) / eval(vars[k-1]+'_trueT'))))[:5]+\
1894
+ # ' %', fontsize=8)
1895
+ # plt.title(r''+labs[k-1]+') '+vars[k-1], fontsize=8, loc='left')
1896
+ #
1897
+ # varstring = ''.join([str(k)+'_' for k in vars])
1898
+ #
1899
+ # plt.savefig(weights_path.replace('.hdf5', '_skill.png'),
1900
+ # dpi=300, bbox_inches='tight')
1901
+ # plt.close()
1902
+ # del fig
1903
+ #
1904
+
1905
+ #
1906
+ # ###===================================================
1907
+ # def get_data_generator_Nvars_miso_mimo(df, indices, for_training, vars, auxin,
1908
+ # batch_size, greyscale, CS, CSaux): ##BATCH_SIZE
1909
+ # """
1910
+ # This function generates data for a batch of images and 1 auxilliary variable,
1911
+ # and N associated output metrics
1912
+ # """
1913
+ # if len(vars)==1:
1914
+ # images, a, p1s = [], [], []
1915
+ # elif len(vars)==2:
1916
+ # images, a, p1s, p2s = [], [], [], []
1917
+ # elif len(vars)==3:
1918
+ # images, a, p1s, p2s, p3s = [], [], [], [], []
1919
+ # elif len(vars)==4:
1920
+ # images, a, p1s, p2s, p3s, p4s = [], [], [], [], [], []
1921
+ # elif len(vars)==5:
1922
+ # images, a, p1s, p2s, p3s, p4s, p5s = [], [], [], [], [], [], []
1923
+ # elif len(vars)==6:
1924
+ # images, a, p1s, p2s, p3s, p4s, p5s, p6s = \
1925
+ # [], [], [], [], [], [], [], []
1926
+ # elif len(vars)==7:
1927
+ # images, a, p1s, p2s, p3s, p4s, p5s, p6s, p7s = \
1928
+ # [], [], [], [], [], [], [], [], []
1929
+ # elif len(vars)==8:
1930
+ # images, a, p1s, p2s, p3s, p4s, p5s, p6s, p7s, p8s = \
1931
+ # [], [], [], [], [], [], [], [], [], []
1932
+ # elif len(vars)==9:
1933
+ # images, a, p1s, p2s, p3s, p4s, p5s, p6s, p7s, p8s, p9s = \
1934
+ # [], [], [], [], [], [], [], [], [], [], []
1935
+ #
1936
+ # while True:
1937
+ # for i in indices:
1938
+ # r = df.iloc[i]
1939
+ # if len(vars)==1:
1940
+ # file, p1, aa = r['files'], r[vars[0]], r[auxin]
1941
+ # if len(vars)==2:
1942
+ # file, p1, p2, aa = \
1943
+ # r['files'], r[vars[0]], r[vars[1]], r[auxin]
1944
+ # if len(vars)==3:
1945
+ # file, p1, p2, p3, aa = \
1946
+ # r['files'], r[vars[0]], r[vars[1]], r[vars[2]], r[auxin]
1947
+ # if len(vars)==4:
1948
+ # file, p1, p2, p3, p4, aa = \
1949
+ # r['files'], r[vars[0]], r[vars[1]], r[vars[2]], r[vars[3]], r[auxin]
1950
+ # if len(vars)==5:
1951
+ # file, p1, p2, p3, p4, p5, aa = \
1952
+ # r['files'], r[vars[0]], r[vars[1]], r[vars[2]], r[vars[3]], r[vars[4]], r[auxin]
1953
+ # if len(vars)==6:
1954
+ # file, p1, p2, p3, p4, p5, p6, aa = \
1955
+ # r['files'], r[vars[0]], r[vars[1]], r[vars[2]], r[vars[3]], r[vars[4]], r[vars[5]], r[auxin]
1956
+ # if len(vars)==7:
1957
+ # file, p1, p2, p3, p4, p5, p6, p7, aa =\
1958
+ # r['files'], r[vars[0]], r[vars[1]], r[vars[2]], r[vars[3]], r[vars[4]], r[vars[5]], r[vars[6]], r[auxin]
1959
+ # if len(vars)==8:
1960
+ # file, p1, p2, p3, p4, p5, p6, p7, p8, aa = \
1961
+ # r['files'], r[vars[0]], r[vars[1]], r[vars[2]], r[vars[3]], r[vars[4]], r[vars[5]], r[vars[6]], r[vars[7]], r[auxin]
1962
+ # elif len(vars)==9:
1963
+ # file, p1, p2, p3, p4, p5, p6, p7, p8, p9, aa = \
1964
+ # r['files'], r[vars[0]], r[vars[1]], r[vars[2]], r[vars[3]], r[vars[4]], r[vars[5]], r[vars[6]], r[vars[7]], r[vars[8]], r[auxin]
1965
+ #
1966
+ # if greyscale==True:
1967
+ # im = Image.open(file).convert('LA')
1968
+ # else:
1969
+ # im = Image.open(file)
1970
+ # im = im.resize((IM_HEIGHT, IM_HEIGHT))
1971
+ # im = np.array(im) / 255.0
1972
+ #
1973
+ # if np.ndim(im)==2:
1974
+ # im = np.dstack((im, im , im)) ##np.expand_dims(im[:,:,0], axis=2)
1975
+ #
1976
+ # im = im[:,:,:3]
1977
+ #
1978
+ # if greyscale==True:
1979
+ # images.append(np.expand_dims(im, axis=2))
1980
+ # else:
1981
+ # images.append(im)
1982
+ #
1983
+ # if len(vars)==1:
1984
+ # p1s.append(p1); a.append(aa)
1985
+ # elif len(vars)==2:
1986
+ # p1s.append(p1); p2s.append(p2); a.append(aa)
1987
+ # elif len(vars)==3:
1988
+ # p1s.append(p1); p2s.append(p2); a.append(aa)
1989
+ # p3s.append(p3);
1990
+ # elif len(vars)==4:
1991
+ # p1s.append(p1); p2s.append(p2); a.append(aa)
1992
+ # p3s.append(p3); p4s.append(p4)
1993
+ # elif len(vars)==5:
1994
+ # p1s.append(p1); p2s.append(p2); a.append(aa)
1995
+ # p3s.append(p3); p4s.append(p4)
1996
+ # p5s.append(p5);
1997
+ # elif len(vars)==6:
1998
+ # p1s.append(p1); p2s.append(p2); a.append(aa)
1999
+ # p3s.append(p3); p4s.append(p4)
2000
+ # p5s.append(p5); p6s.append(p6)
2001
+ # elif len(vars)==7:
2002
+ # p1s.append(p1); p2s.append(p2); a.append(aa)
2003
+ # p3s.append(p3); p4s.append(p4)
2004
+ # p5s.append(p5); p6s.append(p6)
2005
+ # p7s.append(p7);
2006
+ # elif len(vars)==8:
2007
+ # p1s.append(p1); p2s.append(p2); a.append(aa)
2008
+ # p3s.append(p3); p4s.append(p4)
2009
+ # p5s.append(p5); p6s.append(p6)
2010
+ # p7s.append(p7); p8s.append(p8)
2011
+ # elif len(vars)==9:
2012
+ # p1s.append(p1); p2s.append(p2); a.append(aa)
2013
+ # p3s.append(p3); p4s.append(p4)
2014
+ # p5s.append(p5); p6s.append(p6)
2015
+ # p7s.append(p7); p8s.append(p8)
2016
+ # p9s.append(p9)
2017
+ #
2018
+ #
2019
+ # if len(images) >= batch_size:
2020
+ # if len(vars)==1:
2021
+ # p1s = np.squeeze(CS[0].transform(np.array(p1s).reshape(-1, 1)))
2022
+ # a = np.squeeze(CSaux[0].transform(np.array(a).reshape(-1, 1)))
2023
+ # yield [np.array(a), np.array(images)], [np.array(p1s)]
2024
+ # images, a, p1s = [], [], []
2025
+ # elif len(vars)==2:
2026
+ # p1s = np.squeeze(CS[0].transform(np.array(p1s).reshape(-1, 1)))
2027
+ # p2s = np.squeeze(CS[1].transform(np.array(p2s).reshape(-1, 1)))
2028
+ # a = np.squeeze(CSaux[0].transform(np.array(a).reshape(-1, 1)))
2029
+ # yield [np.array(a), np.array(images)],[np.array(p1s), np.array(p2s)]
2030
+ # images, a, p1s, p2s = [], [], [], []
2031
+ # elif len(vars)==3:
2032
+ # p1s = np.squeeze(CS[0].transform(np.array(p1s).reshape(-1, 1)))
2033
+ # p2s = np.squeeze(CS[1].transform(np.array(p2s).reshape(-1, 1)))
2034
+ # p3s = np.squeeze(CS[2].transform(np.array(p3s).reshape(-1, 1)))
2035
+ # a = np.squeeze(CSaux[0].transform(np.array(a).reshape(-1, 1)))
2036
+ # yield [np.array(a), np.array(images)],[np.array(p1s), np.array(p2s), np.array(p3s)]
2037
+ # images, a, p1s, p2s, p3s = [], [], [], [], []
2038
+ # elif len(vars)==4:
2039
+ # p1s = np.squeeze(CS[0].transform(np.array(p1s).reshape(-1, 1)))
2040
+ # p2s = np.squeeze(CS[1].transform(np.array(p2s).reshape(-1, 1)))
2041
+ # p3s = np.squeeze(CS[2].transform(np.array(p3s).reshape(-1, 1)))
2042
+ # p4s = np.squeeze(CS[3].transform(np.array(p4s).reshape(-1, 1)))
2043
+ # a = np.squeeze(CSaux[0].transform(np.array(a).reshape(-1, 1)))
2044
+ # yield [np.array(a), np.array(images)],[np.array(p1s), np.array(p2s), np.array(p3s), np.array(p4s)]
2045
+ # images, a, p1s, p2s, p3s, p4s = [], [], [], [], [], []
2046
+ # elif len(vars)==5:
2047
+ # p1s = np.squeeze(CS[0].transform(np.array(p1s).reshape(-1, 1)))
2048
+ # p2s = np.squeeze(CS[1].transform(np.array(p2s).reshape(-1, 1)))
2049
+ # p3s = np.squeeze(CS[2].transform(np.array(p3s).reshape(-1, 1)))
2050
+ # p4s = np.squeeze(CS[3].transform(np.array(p4s).reshape(-1, 1)))
2051
+ # p5s = np.squeeze(CS[4].transform(np.array(p5s).reshape(-1, 1)))
2052
+ # a = np.squeeze(CSaux[0].transform(np.array(a).reshape(-1, 1)))
2053
+ # yield [np.array(a), np.array(images)],[np.array(p1s), np.array(p2s), np.array(p3s),
2054
+ # np.array(p4s), np.array(p5s)]
2055
+ # images, a, p1s, p2s, p3s, p4s, p5s = \
2056
+ # [], [], [], [], [], [], []
2057
+ # elif len(vars)==6:
2058
+ # p1s = np.squeeze(CS[0].transform(np.array(p1s).reshape(-1, 1)))
2059
+ # p2s = np.squeeze(CS[1].transform(np.array(p2s).reshape(-1, 1)))
2060
+ # p3s = np.squeeze(CS[2].transform(np.array(p3s).reshape(-1, 1)))
2061
+ # p4s = np.squeeze(CS[3].transform(np.array(p4s).reshape(-1, 1)))
2062
+ # p5s = np.squeeze(CS[4].transform(np.array(p5s).reshape(-1, 1)))
2063
+ # p6s = np.squeeze(CS[5].transform(np.array(p6s).reshape(-1, 1)))
2064
+ # a = np.squeeze(CSaux[0].transform(np.array(a).reshape(-1, 1)))
2065
+ # yield [np.array(a), np.array(images)],[np.array(p1s), np.array(p2s), np.array(p3s),
2066
+ # np.array(p4s), np.array(p5s), np.array(p6s)]
2067
+ # images, a, p1s, p2s, p3s, p4s, p5s, p6s = \
2068
+ # [], [], [], [], [], [], [], []
2069
+ # elif len(vars)==7:
2070
+ # p1s = np.squeeze(CS[0].transform(np.array(p1s).reshape(-1, 1)))
2071
+ # p2s = np.squeeze(CS[1].transform(np.array(p2s).reshape(-1, 1)))
2072
+ # p3s = np.squeeze(CS[2].transform(np.array(p3s).reshape(-1, 1)))
2073
+ # p4s = np.squeeze(CS[3].transform(np.array(p4s).reshape(-1, 1)))
2074
+ # p5s = np.squeeze(CS[4].transform(np.array(p5s).reshape(-1, 1)))
2075
+ # p6s = np.squeeze(CS[5].transform(np.array(p6s).reshape(-1, 1)))
2076
+ # p7s = np.squeeze(CS[6].transform(np.array(p7s).reshape(-1, 1)))
2077
+ # a = np.squeeze(CSaux[0].transform(np.array(a).reshape(-1, 1)))
2078
+ # yield [np.array(a), np.array(images)],[np.array(p1s), np.array(p2s), np.array(p3s),
2079
+ # np.array(p4s), np.array(p5s), np.array(p6s), np.array(p7s)]
2080
+ # images, a, p1s, p2s, p3s, p4s, p5s, p6s, p7s = \
2081
+ # [], [], [], [], [], [], [], [], []
2082
+ # elif len(vars)==8:
2083
+ # p1s = np.squeeze(CS[0].transform(np.array(p1s).reshape(-1, 1)))
2084
+ # p2s = np.squeeze(CS[1].transform(np.array(p2s).reshape(-1, 1)))
2085
+ # p3s = np.squeeze(CS[2].transform(np.array(p3s).reshape(-1, 1)))
2086
+ # p4s = np.squeeze(CS[3].transform(np.array(p4s).reshape(-1, 1)))
2087
+ # p5s = np.squeeze(CS[4].transform(np.array(p5s).reshape(-1, 1)))
2088
+ # p6s = np.squeeze(CS[5].transform(np.array(p6s).reshape(-1, 1)))
2089
+ # p7s = np.squeeze(CS[6].transform(np.array(p7s).reshape(-1, 1)))
2090
+ # p8s = np.squeeze(CS[7].transform(np.array(p8s).reshape(-1, 1)))
2091
+ # a = np.squeeze(CSaux[0].transform(np.array(a).reshape(-1, 1)))
2092
+ # yield [np.array(a), np.array(images)],[np.array(p1s), np.array(p2s), np.array(p3s),
2093
+ # np.array(p4s), np.array(p5s), np.array(p6s),
2094
+ # np.array(p7s), np.array(p8s)]
2095
+ # images, a, p1s, p2s, p3s, p4s, p5s, p6s, p7s, p8s = \
2096
+ # [], [], [], [], [], [], [], [], [], []
2097
+ # elif len(vars)==9:
2098
+ # p1s = np.squeeze(CS[0].transform(np.array(p1s).reshape(-1, 1)))
2099
+ # p2s = np.squeeze(CS[1].transform(np.array(p2s).reshape(-1, 1)))
2100
+ # p3s = np.squeeze(CS[2].transform(np.array(p3s).reshape(-1, 1)))
2101
+ # p4s = np.squeeze(CS[3].transform(np.array(p4s).reshape(-1, 1)))
2102
+ # p5s = np.squeeze(CS[4].transform(np.array(p5s).reshape(-1, 1)))
2103
+ # p6s = np.squeeze(CS[5].transform(np.array(p6s).reshape(-1, 1)))
2104
+ # p7s = np.squeeze(CS[6].transform(np.array(p7s).reshape(-1, 1)))
2105
+ # p8s = np.squeeze(CS[7].transform(np.array(p8s).reshape(-1, 1)))
2106
+ # p9s = np.squeeze(CS[8].transform(np.array(p9s).reshape(-1, 1)))
2107
+ # a = np.squeeze(CSaux[0].transform(np.array(a).reshape(-1, 1)))
2108
+ # try:
2109
+ # yield [np.array(a), np.array(images)],[np.array(p1s), np.array(p2s), np.array(p3s),
2110
+ # np.array(p4s), np.array(p5s), np.array(p6s),
2111
+ # np.array(p7s), np.array(p8s), np.array(p9s)]
2112
+ # except GeneratorExit:
2113
+ # print(" ") #pass
2114
+ # images, a, p1s, p2s, p3s, p4s, p5s, p6s, p7s, p8s, p9s = \
2115
+ # [], [], [], [], [], [], [], [], [], [], []
2116
+ # if not for_training:
2117
+ # break
examples/20210208_172834_cropped.jpg ADDED

Git LFS Details

  • SHA256: 691a6e1691bc74a3ef953da2365896fa3ef8c9d4d5b96b119a0ca6a712665fa7
  • Pointer size: 132 Bytes
  • Size of remote file: 1.92 MB
examples/20220101_165359_cropped.jpg ADDED

Git LFS Details

  • SHA256: 381a942416b08d2f6b8e664ce27617e2d58528fc20b504e81fa8377f62898627
  • Pointer size: 132 Bytes
  • Size of remote file: 1 MB
examples/IMG_20210922_170908944_cropped.jpg ADDED

Git LFS Details

  • SHA256: 3cede9fc81a7e219e7afcab3c094f2804b431c2d9809d44f814651c2650992a7
  • Pointer size: 132 Bytes
  • Size of remote file: 2.23 MB
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ tensorflow
2
+ numpy
3
+ matplotlib
4
+ scikit-image
weights/config_usace_combined2021_2022_v12.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:54a50d316b63ca29d01277eadbaa098ee71435f69e7b57f5a9c0d2be80c8b282
3
+ size 597
weights/sandsnap_merged_1024_modelrevOct2022_v12_simo_batch10_im1024_9vars_mse_noaug.hdf5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6383827d51f659cb4fc433c9c2d75594708456a5e3fd15906be8821454c0c74e
3
+ size 1010376
weights/sandsnap_merged_1024_modelrevOct2022_v12_simo_batch10_im1024_9vars_mse_noaug.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:927bf368c23ef8c512e5c563bf953b0b1b04950d3eb6f2c69f4e72d32b1d0cfe
3
+ size 17855