| import os |
| import torch |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset, DataLoader |
| from tqdm import tqdm |
| from dataset import Autoencoder_dataset |
| from model import Autoencoder |
| from torch.utils.tensorboard import SummaryWriter |
| import argparse |
|
|
| torch.autograd.set_detect_anomaly(True) |
|
|
| def l2_loss(network_output, gt): |
| return ((network_output - gt) ** 2).mean() |
|
|
| def cos_loss(network_output, gt): |
| return 1 - F.cosine_similarity(network_output, gt, dim=0).mean() |
|
|
|
|
| if __name__ == '__main__': |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--dataset_path', type=str, required=True) |
| parser.add_argument('--num_epochs', type=int, default=100) |
| parser.add_argument('--lr', type=float, default=0.0001) |
| parser.add_argument('--encoder_dims', |
| nargs = '+', |
| type=int, |
| default=[256, 128, 64, 32, 3], |
| ) |
| parser.add_argument('--decoder_dims', |
| nargs = '+', |
| type=int, |
| default=[16, 32, 64, 128, 256, 256, 512], |
| ) |
| parser.add_argument('--dataset_name', type=str, required=True) |
| args = parser.parse_args() |
| dataset_path = args.dataset_path |
| num_epochs = args.num_epochs |
| data_dir = f"{dataset_path}/language_features" |
| os.makedirs(f'ckpt/{args.dataset_name}', exist_ok=True) |
| train_dataset = Autoencoder_dataset(data_dir) |
| train_loader = DataLoader( |
| dataset=train_dataset, |
| batch_size=64, |
| shuffle=True, |
| num_workers=16, |
| drop_last=False |
| ) |
|
|
| test_loader = DataLoader( |
| dataset=train_dataset, |
| batch_size=256, |
| shuffle=False, |
| num_workers=16, |
| drop_last=False |
| ) |
| |
| encoder_hidden_dims = args.encoder_dims |
| decoder_hidden_dims = args.decoder_dims |
|
|
| model = Autoencoder(encoder_hidden_dims, decoder_hidden_dims).to("cuda:0") |
|
|
| optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) |
| logdir = f'ckpt/{args.dataset_name}' |
| tb_writer = SummaryWriter(logdir) |
|
|
| best_eval_loss = 100.0 |
| best_epoch = 0 |
| for epoch in tqdm(range(num_epochs)): |
| model.train() |
| for idx, feature in enumerate(train_loader): |
| data = feature.to("cuda:0") |
| outputs_dim3 = model.encode(data) |
| outputs = model.decode(outputs_dim3) |
| |
| l2loss = l2_loss(outputs, data) |
| cosloss = cos_loss(outputs, data) |
| loss = l2loss + cosloss * 0.001 |
| |
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
|
|
| global_iter = epoch * len(train_loader) + idx |
| tb_writer.add_scalar('train_loss/l2_loss', l2loss.item(), global_iter) |
| tb_writer.add_scalar('train_loss/cos_loss', cosloss.item(), global_iter) |
| tb_writer.add_scalar('train_loss/total_loss', loss.item(), global_iter) |
| tb_writer.add_histogram("feat", outputs, global_iter) |
|
|
| if epoch > 95: |
| eval_loss = 0.0 |
| model.eval() |
| for idx, feature in enumerate(test_loader): |
| data = feature.to("cuda:0") |
| with torch.no_grad(): |
| outputs = model(data) |
| loss = l2_loss(outputs, data) + cos_loss(outputs, data) |
| eval_loss += loss * len(feature) |
| eval_loss = eval_loss / len(train_dataset) |
| print("eval_loss:{:.8f}".format(eval_loss)) |
| if eval_loss < best_eval_loss: |
| best_eval_loss = eval_loss |
| best_epoch = epoch |
| torch.save(model.state_dict(), f'ckpt/{args.dataset_name}/best_ckpt.pth') |
| |
| if epoch % 10 == 0: |
| torch.save(model.state_dict(), f'ckpt/{args.dataset_name}/{epoch}_ckpt.pth') |
| |
| print(f"best_epoch: {best_epoch}") |
| print("best_loss: {:.8f}".format(best_eval_loss)) |