import json import numpy as np import torch from torch import nn from torch.utils.tensorboard import SummaryWriter from tools import create_key class TimbreEncoder(nn.Module): def __init__(self, input_dim, feature_dim, hidden_dim, num_instrument_classes, num_instrument_family_classes, num_velocity_classes, num_qualities, num_layers=1): super(TimbreEncoder, self).__init__() # Input layer self.input_layer = nn.Linear(input_dim, feature_dim) # LSTM Layer self.lstm = nn.LSTM(feature_dim, hidden_dim, num_layers=num_layers, batch_first=True) # Fully Connected Layers for classification self.instrument_classifier_layer = nn.Linear(hidden_dim, num_instrument_classes) self.instrument_family_classifier_layer = nn.Linear(hidden_dim, num_instrument_family_classes) self.velocity_classifier_layer = nn.Linear(hidden_dim, num_velocity_classes) self.qualities_classifier_layer = nn.Linear(hidden_dim, num_qualities) # Softmax for converting output to probabilities self.softmax = nn.LogSoftmax(dim=1) def forward(self, x): # # Merge first two dimensions batch_size, _, _, seq_len = x.shape x = x.view(batch_size, -1, seq_len) # [batch_size, input_dim, seq_len] # Forward propagate LSTM x = x.permute(0, 2, 1) x = self.input_layer(x) feature, _ = self.lstm(x) feature = feature[:, -1, :] # Apply classification layers instrument_logits = self.instrument_classifier_layer(feature) instrument_family_logits = self.instrument_family_classifier_layer(feature) velocity_logits = self.velocity_classifier_layer(feature) qualities = self.qualities_classifier_layer(feature) # Apply Softmax instrument_logits = self.softmax(instrument_logits) instrument_family_logits= self.softmax(instrument_family_logits) velocity_logits = self.softmax(velocity_logits) qualities = torch.sigmoid(qualities) return feature, instrument_logits, instrument_family_logits, velocity_logits, qualities def get_multiclass_acc(outputs, ground_truth): _, predicted = torch.max(outputs.data, 1) total = ground_truth.size(0) correct = (predicted == ground_truth).sum().item() accuracy = 100 * correct / total return accuracy def get_binary_accuracy(y_pred, y_true): predictions = (y_pred > 0.5).int() correct_predictions = (predictions == y_true).float() accuracy = correct_predictions.mean() return accuracy.item() * 100.0 def get_timbre_encoder(model_Config, load_pretrain=False, model_name=None, device="cpu"): timbreEncoder = TimbreEncoder(**model_Config) print(f"Model intialized, size: {sum(p.numel() for p in timbreEncoder.parameters() if p.requires_grad)}") timbreEncoder.to(device) if load_pretrain: print(f"Loading weights from models/{model_name}_timbre_encoder.pth") checkpoint = torch.load(f'models/{model_name}_timbre_encoder.pth', map_location=device) timbreEncoder.load_state_dict(checkpoint['model_state_dict']) timbreEncoder.eval() return timbreEncoder def evaluate_timbre_encoder(device, model, iterator, nll_Loss, bce_Loss, n_sample=100): model.to(device) model.eval() eva_loss = [] for i in range(n_sample): representation, attributes = next(iter(iterator)) instrument = torch.tensor([s["instrument"] for s in attributes], dtype=torch.long).to(device) instrument_family = torch.tensor([s["instrument_family"] for s in attributes], dtype=torch.long).to(device) velocity = torch.tensor([s["velocity"] for s in attributes], dtype=torch.long).to(device) qualities = torch.tensor([[int(char) for char in create_key(attribute)[-10:]] for attribute in attributes], dtype=torch.float32).to(device) _, instrument_logits, instrument_family_logits, velocity_logits, qualities_pred = model(representation.to(device)) # compute loss instrument_loss = nll_Loss(instrument_logits, instrument) instrument_family_loss = nll_Loss(instrument_family_logits, instrument_family) velocity_loss = nll_Loss(velocity_logits, velocity) qualities_loss = bce_Loss(qualities_pred, qualities) loss = instrument_loss + instrument_family_loss + velocity_loss + qualities_loss eva_loss.append(loss.item()) eva_loss = np.mean(eva_loss) return eva_loss def train_timbre_encoder(device, model_name, timbre_encoder_Config, BATCH_SIZE, lr, max_iter, training_iterator, load_pretrain): def save_model_hyperparameter(model_name, timbre_encoder_Config, BATCH_SIZE, lr, model_size, current_iter, current_loss): model_hyperparameter = timbre_encoder_Config model_hyperparameter["BATCH_SIZE"] = BATCH_SIZE model_hyperparameter["lr"] = lr model_hyperparameter["model_size"] = model_size model_hyperparameter["current_iter"] = current_iter model_hyperparameter["current_loss"] = current_loss with open(f"models/hyperparameters/{model_name}_timbre_encoder.json", "w") as json_file: json.dump(model_hyperparameter, json_file, ensure_ascii=False, indent=4) model = TimbreEncoder(**timbre_encoder_Config) model_size = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"Model size: {model_size}") model.to(device) nll_Loss = torch.nn.NLLLoss() bce_Loss = torch.nn.BCELoss() optimizer = torch.optim.Adam(model.parameters(), lr=lr, amsgrad=False) if load_pretrain: print(f"Loading weights from models/{model_name}_timbre_encoder.pt") checkpoint = torch.load(f'models/{model_name}_timbre_encoder.pth') model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) else: print("Model initialized.") if max_iter == 0: print("Return model directly.") return model, model train_loss, training_instrument_acc, training_instrument_family_acc, training_velocity_acc, training_qualities_acc = [], [], [], [], [] writer = SummaryWriter(f'runs/{model_name}_timbre_encoder') current_best_model = model previous_lowest_loss = 100.0 print(f"initial__loss: {previous_lowest_loss}") for i in range(max_iter): model.train() representation, attributes = next(iter(training_iterator)) instrument = torch.tensor([s["instrument"] for s in attributes], dtype=torch.long).to(device) instrument_family = torch.tensor([s["instrument_family"] for s in attributes], dtype=torch.long).to(device) velocity = torch.tensor([s["velocity"] for s in attributes], dtype=torch.long).to(device) qualities = torch.tensor([[int(char) for char in create_key(attribute)[-10:]] for attribute in attributes], dtype=torch.float32).to(device) optimizer.zero_grad() _, instrument_logits, instrument_family_logits, velocity_logits, qualities_pred = model(representation.to(device)) # compute loss instrument_loss = nll_Loss(instrument_logits, instrument) instrument_family_loss = nll_Loss(instrument_family_logits, instrument_family) velocity_loss = nll_Loss(velocity_logits, velocity) qualities_loss = bce_Loss(qualities_pred, qualities) loss = instrument_loss + instrument_family_loss + velocity_loss + qualities_loss loss.backward() optimizer.step() instrument_acc = get_multiclass_acc(instrument_logits, instrument) instrument_family_acc = get_multiclass_acc(instrument_family_logits, instrument_family) velocity_acc = get_multiclass_acc(velocity_logits, velocity) qualities_acc = get_binary_accuracy(qualities_pred, qualities) train_loss.append(loss.item()) training_instrument_acc.append(instrument_acc) training_instrument_family_acc.append(instrument_family_acc) training_velocity_acc.append(velocity_acc) training_qualities_acc.append(qualities_acc) step = int(optimizer.state_dict()['state'][list(optimizer.state_dict()['state'].keys())[0]]['step'].numpy()) if (i + 1) % 100 == 0: print('%d step' % (step)) save_steps = 500 if (i + 1) % save_steps == 0: current_loss = np.mean(train_loss[-save_steps:]) current_instrument_acc = np.mean(training_instrument_acc[-save_steps:]) current_instrument_family_acc = np.mean(training_instrument_family_acc[-save_steps:]) current_velocity_acc = np.mean(training_velocity_acc[-save_steps:]) current_qualities_acc = np.mean(training_qualities_acc[-save_steps:]) print('train_loss: %.5f' % current_loss) print('current_instrument_acc: %.5f' % current_instrument_acc) print('current_instrument_family_acc: %.5f' % current_instrument_family_acc) print('current_velocity_acc: %.5f' % current_velocity_acc) print('current_qualities_acc: %.5f' % current_qualities_acc) writer.add_scalar(f"train_loss", current_loss, step) writer.add_scalar(f"current_instrument_acc", current_instrument_acc, step) writer.add_scalar(f"current_instrument_family_acc", current_instrument_family_acc, step) writer.add_scalar(f"current_velocity_acc", current_velocity_acc, step) writer.add_scalar(f"current_qualities_acc", current_qualities_acc, step) if current_loss < previous_lowest_loss: previous_lowest_loss = current_loss current_best_model = model torch.save({ 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), }, f'models/{model_name}_timbre_encoder.pth') save_model_hyperparameter(model_name, timbre_encoder_Config, BATCH_SIZE, lr, model_size, step, current_loss) return model, current_best_model