import json import logging import os import random import datetime import numpy as np import torch from torch.utils.data import ConcatDataset from data import SeqRecDataset, ItemFeatDataset, ItemSearchDataset, FusionSeqRecDataset, SeqRecTestDataset, PreferenceObtainDataset def parse_global_args(parser): parser.add_argument("--seed", type=int, default=42, help="Random seed") parser.add_argument("--base_model", type=str, default="./llama-7b/", help="basic model path") parser.add_argument("--output_dir", type=str, default="./ckpt/", help="The output directory") return parser def parse_dataset_args(parser): parser.add_argument("--data_path", type=str, default="", help="data directory") parser.add_argument("--tasks", type=str, default="seqrec,item2index,index2item,fusionseqrec,itemsearch,preferenceobtain", help="Downstream tasks, separate by comma") parser.add_argument("--dataset", type=str, default="Games", help="Dataset name") parser.add_argument("--inter_path", type=str, default="", help="data directory") parser.add_argument("--feature_path", type=str, default="", help="data directory") parser.add_argument("--index_file", type=str, default=".index.json", help="the item indices file") parser.add_argument("--dataloader_num_workers", type=int, default=0, help="dataloader num_workers") parser.add_argument("--dataloader_prefetch_factor", type=int, default=2, help="dataloader prefetch_factor") # arguments related to sequential task parser.add_argument("--max_his_len", type=int, default=20, help="the max number of items in history sequence, -1 means no limit") parser.add_argument("--add_prefix", action="store_true", default=False, help="whether add sequential prefix in history") parser.add_argument("--his_sep", type=str, default=", ", help="The separator used for history") parser.add_argument("--only_train_response", action="store_true", default=False, help="whether only train on responses") parser.add_argument("--train_prompt_sample_num", type=str, default="1,1,1,1,1,1", help="the number of sampling prompts for each task") parser.add_argument("--train_data_sample_num", type=str, default="0,0,0,100000,0,0", help="the number of sampling prompts for each task") parser.add_argument("--valid_prompt_id", type=int, default=0, help="The prompt used for validation") parser.add_argument("--sample_valid", action="store_true", default=True, help="use sampled prompt for validation") parser.add_argument("--valid_prompt_sample_num", type=int, default=2, help="the number of sampling validation sequential recommendation prompts") return parser def parse_train_args(parser): parser.add_argument("--optim", type=str, default="adamw_torch", help='The name of the optimizer') parser.add_argument("--epochs", type=int, default=4) parser.add_argument("--learning_rate", type=float, default=2e-5) parser.add_argument("--per_device_batch_size", type=int, default=8) parser.add_argument("--gradient_accumulation_steps", type=int, default=2) parser.add_argument("--logging_step", type=int, default=10) parser.add_argument("--model_max_length", type=int, default=2048) parser.add_argument("--weight_decay", type=float, default=0.01) parser.add_argument("--lora_r", type=int, default=8) parser.add_argument("--lora_alpha", type=int, default=32) parser.add_argument("--lora_dropout", type=float, default=0.05) parser.add_argument("--lora_target_modules", type=str, default="q_proj,v_proj,k_proj,o_proj,gate_proj,down_proj,up_proj", help="separate by comma") parser.add_argument("--lora_modules_to_save", type=str, default="embed_tokens,lm_head", help="separate by comma") parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="either training checkpoint or final adapter") parser.add_argument("--warmup_ratio", type=float, default=0.01) parser.add_argument("--lr_scheduler_type", type=str, default="cosine") parser.add_argument("--save_and_eval_strategy", type=str, default="epoch") parser.add_argument("--save_and_eval_steps", type=int, default=1000) parser.add_argument("--fp16", action="store_true", default=False) parser.add_argument("--bf16", action="store_true", default=False) parser.add_argument("--deepspeed", type=str, default="./config/ds_z3_bf16.json") return parser def parse_test_args(parser): parser.add_argument("--ckpt_path", type=str, default="", help="The checkpoint path") parser.add_argument("--lora", action="store_true", default=False) parser.add_argument("--filter_items", action="store_true", default=False, help="whether filter illegal items") parser.add_argument("--results_file", type=str, default="./results/test-ddp.json", help="result output path") parser.add_argument("--test_batch_size", type=int, default=1) parser.add_argument("--num_beams", type=int, default=20) parser.add_argument("--sample_num", type=int, default=-1, help="test sample number, -1 represents using all test data") parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID when testing with single GPU") parser.add_argument("--test_prompt_ids", type=str, default="0", help="test prompt ids, separate by comma. 'all' represents using all") parser.add_argument("--metrics", type=str, default="hit@1,hit@5,hit@10,ndcg@5,ndcg@10", help="test metrics, separate by comma") parser.add_argument("--test_task", type=str, default="SeqRec") return parser def get_local_time(): cur = datetime.datetime.now() cur = cur.strftime("%b-%d-%Y_%H-%M-%S") return cur def set_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.backends.cudnn.enabled = False def ensure_dir(dir_path): os.makedirs(dir_path, exist_ok=True) def load_datasets(args): tasks = args.tasks.split(",") train_prompt_sample_num = [int(_) for _ in args.train_prompt_sample_num.split(",")] assert len(tasks) == len(train_prompt_sample_num), "prompt sample number does not match task number" train_data_sample_num = [int(_) for _ in args.train_data_sample_num.split(",")] assert len(tasks) == len(train_data_sample_num), "data sample number does not match task number" train_datasets = [] for task, prompt_sample_num,data_sample_num in zip(tasks,train_prompt_sample_num,train_data_sample_num): if task.lower() == "seqrec": dataset = SeqRecDataset(args, mode="train", prompt_sample_num=prompt_sample_num, sample_num=data_sample_num) elif task.lower() == "item2index" or task.lower() == "index2item": dataset = ItemFeatDataset(args, task=task.lower(), prompt_sample_num=prompt_sample_num, sample_num=data_sample_num) elif task.lower() == "fusionseqrec": dataset = FusionSeqRecDataset(args, mode="train", prompt_sample_num=prompt_sample_num, sample_num=data_sample_num) elif task.lower() == "itemsearch": dataset = ItemSearchDataset(args, mode="train", prompt_sample_num=prompt_sample_num, sample_num=data_sample_num) elif task.lower() == "preferenceobtain": dataset = PreferenceObtainDataset(args, prompt_sample_num=prompt_sample_num, sample_num=data_sample_num) else: raise NotImplementedError train_datasets.append(dataset) train_data = ConcatDataset(train_datasets) valid_data = SeqRecDataset(args,"valid",args.valid_prompt_sample_num) return train_data, valid_data def load_test_dataset(args): if args.test_task.lower() == "seqrec": test_data = SeqRecDataset(args, mode="test", sample_num=args.sample_num) # test_data = SeqRecTestDataset(args, sample_num=args.sample_num) elif args.test_task.lower() == "itemsearch": test_data = ItemSearchDataset(args, mode="test", sample_num=args.sample_num) elif args.test_task.lower() == "fusionseqrec": test_data = FusionSeqRecDataset(args, mode="test", sample_num=args.sample_num) else: raise NotImplementedError return test_data def load_json(file): with open(file, 'r') as f: data = json.load(f) return data