|
import argparse |
|
import random |
|
import torch |
|
import numpy as np |
|
from time import time |
|
import logging |
|
|
|
from torch.utils.data import DataLoader |
|
|
|
from datasets import EmbDataset |
|
from models.rqvae import RQVAE |
|
from trainer import Trainer |
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser(description="Index") |
|
|
|
parser.add_argument('--lr', type=float, default=1e-3, help='learning rate') |
|
parser.add_argument('--epochs', type=int, default=5000, help='number of epochs') |
|
parser.add_argument('--batch_size', type=int, default=1024, help='batch size') |
|
parser.add_argument('--num_workers', type=int, default=4, ) |
|
parser.add_argument('--eval_step', type=int, default=50, help='eval step') |
|
parser.add_argument('--learner', type=str, default="AdamW", help='optimizer') |
|
parser.add_argument("--data_path", type=str, |
|
default="../data/Games/Games.emb-llama-td.npy", |
|
help="Input data path.") |
|
|
|
parser.add_argument('--weight_decay', type=float, default=1e-4, help='l2 regularization weight') |
|
parser.add_argument("--dropout_prob", type=float, default=0.0, help="dropout ratio") |
|
parser.add_argument("--bn", type=bool, default=False, help="use bn or not") |
|
parser.add_argument("--loss_type", type=str, default="mse", help="loss_type") |
|
parser.add_argument("--kmeans_init", type=bool, default=True, help="use kmeans_init or not") |
|
parser.add_argument("--kmeans_iters", type=int, default=100, help="max kmeans iters") |
|
parser.add_argument('--sk_epsilons', type=float, nargs='+', default=[0.0, 0.0, 0.0], help="sinkhorn epsilons") |
|
parser.add_argument("--sk_iters", type=int, default=50, help="max sinkhorn iters") |
|
|
|
parser.add_argument("--device", type=str, default="cuda:1", help="gpu or cpu") |
|
|
|
parser.add_argument('--num_emb_list', type=int, nargs='+', default=[256,256,256], help='emb num of every vq') |
|
parser.add_argument('--e_dim', type=int, default=32, help='vq codebook embedding size') |
|
parser.add_argument('--quant_loss_weight', type=float, default=1.0, help='vq quantion loss weight') |
|
parser.add_argument('--layers', type=int, nargs='+', default=[2048,1024,512,256,128,64], help='hidden sizes of every layer') |
|
|
|
parser.add_argument("--ckpt_dir", type=str, default="", help="output directory for model") |
|
|
|
return parser.parse_args() |
|
|
|
|
|
if __name__ == '__main__': |
|
"""fix the random seed""" |
|
seed = 2023 |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
torch.backends.cudnn.deterministic = True |
|
torch.backends.cudnn.benchmark = False |
|
|
|
args = parse_args() |
|
print(args) |
|
|
|
logging.basicConfig(level=logging.DEBUG) |
|
|
|
"""build dataset""" |
|
data = EmbDataset(args.data_path) |
|
model = RQVAE(in_dim=data.dim, |
|
num_emb_list=args.num_emb_list, |
|
e_dim=args.e_dim, |
|
layers=args.layers, |
|
dropout_prob=args.dropout_prob, |
|
bn=args.bn, |
|
loss_type=args.loss_type, |
|
quant_loss_weight=args.quant_loss_weight, |
|
kmeans_init=args.kmeans_init, |
|
kmeans_iters=args.kmeans_iters, |
|
sk_epsilons=args.sk_epsilons, |
|
sk_iters=args.sk_iters, |
|
) |
|
print(model) |
|
data_loader = DataLoader(data,num_workers=args.num_workers, |
|
batch_size=args.batch_size, shuffle=True, |
|
pin_memory=True) |
|
trainer = Trainer(args,model) |
|
best_loss, best_collision_rate = trainer.fit(data_loader) |
|
|
|
print("Best Loss",best_loss) |
|
print("Best Collision Rate", best_collision_rate) |
|
|
|
|