Spaces:
Running
Running
File size: 5,151 Bytes
c56c253 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
from speaker_encoder.visualizations import Visualizations
from speaker_encoder.data_objects import SpeakerVerificationDataLoader, SpeakerVerificationDataset
from speaker_encoder.params_model import *
from speaker_encoder.model import SpeakerEncoder
from utils.profiler import Profiler
from pathlib import Path
import torch
def sync(device: torch.device):
# FIXME
return
# For correct profiling (cuda operations are async)
if device.type == "cuda":
torch.cuda.synchronize(device)
def train(run_id: str, clean_data_root: Path, models_dir: Path, umap_every: int, save_every: int,
backup_every: int, vis_every: int, force_restart: bool, visdom_server: str,
no_visdom: bool):
# Create a dataset and a dataloader
dataset = SpeakerVerificationDataset(clean_data_root)
loader = SpeakerVerificationDataLoader(
dataset,
speakers_per_batch, # 64
utterances_per_speaker, # 10
num_workers=8,
)
# Setup the device on which to run the forward pass and the loss. These can be different,
# because the forward pass is faster on the GPU whereas the loss is often (depending on your
# hyperparameters) faster on the CPU.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# FIXME: currently, the gradient is None if loss_device is cuda
loss_device = torch.device("cpu")
# Create the model and the optimizer
model = SpeakerEncoder(device, loss_device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate_init)
init_step = 1
# Configure file path for the model
state_fpath = models_dir.joinpath(run_id + ".pt")
backup_dir = models_dir.joinpath(run_id + "_backups")
# Load any existing model
if not force_restart:
if state_fpath.exists():
print("Found existing model \"%s\", loading it and resuming training." % run_id)
checkpoint = torch.load(state_fpath)
init_step = checkpoint["step"]
model.load_state_dict(checkpoint["model_state"])
optimizer.load_state_dict(checkpoint["optimizer_state"])
optimizer.param_groups[0]["lr"] = learning_rate_init
else:
print("No model \"%s\" found, starting training from scratch." % run_id)
else:
print("Starting the training from scratch.")
model.train()
# Initialize the visualization environment
vis = Visualizations(run_id, vis_every, server=visdom_server, disabled=no_visdom)
vis.log_dataset(dataset)
vis.log_params()
device_name = str(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")
vis.log_implementation({"Device": device_name})
# Training loop
profiler = Profiler(summarize_every=10, disabled=False)
for step, speaker_batch in enumerate(loader, init_step):
profiler.tick("Blocking, waiting for batch (threaded)")
# Forward pass
inputs = torch.from_numpy(speaker_batch.data).to(device)
sync(device)
profiler.tick("Data to %s" % device)
embeds = model(inputs)
sync(device)
profiler.tick("Forward pass")
embeds_loss = embeds.view((speakers_per_batch, utterances_per_speaker, -1)).to(loss_device)
loss, eer = model.loss(embeds_loss)
sync(loss_device)
profiler.tick("Loss")
# Backward pass
model.zero_grad()
loss.backward()
profiler.tick("Backward pass")
model.do_gradient_ops()
optimizer.step()
profiler.tick("Parameter update")
# Update visualizations
# learning_rate = optimizer.param_groups[0]["lr"]
vis.update(loss.item(), eer, step)
# Draw projections and save them to the backup folder
if umap_every != 0 and step % umap_every == 0:
print("Drawing and saving projections (step %d)" % step)
backup_dir.mkdir(exist_ok=True)
projection_fpath = backup_dir.joinpath("%s_umap_%06d.png" % (run_id, step))
embeds = embeds.detach().cpu().numpy()
vis.draw_projections(embeds, utterances_per_speaker, step, projection_fpath)
vis.save()
# Overwrite the latest version of the model
if save_every != 0 and step % save_every == 0:
print("Saving the model (step %d)" % step)
torch.save({
"step": step + 1,
"model_state": model.state_dict(),
"optimizer_state": optimizer.state_dict(),
}, state_fpath)
# Make a backup
if backup_every != 0 and step % backup_every == 0:
print("Making a backup (step %d)" % step)
backup_dir.mkdir(exist_ok=True)
backup_fpath = backup_dir.joinpath("%s_bak_%06d.pt" % (run_id, step))
torch.save({
"step": step + 1,
"model_state": model.state_dict(),
"optimizer_state": optimizer.state_dict(),
}, backup_fpath)
profiler.tick("Extras (visualizations, saving)")
|