DiffuSynth / model /timbre_encoder_pretrain.py
WeixuanYuan's picture
Upload 49 files
2b389c5 verified
import json
import numpy as np
import torch
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from tools import create_key
class TimbreEncoder(nn.Module):
def __init__(self, input_dim, feature_dim, hidden_dim, num_instrument_classes, num_instrument_family_classes, num_velocity_classes, num_qualities, num_layers=1):
super(TimbreEncoder, self).__init__()
# Input layer
self.input_layer = nn.Linear(input_dim, feature_dim)
# LSTM Layer
self.lstm = nn.LSTM(feature_dim, hidden_dim, num_layers=num_layers, batch_first=True)
# Fully Connected Layers for classification
self.instrument_classifier_layer = nn.Linear(hidden_dim, num_instrument_classes)
self.instrument_family_classifier_layer = nn.Linear(hidden_dim, num_instrument_family_classes)
self.velocity_classifier_layer = nn.Linear(hidden_dim, num_velocity_classes)
self.qualities_classifier_layer = nn.Linear(hidden_dim, num_qualities)
# Softmax for converting output to probabilities
self.softmax = nn.LogSoftmax(dim=1)
def forward(self, x):
# # Merge first two dimensions
batch_size, _, _, seq_len = x.shape
x = x.view(batch_size, -1, seq_len) # [batch_size, input_dim, seq_len]
# Forward propagate LSTM
x = x.permute(0, 2, 1)
x = self.input_layer(x)
feature, _ = self.lstm(x)
feature = feature[:, -1, :]
# Apply classification layers
instrument_logits = self.instrument_classifier_layer(feature)
instrument_family_logits = self.instrument_family_classifier_layer(feature)
velocity_logits = self.velocity_classifier_layer(feature)
qualities = self.qualities_classifier_layer(feature)
# Apply Softmax
instrument_logits = self.softmax(instrument_logits)
instrument_family_logits= self.softmax(instrument_family_logits)
velocity_logits = self.softmax(velocity_logits)
qualities = torch.sigmoid(qualities)
return feature, instrument_logits, instrument_family_logits, velocity_logits, qualities
def get_multiclass_acc(outputs, ground_truth):
_, predicted = torch.max(outputs.data, 1)
total = ground_truth.size(0)
correct = (predicted == ground_truth).sum().item()
accuracy = 100 * correct / total
return accuracy
def get_binary_accuracy(y_pred, y_true):
predictions = (y_pred > 0.5).int()
correct_predictions = (predictions == y_true).float()
accuracy = correct_predictions.mean()
return accuracy.item() * 100.0
def get_timbre_encoder(model_Config, load_pretrain=False, model_name=None, device="cpu"):
timbreEncoder = TimbreEncoder(**model_Config)
print(f"Model intialized, size: {sum(p.numel() for p in timbreEncoder.parameters() if p.requires_grad)}")
timbreEncoder.to(device)
if load_pretrain:
print(f"Loading weights from models/{model_name}_timbre_encoder.pth")
checkpoint = torch.load(f'models/{model_name}_timbre_encoder.pth', map_location=device)
timbreEncoder.load_state_dict(checkpoint['model_state_dict'])
timbreEncoder.eval()
return timbreEncoder
def evaluate_timbre_encoder(device, model, iterator, nll_Loss, bce_Loss, n_sample=100):
model.to(device)
model.eval()
eva_loss = []
for i in range(n_sample):
representation, attributes = next(iter(iterator))
instrument = torch.tensor([s["instrument"] for s in attributes], dtype=torch.long).to(device)
instrument_family = torch.tensor([s["instrument_family"] for s in attributes], dtype=torch.long).to(device)
velocity = torch.tensor([s["velocity"] for s in attributes], dtype=torch.long).to(device)
qualities = torch.tensor([[int(char) for char in create_key(attribute)[-10:]] for attribute in attributes], dtype=torch.float32).to(device)
_, instrument_logits, instrument_family_logits, velocity_logits, qualities_pred = model(representation.to(device))
# compute loss
instrument_loss = nll_Loss(instrument_logits, instrument)
instrument_family_loss = nll_Loss(instrument_family_logits, instrument_family)
velocity_loss = nll_Loss(velocity_logits, velocity)
qualities_loss = bce_Loss(qualities_pred, qualities)
loss = instrument_loss + instrument_family_loss + velocity_loss + qualities_loss
eva_loss.append(loss.item())
eva_loss = np.mean(eva_loss)
return eva_loss
def train_timbre_encoder(device, model_name, timbre_encoder_Config, BATCH_SIZE, lr, max_iter, training_iterator, load_pretrain):
def save_model_hyperparameter(model_name, timbre_encoder_Config, BATCH_SIZE, lr, model_size, current_iter,
current_loss):
model_hyperparameter = timbre_encoder_Config
model_hyperparameter["BATCH_SIZE"] = BATCH_SIZE
model_hyperparameter["lr"] = lr
model_hyperparameter["model_size"] = model_size
model_hyperparameter["current_iter"] = current_iter
model_hyperparameter["current_loss"] = current_loss
with open(f"models/hyperparameters/{model_name}_timbre_encoder.json", "w") as json_file:
json.dump(model_hyperparameter, json_file, ensure_ascii=False, indent=4)
model = TimbreEncoder(**timbre_encoder_Config)
model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model size: {model_size}")
model.to(device)
nll_Loss = torch.nn.NLLLoss()
bce_Loss = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr, amsgrad=False)
if load_pretrain:
print(f"Loading weights from models/{model_name}_timbre_encoder.pt")
checkpoint = torch.load(f'models/{model_name}_timbre_encoder.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
else:
print("Model initialized.")
if max_iter == 0:
print("Return model directly.")
return model, model
train_loss, training_instrument_acc, training_instrument_family_acc, training_velocity_acc, training_qualities_acc = [], [], [], [], []
writer = SummaryWriter(f'runs/{model_name}_timbre_encoder')
current_best_model = model
previous_lowest_loss = 100.0
print(f"initial__loss: {previous_lowest_loss}")
for i in range(max_iter):
model.train()
representation, attributes = next(iter(training_iterator))
instrument = torch.tensor([s["instrument"] for s in attributes], dtype=torch.long).to(device)
instrument_family = torch.tensor([s["instrument_family"] for s in attributes], dtype=torch.long).to(device)
velocity = torch.tensor([s["velocity"] for s in attributes], dtype=torch.long).to(device)
qualities = torch.tensor([[int(char) for char in create_key(attribute)[-10:]] for attribute in attributes], dtype=torch.float32).to(device)
optimizer.zero_grad()
_, instrument_logits, instrument_family_logits, velocity_logits, qualities_pred = model(representation.to(device))
# compute loss
instrument_loss = nll_Loss(instrument_logits, instrument)
instrument_family_loss = nll_Loss(instrument_family_logits, instrument_family)
velocity_loss = nll_Loss(velocity_logits, velocity)
qualities_loss = bce_Loss(qualities_pred, qualities)
loss = instrument_loss + instrument_family_loss + velocity_loss + qualities_loss
loss.backward()
optimizer.step()
instrument_acc = get_multiclass_acc(instrument_logits, instrument)
instrument_family_acc = get_multiclass_acc(instrument_family_logits, instrument_family)
velocity_acc = get_multiclass_acc(velocity_logits, velocity)
qualities_acc = get_binary_accuracy(qualities_pred, qualities)
train_loss.append(loss.item())
training_instrument_acc.append(instrument_acc)
training_instrument_family_acc.append(instrument_family_acc)
training_velocity_acc.append(velocity_acc)
training_qualities_acc.append(qualities_acc)
step = int(optimizer.state_dict()['state'][list(optimizer.state_dict()['state'].keys())[0]]['step'].numpy())
if (i + 1) % 100 == 0:
print('%d step' % (step))
save_steps = 500
if (i + 1) % save_steps == 0:
current_loss = np.mean(train_loss[-save_steps:])
current_instrument_acc = np.mean(training_instrument_acc[-save_steps:])
current_instrument_family_acc = np.mean(training_instrument_family_acc[-save_steps:])
current_velocity_acc = np.mean(training_velocity_acc[-save_steps:])
current_qualities_acc = np.mean(training_qualities_acc[-save_steps:])
print('train_loss: %.5f' % current_loss)
print('current_instrument_acc: %.5f' % current_instrument_acc)
print('current_instrument_family_acc: %.5f' % current_instrument_family_acc)
print('current_velocity_acc: %.5f' % current_velocity_acc)
print('current_qualities_acc: %.5f' % current_qualities_acc)
writer.add_scalar(f"train_loss", current_loss, step)
writer.add_scalar(f"current_instrument_acc", current_instrument_acc, step)
writer.add_scalar(f"current_instrument_family_acc", current_instrument_family_acc, step)
writer.add_scalar(f"current_velocity_acc", current_velocity_acc, step)
writer.add_scalar(f"current_qualities_acc", current_qualities_acc, step)
if current_loss < previous_lowest_loss:
previous_lowest_loss = current_loss
current_best_model = model
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, f'models/{model_name}_timbre_encoder.pth')
save_model_hyperparameter(model_name, timbre_encoder_Config, BATCH_SIZE, lr, model_size, step,
current_loss)
return model, current_best_model