Spaces:
Running
Running
"""Module for training a custom classifier. | |
Can be used to train a custom classifier with new training data. | |
""" | |
import argparse | |
import os | |
import numpy as np | |
import audio | |
import config as cfg | |
import model | |
import utils | |
def _loadTrainingData(cache_mode="none", cache_file=""): | |
"""Loads the data for training. | |
Reads all subdirectories of "config.TRAIN_DATA_PATH" and uses their names as new labels. | |
These directories should contain all the training data for each label. | |
If a cache file is provided, the training data is loaded from there. | |
Args: | |
cache_mode: Cache mode. Can be 'none', 'load' or 'save'. Defaults to 'none'. | |
cache_file: Path to cache file. | |
Returns: | |
A tuple of (x_train, y_train, labels). | |
""" | |
# Load from cache | |
if cache_mode == "load": | |
if os.path.isfile(cache_file): | |
print(f"\t...loading from cache: {cache_file}", flush=True) | |
x_train, y_train, labels = utils.loadFromCache(cache_file) | |
return x_train, y_train, labels | |
else: | |
print(f"\t...cache file not found: {cache_file}", flush=True) | |
# Get list of subfolders as labels | |
labels = list(sorted(utils.list_subdirectories(cfg.TRAIN_DATA_PATH))) | |
# Get valid labels | |
valid_labels = [l for l in labels if not l.lower() in cfg.NON_EVENT_CLASSES] | |
# Load training data | |
x_train = [] | |
y_train = [] | |
for label in labels: | |
# Current label | |
print(f"\t- {label}", flush=True) | |
# Get label vector | |
label_vector = np.zeros((len(valid_labels),), dtype="float32") | |
if not label.lower() in cfg.NON_EVENT_CLASSES and not label.startswith("-"): | |
label_vector[valid_labels.index(label)] = 1 | |
# Get list of files | |
# Filter files that start with '.' because macOS seems to them for temp files. | |
files = filter( | |
os.path.isfile, | |
( | |
os.path.join(cfg.TRAIN_DATA_PATH, label, f) | |
for f in sorted(os.listdir(os.path.join(cfg.TRAIN_DATA_PATH, label))) | |
if not f.startswith(".") and f.rsplit(".", 1)[-1].lower() in cfg.ALLOWED_FILETYPES | |
), | |
) | |
# Load files | |
for f in files: | |
# Load audio | |
sig, rate = audio.openAudioFile(f, duration=cfg.SIG_LENGTH if cfg.SAMPLE_CROP_MODE == "first" else None) | |
# Crop training samples | |
if cfg.SAMPLE_CROP_MODE == "center": | |
sig_splits = [audio.cropCenter(sig, rate, cfg.SIG_LENGTH)] | |
elif cfg.SAMPLE_CROP_MODE == "first": | |
sig_splits = [audio.splitSignal(sig, rate, cfg.SIG_LENGTH, cfg.SIG_OVERLAP, cfg.SIG_MINLEN)[0]] | |
else: | |
sig_splits = audio.splitSignal(sig, rate, cfg.SIG_LENGTH, cfg.SIG_OVERLAP, cfg.SIG_MINLEN) | |
# Get feature embeddings | |
for sig in sig_splits: | |
embeddings = model.embeddings([sig])[0] | |
# Add to training data | |
x_train.append(embeddings) | |
y_train.append(label_vector) | |
# Convert to numpy arrays | |
x_train = np.array(x_train, dtype="float32") | |
y_train = np.array(y_train, dtype="float32") | |
# Remove non-event classes from labels | |
labels = [l for l in labels if not l.lower() in cfg.NON_EVENT_CLASSES] | |
# Save to cache? | |
if cache_mode == "save": | |
print(f"\t...saving training data to cache: {cache_file}", flush=True) | |
try: | |
utils.saveToCache(cache_file, x_train, y_train, labels) | |
except Exception as e: | |
print(f"\t...error saving cache: {e}", flush=True) | |
return x_train, y_train, labels | |
def trainModel(on_epoch_end=None): | |
"""Trains a custom classifier. | |
Args: | |
on_epoch_end: A callback function that takes two arguments `epoch`, `logs`. | |
Returns: | |
A keras `History` object, whose `history` property contains all the metrics. | |
""" | |
# Load training data | |
print("Loading training data...", flush=True) | |
x_train, y_train, labels = _loadTrainingData(cfg.TRAIN_CACHE_MODE, cfg.TRAIN_CACHE_FILE) | |
print(f"...Done. Loaded {x_train.shape[0]} training samples and {y_train.shape[1]} labels.", flush=True) | |
# Build model | |
print("Building model...", flush=True) | |
classifier = model.buildLinearClassifier(y_train.shape[1], x_train.shape[1], cfg.TRAIN_HIDDEN_UNITS, cfg.TRAIN_DROPOUT) | |
print("...Done.", flush=True) | |
# Train model | |
print("Training model...", flush=True) | |
classifier, history = model.trainLinearClassifier( | |
classifier, | |
x_train, | |
y_train, | |
epochs=cfg.TRAIN_EPOCHS, | |
batch_size=cfg.TRAIN_BATCH_SIZE, | |
learning_rate=cfg.TRAIN_LEARNING_RATE, | |
val_split=cfg.TRAIN_VAL_SPLIT, | |
upsampling_ratio=cfg.UPSAMPLING_RATIO, | |
upsampling_mode=cfg.UPSAMPLING_MODE, | |
train_with_mixup=cfg.TRAIN_WITH_MIXUP, | |
train_with_label_smoothing=cfg.TRAIN_WITH_LABEL_SMOOTHING, | |
on_epoch_end=on_epoch_end, | |
) | |
# Best validation AUPRC (at minimum validation loss) | |
best_val_auprc = history.history["val_AUPRC"][np.argmin(history.history["val_loss"])] | |
if cfg.TRAINED_MODEL_OUTPUT_FORMAT == "both": | |
model.save_raven_model(classifier, cfg.CUSTOM_CLASSIFIER, labels) | |
model.saveLinearClassifier(classifier, cfg.CUSTOM_CLASSIFIER, labels) | |
elif cfg.TRAINED_MODEL_OUTPUT_FORMAT == "tflite": | |
model.saveLinearClassifier(classifier, cfg.CUSTOM_CLASSIFIER, labels) | |
elif cfg.TRAINED_MODEL_OUTPUT_FORMAT == "raven": | |
model.save_raven_model(classifier, cfg.CUSTOM_CLASSIFIER, labels) | |
else: | |
raise ValueError(f"Unknown model output format: {cfg.TRAINED_MODEL_OUTPUT_FORMAT}") | |
print(f"...Done. Best AUPRC: {best_val_auprc}", flush=True) | |
return history | |
if __name__ == "__main__": | |
# Parse arguments | |
parser = argparse.ArgumentParser(description="Train a custom classifier with BirdNET") | |
parser.add_argument("--i", default="train_data/", help="Path to training data folder. Subfolder names are used as labels.") | |
parser.add_argument("--crop_mode", default="center", help="Crop mode for training data. Can be 'center', 'first' or 'segments'. Defaults to 'center'.") | |
parser.add_argument("--crop_overlap", type=float, default=0.0, help="Overlap of training data segments in seconds if crop_mode is 'segments'. Defaults to 0.") | |
parser.add_argument( | |
"--o", default="checkpoints/custom/Custom_Classifier", help="Path to trained classifier model output." | |
) | |
parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs. Defaults to 100.") | |
parser.add_argument("--batch_size", type=int, default=32, help="Batch size. Defaults to 32.") | |
parser.add_argument("--val_split", type=float, default=0.2, help="Validation split ratio. Defaults to 0.2.") | |
parser.add_argument("--learning_rate", type=float, default=0.01, help="Learning rate. Defaults to 0.01.") | |
parser.add_argument( | |
"--hidden_units", | |
type=int, | |
default=0, | |
help="Number of hidden units. Defaults to 0. If set to >0, a two-layer classifier is used.", | |
) | |
parser.add_argument("--dropout", type=float, default=0.0, help="Dropout rate. Defaults to 0.") | |
parser.add_argument("--mixup", action=argparse.BooleanOptionalAction, help="Whether to use mixup for training.") | |
parser.add_argument("--upsampling_ratio", type=float, default=0.0, help="Balance train data and upsample minority classes. Values between 0 and 1. Defaults to 0.") | |
parser.add_argument("--upsampling_mode", default="repeat", help="Upsampling mode. Can be 'repeat', 'mean' or 'smote'. Defaults to 'repeat'.") | |
parser.add_argument("--model_format", default="tflite", help="Model output format. Can be 'tflite', 'raven' or 'both'. Defaults to 'tflite'.") | |
parser.add_argument("--cache_mode", default="none", help="Cache mode. Can be 'none', 'load' or 'save'. Defaults to 'none'.") | |
parser.add_argument("--cache_file", default="train_cache.npz", help="Path to cache file. Defaults to 'train_cache.npz'.") | |
args = parser.parse_args() | |
# Config | |
cfg.TRAIN_DATA_PATH = args.i | |
cfg.SAMPLE_CROP_MODE = args.crop_mode | |
cfg.SIG_OVERLAP = args.crop_overlap | |
cfg.CUSTOM_CLASSIFIER = args.o | |
cfg.TRAIN_EPOCHS = args.epochs | |
cfg.TRAIN_BATCH_SIZE = args.batch_size | |
cfg.TRAIN_VAL_SPLIT = args.val_split | |
cfg.TRAIN_LEARNING_RATE = args.learning_rate | |
cfg.TRAIN_HIDDEN_UNITS = args.hidden_units | |
cfg.TRAIN_DROPOUT = min(max(0, args.dropout), 0.9) | |
cfg.TRAIN_WITH_MIXUP = args.mixup | |
cfg.UPSAMPLING_RATIO = min(max(0, args.upsampling_ratio), 1) | |
cfg.UPSAMPLING_MODE = args.upsampling_mode | |
cfg.TRAINED_MODEL_OUTPUT_FORMAT = args.model_format | |
cfg.TRAIN_CACHE_MODE = args.cache_mode.lower() | |
cfg.TRAIN_CACHE_FILE = args.cache_file | |
cfg.TFLITE_THREADS = 4 # Set this to 4 to speed things up a bit | |
# Train model | |
trainModel() | |