BilalSardar commited on
Commit
bab1cc1
1 Parent(s): 13cbb3a

Upload 4 files

Browse files
Files changed (4) hide show
  1. audio.py +123 -0
  2. config.py +269 -0
  3. model.py +505 -0
  4. utils.py +335 -0
audio.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing audio helper functions.
2
+ """
3
+ import numpy as np
4
+
5
+ import config as cfg
6
+
7
+ RANDOM = np.random.RandomState(cfg.RANDOM_SEED)
8
+
9
+
10
+ def openAudioFile(path: str, sample_rate=48000, offset=0.0, duration=None):
11
+ """Open an audio file.
12
+
13
+ Opens an audio file with librosa and the given settings.
14
+
15
+ Args:
16
+ path: Path to the audio file.
17
+ sample_rate: The sample rate at which the file should be processed.
18
+ offset: The starting offset.
19
+ duration: Maximum duration of the loaded content.
20
+
21
+ Returns:
22
+ Returns the audio time series and the sampling rate.
23
+ """
24
+ # Open file with librosa (uses ffmpeg or libav)
25
+ import librosa
26
+
27
+ sig, rate = librosa.load(path, sr=sample_rate, offset=offset, duration=duration, mono=True, res_type="kaiser_fast")
28
+
29
+ return sig, rate
30
+
31
+
32
+ def get_sample_rate(path: str):
33
+ import librosa
34
+ return librosa.get_samplerate(path)
35
+
36
+
37
+ def saveSignal(sig, fname: str):
38
+ """Saves a signal to file.
39
+
40
+ Args:
41
+ sig: The signal to be saved.
42
+ fname: The file path.
43
+ """
44
+ import soundfile as sf
45
+
46
+ sf.write(fname, sig, 48000, "PCM_16")
47
+
48
+
49
+ def noise(sig, shape, amount=None):
50
+ """Creates noise.
51
+
52
+ Creates a noise vector with the given shape.
53
+
54
+ Args:
55
+ sig: The original audio signal.
56
+ shape: Shape of the noise.
57
+ amount: The noise intensity.
58
+
59
+ Returns:
60
+ An numpy array of noise with the given shape.
61
+ """
62
+ # Random noise intensity
63
+ if amount == None:
64
+ amount = RANDOM.uniform(0.1, 0.5)
65
+
66
+ # Create Gaussian noise
67
+ try:
68
+ noise = RANDOM.normal(min(sig) * amount, max(sig) * amount, shape)
69
+ except:
70
+ noise = np.zeros(shape)
71
+
72
+ return noise.astype("float32")
73
+
74
+
75
+ def splitSignal(sig, rate, seconds, overlap, minlen):
76
+ """Split signal with overlap.
77
+
78
+ Args:
79
+ sig: The original signal to be split.
80
+ rate: The sampling rate.
81
+ seconds: The duration of a segment.
82
+ overlap: The overlapping seconds of segments.
83
+ minlen: Minimum length of a split.
84
+
85
+ Returns:
86
+ A list of splits.
87
+ """
88
+ sig_splits = []
89
+
90
+ for i in range(0, len(sig), int((seconds - overlap) * rate)):
91
+ split = sig[i : i + int(seconds * rate)]
92
+
93
+ # End of signal?
94
+ if len(split) < int(minlen * rate) and len(sig_splits) > 0:
95
+ break
96
+
97
+ # Signal chunk too short?
98
+ if len(split) < int(rate * seconds):
99
+ split = np.hstack((split, noise(split, (int(rate * seconds) - len(split)), 0.5)))
100
+
101
+ sig_splits.append(split)
102
+
103
+ return sig_splits
104
+
105
+
106
+ def cropCenter(sig, rate, seconds):
107
+ """Crop signal to center.
108
+
109
+ Args:
110
+ sig: The original signal.
111
+ rate: The sampling rate.
112
+ seconds: The length of the signal.
113
+ """
114
+ if len(sig) > int(seconds * rate):
115
+ start = int((len(sig) - int(seconds * rate)) / 2)
116
+ end = start + int(seconds * rate)
117
+ sig = sig[start:end]
118
+
119
+ # Pad with noise
120
+ elif len(sig) < int(seconds * rate):
121
+ sig = np.hstack((sig, noise(sig, (int(seconds * rate) - len(sig)), 0.5)))
122
+
123
+ return sig
config.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################
2
+ # Misc settings #
3
+ #################
4
+
5
+ # Random seed for gaussian noise
6
+ RANDOM_SEED = 42
7
+
8
+ ##########################
9
+ # Model paths and config #
10
+ ##########################
11
+
12
+ MODEL_VESION = 'V2.4'
13
+ PB_MODEL = 'checkpoints/V2.4/BirdNET_GLOBAL_6K_V2.4_Model'
14
+ # MODEL_PATH = PB_MODEL # This will load the protobuf model
15
+ MODEL_PATH = 'checkpoints/V2.4/BirdNET_GLOBAL_6K_V2.4_Model_FP32.tflite'
16
+ MDATA_MODEL_PATH = 'checkpoints/V2.4/BirdNET_GLOBAL_6K_V2.4_MData_Model_FP16.tflite'
17
+ LABELS_FILE = 'checkpoints/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels.txt'
18
+ TRANSLATED_LABELS_PATH = 'labels/V2.4'
19
+
20
+ # Path to custom trained classifier
21
+ # If None, no custom classifier will be used
22
+ # Make sure to set the LABELS_FILE above accordingly
23
+ CUSTOM_CLASSIFIER = None
24
+
25
+ ##################
26
+ # Audio settings #
27
+ ##################
28
+
29
+ # We use a sample rate of 48kHz, so the model input size is
30
+ # (batch size, 48000 kHz * 3 seconds) = (1, 144000)
31
+ # Recordings will be resampled automatically.
32
+ SAMPLE_RATE: int = 48000
33
+
34
+ # We're using 3-second chunks
35
+ SIG_LENGTH: float = 3.0
36
+
37
+ # Define overlap between consecutive chunks <3.0; 0 = no overlap
38
+ SIG_OVERLAP: float = 0
39
+
40
+ # Define minimum length of audio chunk for prediction,
41
+ # chunks shorter than 3 seconds will be padded with zeros
42
+ SIG_MINLEN: float = 1.0
43
+
44
+ # Frequency range. This is model specific and should not be changed.
45
+ SIG_FMIN = 0
46
+ SIG_FMAX = 15000
47
+
48
+ #####################
49
+ # Metadata settings #
50
+ #####################
51
+
52
+ LATITUDE = -1
53
+ LONGITUDE = -1
54
+ WEEK = -1
55
+ LOCATION_FILTER_THRESHOLD = 0.03
56
+
57
+ ######################
58
+ # Inference settings #
59
+ ######################
60
+
61
+ # If None or empty file, no custom species list will be used
62
+ # Note: Entries in this list have to match entries from the LABELS_FILE
63
+ # We use the 2021 eBird taxonomy for species names (Clements list)
64
+ CODES_FILE = 'eBird_taxonomy_codes_2021E.json'
65
+ SPECIES_LIST_FILE = 'example/species_list.txt'
66
+
67
+ # File input path and output path for selection tables
68
+ INPUT_PATH: str = 'example/'
69
+ OUTPUT_PATH: str = 'example/'
70
+
71
+ ALLOWED_FILETYPES = ['wav', 'flac', 'mp3', 'ogg', 'm4a']
72
+
73
+ # Number of threads to use for inference.
74
+ # Can be as high as number of CPUs in your system
75
+ CPU_THREADS: int = 8
76
+ TFLITE_THREADS: int = 1
77
+
78
+ # False will output logits, True will convert to sigmoid activations
79
+ APPLY_SIGMOID: bool = True
80
+ SIGMOID_SENSITIVITY: float = 1.0
81
+
82
+ # Minimum confidence score to include in selection table
83
+ # (be aware: if APPLY_SIGMOID = False, this no longer represents
84
+ # probabilities and needs to be adjusted)
85
+ MIN_CONFIDENCE: float = 0.1
86
+
87
+ # Number of samples to process at the same time. Higher values can increase
88
+ # processing speed, but will also increase memory usage.
89
+ # Might only be useful for GPU inference.
90
+ BATCH_SIZE: int = 1
91
+
92
+ # Specifies the output format. 'table' denotes a Raven selection table,
93
+ # 'audacity' denotes a TXT file with the same format as Audacity timeline labels
94
+ # 'csv' denotes a CSV file with start, end, species and confidence.
95
+ RESULT_TYPE = 'table'
96
+
97
+ #####################
98
+ # Training settings #
99
+ #####################
100
+
101
+ # Training data path
102
+ TRAIN_DATA_PATH = 'train_data/'
103
+
104
+ # Sample crop mode
105
+ SAMPLE_CROP_MODE = 'center'
106
+
107
+ # List of non-event classes
108
+ NON_EVENT_CLASSES = ["noise", "other", "background", "silence"]
109
+
110
+ # Upsampling settings
111
+ UPSAMPLING_RATIO = 0.0
112
+ UPSAMPLING_MODE = 'repeat'
113
+
114
+ # Number of epochs to train for
115
+ TRAIN_EPOCHS: int = 100
116
+
117
+ # Batch size for training
118
+ TRAIN_BATCH_SIZE: int = 32
119
+
120
+ # Validation split (percentage)
121
+ TRAIN_VAL_SPLIT: float = 0.2
122
+
123
+ # Learning rate for training
124
+ TRAIN_LEARNING_RATE: float = 0.01
125
+
126
+ # Number of hidden units in custom classifier
127
+ # If >0, a two-layer classifier will be trained
128
+ TRAIN_HIDDEN_UNITS: int = 0
129
+
130
+ # Dropout rate for training
131
+ TRAIN_DROPOUT: float = 0.0
132
+
133
+ # Whether to use mixup for training
134
+ TRAIN_WITH_MIXUP: bool = False
135
+
136
+ # Whether to apply label smoothing for training
137
+ TRAIN_WITH_LABEL_SMOOTHING: bool = False
138
+
139
+ # Model output format
140
+ TRAINED_MODEL_OUTPUT_FORMAT = 'tflite'
141
+
142
+ # Cache settings
143
+ TRAIN_CACHE_MODE = 'none'
144
+ TRAIN_CACHE_FILE = 'train_cache.npz'
145
+
146
+ #####################
147
+ # Misc runtime vars #
148
+ #####################
149
+ CODES = {}
150
+ LABELS: list[str] = []
151
+ TRANSLATED_LABELS: list[str] = []
152
+ SPECIES_LIST: list[str] = []
153
+ ERROR_LOG_FILE: str = 'error_log.txt'
154
+ FILE_LIST = []
155
+ FILE_STORAGE_PATH = ''
156
+
157
+ ######################
158
+ # Get and set config #
159
+ ######################
160
+
161
+ def getConfig():
162
+ return {
163
+ 'RANDOM_SEED': RANDOM_SEED,
164
+ 'MODEL_PATH': MODEL_PATH,
165
+ 'MDATA_MODEL_PATH': MDATA_MODEL_PATH,
166
+ 'LABELS_FILE': LABELS_FILE,
167
+ 'CUSTOM_CLASSIFIER': CUSTOM_CLASSIFIER,
168
+ 'SAMPLE_RATE': SAMPLE_RATE,
169
+ 'SIG_LENGTH': SIG_LENGTH,
170
+ 'SIG_OVERLAP': SIG_OVERLAP,
171
+ 'SIG_MINLEN': SIG_MINLEN,
172
+ 'LATITUDE': LATITUDE,
173
+ 'LONGITUDE': LONGITUDE,
174
+ 'WEEK': WEEK,
175
+ 'LOCATION_FILTER_THRESHOLD': LOCATION_FILTER_THRESHOLD,
176
+ 'CODES_FILE': CODES_FILE,
177
+ 'SPECIES_LIST_FILE': SPECIES_LIST_FILE,
178
+ 'INPUT_PATH': INPUT_PATH,
179
+ 'OUTPUT_PATH': OUTPUT_PATH,
180
+ 'CPU_THREADS': CPU_THREADS,
181
+ 'TFLITE_THREADS': TFLITE_THREADS,
182
+ 'APPLY_SIGMOID': APPLY_SIGMOID,
183
+ 'SIGMOID_SENSITIVITY': SIGMOID_SENSITIVITY,
184
+ 'MIN_CONFIDENCE': MIN_CONFIDENCE,
185
+ 'BATCH_SIZE': BATCH_SIZE,
186
+ 'RESULT_TYPE': RESULT_TYPE,
187
+ 'TRAIN_DATA_PATH': TRAIN_DATA_PATH,
188
+ 'TRAIN_EPOCHS': TRAIN_EPOCHS,
189
+ 'TRAIN_BATCH_SIZE': TRAIN_BATCH_SIZE,
190
+ 'TRAIN_LEARNING_RATE': TRAIN_LEARNING_RATE,
191
+ 'TRAIN_HIDDEN_UNITS': TRAIN_HIDDEN_UNITS,
192
+ 'CODES': CODES,
193
+ 'LABELS': LABELS,
194
+ 'TRANSLATED_LABELS': TRANSLATED_LABELS,
195
+ 'SPECIES_LIST': SPECIES_LIST,
196
+ 'ERROR_LOG_FILE': ERROR_LOG_FILE
197
+ }
198
+
199
+ def setConfig(c):
200
+
201
+ global RANDOM_SEED
202
+ global MODEL_PATH
203
+ global MDATA_MODEL_PATH
204
+ global LABELS_FILE
205
+ global CUSTOM_CLASSIFIER
206
+ global SAMPLE_RATE
207
+ global SIG_LENGTH
208
+ global SIG_OVERLAP
209
+ global SIG_MINLEN
210
+ global LATITUDE
211
+ global LONGITUDE
212
+ global WEEK
213
+ global LOCATION_FILTER_THRESHOLD
214
+ global CODES_FILE
215
+ global SPECIES_LIST_FILE
216
+ global INPUT_PATH
217
+ global OUTPUT_PATH
218
+ global CPU_THREADS
219
+ global TFLITE_THREADS
220
+ global APPLY_SIGMOID
221
+ global SIGMOID_SENSITIVITY
222
+ global MIN_CONFIDENCE
223
+ global BATCH_SIZE
224
+ global RESULT_TYPE
225
+ global TRAIN_DATA_PATH
226
+ global TRAIN_EPOCHS
227
+ global TRAIN_BATCH_SIZE
228
+ global TRAIN_LEARNING_RATE
229
+ global TRAIN_HIDDEN_UNITS
230
+ global CODES
231
+ global LABELS
232
+ global TRANSLATED_LABELS
233
+ global SPECIES_LIST
234
+ global ERROR_LOG_FILE
235
+
236
+ RANDOM_SEED = c['RANDOM_SEED']
237
+ MODEL_PATH = c['MODEL_PATH']
238
+ MDATA_MODEL_PATH = c['MDATA_MODEL_PATH']
239
+ LABELS_FILE = c['LABELS_FILE']
240
+ CUSTOM_CLASSIFIER = c['CUSTOM_CLASSIFIER']
241
+ SAMPLE_RATE = c['SAMPLE_RATE']
242
+ SIG_LENGTH = c['SIG_LENGTH']
243
+ SIG_OVERLAP = c['SIG_OVERLAP']
244
+ SIG_MINLEN = c['SIG_MINLEN']
245
+ LATITUDE = c['LATITUDE']
246
+ LONGITUDE = c['LONGITUDE']
247
+ WEEK = c['WEEK']
248
+ LOCATION_FILTER_THRESHOLD = c['LOCATION_FILTER_THRESHOLD']
249
+ CODES_FILE = c['CODES_FILE']
250
+ SPECIES_LIST_FILE = c['SPECIES_LIST_FILE']
251
+ INPUT_PATH = c['INPUT_PATH']
252
+ OUTPUT_PATH = c['OUTPUT_PATH']
253
+ CPU_THREADS = c['CPU_THREADS']
254
+ TFLITE_THREADS = c['TFLITE_THREADS']
255
+ APPLY_SIGMOID = c['APPLY_SIGMOID']
256
+ SIGMOID_SENSITIVITY = c['SIGMOID_SENSITIVITY']
257
+ MIN_CONFIDENCE = c['MIN_CONFIDENCE']
258
+ BATCH_SIZE = c['BATCH_SIZE']
259
+ RESULT_TYPE = c['RESULT_TYPE']
260
+ TRAIN_DATA_PATH = c['TRAIN_DATA_PATH']
261
+ TRAIN_EPOCHS = c['TRAIN_EPOCHS']
262
+ TRAIN_BATCH_SIZE = c['TRAIN_BATCH_SIZE']
263
+ TRAIN_LEARNING_RATE = c['TRAIN_LEARNING_RATE']
264
+ TRAIN_HIDDEN_UNITS = c['TRAIN_HIDDEN_UNITS']
265
+ CODES = c['CODES']
266
+ LABELS = c['LABELS']
267
+ TRANSLATED_LABELS = c['TRANSLATED_LABELS']
268
+ SPECIES_LIST = c['SPECIES_LIST']
269
+ ERROR_LOG_FILE = c['ERROR_LOG_FILE']
model.py ADDED
@@ -0,0 +1,505 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Contains functions to use the BirdNET models.
2
+ """
3
+ import os
4
+ import warnings
5
+
6
+ import numpy as np
7
+
8
+ import config as cfg
9
+ import utils
10
+
11
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
12
+ os.environ["CUDA_VISIBLE_DEVICES"] = ""
13
+
14
+ warnings.filterwarnings("ignore")
15
+
16
+ # Import TFLite from runtime or Tensorflow;
17
+ # import Keras if protobuf model;
18
+ # NOTE: we have to use TFLite if we want to use
19
+ # the metadata model or want to extract embeddings
20
+ try:
21
+ import tflite_runtime.interpreter as tflite
22
+ except ModuleNotFoundError:
23
+ from tensorflow import lite as tflite
24
+ if not cfg.MODEL_PATH.endswith(".tflite"):
25
+ from tensorflow import keras
26
+
27
+ INTERPRETER: tflite.Interpreter = None
28
+ C_INTERPRETER: tflite.Interpreter = None
29
+ M_INTERPRETER: tflite.Interpreter = None
30
+ PBMODEL = None
31
+
32
+
33
+ def loadModel(class_output=True):
34
+ """Initializes the BirdNET Model.
35
+
36
+ Args:
37
+ class_output: Omits the last layer when False.
38
+ """
39
+ global PBMODEL
40
+ global INTERPRETER
41
+ global INPUT_LAYER_INDEX
42
+ global OUTPUT_LAYER_INDEX
43
+
44
+ # Do we have to load the tflite or protobuf model?
45
+ if cfg.MODEL_PATH.endswith(".tflite"):
46
+ # Load TFLite model and allocate tensors.
47
+ INTERPRETER = tflite.Interpreter(model_path=cfg.MODEL_PATH, num_threads=cfg.TFLITE_THREADS)
48
+ INTERPRETER.allocate_tensors()
49
+
50
+ # Get input and output tensors.
51
+ input_details = INTERPRETER.get_input_details()
52
+ output_details = INTERPRETER.get_output_details()
53
+
54
+ # Get input tensor index
55
+ INPUT_LAYER_INDEX = input_details[0]["index"]
56
+
57
+ # Get classification output or feature embeddings
58
+ if class_output:
59
+ OUTPUT_LAYER_INDEX = output_details[0]["index"]
60
+ else:
61
+ OUTPUT_LAYER_INDEX = output_details[0]["index"] - 1
62
+
63
+ else:
64
+ # Load protobuf model
65
+ # Note: This will throw a bunch of warnings about custom gradients
66
+ # which we will ignore until TF lets us block them
67
+ PBMODEL = keras.models.load_model(cfg.MODEL_PATH, compile=False)
68
+
69
+
70
+ def loadCustomClassifier():
71
+ """Loads the custom classifier."""
72
+ global C_INTERPRETER
73
+ global C_INPUT_LAYER_INDEX
74
+ global C_OUTPUT_LAYER_INDEX
75
+ global C_INPUT_SIZE
76
+
77
+ # Load TFLite model and allocate tensors.
78
+ C_INTERPRETER = tflite.Interpreter(model_path=cfg.CUSTOM_CLASSIFIER, num_threads=cfg.TFLITE_THREADS)
79
+ C_INTERPRETER.allocate_tensors()
80
+
81
+ # Get input and output tensors.
82
+ input_details = C_INTERPRETER.get_input_details()
83
+ output_details = C_INTERPRETER.get_output_details()
84
+
85
+ # Get input tensor index
86
+ C_INPUT_LAYER_INDEX = input_details[0]["index"]
87
+
88
+ C_INPUT_SIZE = input_details[0]["shape"][-1]
89
+
90
+ # Get classification output
91
+ C_OUTPUT_LAYER_INDEX = output_details[0]["index"]
92
+
93
+
94
+ def loadMetaModel():
95
+ """Loads the model for species prediction.
96
+
97
+ Initializes the model used to predict species list, based on coordinates and week of year.
98
+ """
99
+ global M_INTERPRETER
100
+ global M_INPUT_LAYER_INDEX
101
+ global M_OUTPUT_LAYER_INDEX
102
+
103
+ # Load TFLite model and allocate tensors.
104
+ M_INTERPRETER = tflite.Interpreter(model_path=cfg.MDATA_MODEL_PATH, num_threads=cfg.TFLITE_THREADS)
105
+ M_INTERPRETER.allocate_tensors()
106
+
107
+ # Get input and output tensors.
108
+ input_details = M_INTERPRETER.get_input_details()
109
+ output_details = M_INTERPRETER.get_output_details()
110
+
111
+ # Get input tensor index
112
+ M_INPUT_LAYER_INDEX = input_details[0]["index"]
113
+ M_OUTPUT_LAYER_INDEX = output_details[0]["index"]
114
+
115
+
116
+ def buildLinearClassifier(num_labels, input_size, hidden_units=0, dropout=0.0):
117
+ """Builds a classifier.
118
+
119
+ Args:
120
+ num_labels: Output size.
121
+ input_size: Size of the input.
122
+ hidden_units: If > 0, creates another hidden layer with the given number of units.
123
+
124
+ Returns:
125
+ A new classifier.
126
+ """
127
+ # import keras
128
+ from tensorflow import keras
129
+
130
+ # Build a simple one- or two-layer linear classifier
131
+ model = keras.Sequential()
132
+
133
+ # Input layer
134
+ model.add(keras.layers.InputLayer(input_shape=(input_size,)))
135
+
136
+ # Hidden layer
137
+ if hidden_units > 0:
138
+ # Dropout layer?
139
+ if dropout > 0:
140
+ model.add(keras.layers.Dropout(dropout))
141
+ model.add(keras.layers.Dense(hidden_units, activation="relu"))
142
+
143
+ # Dropout layer?
144
+ if dropout > 0:
145
+ model.add(keras.layers.Dropout(dropout))
146
+
147
+ # Classification layer
148
+ model.add(keras.layers.Dense(num_labels))
149
+
150
+ # Activation layer
151
+ model.add(keras.layers.Activation("sigmoid"))
152
+
153
+ return model
154
+
155
+
156
+ def trainLinearClassifier(classifier,
157
+ x_train,
158
+ y_train,
159
+ epochs,
160
+ batch_size,
161
+ learning_rate,
162
+ val_split,
163
+ upsampling_ratio,
164
+ upsampling_mode,
165
+ train_with_mixup,
166
+ train_with_label_smoothing,
167
+ on_epoch_end=None):
168
+ """Trains a custom classifier.
169
+
170
+ Trains a new classifier for BirdNET based on the given data.
171
+
172
+ Args:
173
+ classifier: The classifier to be trained.
174
+ x_train: Samples.
175
+ y_train: Labels.
176
+ epochs: Number of epochs to train.
177
+ batch_size: Batch size.
178
+ learning_rate: The learning rate during training.
179
+ on_epoch_end: Optional callback `function(epoch, logs)`.
180
+
181
+ Returns:
182
+ (classifier, history)
183
+ """
184
+ # import keras
185
+ from tensorflow import keras
186
+
187
+ class FunctionCallback(keras.callbacks.Callback):
188
+ def __init__(self, on_epoch_end=None) -> None:
189
+ super().__init__()
190
+ self.on_epoch_end_fn = on_epoch_end
191
+
192
+ def on_epoch_end(self, epoch, logs=None):
193
+ if self.on_epoch_end_fn:
194
+ self.on_epoch_end_fn(epoch, logs)
195
+
196
+ # Set random seed
197
+ np.random.seed(cfg.RANDOM_SEED)
198
+
199
+ # Shuffle data
200
+ idx = np.arange(x_train.shape[0])
201
+ np.random.shuffle(idx)
202
+ x_train = x_train[idx]
203
+ y_train = y_train[idx]
204
+
205
+ # Random val split
206
+ x_train, y_train, x_val, y_val = utils.random_split(x_train, y_train, val_split)
207
+ print(f"Training on {x_train.shape[0]} samples, validating on {x_val.shape[0]} samples.", flush=True)
208
+
209
+ # Upsample training data
210
+ if upsampling_ratio > 0:
211
+ x_train, y_train = utils.upsampling(x_train, y_train, upsampling_ratio, upsampling_mode)
212
+ print(f"Upsampled training data to {x_train.shape[0]} samples.", flush=True)
213
+
214
+ # Apply mixup to training data
215
+ if train_with_mixup:
216
+ x_train, y_train = utils.mixup(x_train, y_train)
217
+
218
+ # Apply label smoothing
219
+ if train_with_label_smoothing:
220
+ y_train = utils.label_smoothing(y_train)
221
+
222
+ # Early stopping
223
+ callbacks = [
224
+ keras.callbacks.EarlyStopping(
225
+ monitor="val_loss", patience=5, verbose=1, start_from_epoch=epochs // 4, restore_best_weights=True
226
+ ),
227
+ FunctionCallback(on_epoch_end=on_epoch_end),
228
+ ]
229
+
230
+ # Cosine annealing lr schedule
231
+ lr_schedule = keras.experimental.CosineDecay(learning_rate, epochs * x_train.shape[0] / batch_size)
232
+
233
+ # Compile model
234
+ classifier.compile(
235
+ optimizer=keras.optimizers.Adam(learning_rate=lr_schedule),
236
+ loss="binary_crossentropy",
237
+ metrics=[keras.metrics.AUC(curve="PR", multi_label=False, name="AUPRC")],
238
+ )
239
+
240
+ # Train model
241
+ history = classifier.fit(
242
+ x_train, y_train, epochs=epochs, batch_size=batch_size, validation_data=(x_val, y_val), callbacks=callbacks
243
+ )
244
+
245
+ return classifier, history
246
+
247
+
248
+ def saveLinearClassifier(classifier, model_path, labels):
249
+ """Saves a custom classifier on the hard drive.
250
+
251
+ Saves the classifier as a tflite model, as well as the used labels in a .txt.
252
+
253
+ Args:
254
+ classifier: The custom classifier.
255
+ model_path: Path the model will be saved at.
256
+ labels: List of labels used for the classifier.
257
+ """
258
+ import tensorflow as tf
259
+
260
+ saved_model = PBMODEL if PBMODEL else tf.keras.models.load_model(cfg.PB_MODEL, compile=False)
261
+
262
+ # Remove activation layer
263
+ classifier.pop()
264
+
265
+ combined_model = tf.keras.Sequential([saved_model.embeddings_model, classifier], "basic")
266
+
267
+ # Append .tflite if necessary
268
+ if not model_path.endswith(".tflite"):
269
+ model_path += ".tflite"
270
+
271
+ # Make folders
272
+ os.makedirs(os.path.dirname(model_path), exist_ok=True)
273
+
274
+ # Save model as tflite
275
+ converter = tflite.TFLiteConverter.from_keras_model(combined_model)
276
+ tflite_model = converter.convert()
277
+ open(model_path, "wb").write(tflite_model)
278
+
279
+ # Save labels
280
+ with open(model_path.replace(".tflite", "_Labels.txt"), "w") as f:
281
+ for label in labels:
282
+ f.write(label + "\n")
283
+
284
+
285
+ def save_raven_model(classifier, model_path, labels):
286
+ import tensorflow as tf
287
+ import csv
288
+ import json
289
+
290
+ saved_model = PBMODEL if PBMODEL else tf.keras.models.load_model(cfg.PB_MODEL, compile=False)
291
+ combined_model = tf.keras.Sequential([saved_model.embeddings_model, classifier], "basic")
292
+
293
+ # Make signatures
294
+ class SignatureModule(tf.Module):
295
+ def __init__(self, keras_model):
296
+ super().__init__()
297
+ self.model = keras_model
298
+
299
+ @tf.function(input_signature=[tf.TensorSpec(shape=[None, 144000], dtype=tf.float32)])
300
+ def basic(self, inputs):
301
+ return {"scores": self.model(inputs)}
302
+
303
+ smodel = SignatureModule(combined_model)
304
+ signatures = {
305
+ "basic": smodel.basic,
306
+ }
307
+
308
+ # Save signature model
309
+ os.makedirs(os.path.dirname(model_path), exist_ok=True)
310
+ model_path = model_path[:-7] if model_path.endswith(".tflite") else model_path
311
+ tf.saved_model.save(smodel, model_path, signatures=signatures)
312
+
313
+ # Save label file
314
+ labelIds = [label[:4].replace(" ", "") + str(i) for i, label in enumerate(labels, 1)]
315
+ labels_dir = os.path.join(model_path, "labels")
316
+
317
+ os.makedirs(labels_dir, exist_ok=True)
318
+
319
+ with open(os.path.join(labels_dir, "label_names.csv"), "w", newline="") as labelsfile:
320
+ labelwriter = csv.writer(labelsfile)
321
+ labelwriter.writerows(zip(labelIds, labels))
322
+
323
+ # Save class names file
324
+ classes_dir = os.path.join(model_path, "classes")
325
+
326
+ os.makedirs(classes_dir, exist_ok=True)
327
+
328
+ with open(os.path.join(classes_dir, "classes.csv"), "w", newline="") as classesfile:
329
+ classeswriter = csv.writer(classesfile)
330
+ for labelId in labelIds:
331
+ classeswriter.writerow((labelId, 0.25, cfg.SIG_FMIN, cfg.SIG_FMAX, False))
332
+
333
+ # Save model config
334
+ model_config = os.path.join(model_path, "model_config.json")
335
+ with open(model_config, "w") as modelconfigfile:
336
+ modelconfig = {
337
+ "specVersion": 1,
338
+ "modelDescription": "Custom classifier trained with BirdNET "
339
+ + cfg.MODEL_VESION
340
+ + " embeddings.\nBirdNET was developed by the K. Lisa Yang Center for Conservation Bioacoustics at the Cornell Lab of Ornithology in collaboration with Chemnitz University of Technology.\n\nhttps://birdnet.cornell.edu",
341
+ "modelTypeConfig": {"modelType": "RECOGNITION"},
342
+ "signatures": [
343
+ {
344
+ "signatureName": "basic",
345
+ "modelInputs": [{"inputName": "inputs", "sampleRate": 48000.0, "inputConfig": ["batch", "samples"]}],
346
+ "modelOutputs": [{"outputName": "scores", "outputType": "SCORES"}],
347
+ }
348
+ ],
349
+ "globalSemanticKeys": labelIds,
350
+ }
351
+ json.dump(modelconfig, modelconfigfile, indent=2)
352
+
353
+
354
+ def predictFilter(lat, lon, week):
355
+ """Predicts the probability for each species.
356
+
357
+ Args:
358
+ lat: The latitude.
359
+ lon: The longitude.
360
+ week: The week of the year [1-48]. Use -1 for yearlong.
361
+
362
+ Returns:
363
+ A list of probabilities for all species.
364
+ """
365
+ global M_INTERPRETER
366
+
367
+ # Does interpreter exist?
368
+ if M_INTERPRETER == None:
369
+ loadMetaModel()
370
+
371
+ # Prepare mdata as sample
372
+ sample = np.expand_dims(np.array([lat, lon, week], dtype="float32"), 0)
373
+
374
+ # Run inference
375
+ M_INTERPRETER.set_tensor(M_INPUT_LAYER_INDEX, sample)
376
+ M_INTERPRETER.invoke()
377
+
378
+ return M_INTERPRETER.get_tensor(M_OUTPUT_LAYER_INDEX)[0]
379
+
380
+
381
+ def explore(lat: float, lon: float, week: int):
382
+ """Predicts the species list.
383
+
384
+ Predicts the species list based on the coordinates and week of year.
385
+
386
+ Args:
387
+ lat: The latitude.
388
+ lon: The longitude.
389
+ week: The week of the year [1-48]. Use -1 for yearlong.
390
+
391
+ Returns:
392
+ A sorted list of tuples with the score and the species.
393
+ """
394
+ # Make filter prediction
395
+ l_filter = predictFilter(lat, lon, week)
396
+
397
+ # Apply threshold
398
+ l_filter = np.where(l_filter >= cfg.LOCATION_FILTER_THRESHOLD, l_filter, 0)
399
+
400
+ # Zip with labels
401
+ l_filter = list(zip(l_filter, cfg.LABELS))
402
+
403
+ # Sort by filter value
404
+ l_filter = sorted(l_filter, key=lambda x: x[0], reverse=True)
405
+
406
+ return l_filter
407
+
408
+
409
+ def flat_sigmoid(x, sensitivity=-1):
410
+ return 1 / (1.0 + np.exp(sensitivity * np.clip(x, -15, 15)))
411
+
412
+
413
+ def predict(sample):
414
+ """Uses the main net to predict a sample.
415
+
416
+ Args:
417
+ sample: Audio sample.
418
+
419
+ Returns:
420
+ The prediction scores for the sample.
421
+ """
422
+ # Has custom classifier?
423
+ if cfg.CUSTOM_CLASSIFIER != None:
424
+ return predictWithCustomClassifier(sample)
425
+
426
+ global INTERPRETER
427
+
428
+ # Does interpreter or keras model exist?
429
+ if INTERPRETER == None and PBMODEL == None:
430
+ loadModel()
431
+
432
+ if PBMODEL == None:
433
+ # Reshape input tensor
434
+ INTERPRETER.resize_tensor_input(INPUT_LAYER_INDEX, [len(sample), *sample[0].shape])
435
+ INTERPRETER.allocate_tensors()
436
+
437
+ # Make a prediction (Audio only for now)
438
+ INTERPRETER.set_tensor(INPUT_LAYER_INDEX, np.array(sample, dtype="float32"))
439
+ INTERPRETER.invoke()
440
+ prediction = INTERPRETER.get_tensor(OUTPUT_LAYER_INDEX)
441
+
442
+ return prediction
443
+
444
+ else:
445
+ # Make a prediction (Audio only for now)
446
+ prediction = PBMODEL.embeddings_model.predict(sample)
447
+
448
+ return prediction
449
+
450
+
451
+ def predictWithCustomClassifier(sample):
452
+ """Uses the custom classifier to make a prediction.
453
+
454
+ Args:
455
+ sample: Audio sample.
456
+
457
+ Returns:
458
+ The prediction scores for the sample.
459
+ """
460
+ global C_INTERPRETER
461
+ global C_INPUT_SIZE
462
+
463
+ # Does interpreter exist?
464
+ if C_INTERPRETER == None:
465
+ loadCustomClassifier()
466
+
467
+ vector = embeddings(sample) if C_INPUT_SIZE != 144000 else sample
468
+
469
+ # Reshape input tensor
470
+ C_INTERPRETER.resize_tensor_input(C_INPUT_LAYER_INDEX, [len(vector), *vector[0].shape])
471
+ C_INTERPRETER.allocate_tensors()
472
+
473
+ # Make a prediction
474
+ C_INTERPRETER.set_tensor(C_INPUT_LAYER_INDEX, np.array(vector, dtype="float32"))
475
+ C_INTERPRETER.invoke()
476
+ prediction = C_INTERPRETER.get_tensor(C_OUTPUT_LAYER_INDEX)
477
+
478
+ return prediction
479
+
480
+
481
+ def embeddings(sample):
482
+ """Extracts the embeddings for a sample.
483
+
484
+ Args:
485
+ sample: Audio samples.
486
+
487
+ Returns:
488
+ The embeddings.
489
+ """
490
+ global INTERPRETER
491
+
492
+ # Does interpreter exist?
493
+ if INTERPRETER == None:
494
+ loadModel(False)
495
+
496
+ # Reshape input tensor
497
+ INTERPRETER.resize_tensor_input(INPUT_LAYER_INDEX, [len(sample), *sample[0].shape])
498
+ INTERPRETER.allocate_tensors()
499
+
500
+ # Extract feature embeddings
501
+ INTERPRETER.set_tensor(INPUT_LAYER_INDEX, np.array(sample, dtype="float32"))
502
+ INTERPRETER.invoke()
503
+ features = INTERPRETER.get_tensor(OUTPUT_LAYER_INDEX)
504
+
505
+ return features
utils.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing common function.
2
+ """
3
+ import os
4
+ import traceback
5
+ import numpy as np
6
+ from pathlib import Path
7
+
8
+ import config as cfg
9
+
10
+
11
+ def collect_audio_files(path: str):
12
+ """Collects all audio files in the given directory.
13
+
14
+ Args:
15
+ path: The directory to be searched.
16
+
17
+ Returns:
18
+ A sorted list of all audio files in the directory.
19
+ """
20
+ # Get all files in directory with os.walk
21
+ files = []
22
+
23
+ for root, _, flist in os.walk(path):
24
+ for f in flist:
25
+ if not f.startswith(".") and f.rsplit(".", 1)[-1].lower() in cfg.ALLOWED_FILETYPES:
26
+ files.append(os.path.join(root, f))
27
+
28
+ return sorted(files)
29
+
30
+
31
+ def readLines(path: str):
32
+ """Reads the lines into a list.
33
+
34
+ Opens the file and reads its contents into a list.
35
+ It is expected to have one line for each species or label.
36
+
37
+ Args:
38
+ path: Absolute path to the species file.
39
+
40
+ Returns:
41
+ A list of all species inside the file.
42
+ """
43
+ return Path(path).read_text(encoding="utf-8").splitlines() if path else []
44
+
45
+
46
+ def list_subdirectories(path: str):
47
+ """Lists all directories inside a path.
48
+
49
+ Retrieves all the subdirectories in a given path without recursion.
50
+
51
+ Args:
52
+ path: Directory to be searched.
53
+
54
+ Returns:
55
+ A filter sequence containing the absolute paths to all directories.
56
+ """
57
+ return filter(lambda el: os.path.isdir(os.path.join(path, el)), os.listdir(path))
58
+
59
+ def random_split(x, y, val_ratio=0.2):
60
+ """Splits the data into training and validation data.
61
+
62
+ Makes sure that each class is represented in both sets.
63
+
64
+ Args:
65
+ x: Samples.
66
+ y: One-hot labels.
67
+ val_ratio: The ratio of validation data.
68
+
69
+ Returns:
70
+ A tuple of (x_train, y_train, x_val, y_val).
71
+ """
72
+
73
+ # Set numpy random seed
74
+ np.random.seed(cfg.RANDOM_SEED)
75
+
76
+ # Get number of classes
77
+ num_classes = y.shape[1]
78
+
79
+ # Initialize training and validation data
80
+ x_train, y_train, x_val, y_val = [], [], [], []
81
+
82
+ # Split data
83
+ for i in range(num_classes):
84
+
85
+ # Get indices of current class
86
+ indices = np.where(y[:, i] == 1)[0]
87
+
88
+ # Get number of samples for each set
89
+ num_samples = len(indices)
90
+ num_samples_train = max(1, int(num_samples * (1 - val_ratio)))
91
+ num_samples_val = max(0, num_samples - num_samples_train)
92
+
93
+ # Randomly choose samples for training and validation
94
+ np.random.shuffle(indices)
95
+ train_indices = indices[:num_samples_train]
96
+ val_indices = indices[num_samples_train:num_samples_train + num_samples_val]
97
+
98
+ # Append samples to training and validation data
99
+ x_train.append(x[train_indices])
100
+ y_train.append(y[train_indices])
101
+ x_val.append(x[val_indices])
102
+ y_val.append(y[val_indices])
103
+
104
+ # Concatenate data
105
+ x_train = np.concatenate(x_train)
106
+ y_train = np.concatenate(y_train)
107
+ x_val = np.concatenate(x_val)
108
+ y_val = np.concatenate(y_val)
109
+
110
+ # Shuffle data
111
+ indices = np.arange(len(x_train))
112
+ np.random.shuffle(indices)
113
+ x_train = x_train[indices]
114
+ y_train = y_train[indices]
115
+
116
+ indices = np.arange(len(x_val))
117
+ np.random.shuffle(indices)
118
+ x_val = x_val[indices]
119
+ y_val = y_val[indices]
120
+
121
+ return x_train, y_train, x_val, y_val
122
+
123
+ def mixup(x, y, augmentation_ratio=0.25, alpha=0.2):
124
+ """Apply mixup to the given data.
125
+
126
+ Mixup is a data augmentation technique that generates new samples by
127
+ mixing two samples and their labels.
128
+
129
+ Args:
130
+ x: Samples.
131
+ y: One-hot labels.
132
+ augmentation_ratio: The ratio of augmented samples.
133
+ alpha: The beta distribution parameter.
134
+
135
+ Returns:
136
+ Augmented data.
137
+ """
138
+
139
+ # Calculate the number of samples to augment based on the ratio
140
+ num_samples_to_augment = int(len(x) * augmentation_ratio)
141
+
142
+ for _ in range(num_samples_to_augment):
143
+
144
+ # Randomly choose one instance from the dataset
145
+ index = np.random.choice(len(x))
146
+ x1, y1 = x[index], y[index]
147
+
148
+ # Randomly choose a different instance from the dataset
149
+ second_index = np.random.choice(len(x))
150
+ while second_index == index:
151
+ second_index = np.random.choice(len(x))
152
+ x2, y2 = x[second_index], y[second_index]
153
+
154
+ # Generate a random mixing coefficient (lambda)
155
+ lambda_ = np.random.beta(alpha, alpha)
156
+
157
+ # Mix the embeddings and labels
158
+ mixed_x = lambda_ * x1 + (1 - lambda_) * x2
159
+ mixed_y = lambda_ * y1 + (1 - lambda_) * y2
160
+
161
+ # Replace one of the original samples and labels with the augmented sample and labels
162
+ x[index] = mixed_x
163
+ y[index] = mixed_y
164
+
165
+ return x, y
166
+
167
+ def label_smoothing(y, alpha=0.1):
168
+
169
+ # Subtract alpha from correct label when it is >0
170
+ y[y > 0] -= alpha
171
+
172
+ # Assigned alpha to all other labels
173
+ y[y == 0] = alpha / y.shape[0]
174
+
175
+ return y
176
+
177
+ def upsampling(x, y, ratio=0.5, mode="repeat"):
178
+ """Balance data through upsampling.
179
+
180
+ We upsample minority classes to have at least 10% (ratio=0.1) of the samples of the majority class.
181
+
182
+ Args:
183
+ x: Samples.
184
+ y: One-hot labels.
185
+ ratio: The minimum ratio of minority to majority samples.
186
+ mode: The upsampling mode. Either 'repeat', 'mean' or 'smote'.
187
+
188
+ Returns:
189
+ Upsampled data.
190
+ """
191
+
192
+ # Set numpy random seed
193
+ np.random.seed(cfg.RANDOM_SEED)
194
+
195
+ # Determin min number of samples
196
+ min_samples = int(np.max(y.sum(axis=0)) * ratio)
197
+
198
+ x_temp = []
199
+ y_temp = []
200
+ if mode == 'repeat':
201
+
202
+ # For each class with less than min_samples ranomdly repeat samples
203
+ for i in range(y.shape[1]):
204
+
205
+ while y[:, i].sum() + len(y_temp) < min_samples:
206
+
207
+ # Randomly choose a sample from the minority class
208
+ random_index = np.random.choice(np.where(y[:, i] == 1)[0])
209
+
210
+ # Append the sample and label to a temp list
211
+ x_temp.append(x[random_index])
212
+ y_temp.append(y[random_index])
213
+
214
+ elif mode == 'mean':
215
+
216
+ # For each class with less than min_samples
217
+ # select two random samples and calculate the mean
218
+ for i in range(y.shape[1]):
219
+
220
+ x_temp = []
221
+ y_temp = []
222
+ while y[:, i].sum() + len(y_temp) < min_samples:
223
+
224
+ # Randomly choose two samples from the minority class
225
+ random_indices = np.random.choice(np.where(y[:, i] == 1)[0], 2)
226
+
227
+ # Calculate the mean of the two samples
228
+ mean = np.mean(x[random_indices], axis=0)
229
+
230
+ # Append the mean and label to a temp list
231
+ x_temp.append(mean)
232
+ y_temp.append(y[random_indices[0]])
233
+
234
+ elif mode == 'smote':
235
+
236
+ # For each class with less than min_samples apply SMOTE
237
+ for i in range(y.shape[1]):
238
+
239
+ x_temp = []
240
+ y_temp = []
241
+ while y[:, i].sum() + len(y_temp) < min_samples:
242
+
243
+ # Randomly choose a sample from the minority class
244
+ random_index = np.random.choice(np.where(y[:, i] == 1)[0])
245
+
246
+ # Get the k nearest neighbors
247
+ k = 5
248
+ distances = np.sqrt(np.sum((x - x[random_index])**2, axis=1))
249
+ indices = np.argsort(distances)[1:k+1]
250
+
251
+ # Randomly choose one of the neighbors
252
+ random_neighbor = np.random.choice(indices)
253
+
254
+ # Calculate the difference vector
255
+ diff = x[random_neighbor] - x[random_index]
256
+
257
+ # Randomly choose a weight between 0 and 1
258
+ weight = np.random.uniform(0, 1)
259
+
260
+ # Calculate the new sample
261
+ new_sample = x[random_index] + weight * diff
262
+
263
+ # Append the new sample and label to a temp list
264
+ x_temp.append(new_sample)
265
+ y_temp.append(y[random_index])
266
+
267
+ # Append the temp list to the original data
268
+ if len(x_temp) > 0:
269
+ x = np.vstack((x, np.array(x_temp)))
270
+ y = np.vstack((y, np.array(y_temp)))
271
+
272
+ # Shuffle data
273
+ indices = np.arange(len(x))
274
+ np.random.shuffle(indices)
275
+ x = x[indices]
276
+ y = y[indices]
277
+
278
+ return x, y
279
+
280
+ def saveToCache(cache_file: str, x_train: np.ndarray, y_train: np.ndarray, labels: list[str]):
281
+ """Saves the training data to a cache file.
282
+
283
+ Args:
284
+ cache_file: The path to the cache file.
285
+ x_train: The training samples.
286
+ y_train: The training labels.
287
+ labels: The list of labels.
288
+ """
289
+ # Create cache directory
290
+ os.makedirs(os.path.dirname(cache_file), exist_ok=True)
291
+
292
+ # Save to cache
293
+ np.savez_compressed(cache_file, x_train=x_train, y_train=y_train, labels=labels)
294
+
295
+ def loadFromCache(cache_file: str):
296
+ """Loads the training data from a cache file.
297
+
298
+ Args:
299
+ cache_file: The path to the cache file.
300
+
301
+ Returns:
302
+ A tuple of (x_train, y_train, labels).
303
+
304
+ """
305
+ # Load from cache
306
+ cache = np.load(cache_file, allow_pickle=True)
307
+
308
+ # Get data
309
+ x_train = cache["x_train"]
310
+ y_train = cache["y_train"]
311
+ labels = cache["labels"]
312
+
313
+ return x_train, y_train, labels
314
+
315
+ def clearErrorLog():
316
+ """Clears the error log file.
317
+
318
+ For debugging purposes.
319
+ """
320
+ if os.path.isfile(cfg.ERROR_LOG_FILE):
321
+ os.remove(cfg.ERROR_LOG_FILE)
322
+
323
+
324
+ def writeErrorLog(ex: Exception):
325
+ """Writes an exception to the error log.
326
+
327
+ Formats the stacktrace and writes it in the error log file configured in the config.
328
+
329
+ Args:
330
+ ex: An exception that occurred.
331
+ """
332
+ with open(cfg.ERROR_LOG_FILE, "a") as elog:
333
+ elog.write("".join(traceback.TracebackException.from_exception(ex).format()) + "\n")
334
+
335
+