| |
| |
| |
| |
| |
|
|
| import sys |
|
|
| import tqdm |
| from torch.utils.data import DataLoader |
| from torch.utils.data.distributed import DistributedSampler |
|
|
| from .utils import apply_model, average_metric, center_trim |
|
|
|
|
| def train_model(epoch, |
| dataset, |
| model, |
| criterion, |
| optimizer, |
| augment, |
| quantizer=None, |
| diffq=0, |
| repeat=1, |
| device="cpu", |
| seed=None, |
| workers=4, |
| world_size=1, |
| batch_size=16): |
|
|
| if world_size > 1: |
| sampler = DistributedSampler(dataset) |
| sampler_epoch = epoch * repeat |
| if seed is not None: |
| sampler_epoch += seed * 1000 |
| sampler.set_epoch(sampler_epoch) |
| batch_size //= world_size |
| loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler, num_workers=workers) |
| else: |
| loader = DataLoader(dataset, batch_size=batch_size, num_workers=workers, shuffle=True) |
| current_loss = 0 |
| model_size = 0 |
| for repetition in range(repeat): |
| tq = tqdm.tqdm(loader, |
| ncols=120, |
| desc=f"[{epoch:03d}] train ({repetition + 1}/{repeat})", |
| leave=False, |
| file=sys.stdout, |
| unit=" batch") |
| total_loss = 0 |
| for idx, sources in enumerate(tq): |
| if len(sources) < batch_size: |
| |
| continue |
| sources = sources.to(device) |
| sources = augment(sources) |
| mix = sources.sum(dim=1) |
|
|
| estimates = model(mix) |
| sources = center_trim(sources, estimates) |
| loss = criterion(estimates, sources) |
| model_size = 0 |
| if quantizer is not None: |
| model_size = quantizer.model_size() |
|
|
| train_loss = loss + diffq * model_size |
| train_loss.backward() |
| grad_norm = 0 |
| for p in model.parameters(): |
| if p.grad is not None: |
| grad_norm += p.grad.data.norm()**2 |
| grad_norm = grad_norm**0.5 |
| optimizer.step() |
| optimizer.zero_grad() |
|
|
| if quantizer is not None: |
| model_size = model_size.item() |
|
|
| total_loss += loss.item() |
| current_loss = total_loss / (1 + idx) |
| tq.set_postfix(loss=f"{current_loss:.4f}", ms=f"{model_size:.2f}", |
| grad=f"{grad_norm:.5f}") |
|
|
| |
| del sources, mix, estimates, loss, train_loss |
|
|
| if world_size > 1: |
| sampler.epoch += 1 |
|
|
| if world_size > 1: |
| current_loss = average_metric(current_loss) |
| return current_loss, model_size |
|
|
|
|
| def validate_model(epoch, |
| dataset, |
| model, |
| criterion, |
| device="cpu", |
| rank=0, |
| world_size=1, |
| shifts=0, |
| overlap=0.25, |
| split=False): |
| indexes = range(rank, len(dataset), world_size) |
| tq = tqdm.tqdm(indexes, |
| ncols=120, |
| desc=f"[{epoch:03d}] valid", |
| leave=False, |
| file=sys.stdout, |
| unit=" track") |
| current_loss = 0 |
| for index in tq: |
| streams = dataset[index] |
| |
| streams = streams[..., :15_000_000] |
| streams = streams.to(device) |
| sources = streams[1:] |
| mix = streams[0] |
| estimates = apply_model(model, mix, shifts=shifts, split=split, overlap=overlap) |
| loss = criterion(estimates, sources) |
| current_loss += loss.item() / len(indexes) |
| del estimates, streams, sources |
|
|
| if world_size > 1: |
| current_loss = average_metric(current_loss, len(indexes)) |
| return current_loss |
|
|