BilalSardar commited on
Commit
9f41dd5
1 Parent(s): 6b3d53a

Upload 2 files

Browse files
Files changed (2) hide show
  1. gui.py +1041 -0
  2. train.py +214 -0
gui.py ADDED
@@ -0,0 +1,1041 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import concurrent.futures
2
+ import os
3
+ import sys
4
+ from multiprocessing import freeze_support
5
+ from pathlib import Path
6
+
7
+ import gradio as gr
8
+ import librosa
9
+ import webview
10
+
11
+ import analyze
12
+ import config as cfg
13
+ import segments
14
+ import species
15
+ import utils
16
+ from train import trainModel
17
+
18
+ _WINDOW: webview.Window
19
+ OUTPUT_TYPE_MAP = {"Raven selection table": "table", "Audacity": "audacity", "R": "r", "CSV": "csv"}
20
+ ORIGINAL_MODEL_PATH = cfg.MODEL_PATH
21
+ ORIGINAL_MDATA_MODEL_PATH = cfg.MDATA_MODEL_PATH
22
+ ORIGINAL_LABELS_FILE = cfg.LABELS_FILE
23
+ ORIGINAL_TRANSLATED_LABELS_PATH = cfg.TRANSLATED_LABELS_PATH
24
+
25
+
26
+ def analyzeFile_wrapper(entry):
27
+ return (entry[0], analyze.analyzeFile(entry))
28
+
29
+
30
+ def extractSegments_wrapper(entry):
31
+ return (entry[0][0], segments.extractSegments(entry))
32
+
33
+
34
+ def validate(value, msg):
35
+ """Checks if the value ist not falsy.
36
+
37
+ If the value is falsy, an error will be raised.
38
+
39
+ Args:
40
+ value: Value to be tested.
41
+ msg: Message in case of an error.
42
+ """
43
+ if not value:
44
+ raise gr.Error(msg)
45
+
46
+
47
+ def runSingleFileAnalysis(
48
+ input_path,
49
+ confidence,
50
+ sensitivity,
51
+ overlap,
52
+ species_list_choice,
53
+ species_list_file,
54
+ lat,
55
+ lon,
56
+ week,
57
+ use_yearlong,
58
+ sf_thresh,
59
+ custom_classifier_file,
60
+ locale,
61
+ ):
62
+ validate(input_path, "Please select a file.")
63
+
64
+ return runAnalysis(
65
+ input_path,
66
+ None,
67
+ confidence,
68
+ sensitivity,
69
+ overlap,
70
+ species_list_choice,
71
+ species_list_file,
72
+ lat,
73
+ lon,
74
+ week,
75
+ use_yearlong,
76
+ sf_thresh,
77
+ custom_classifier_file,
78
+ "csv",
79
+ "en" if not locale else locale,
80
+ 1,
81
+ 4,
82
+ None,
83
+ progress=None,
84
+ )
85
+
86
+
87
+ def runBatchAnalysis(
88
+ output_path,
89
+ confidence,
90
+ sensitivity,
91
+ overlap,
92
+ species_list_choice,
93
+ species_list_file,
94
+ lat,
95
+ lon,
96
+ week,
97
+ use_yearlong,
98
+ sf_thresh,
99
+ custom_classifier_file,
100
+ output_type,
101
+ locale,
102
+ batch_size,
103
+ threads,
104
+ input_dir,
105
+ progress=gr.Progress(),
106
+ ):
107
+ validate(input_dir, "Please select a directory.")
108
+ batch_size = int(batch_size)
109
+ threads = int(threads)
110
+
111
+ if species_list_choice == _CUSTOM_SPECIES:
112
+ validate(species_list_file, "Please select a species list.")
113
+
114
+ return runAnalysis(
115
+ None,
116
+ output_path,
117
+ confidence,
118
+ sensitivity,
119
+ overlap,
120
+ species_list_choice,
121
+ species_list_file,
122
+ lat,
123
+ lon,
124
+ week,
125
+ use_yearlong,
126
+ sf_thresh,
127
+ custom_classifier_file,
128
+ output_type,
129
+ "en" if not locale else locale,
130
+ batch_size if batch_size and batch_size > 0 else 1,
131
+ threads if threads and threads > 0 else 4,
132
+ input_dir,
133
+ progress,
134
+ )
135
+
136
+
137
+ def runAnalysis(
138
+ input_path: str,
139
+ output_path: str | None,
140
+ confidence: float,
141
+ sensitivity: float,
142
+ overlap: float,
143
+ species_list_choice: str,
144
+ species_list_file,
145
+ lat: float,
146
+ lon: float,
147
+ week: int,
148
+ use_yearlong: bool,
149
+ sf_thresh: float,
150
+ custom_classifier_file,
151
+ output_type: str,
152
+ locale: str,
153
+ batch_size: int,
154
+ threads: int,
155
+ input_dir: str,
156
+ progress: gr.Progress | None,
157
+ ):
158
+ """Starts the analysis.
159
+
160
+ Args:
161
+ input_path: Either a file or directory.
162
+ output_path: The output path for the result, if None the input_path is used
163
+ confidence: The selected minimum confidence.
164
+ sensitivity: The selected sensitivity.
165
+ overlap: The selected segment overlap.
166
+ species_list_choice: The choice for the species list.
167
+ species_list_file: The selected custom species list file.
168
+ lat: The selected latitude.
169
+ lon: The selected longitude.
170
+ week: The selected week of the year.
171
+ use_yearlong: Use yearlong instead of week.
172
+ sf_thresh: The threshold for the predicted species list.
173
+ custom_classifier_file: Custom classifier to be used.
174
+ output_type: The type of result to be generated.
175
+ locale: The translation to be used.
176
+ batch_size: The number of samples in a batch.
177
+ threads: The number of threads to be used.
178
+ input_dir: The input directory.
179
+ progress: The gradio progress bar.
180
+ """
181
+ if progress is not None:
182
+ progress(0, desc="Preparing ...")
183
+
184
+ locale = locale.lower()
185
+ # Load eBird codes, labels
186
+ cfg.CODES = analyze.loadCodes()
187
+ cfg.LABELS = utils.readLines(ORIGINAL_LABELS_FILE)
188
+ cfg.LATITUDE, cfg.LONGITUDE, cfg.WEEK = lat, lon, -1 if use_yearlong else week
189
+ cfg.LOCATION_FILTER_THRESHOLD = sf_thresh
190
+
191
+ if species_list_choice == _CUSTOM_SPECIES:
192
+ if not species_list_file or not species_list_file.name:
193
+ cfg.SPECIES_LIST_FILE = None
194
+ else:
195
+ cfg.SPECIES_LIST_FILE = os.path.join(os.path.dirname(os.path.abspath(sys.argv[0])), species_list_file.name)
196
+
197
+ if os.path.isdir(cfg.SPECIES_LIST_FILE):
198
+ cfg.SPECIES_LIST_FILE = os.path.join(cfg.SPECIES_LIST_FILE, "species_list.txt")
199
+
200
+ cfg.SPECIES_LIST = utils.readLines(cfg.SPECIES_LIST_FILE)
201
+ cfg.CUSTOM_CLASSIFIER = None
202
+ elif species_list_choice == _PREDICT_SPECIES:
203
+ cfg.SPECIES_LIST_FILE = None
204
+ cfg.CUSTOM_CLASSIFIER = None
205
+ cfg.SPECIES_LIST = species.getSpeciesList(cfg.LATITUDE, cfg.LONGITUDE, cfg.WEEK, cfg.LOCATION_FILTER_THRESHOLD)
206
+ elif species_list_choice == _CUSTOM_CLASSIFIER:
207
+ if custom_classifier_file is None:
208
+ raise gr.Error("No custom classifier selected.")
209
+
210
+ # Set custom classifier?
211
+ cfg.CUSTOM_CLASSIFIER = custom_classifier_file # we treat this as absolute path, so no need to join with dirname
212
+ cfg.LABELS_FILE = custom_classifier_file.replace(".tflite", "_Labels.txt") # same for labels file
213
+ cfg.LABELS = utils.readLines(cfg.LABELS_FILE)
214
+ cfg.LATITUDE = -1
215
+ cfg.LONGITUDE = -1
216
+ cfg.SPECIES_LIST_FILE = None
217
+ cfg.SPECIES_LIST = []
218
+ locale = "en"
219
+ else:
220
+ cfg.SPECIES_LIST_FILE = None
221
+ cfg.SPECIES_LIST = []
222
+ cfg.CUSTOM_CLASSIFIER = None
223
+
224
+ # Load translated labels
225
+ lfile = os.path.join(cfg.TRANSLATED_LABELS_PATH, os.path.basename(cfg.LABELS_FILE).replace(".txt", f"_{locale}.txt"))
226
+ if not locale in ["en"] and os.path.isfile(lfile):
227
+ cfg.TRANSLATED_LABELS = utils.readLines(lfile)
228
+ else:
229
+ cfg.TRANSLATED_LABELS = cfg.LABELS
230
+
231
+ if len(cfg.SPECIES_LIST) == 0:
232
+ print(f"Species list contains {len(cfg.LABELS)} species")
233
+ else:
234
+ print(f"Species list contains {len(cfg.SPECIES_LIST)} species")
235
+
236
+ # Set input and output path
237
+ cfg.INPUT_PATH = input_path
238
+
239
+ if input_dir:
240
+ cfg.OUTPUT_PATH = output_path if output_path else input_dir
241
+ else:
242
+ cfg.OUTPUT_PATH = output_path if output_path else input_path.split(".", 1)[0] + ".csv"
243
+
244
+ # Parse input files
245
+ if input_dir:
246
+ cfg.FILE_LIST = utils.collect_audio_files(input_dir)
247
+ cfg.INPUT_PATH = input_dir
248
+ elif os.path.isdir(cfg.INPUT_PATH):
249
+ cfg.FILE_LIST = utils.collect_audio_files(cfg.INPUT_PATH)
250
+ else:
251
+ cfg.FILE_LIST = [cfg.INPUT_PATH]
252
+
253
+ validate(cfg.FILE_LIST, "No audio files found.")
254
+
255
+ # Set confidence threshold
256
+ cfg.MIN_CONFIDENCE = confidence
257
+
258
+ # Set sensitivity
259
+ cfg.SIGMOID_SENSITIVITY = sensitivity
260
+
261
+ # Set overlap
262
+ cfg.SIG_OVERLAP = overlap
263
+
264
+ # Set result type
265
+ cfg.RESULT_TYPE = OUTPUT_TYPE_MAP[output_type] if output_type in OUTPUT_TYPE_MAP else output_type.lower()
266
+
267
+ if not cfg.RESULT_TYPE in ["table", "audacity", "r", "csv"]:
268
+ cfg.RESULT_TYPE = "table"
269
+
270
+ # Set number of threads
271
+ if input_dir:
272
+ cfg.CPU_THREADS = max(1, int(threads))
273
+ cfg.TFLITE_THREADS = 1
274
+ else:
275
+ cfg.CPU_THREADS = 1
276
+ cfg.TFLITE_THREADS = max(1, int(threads))
277
+
278
+ # Set batch size
279
+ cfg.BATCH_SIZE = max(1, int(batch_size))
280
+
281
+ flist = []
282
+
283
+ for f in cfg.FILE_LIST:
284
+ flist.append((f, cfg.getConfig()))
285
+
286
+ result_list = []
287
+
288
+ if progress is not None:
289
+ progress(0, desc="Starting ...")
290
+
291
+ # Analyze files
292
+ if cfg.CPU_THREADS < 2:
293
+ for entry in flist:
294
+ result = analyzeFile_wrapper(entry)
295
+
296
+ result_list.append(result)
297
+ else:
298
+ with concurrent.futures.ProcessPoolExecutor(max_workers=cfg.CPU_THREADS) as executor:
299
+ futures = (executor.submit(analyzeFile_wrapper, arg) for arg in flist)
300
+ for i, f in enumerate(concurrent.futures.as_completed(futures), start=1):
301
+ if progress is not None:
302
+ progress((i, len(flist)), total=len(flist), unit="files")
303
+ result = f.result()
304
+
305
+ result_list.append(result)
306
+
307
+ return [[os.path.relpath(r[0], input_dir), r[1]] for r in result_list] if input_dir else cfg.OUTPUT_PATH
308
+
309
+
310
+ _CUSTOM_SPECIES = "Custom species list"
311
+ _PREDICT_SPECIES = "Species by location"
312
+ _CUSTOM_CLASSIFIER = "Custom classifier"
313
+ _ALL_SPECIES = "all species"
314
+
315
+
316
+ def show_species_choice(choice: str):
317
+ """Sets the visibility of the species list choices.
318
+
319
+ Args:
320
+ choice: The label of the currently active choice.
321
+
322
+ Returns:
323
+ A list of [
324
+ Row update,
325
+ File update,
326
+ Column update,
327
+ Column update,
328
+ ]
329
+ """
330
+ if choice == _CUSTOM_SPECIES:
331
+ return [
332
+ gr.Row.update(visible=False),
333
+ gr.File.update(visible=True),
334
+ gr.Column.update(visible=False),
335
+ gr.Column.update(visible=False),
336
+ ]
337
+ elif choice == _PREDICT_SPECIES:
338
+ return [
339
+ gr.Row.update(visible=True),
340
+ gr.File.update(visible=False),
341
+ gr.Column.update(visible=False),
342
+ gr.Column.update(visible=False),
343
+ ]
344
+ elif choice == _CUSTOM_CLASSIFIER:
345
+ return [
346
+ gr.Row.update(visible=False),
347
+ gr.File.update(visible=False),
348
+ gr.Column.update(visible=True),
349
+ gr.Column.update(visible=False),
350
+ ]
351
+
352
+ return [
353
+ gr.Row.update(visible=False),
354
+ gr.File.update(visible=False),
355
+ gr.Column.update(visible=False),
356
+ gr.Column.update(visible=True),
357
+ ]
358
+
359
+
360
+ def select_subdirectories():
361
+ """Creates a directory selection dialog.
362
+
363
+ Returns:
364
+ A tuples of (directory, list of subdirectories) or (None, None) if the dialog was canceled.
365
+ """
366
+ dir_name = _WINDOW.create_file_dialog(webview.FOLDER_DIALOG)
367
+
368
+ if dir_name:
369
+ subdirs = utils.list_subdirectories(dir_name[0])
370
+
371
+ return dir_name[0], [[d] for d in subdirs]
372
+
373
+ return None, None
374
+
375
+
376
+ def select_file(filetypes=()):
377
+ """Creates a file selection dialog.
378
+
379
+ Args:
380
+ filetypes: List of filetypes to be filtered in the dialog.
381
+
382
+ Returns:
383
+ The selected file or None of the dialog was canceled.
384
+ """
385
+ files = _WINDOW.create_file_dialog(webview.OPEN_DIALOG, file_types=filetypes)
386
+ return files[0] if files else None
387
+
388
+
389
+ def format_seconds(secs: float):
390
+ """Formats a number of seconds into a string.
391
+
392
+ Formats the seconds into the format "h:mm:ss.ms"
393
+
394
+ Args:
395
+ secs: Number of seconds.
396
+
397
+ Returns:
398
+ A string with the formatted seconds.
399
+ """
400
+ hours, secs = divmod(secs, 3600)
401
+ minutes, secs = divmod(secs, 60)
402
+
403
+ return "{:2.0f}:{:02.0f}:{:06.3f}".format(hours, minutes, secs)
404
+
405
+
406
+ def select_directory(collect_files=True):
407
+ """Shows a directory selection system dialog.
408
+
409
+ Uses the pywebview to create a system dialog.
410
+
411
+ Args:
412
+ collect_files: If True, also lists a files inside the directory.
413
+
414
+ Returns:
415
+ If collect_files==True, returns (directory path, list of (relative file path, audio length))
416
+ else just the directory path.
417
+ All values will be None of the dialog is cancelled.
418
+ """
419
+ dir_name = _WINDOW.create_file_dialog(webview.FOLDER_DIALOG)
420
+
421
+ if collect_files:
422
+ if not dir_name:
423
+ return None, None
424
+
425
+ files = utils.collect_audio_files(dir_name[0])
426
+
427
+ return dir_name[0], [
428
+ [os.path.relpath(file, dir_name[0]), format_seconds(librosa.get_duration(filename=file))] for file in files
429
+ ]
430
+
431
+ return dir_name[0] if dir_name else None
432
+
433
+
434
+ def start_training(
435
+ data_dir,
436
+ crop_mode,
437
+ crop_overlap,
438
+ output_dir,
439
+ classifier_name,
440
+ epochs,
441
+ batch_size,
442
+ learning_rate,
443
+ hidden_units,
444
+ use_mixup,
445
+ upsampling_ratio,
446
+ upsampling_mode,
447
+ model_format,
448
+ progress=gr.Progress(),
449
+ ):
450
+ """Starts the training of a custom classifier.
451
+
452
+ Args:
453
+ data_dir: Directory containing the training data.
454
+ output_dir: Directory for the new classifier.
455
+ classifier_name: File name of the classifier.
456
+ epochs: Number of epochs to train for.
457
+ batch_size: Number of samples in one batch.
458
+ learning_rate: Learning rate for training.
459
+ hidden_units: If > 0 the classifier contains a further hidden layer.
460
+ progress: The gradio progress bar.
461
+
462
+ Returns:
463
+ Returns a matplotlib.pyplot figure.
464
+ """
465
+ validate(data_dir, "Please select your Training data.")
466
+ validate(output_dir, "Please select a directory for the classifier.")
467
+ validate(classifier_name, "Please enter a valid name for the classifier.")
468
+
469
+ if not epochs or epochs < 0:
470
+ raise gr.Error("Please enter a valid number of epochs.")
471
+
472
+ if not batch_size or batch_size < 0:
473
+ raise gr.Error("Please enter a valid batch size.")
474
+
475
+ if not learning_rate or learning_rate < 0:
476
+ raise gr.Error("Please enter a valid learning rate.")
477
+
478
+ if not hidden_units or hidden_units < 0:
479
+ hidden_units = 0
480
+
481
+ if progress is not None:
482
+ progress((0, epochs), desc="Loading data & building classifier", unit="epoch")
483
+
484
+ cfg.TRAIN_DATA_PATH = data_dir
485
+ cfg.SAMPLE_CROP_MODE = crop_mode
486
+ cfg.SIG_OVERLAP = crop_overlap
487
+ cfg.CUSTOM_CLASSIFIER = str(Path(output_dir) / classifier_name)
488
+ cfg.TRAIN_EPOCHS = int(epochs)
489
+ cfg.TRAIN_BATCH_SIZE = int(batch_size)
490
+ cfg.TRAIN_LEARNING_RATE = learning_rate
491
+ cfg.TRAIN_HIDDEN_UNITS = int(hidden_units)
492
+ cfg.TRAIN_WITH_MIXUP = use_mixup
493
+ cfg.UPSAMPLING_RATIO = min(max(0, upsampling_ratio), 1)
494
+ cfg.UPSAMPLING_MODE = upsampling_mode
495
+ cfg.TRAINED_MODEL_OUTPUT_FORMAT = model_format
496
+
497
+ def progression(epoch, logs=None):
498
+ if progress is not None:
499
+ if epoch + 1 == epochs:
500
+ progress((epoch + 1, epochs), total=epochs, unit="epoch", desc=f"Saving at {cfg.CUSTOM_CLASSIFIER}")
501
+ else:
502
+ progress((epoch + 1, epochs), total=epochs, unit="epoch")
503
+
504
+ history = trainModel(on_epoch_end=progression)
505
+
506
+ if len(history.epoch) < epochs:
507
+ gr.Info("Stopped early - validation metric not improving.")
508
+
509
+ auprc = history.history["val_AUPRC"]
510
+
511
+ import matplotlib.pyplot as plt
512
+
513
+ fig = plt.figure()
514
+ plt.plot(auprc)
515
+ plt.ylabel("Area under precision-recall curve")
516
+ plt.xlabel("Epoch")
517
+
518
+ return fig
519
+
520
+
521
+ def extract_segments(audio_dir, result_dir, output_dir, min_conf, num_seq, seq_length, threads, progress=gr.Progress()):
522
+ validate(audio_dir, "No audio directory selected")
523
+
524
+ if not result_dir:
525
+ result_dir = audio_dir
526
+
527
+ if not output_dir:
528
+ output_dir = audio_dir
529
+
530
+ if progress is not None:
531
+ progress(0, desc="Searching files ...")
532
+
533
+ # Parse audio and result folders
534
+ cfg.FILE_LIST = segments.parseFolders(audio_dir, result_dir)
535
+
536
+ # Set output folder
537
+ cfg.OUTPUT_PATH = output_dir
538
+
539
+ # Set number of threads
540
+ cfg.CPU_THREADS = int(threads)
541
+
542
+ # Set confidence threshold
543
+ cfg.MIN_CONFIDENCE = max(0.01, min(0.99, min_conf))
544
+
545
+ # Parse file list and make list of segments
546
+ cfg.FILE_LIST = segments.parseFiles(cfg.FILE_LIST, max(1, int(num_seq)))
547
+
548
+ # Add config items to each file list entry.
549
+ # We have to do this for Windows which does not
550
+ # support fork() and thus each process has to
551
+ # have its own config. USE LINUX!
552
+ flist = [(entry, max(cfg.SIG_LENGTH, float(seq_length)), cfg.getConfig()) for entry in cfg.FILE_LIST]
553
+
554
+ result_list = []
555
+
556
+ # Extract segments
557
+ if cfg.CPU_THREADS < 2:
558
+ for i, entry in enumerate(flist):
559
+ result = extractSegments_wrapper(entry)
560
+ result_list.append(result)
561
+
562
+ if progress is not None:
563
+ progress((i, len(flist)), total=len(flist), unit="files")
564
+ else:
565
+ with concurrent.futures.ProcessPoolExecutor(max_workers=cfg.CPU_THREADS) as executor:
566
+ futures = (executor.submit(extractSegments_wrapper, arg) for arg in flist)
567
+ for i, f in enumerate(concurrent.futures.as_completed(futures), start=1):
568
+ if progress is not None:
569
+ progress((i, len(flist)), total=len(flist), unit="files")
570
+ result = f.result()
571
+
572
+ result_list.append(result)
573
+
574
+ return [[os.path.relpath(r[0], audio_dir), r[1]] for r in result_list]
575
+
576
+
577
+ def sample_sliders(opened=True):
578
+ """Creates the gradio accordion for the inference settings.
579
+
580
+ Args:
581
+ opened: If True the accordion is open on init.
582
+
583
+ Returns:
584
+ A tuple with the created elements:
585
+ (Slider (min confidence), Slider (sensitivity), Slider (overlap))
586
+ """
587
+ with gr.Accordion("Inference settings", open=opened):
588
+ with gr.Row():
589
+ confidence_slider = gr.Slider(
590
+ minimum=0, maximum=1, value=0.5, step=0.01, label="Minimum Confidence", info="Minimum confidence threshold."
591
+ )
592
+ sensitivity_slider = gr.Slider(
593
+ minimum=0.5,
594
+ maximum=1.5,
595
+ value=1,
596
+ step=0.01,
597
+ label="Sensitivity",
598
+ info="Detection sensitivity; Higher values result in higher sensitivity.",
599
+ )
600
+ overlap_slider = gr.Slider(
601
+ minimum=0, maximum=2.99, value=0, step=0.01, label="Overlap", info="Overlap of prediction segments."
602
+ )
603
+
604
+ return confidence_slider, sensitivity_slider, overlap_slider
605
+
606
+
607
+ def locale():
608
+ """Creates the gradio elements for locale selection
609
+
610
+ Reads the translated labels inside the checkpoints directory.
611
+
612
+ Returns:
613
+ The dropdown element.
614
+ """
615
+ label_files = os.listdir(os.path.join(os.path.dirname(sys.argv[0]), ORIGINAL_TRANSLATED_LABELS_PATH))
616
+ options = ["EN"] + [label_file.rsplit("_", 1)[-1].split(".")[0].upper() for label_file in label_files]
617
+
618
+ return gr.Dropdown(options, value="EN", label="Locale", info="Locale for the translated species common names.")
619
+
620
+
621
+ def species_lists(opened=True):
622
+ """Creates the gradio accordion for species selection.
623
+
624
+ Args:
625
+ opened: If True the accordion is open on init.
626
+
627
+ Returns:
628
+ A tuple with the created elements:
629
+ (Radio (choice), File (custom species list), Slider (lat), Slider (lon), Slider (week), Slider (threshold), Checkbox (yearlong?), State (custom classifier))
630
+ """
631
+ with gr.Accordion("Species selection", open=opened):
632
+ with gr.Row():
633
+ species_list_radio = gr.Radio(
634
+ [_CUSTOM_SPECIES, _PREDICT_SPECIES, _CUSTOM_CLASSIFIER, _ALL_SPECIES],
635
+ value=_ALL_SPECIES,
636
+ label="Species list",
637
+ info="List of all possible species",
638
+ elem_classes="d-block",
639
+ )
640
+
641
+ with gr.Column(visible=False) as position_row:
642
+ lat_number = gr.Slider(
643
+ minimum=-90, maximum=90, value=0, step=1, label="Latitude", info="Recording location latitude."
644
+ )
645
+ lon_number = gr.Slider(
646
+ minimum=-180, maximum=180, value=0, step=1, label="Longitude", info="Recording location longitude."
647
+ )
648
+ with gr.Row():
649
+ yearlong_checkbox = gr.Checkbox(True, label="Year-round")
650
+ week_number = gr.Slider(
651
+ minimum=1,
652
+ maximum=48,
653
+ value=1,
654
+ step=1,
655
+ interactive=False,
656
+ label="Week",
657
+ info="Week of the year when the recording was made. Values in [1, 48] (4 weeks per month).",
658
+ )
659
+
660
+ def onChange(use_yearlong):
661
+ return gr.Slider.update(interactive=(not use_yearlong))
662
+
663
+ yearlong_checkbox.change(onChange, inputs=yearlong_checkbox, outputs=week_number, show_progress=False)
664
+ sf_thresh_number = gr.Slider(
665
+ minimum=0.01,
666
+ maximum=0.99,
667
+ value=0.03,
668
+ step=0.01,
669
+ label="Location filter threshold",
670
+ info="Minimum species occurrence frequency threshold for location filter.",
671
+ )
672
+
673
+ species_file_input = gr.File(file_types=[".txt"], info="Path to species list file or folder.", visible=False)
674
+ empty_col = gr.Column()
675
+
676
+ with gr.Column(visible=False) as custom_classifier_selector:
677
+ classifier_selection_button = gr.Button("Select classifier")
678
+ classifier_file_input = gr.Files(
679
+ file_types=[".tflite"], info="Path to the custom classifier.", visible=False, interactive=False
680
+ )
681
+ selected_classifier_state = gr.State()
682
+
683
+ def on_custom_classifier_selection_click():
684
+ file = select_file(("TFLite classifier (*.tflite)",))
685
+
686
+ if file:
687
+ labels = os.path.splitext(file)[0] + "_Labels.txt"
688
+
689
+ return file, gr.File.update(value=[file, labels], visible=True)
690
+
691
+ return None
692
+
693
+ classifier_selection_button.click(
694
+ on_custom_classifier_selection_click,
695
+ outputs=[selected_classifier_state, classifier_file_input],
696
+ show_progress=False,
697
+ )
698
+
699
+ species_list_radio.change(
700
+ show_species_choice,
701
+ inputs=[species_list_radio],
702
+ outputs=[position_row, species_file_input, custom_classifier_selector, empty_col],
703
+ show_progress=False,
704
+ )
705
+
706
+ return (
707
+ species_list_radio,
708
+ species_file_input,
709
+ lat_number,
710
+ lon_number,
711
+ week_number,
712
+ sf_thresh_number,
713
+ yearlong_checkbox,
714
+ selected_classifier_state,
715
+ )
716
+
717
+
718
+ if __name__ == "__main__":
719
+ freeze_support()
720
+
721
+ def build_single_analysis_tab():
722
+ with gr.Tab("Single file"):
723
+ audio_input = gr.Audio(type="filepath", label="file", elem_id="single_file_audio")
724
+
725
+ confidence_slider, sensitivity_slider, overlap_slider = sample_sliders(False)
726
+ (
727
+ species_list_radio,
728
+ species_file_input,
729
+ lat_number,
730
+ lon_number,
731
+ week_number,
732
+ sf_thresh_number,
733
+ yearlong_checkbox,
734
+ selected_classifier_state,
735
+ ) = species_lists(False)
736
+ locale_radio = locale()
737
+
738
+ inputs = [
739
+ audio_input,
740
+ confidence_slider,
741
+ sensitivity_slider,
742
+ overlap_slider,
743
+ species_list_radio,
744
+ species_file_input,
745
+ lat_number,
746
+ lon_number,
747
+ week_number,
748
+ yearlong_checkbox,
749
+ sf_thresh_number,
750
+ selected_classifier_state,
751
+ locale_radio,
752
+ ]
753
+
754
+ output_dataframe = gr.Dataframe(
755
+ type="pandas",
756
+ headers=["Start (s)", "End (s)", "Scientific name", "Common name", "Confidence"],
757
+ elem_classes="mh-200",
758
+ )
759
+
760
+ single_file_analyze = gr.Button("Analyze")
761
+
762
+ single_file_analyze.click(runSingleFileAnalysis, inputs=inputs, outputs=output_dataframe)
763
+
764
+ def build_multi_analysis_tab():
765
+ with gr.Tab("Multiple files"):
766
+ input_directory_state = gr.State()
767
+ output_directory_predict_state = gr.State()
768
+ with gr.Row():
769
+ with gr.Column():
770
+ select_directory_btn = gr.Button("Select directory (recursive)")
771
+ directory_input = gr.Matrix(interactive=False, elem_classes="mh-200", headers=["Subpath", "Length"])
772
+
773
+ def select_directory_on_empty():
774
+ res = select_directory()
775
+
776
+ return res if res[1] else [res[0], [["No files found"]]]
777
+
778
+ select_directory_btn.click(
779
+ select_directory_on_empty, outputs=[input_directory_state, directory_input], show_progress=True
780
+ )
781
+
782
+ with gr.Column():
783
+ select_out_directory_btn = gr.Button("Select output directory.")
784
+ selected_out_textbox = gr.Textbox(
785
+ label="Output directory",
786
+ interactive=False,
787
+ placeholder="If not selected, the input directory will be used.",
788
+ )
789
+
790
+ def select_directory_wrapper():
791
+ return (select_directory(collect_files=False),) * 2
792
+
793
+ select_out_directory_btn.click(
794
+ select_directory_wrapper,
795
+ outputs=[output_directory_predict_state, selected_out_textbox],
796
+ show_progress=False,
797
+ )
798
+
799
+ confidence_slider, sensitivity_slider, overlap_slider = sample_sliders()
800
+
801
+ (
802
+ species_list_radio,
803
+ species_file_input,
804
+ lat_number,
805
+ lon_number,
806
+ week_number,
807
+ sf_thresh_number,
808
+ yearlong_checkbox,
809
+ selected_classifier_state,
810
+ ) = species_lists()
811
+
812
+ output_type_radio = gr.Radio(
813
+ list(OUTPUT_TYPE_MAP.keys()),
814
+ value="Raven selection table",
815
+ label="Result type",
816
+ info="Specifies output format.",
817
+ )
818
+
819
+ with gr.Row():
820
+ batch_size_number = gr.Number(
821
+ precision=1, label="Batch size", value=1, info="Number of samples to process at the same time."
822
+ )
823
+ threads_number = gr.Number(precision=1, label="Threads", value=4, info="Number of CPU threads.")
824
+
825
+ locale_radio = locale()
826
+
827
+ start_batch_analysis_btn = gr.Button("Analyze")
828
+
829
+ result_grid = gr.Matrix(headers=["File", "Execution"], elem_classes="mh-200")
830
+
831
+ inputs = [
832
+ output_directory_predict_state,
833
+ confidence_slider,
834
+ sensitivity_slider,
835
+ overlap_slider,
836
+ species_list_radio,
837
+ species_file_input,
838
+ lat_number,
839
+ lon_number,
840
+ week_number,
841
+ yearlong_checkbox,
842
+ sf_thresh_number,
843
+ selected_classifier_state,
844
+ output_type_radio,
845
+ locale_radio,
846
+ batch_size_number,
847
+ threads_number,
848
+ input_directory_state,
849
+ ]
850
+
851
+ start_batch_analysis_btn.click(runBatchAnalysis, inputs=inputs, outputs=result_grid)
852
+
853
+ def build_train_tab():
854
+ with gr.Tab("Train"):
855
+ input_directory_state = gr.State()
856
+ output_directory_state = gr.State()
857
+
858
+ with gr.Row():
859
+ with gr.Column():
860
+ select_directory_btn = gr.Button("Training data")
861
+ directory_input = gr.List(headers=["Classes"], interactive=False, elem_classes="mh-200")
862
+ select_directory_btn.click(
863
+ select_subdirectories, outputs=[input_directory_state, directory_input], show_progress=False
864
+ )
865
+
866
+ with gr.Column():
867
+ select_directory_btn = gr.Button("Classifier output")
868
+
869
+ with gr.Column():
870
+ classifier_name = gr.Textbox(
871
+ "CustomClassifier",
872
+ visible=False,
873
+ info="The name of the new classifier.",
874
+ )
875
+ output_format = gr.Radio(
876
+ ["tflite", "raven", "both"],
877
+ value="tflite",
878
+ label="Model output format",
879
+ info="Format for the trained classifier.",
880
+ visible=False,
881
+ )
882
+
883
+ def select_directory_and_update_tb():
884
+ dir_name = _WINDOW.create_file_dialog(webview.FOLDER_DIALOG)
885
+
886
+ if dir_name:
887
+ return (
888
+ dir_name[0],
889
+ gr.Textbox.update(label=dir_name[0] + "\\", visible=True),
890
+ gr.Radio.update(visible=True, interactive=True),
891
+ )
892
+
893
+ return None, None
894
+
895
+ select_directory_btn.click(
896
+ select_directory_and_update_tb,
897
+ outputs=[output_directory_state, classifier_name, output_format],
898
+ show_progress=False,
899
+ )
900
+
901
+ with gr.Row():
902
+ epoch_number = gr.Number(100, label="Epochs", info="Number of training epochs.")
903
+ batch_size_number = gr.Number(32, label="Batch size", info="Batch size.")
904
+ learning_rate_number = gr.Number(0.01, label="Learning rate", info="Learning rate.")
905
+
906
+ with gr.Row():
907
+ crop_mode = gr.Radio(
908
+ ["center", "first", "segments"],
909
+ value="center",
910
+ label="Crop mode",
911
+ info="Crop mode for training data.",
912
+ )
913
+ crop_overlap = gr.Number(0.0, label="Crop overlap", info="Overlap of training data segments", visible=False)
914
+
915
+ def on_crop_select(new_crop_mode):
916
+ return gr.Number.update(visible=new_crop_mode == "segments", interactive=new_crop_mode == "segments")
917
+
918
+ crop_mode.change(on_crop_select, inputs=crop_mode, outputs=crop_overlap)
919
+
920
+ with gr.Row():
921
+ upsampling_mode = gr.Radio(
922
+ ["repeat", "mean", "smote"],
923
+ value="repeat",
924
+ label="Upsampling mode",
925
+ info="Balance data through upsampling.",
926
+ )
927
+ upsampling_ratio = gr.Slider(
928
+ 0.0, 1.0, 0.0, step=0.01, label="Upsampling ratio", info="Balance train data and upsample minority classes."
929
+ )
930
+
931
+ with gr.Row():
932
+ hidden_units_number = gr.Number(
933
+ 0, label="Hidden units", info="Number of hidden units. If set to >0, a two-layer classifier is used."
934
+ )
935
+ use_mixup = gr.Checkbox(False, label="Use mixup", info="Whether to use mixup for training.", show_label=True)
936
+
937
+ train_history_plot = gr.Plot()
938
+
939
+ start_training_button = gr.Button("Start training")
940
+
941
+ start_training_button.click(
942
+ start_training,
943
+ inputs=[
944
+ input_directory_state,
945
+ crop_mode,
946
+ crop_overlap,
947
+ output_directory_state,
948
+ classifier_name,
949
+ epoch_number,
950
+ batch_size_number,
951
+ learning_rate_number,
952
+ hidden_units_number,
953
+ use_mixup,
954
+ upsampling_ratio,
955
+ upsampling_mode,
956
+ output_format,
957
+ ],
958
+ outputs=[train_history_plot],
959
+ )
960
+
961
+ def build_segments_tab():
962
+ with gr.Tab("Segments"):
963
+ audio_directory_state = gr.State()
964
+ result_directory_state = gr.State()
965
+ output_directory_state = gr.State()
966
+
967
+ def select_directory_to_state_and_tb():
968
+ return (select_directory(collect_files=False),) * 2
969
+
970
+ with gr.Row():
971
+ select_audio_directory_btn = gr.Button("Select audio directory (recursive)")
972
+ selected_audio_directory_tb = gr.Textbox(show_label=False, interactive=False)
973
+ select_audio_directory_btn.click(
974
+ select_directory_to_state_and_tb,
975
+ outputs=[selected_audio_directory_tb, audio_directory_state],
976
+ show_progress=False,
977
+ )
978
+
979
+ with gr.Row():
980
+ select_result_directory_btn = gr.Button("Select result directory")
981
+ selected_result_directory_tb = gr.Textbox(
982
+ show_label=False, interactive=False, placeholder="Same as audio directory if not selected"
983
+ )
984
+ select_result_directory_btn.click(
985
+ select_directory_to_state_and_tb,
986
+ outputs=[result_directory_state, selected_result_directory_tb],
987
+ show_progress=False,
988
+ )
989
+
990
+ with gr.Row():
991
+ select_output_directory_btn = gr.Button("Select output directory")
992
+ selected_output_directory_tb = gr.Textbox(
993
+ show_label=False, interactive=False, placeholder="Same as audio directory if not selected"
994
+ )
995
+ select_output_directory_btn.click(
996
+ select_directory_to_state_and_tb,
997
+ outputs=[selected_output_directory_tb, output_directory_state],
998
+ show_progress=False,
999
+ )
1000
+
1001
+ min_conf_slider = gr.Slider(
1002
+ minimum=0.1, maximum=0.99, step=0.01, label="Minimum confidence", info="Minimum confidence threshold."
1003
+ )
1004
+ num_seq_number = gr.Number(
1005
+ 100, label="Max number of segments", info="Maximum number of randomly extracted segments per species."
1006
+ )
1007
+ seq_length_number = gr.Number(3.0, label="Sequence length", info="Length of extracted segments in seconds.")
1008
+ threads_number = gr.Number(4, label="Threads", info="Number of CPU threads.")
1009
+
1010
+ extract_segments_btn = gr.Button("Extract segments")
1011
+
1012
+ result_grid = gr.Matrix(headers=["File", "Execution"], elem_classes="mh-200")
1013
+
1014
+ extract_segments_btn.click(
1015
+ extract_segments,
1016
+ inputs=[
1017
+ audio_directory_state,
1018
+ result_directory_state,
1019
+ output_directory_state,
1020
+ min_conf_slider,
1021
+ num_seq_number,
1022
+ seq_length_number,
1023
+ threads_number,
1024
+ ],
1025
+ outputs=result_grid,
1026
+ )
1027
+
1028
+ with gr.Blocks(
1029
+ css=r".d-block .wrap {display: block !important;} .mh-200 {max-height: 300px; overflow-y: auto !important;} footer {display: none !important;} #single_file_audio, #single_file_audio * {max-height: 81.6px; min-height: 0;}",
1030
+ theme=gr.themes.Default(),
1031
+ analytics_enabled=False,
1032
+ ) as demo:
1033
+ build_single_analysis_tab()
1034
+ build_multi_analysis_tab()
1035
+ build_train_tab()
1036
+ build_segments_tab()
1037
+
1038
+ url = demo.queue(api_open=False).launch(prevent_thread_lock=True, quiet=True)[1]
1039
+ _WINDOW = webview.create_window("BirdNET-Analyzer", url.rstrip("/") + "?__theme=light", min_size=(1024, 768))
1040
+
1041
+ webview.start(private_mode=False)
train.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module for training a custom classifier.
2
+
3
+ Can be used to train a custom classifier with new training data.
4
+ """
5
+ import argparse
6
+ import os
7
+
8
+ import numpy as np
9
+
10
+ import audio
11
+ import config as cfg
12
+ import model
13
+ import utils
14
+
15
+
16
+ def _loadTrainingData(cache_mode="none", cache_file=""):
17
+ """Loads the data for training.
18
+
19
+ Reads all subdirectories of "config.TRAIN_DATA_PATH" and uses their names as new labels.
20
+
21
+ These directories should contain all the training data for each label.
22
+
23
+ If a cache file is provided, the training data is loaded from there.
24
+
25
+ Args:
26
+ cache_mode: Cache mode. Can be 'none', 'load' or 'save'. Defaults to 'none'.
27
+ cache_file: Path to cache file.
28
+
29
+ Returns:
30
+ A tuple of (x_train, y_train, labels).
31
+ """
32
+ # Load from cache
33
+ if cache_mode == "load":
34
+ if os.path.isfile(cache_file):
35
+ print(f"\t...loading from cache: {cache_file}", flush=True)
36
+ x_train, y_train, labels = utils.loadFromCache(cache_file)
37
+ return x_train, y_train, labels
38
+ else:
39
+ print(f"\t...cache file not found: {cache_file}", flush=True)
40
+
41
+ # Get list of subfolders as labels
42
+ labels = list(sorted(utils.list_subdirectories(cfg.TRAIN_DATA_PATH)))
43
+
44
+ # Get valid labels
45
+ valid_labels = [l for l in labels if not l.lower() in cfg.NON_EVENT_CLASSES]
46
+
47
+ # Load training data
48
+ x_train = []
49
+ y_train = []
50
+
51
+ for label in labels:
52
+
53
+ # Current label
54
+ print(f"\t- {label}", flush=True)
55
+
56
+ # Get label vector
57
+ label_vector = np.zeros((len(valid_labels),), dtype="float32")
58
+ if not label.lower() in cfg.NON_EVENT_CLASSES and not label.startswith("-"):
59
+ label_vector[valid_labels.index(label)] = 1
60
+
61
+ # Get list of files
62
+ # Filter files that start with '.' because macOS seems to them for temp files.
63
+ files = filter(
64
+ os.path.isfile,
65
+ (
66
+ os.path.join(cfg.TRAIN_DATA_PATH, label, f)
67
+ for f in sorted(os.listdir(os.path.join(cfg.TRAIN_DATA_PATH, label)))
68
+ if not f.startswith(".") and f.rsplit(".", 1)[-1].lower() in cfg.ALLOWED_FILETYPES
69
+ ),
70
+ )
71
+
72
+ # Load files
73
+ for f in files:
74
+ # Load audio
75
+ sig, rate = audio.openAudioFile(f, duration=cfg.SIG_LENGTH if cfg.SAMPLE_CROP_MODE == "first" else None)
76
+
77
+ # Crop training samples
78
+ if cfg.SAMPLE_CROP_MODE == "center":
79
+ sig_splits = [audio.cropCenter(sig, rate, cfg.SIG_LENGTH)]
80
+ elif cfg.SAMPLE_CROP_MODE == "first":
81
+ sig_splits = [audio.splitSignal(sig, rate, cfg.SIG_LENGTH, cfg.SIG_OVERLAP, cfg.SIG_MINLEN)[0]]
82
+ else:
83
+ sig_splits = audio.splitSignal(sig, rate, cfg.SIG_LENGTH, cfg.SIG_OVERLAP, cfg.SIG_MINLEN)
84
+
85
+ # Get feature embeddings
86
+ for sig in sig_splits:
87
+ embeddings = model.embeddings([sig])[0]
88
+
89
+ # Add to training data
90
+ x_train.append(embeddings)
91
+ y_train.append(label_vector)
92
+
93
+ # Convert to numpy arrays
94
+ x_train = np.array(x_train, dtype="float32")
95
+ y_train = np.array(y_train, dtype="float32")
96
+
97
+ # Remove non-event classes from labels
98
+ labels = [l for l in labels if not l.lower() in cfg.NON_EVENT_CLASSES]
99
+
100
+ # Save to cache?
101
+ if cache_mode == "save":
102
+ print(f"\t...saving training data to cache: {cache_file}", flush=True)
103
+ try:
104
+ utils.saveToCache(cache_file, x_train, y_train, labels)
105
+ except Exception as e:
106
+ print(f"\t...error saving cache: {e}", flush=True)
107
+
108
+ return x_train, y_train, labels
109
+
110
+
111
+ def trainModel(on_epoch_end=None):
112
+ """Trains a custom classifier.
113
+
114
+ Args:
115
+ on_epoch_end: A callback function that takes two arguments `epoch`, `logs`.
116
+
117
+ Returns:
118
+ A keras `History` object, whose `history` property contains all the metrics.
119
+ """
120
+ # Load training data
121
+ print("Loading training data...", flush=True)
122
+ x_train, y_train, labels = _loadTrainingData(cfg.TRAIN_CACHE_MODE, cfg.TRAIN_CACHE_FILE)
123
+ print(f"...Done. Loaded {x_train.shape[0]} training samples and {y_train.shape[1]} labels.", flush=True)
124
+
125
+ # Build model
126
+ print("Building model...", flush=True)
127
+ classifier = model.buildLinearClassifier(y_train.shape[1], x_train.shape[1], cfg.TRAIN_HIDDEN_UNITS, cfg.TRAIN_DROPOUT)
128
+ print("...Done.", flush=True)
129
+
130
+ # Train model
131
+ print("Training model...", flush=True)
132
+ classifier, history = model.trainLinearClassifier(
133
+ classifier,
134
+ x_train,
135
+ y_train,
136
+ epochs=cfg.TRAIN_EPOCHS,
137
+ batch_size=cfg.TRAIN_BATCH_SIZE,
138
+ learning_rate=cfg.TRAIN_LEARNING_RATE,
139
+ val_split=cfg.TRAIN_VAL_SPLIT,
140
+ upsampling_ratio=cfg.UPSAMPLING_RATIO,
141
+ upsampling_mode=cfg.UPSAMPLING_MODE,
142
+ train_with_mixup=cfg.TRAIN_WITH_MIXUP,
143
+ train_with_label_smoothing=cfg.TRAIN_WITH_LABEL_SMOOTHING,
144
+ on_epoch_end=on_epoch_end,
145
+ )
146
+
147
+ # Best validation AUPRC (at minimum validation loss)
148
+ best_val_auprc = history.history["val_AUPRC"][np.argmin(history.history["val_loss"])]
149
+
150
+ if cfg.TRAINED_MODEL_OUTPUT_FORMAT == "both":
151
+ model.save_raven_model(classifier, cfg.CUSTOM_CLASSIFIER, labels)
152
+ model.saveLinearClassifier(classifier, cfg.CUSTOM_CLASSIFIER, labels)
153
+ elif cfg.TRAINED_MODEL_OUTPUT_FORMAT == "tflite":
154
+ model.saveLinearClassifier(classifier, cfg.CUSTOM_CLASSIFIER, labels)
155
+ elif cfg.TRAINED_MODEL_OUTPUT_FORMAT == "raven":
156
+ model.save_raven_model(classifier, cfg.CUSTOM_CLASSIFIER, labels)
157
+ else:
158
+ raise ValueError(f"Unknown model output format: {cfg.TRAINED_MODEL_OUTPUT_FORMAT}")
159
+
160
+ print(f"...Done. Best AUPRC: {best_val_auprc}", flush=True)
161
+
162
+ return history
163
+
164
+
165
+ if __name__ == "__main__":
166
+ # Parse arguments
167
+ parser = argparse.ArgumentParser(description="Train a custom classifier with BirdNET")
168
+ parser.add_argument("--i", default="train_data/", help="Path to training data folder. Subfolder names are used as labels.")
169
+ parser.add_argument("--crop_mode", default="center", help="Crop mode for training data. Can be 'center', 'first' or 'segments'. Defaults to 'center'.")
170
+ parser.add_argument("--crop_overlap", type=float, default=0.0, help="Overlap of training data segments in seconds if crop_mode is 'segments'. Defaults to 0.")
171
+ parser.add_argument(
172
+ "--o", default="checkpoints/custom/Custom_Classifier", help="Path to trained classifier model output."
173
+ )
174
+ parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs. Defaults to 100.")
175
+ parser.add_argument("--batch_size", type=int, default=32, help="Batch size. Defaults to 32.")
176
+ parser.add_argument("--val_split", type=float, default=0.2, help="Validation split ratio. Defaults to 0.2.")
177
+ parser.add_argument("--learning_rate", type=float, default=0.01, help="Learning rate. Defaults to 0.01.")
178
+ parser.add_argument(
179
+ "--hidden_units",
180
+ type=int,
181
+ default=0,
182
+ help="Number of hidden units. Defaults to 0. If set to >0, a two-layer classifier is used.",
183
+ )
184
+ parser.add_argument("--dropout", type=float, default=0.0, help="Dropout rate. Defaults to 0.")
185
+ parser.add_argument("--mixup", action=argparse.BooleanOptionalAction, help="Whether to use mixup for training.")
186
+ parser.add_argument("--upsampling_ratio", type=float, default=0.0, help="Balance train data and upsample minority classes. Values between 0 and 1. Defaults to 0.")
187
+ parser.add_argument("--upsampling_mode", default="repeat", help="Upsampling mode. Can be 'repeat', 'mean' or 'smote'. Defaults to 'repeat'.")
188
+ parser.add_argument("--model_format", default="tflite", help="Model output format. Can be 'tflite', 'raven' or 'both'. Defaults to 'tflite'.")
189
+ parser.add_argument("--cache_mode", default="none", help="Cache mode. Can be 'none', 'load' or 'save'. Defaults to 'none'.")
190
+ parser.add_argument("--cache_file", default="train_cache.npz", help="Path to cache file. Defaults to 'train_cache.npz'.")
191
+
192
+ args = parser.parse_args()
193
+
194
+ # Config
195
+ cfg.TRAIN_DATA_PATH = args.i
196
+ cfg.SAMPLE_CROP_MODE = args.crop_mode
197
+ cfg.SIG_OVERLAP = args.crop_overlap
198
+ cfg.CUSTOM_CLASSIFIER = args.o
199
+ cfg.TRAIN_EPOCHS = args.epochs
200
+ cfg.TRAIN_BATCH_SIZE = args.batch_size
201
+ cfg.TRAIN_VAL_SPLIT = args.val_split
202
+ cfg.TRAIN_LEARNING_RATE = args.learning_rate
203
+ cfg.TRAIN_HIDDEN_UNITS = args.hidden_units
204
+ cfg.TRAIN_DROPOUT = min(max(0, args.dropout), 0.9)
205
+ cfg.TRAIN_WITH_MIXUP = args.mixup
206
+ cfg.UPSAMPLING_RATIO = min(max(0, args.upsampling_ratio), 1)
207
+ cfg.UPSAMPLING_MODE = args.upsampling_mode
208
+ cfg.TRAINED_MODEL_OUTPUT_FORMAT = args.model_format
209
+ cfg.TRAIN_CACHE_MODE = args.cache_mode.lower()
210
+ cfg.TRAIN_CACHE_FILE = args.cache_file
211
+ cfg.TFLITE_THREADS = 4 # Set this to 4 to speed things up a bit
212
+
213
+ # Train model
214
+ trainModel()