lhzstar
initial commits
6bc94ac
raw
history blame
No virus
7.31 kB
from pathlib import Path
import numpy as np
from os.path import exists
import torch
from encoder.data_objects import DataLoader, Train_Dataset, Dev_Dataset
from encoder.model import SpeakerEncoder
from encoder.params_model import *
from encoder.visualizations import Visualizations
from utils.profiler import Profiler
def sync(device: torch.device):
# For correct profiling (cuda operations are async)
if device.type == "cuda":
torch.cuda.synchronize(device)
def update_lr(optimizer, lr):
for param_group in optimizer.param_groups:
param_group["lr"] = lr
def train(run_id: str, clean_data_root: Path, models_dir: Path, umap_every: int, save_every: int,
backup_every: int, vis_every: int, force_restart: bool, visdom_server: str,
no_visdom: bool):
# Create a dataset and a dataloader
train_dataset = Train_Dataset(clean_data_root.joinpath("train"))
dev_dataset = Dev_Dataset(clean_data_root.joinpath("dev"))
train_loader = DataLoader(
train_dataset,
speakers_per_batch,
utterances_per_speaker,
shuffle=True,
num_workers=8,
pin_memory=True
)
dev_batch = len(dev_dataset)
dev_loader = DataLoader(
dev_dataset,
dev_batch,
utterances_per_speaker,
shuffle=False,
num_workers=2,
pin_memory=True
)
# Setup the device on which to run the forward pass and the loss. These can be different,
# because the forward pass is faster on the GPU whereas the loss is often (depending on your
# hyperparameters) faster on the CPU.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# FIXME: currently, the gradient is None if loss_device is cuda
# loss_device = torch.device("cpu")
loss_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ####modified####
# Create the model and the optimizer
model = SpeakerEncoder(device, loss_device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate_init)
current_lr = learning_rate_init
init_step = 1
# Configure file path for the model
model_dir = models_dir / run_id
model_dir.mkdir(exist_ok=True, parents=True)
state_fpath = model_dir / "encoder.pt"
# Load any existing model
if not force_restart:
if state_fpath.exists():
print("Found existing model \"%s\", loading it and resuming training." % run_id)
checkpoint = torch.load(state_fpath)
init_step = checkpoint["step"]
print(f"Resuming training from step {init_step}")
model.load_state_dict(checkpoint["model_state"])
optimizer.load_state_dict(checkpoint["optimizer_state"])
optimizer.param_groups[0]["lr"] = learning_rate_init
else:
print("No model \"%s\" found, starting training from scratch." % run_id)
else:
print("Starting the training from scratch.")
# Initialize the visualization environment
vis = Visualizations(run_id, vis_every, server=visdom_server, disabled=no_visdom)
vis.log_dataset(train_dataset)
vis.log_params()
device_name = str(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")
vis.log_implementation({"Device": device_name})
best_eer_file_path = "encoder_loss/best_eer.npy"
if not exists("encoder_loss"):
import os
os.mkdir("encoder_loss")
best_eer = np.load(best_eer_file_path)[0] if exists(best_eer_file_path) else 1
# Training loop
profiler = Profiler(summarize_every=1000, disabled=False)
for step, speaker_batch in enumerate(train_loader, init_step):
model.train()
profiler.tick("Blocking, waiting for batch (threaded)")
# Data to GPU mem
inputs = torch.from_numpy(speaker_batch.data).to(device)
sync(device)
profiler.tick("Data to %s" % device)
# Forward pass
embeds = model(inputs)
sync(device)
profiler.tick("Forward pass")
embeds_loss = embeds.view((speakers_per_batch, utterances_per_speaker, -1)).to(loss_device)
loss, eer = model.loss(embeds_loss)
sync(loss_device)
profiler.tick("Loss")
# Backward pass
model.zero_grad() # Sets gradients of all model parameters to zero
loss.backward() # Calc gradients of all model parameters
profiler.tick("Backward pass")
model.do_gradient_ops()
optimizer.step() # do gradient descent of all model parameters
profiler.tick("Parameter update")
# Update visualizations
# learning_rate = optimizer.param_groups[0]["lr"]
# Overwrite the latest version of the model
if save_every != 0 and step % save_every == 0:
current_lr *= 0.995
update_lr(optimizer, current_lr)
dev_loss, dev_eer, dev_embeds = validate(dev_loader, model, dev_batch, device, loss_device)
sync(device)
sync(loss_device)
profiler.tick("validate")
vis.update(loss.item(), eer, step, dev_loss, dev_eer)
if dev_eer < best_eer:
best_eer = dev_eer
np.save(best_eer_file_path, np.array([best_eer]))
print("Saving the model (step %d)" % step)
torch.save({
"step": step + 1,
"model_state": model.state_dict(),
"optimizer_state": optimizer.state_dict(),
}, state_fpath)
else:
vis.update(loss.item(), eer, step)
# Draw projections and save them to the backup folder
if umap_every != 0 and step % umap_every == 0:
print("Drawing and saving projections (step %d)" % step)
projection_fpath = model_dir / f"umap_{step:06d}.png"
dev_projection_fpath = model_dir / f"dev_umap_{step:06d}.png"
embeds = embeds.detach().cpu().numpy()
dev_embeds = dev_embeds.detach().cpu().numpy()
vis.draw_projections(embeds, dev_embeds, utterances_per_speaker, step, projection_fpath, dev_projection_fpath)
vis.save()
# # Make a backup
# if backup_every != 0 and step % backup_every == 0:
# print("Making a backup (step %d)" % step)
# backup_fpath = model_dir / f"encoder_{step:06d}.bak"
# torch.save({
# "step": step + 1,
# "model_state": model.state_dict(),
# "optimizer_state": optimizer.state_dict(),
# }, backup_fpath)
profiler.tick("Extras (visualizations, saving)")
def validate(dev_loader: DataLoader, model: SpeakerEncoder, dev_batch, device, loss_device):
model.eval()
losses = []
eers = []
with torch.no_grad():
for step, speaker_batch in enumerate(dev_loader, 1):
frames = torch.from_numpy(speaker_batch.data).to(device)
embeds = model.forward(frames)
embeds_loss = embeds.view((dev_batch, utterances_per_speaker, -1)).to(loss_device)
loss, eer = model.loss(embeds_loss)
losses.append(loss.item())
eers.append(eer)
return sum(losses) / len(losses), sum(eers) / len(eers), embeds.detach()