Spaces:
Running
Running
# Copyright (c) Facebook, Inc. and its affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import sys | |
import tqdm | |
from torch.utils.data import DataLoader | |
from torch.utils.data.distributed import DistributedSampler | |
from .utils import apply_model, average_metric, center_trim | |
def train_model(epoch, | |
dataset, | |
model, | |
criterion, | |
optimizer, | |
augment, | |
quantizer=None, | |
diffq=0, | |
repeat=1, | |
device="cpu", | |
seed=None, | |
workers=4, | |
world_size=1, | |
batch_size=16): | |
if world_size > 1: | |
sampler = DistributedSampler(dataset) | |
sampler_epoch = epoch * repeat | |
if seed is not None: | |
sampler_epoch += seed * 1000 | |
sampler.set_epoch(sampler_epoch) | |
batch_size //= world_size | |
loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler, num_workers=workers) | |
else: | |
loader = DataLoader(dataset, batch_size=batch_size, num_workers=workers, shuffle=True) | |
current_loss = 0 | |
model_size = 0 | |
for repetition in range(repeat): | |
tq = tqdm.tqdm(loader, | |
ncols=120, | |
desc=f"[{epoch:03d}] train ({repetition + 1}/{repeat})", | |
leave=False, | |
file=sys.stdout, | |
unit=" batch") | |
total_loss = 0 | |
for idx, sources in enumerate(tq): | |
if len(sources) < batch_size: | |
# skip uncomplete batch for augment.Remix to work properly | |
continue | |
sources = sources.to(device) | |
sources = augment(sources) | |
mix = sources.sum(dim=1) | |
estimates = model(mix) | |
sources = center_trim(sources, estimates) | |
loss = criterion(estimates, sources) | |
model_size = 0 | |
if quantizer is not None: | |
model_size = quantizer.model_size() | |
train_loss = loss + diffq * model_size | |
train_loss.backward() | |
grad_norm = 0 | |
for p in model.parameters(): | |
if p.grad is not None: | |
grad_norm += p.grad.data.norm()**2 | |
grad_norm = grad_norm**0.5 | |
optimizer.step() | |
optimizer.zero_grad() | |
if quantizer is not None: | |
model_size = model_size.item() | |
total_loss += loss.item() | |
current_loss = total_loss / (1 + idx) | |
tq.set_postfix(loss=f"{current_loss:.4f}", ms=f"{model_size:.2f}", | |
grad=f"{grad_norm:.5f}") | |
# free some space before next round | |
del sources, mix, estimates, loss, train_loss | |
if world_size > 1: | |
sampler.epoch += 1 | |
if world_size > 1: | |
current_loss = average_metric(current_loss) | |
return current_loss, model_size | |
def validate_model(epoch, | |
dataset, | |
model, | |
criterion, | |
device="cpu", | |
rank=0, | |
world_size=1, | |
shifts=0, | |
overlap=0.25, | |
split=False): | |
indexes = range(rank, len(dataset), world_size) | |
tq = tqdm.tqdm(indexes, | |
ncols=120, | |
desc=f"[{epoch:03d}] valid", | |
leave=False, | |
file=sys.stdout, | |
unit=" track") | |
current_loss = 0 | |
for index in tq: | |
streams = dataset[index] | |
# first five minutes to avoid OOM on --upsample models | |
streams = streams[..., :15_000_000] | |
streams = streams.to(device) | |
sources = streams[1:] | |
mix = streams[0] | |
estimates = apply_model(model, mix, shifts=shifts, split=split, overlap=overlap) | |
loss = criterion(estimates, sources) | |
current_loss += loss.item() / len(indexes) | |
del estimates, streams, sources | |
if world_size > 1: | |
current_loss = average_metric(current_loss, len(indexes)) | |
return current_loss | |