|
import argparse |
|
import os |
|
|
|
import pytorch_lightning as pl |
|
import soundfile as sf |
|
import torch |
|
from pytorch_lightning.callbacks import ModelCheckpoint |
|
from pytorch_lightning.utilities.model_summary import summarize |
|
from torch.utils.data import DataLoader |
|
|
|
from config import CONFIG |
|
from dataset import TrainDataset, TestLoader, BlindTestLoader |
|
from models.frn import PLCModel, OnnxWrapper |
|
from utils.tblogger import TensorBoardLoggerExpanded |
|
from utils.utils import mkdir_p |
|
|
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument('--version', default=None, |
|
help='version to resume') |
|
parser.add_argument('--mode', default='train', |
|
help='training or testing mode') |
|
|
|
args = parser.parse_args() |
|
os.environ["CUDA_VISIBLE_DEVICES"] = str(CONFIG.gpus) |
|
assert args.mode in ['train', 'eval', 'test', 'onnx'], "--mode should be 'train', 'eval', 'test' or 'onnx'" |
|
|
|
|
|
def resume(train_dataset, val_dataset, version): |
|
print("Version", version) |
|
model_path = os.path.join(CONFIG.LOG.log_dir, 'version_{}/checkpoints/'.format(str(version))) |
|
config_path = os.path.join(CONFIG.LOG.log_dir, 'version_{}/'.format(str(version)) + 'hparams.yaml') |
|
model_name = [x for x in os.listdir(model_path) if x.endswith(".ckpt")][0] |
|
ckpt_path = model_path + model_name |
|
checkpoint = PLCModel.load_from_checkpoint(ckpt_path, |
|
strict=True, |
|
hparams_file=config_path, |
|
train_dataset=train_dataset, |
|
val_dataset=val_dataset, |
|
window_size=CONFIG.DATA.window_size) |
|
|
|
return checkpoint |
|
|
|
|
|
def train(): |
|
train_dataset = TrainDataset('train') |
|
val_dataset = TrainDataset('val') |
|
checkpoint_callback = ModelCheckpoint(monitor='val_loss', mode='min', verbose=True, |
|
filename='frn-{epoch:02d}-{val_loss:.4f}', save_weights_only=False) |
|
gpus = CONFIG.gpus.split(',') |
|
logger = TensorBoardLoggerExpanded(CONFIG.DATA.sr) |
|
if args.version is not None: |
|
model = resume(train_dataset, val_dataset, args.version) |
|
else: |
|
model = PLCModel(train_dataset, |
|
val_dataset, |
|
window_size=CONFIG.DATA.window_size, |
|
enc_layers=CONFIG.MODEL.enc_layers, |
|
enc_in_dim=CONFIG.MODEL.enc_in_dim, |
|
enc_dim=CONFIG.MODEL.enc_dim, |
|
pred_dim=CONFIG.MODEL.pred_dim, |
|
pred_layers=CONFIG.MODEL.pred_layers) |
|
|
|
trainer = pl.Trainer(logger=logger, |
|
gradient_clip_val=CONFIG.TRAIN.clipping_val, |
|
gpus=len(gpus), |
|
max_epochs=CONFIG.TRAIN.epochs, |
|
accelerator="gpu" if len(gpus) > 1 else None, |
|
callbacks=[checkpoint_callback] |
|
) |
|
|
|
print(model.hparams) |
|
print( |
|
'Dataset: {}, Train files: {}, Val files {}'.format(CONFIG.DATA.dataset, len(train_dataset), len(val_dataset))) |
|
trainer.fit(model) |
|
|
|
|
|
def to_onnx(model, onnx_path): |
|
model.eval() |
|
|
|
model = OnnxWrapper(model) |
|
|
|
torch.onnx.export(model, |
|
model.sample, |
|
onnx_path, |
|
export_params=True, |
|
opset_version=12, |
|
input_names=model.input_names, |
|
output_names=model.output_names, |
|
do_constant_folding=True, |
|
verbose=False) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
if args.mode == 'train': |
|
train() |
|
else: |
|
model = resume(None, None, args.version) |
|
print(model.hparams) |
|
print(summarize(model)) |
|
|
|
model.eval() |
|
model.freeze() |
|
if args.mode == 'eval': |
|
model.cuda(device=0) |
|
trainer = pl.Trainer(accelerator='gpu', devices=1, enable_checkpointing=False, logger=False) |
|
testset = TestLoader() |
|
test_loader = DataLoader(testset, batch_size=1, num_workers=4) |
|
trainer.test(model, test_loader) |
|
print('Version', args.version) |
|
masking = CONFIG.DATA.EVAL.masking |
|
prob = CONFIG.DATA.EVAL.transition_probs[0] |
|
loss_percent = (1 - prob[0]) / (2 - prob[0] - prob[1]) * 100 |
|
print('Evaluate with real trace' if masking == 'real' else |
|
'Evaluate with generated trace with {:.2f}% packet loss'.format(loss_percent)) |
|
elif args.mode == 'test': |
|
model.cuda(device=0) |
|
testset = BlindTestLoader(test_dir=CONFIG.TEST.in_dir) |
|
test_loader = DataLoader(testset, batch_size=1, num_workers=4) |
|
trainer = pl.Trainer(accelerator='gpu', devices=1, enable_checkpointing=False, logger=False) |
|
preds = trainer.predict(model, test_loader, return_predictions=True) |
|
mkdir_p(CONFIG.TEST.out_dir) |
|
for idx, path in enumerate(test_loader.dataset.data_list): |
|
out_path = os.path.join(CONFIG.TEST.out_dir, os.path.basename(path)) |
|
sf.write(out_path, preds[idx], samplerate=CONFIG.DATA.sr, subtype='PCM_16') |
|
|
|
else: |
|
onnx_path = 'lightning_logs/version_{}/checkpoints/frn.onnx'.format(str(args.version)) |
|
to_onnx(model, onnx_path) |
|
print('ONNX model saved to', onnx_path) |
|
|