import os, json import argparse import numpy as np from datetime import datetime from const import lftkplus_names from copy import deepcopy def parse_args(ckpt=None): parser = argparse.ArgumentParser() parser.add_argument('--data_dir', default='/data/mohamed/data') parser.add_argument('--data', default='ling_conversion') parser.add_argument('--data_sources') parser.add_argument('--data_type', default='text') parser.add_argument('--aim_repo', default='/data/mohamed/') parser.add_argument('--ckpt_dir', default='/data/mohamed/checkpoints') parser.add_argument('--kld_annealing', default='cyclic') parser.add_argument('--lingpred_annealing', default='mono') parser.add_argument('--ling_embed_type', default = 'one-layer') parser.add_argument('--combine_weight', default=1, type=float) parser.add_argument('--alpha_kld', default=1, type=float) parser.add_argument('--alpha_lingpred', default=1, type=float) parser.add_argument('--alpha_sem', default=1, type=float) parser.add_argument('--max_grad_norm', default=10, type=float) parser.add_argument('--sem_loss_tao', default=0.5, type=float) parser.add_argument('--sem_loss_eps', default=1, type=float) parser.add_argument('--ckpt') parser.add_argument('--disc_ckpt') parser.add_argument('--sem_ckpt') parser.add_argument('--lng_ids') parser.add_argument('--lng_ids_idx', type=int) parser.add_argument('--lng_ids_path', default='/data/mohamed/indices') parser.add_argument('--preds_dir', default='/data/mohamed/preds') parser.add_argument('--model_name', default="google/flan-t5-base") parser.add_argument('--model_path', default="mohdelgaar/lingconv") parser.add_argument('--sem_path', default="mohdelgaar/lingconv-semantic-classifier") parser.add_argument('--sem_model_path', default="mohdelgaar/lingconv-semantic-classifier") parser.add_argument('--disc_model_path', default="mohdelgaar/lingconv-discriminator") parser.add_argument('--disc_type', default="t5") parser.add_argument('--aim_exp', default='ling-conversion') parser.add_argument('--sem_loss_type', default='dedicated') parser.add_argument('--combine_method', default='none') parser.add_argument('--train_log', type=int, default=200) parser.add_argument('--val_log', type=int, default=2000) parser.add_argument('--batch_size', type=int, default=64) parser.add_argument('--eval_batch_size', type=int, default=32) parser.add_argument('--max_eval_samples', type=int, default=1000) parser.add_argument('--test_batch_size', type=int, default=1) parser.add_argument('--hidden_dim', type=int, default=500) parser.add_argument('--latent_dim', type=int, default=150) parser.add_argument('--lng_dim', type=int, default=40) parser.add_argument('--disc_lng_dim', type=int) parser.add_argument('--use_lora', action='store_true') parser.add_argument('--lora_r', type=int, default=64) parser.add_argument('--gpu', type=str, default='0') parser.add_argument('--epochs', type=int, default=10) parser.add_argument('--grad_accumulation', type=int, default=1) parser.add_argument('--n_ica', type=int, default=10) parser.add_argument('--max_length', type=int, default=200) parser.add_argument('--total_steps', type=int) parser.add_argument('--kld_const', type=float, default=1) parser.add_argument('--lr', type=float, default=1e-4) parser.add_argument('--kl_weight', type=float, default=1e-1) parser.add_argument('--weight_decay', type=float, default=1e-2) parser.add_argument('--ling_dropout', type=float, default=0.1) parser.add_argument('--predict_fn', default = 'logs/test.txt') parser.add_argument('--save_predict', action='store_true') parser.add_argument('--use_ica', action='store_true') parser.add_argument('--pretrain_gen', action='store_true') parser.add_argument('--pretrain_sem', action='store_true') parser.add_argument('--pretrain_disc', action='store_true') parser.add_argument('--linggen_type', default='none') parser.add_argument('--linggen_input', default='s+l') parser.add_argument('--aug_same', action='store_true') parser.add_argument('--ling_vae', action='store_true') parser.add_argument('--process_lingpred', action='store_true') parser.add_argument('--fudge_lambda', type=float, default=1.0) parser.add_argument('--use_lingpred', action='store_true') parser.add_argument('--ling2_only', action='store_true') parser.add_argument('--cycle_loss', action='store_true') parser.add_argument('--disc_loss', action='store_true') parser.add_argument('--sem_loss', action='store_true') parser.add_argument('--sim_loss', action='store_true') parser.add_argument('--optuna', action='store_true') parser.add_argument('--debug', action='store_true') parser.add_argument('--demo', action='store_true') parser.add_argument('--fudge', action='store_true') parser.add_argument('--fb_log', default='feedback_logs/default.txt') parser.add_argument('--eval_only', action='store_true') parser.add_argument('--predict_with_feedback', action='store_true') parser.add_argument('--feedback_param', default = 's') parser.add_argument('--eval_ling', action='store_true') parser.add_argument('--seed', type=int, default=0) parser.add_argument('--major_arg', default = 0, type=int) parser.add_argument('--quantize_lng', action='store_true') parser.add_argument('--quant_nbins', type=int, default=20) parser.add_argument('--src_lng', default = 'ling') parser.add_argument('--to_restore', nargs='+', default=[]) # args = parser.parse_args() args, unknown = parser.parse_known_args() args.name = f'{datetime.now().strftime("%m%d_%H-%M-%S")}-{args.data}-{args.combine_method}' major_arg = args.major_arg to_restore = [ ] + args.to_restore to_restore = {k: args.__dict__[k] for k in to_restore} if not args.disc_loss or args.disc_ckpt: args.disc_steps = 0 if args.data_sources is not None: args.data_sources = args.data_sources.split(',') if ckpt is not None: args.ckpt = ckpt args_list = [args] if args.ckpt: if ',' in args.ckpt: ckpts = args.ckpt.split(',') args_list = [deepcopy(args) for _ in range(len(ckpts))] for i in range(len(ckpts)): args_path = ckpts[i].replace('_best', '').replace('.pt', '.json') with open(args_path) as f: args_list[i].__dict__.update(json.load(f)) args_list[i].__dict__.update(to_restore) args_list[i].ckpt = ckpts[i] else: args_path = args.ckpt.replace('_best', '').replace('.pt', '.json') ckpt = args.ckpt with open(args_path) as f: args.__dict__.update(json.load(f)) args.__dict__.update(to_restore) args.ckpt = ckpt lng_names = lftkplus_names for i in range(len(args_list)): if args_list[i].lng_ids or args_list[i].lng_ids_idx: if args_list[i].lng_ids_idx: lng_ids = np.load(os.path.join(args_list[i].lng_ids_path, f'{args_list[i].lng_ids_idx}.npy')) elif args_list[i].lng_ids[0].isnumeric(): lng_ids = [int(x) for x in args_list[i].lng_ids.split(',')] elif ',' in args_list[i].lng_ids: lng_ids = [lng_names.index(x) for x in args_list[i].lng_ids.split(',')] else: lng_ids = np.load(args_list[i].lng_ids) args_list[i].lng_dim = len(lng_ids) args_list[i].lng_ids = lng_ids.tolist() # lng_names = [lng_names[i] for i in lng_ids] elif args_list[i].use_ica: args_list[i].lng_dim = args_list[i].n_ica if args_list[i].disc_lng_dim is None: args_list[i].disc_lng_dim = args_list[i].lng_dim if not args.ckpt and not args.eval_only: args_path = os.path.join(args.ckpt_dir, '%s.json'%args.name) with open(args_path, 'w') as f: s = json.dumps(args.__dict__) f.write(s) return args_list[major_arg], args_list, lng_names