| | import copy
|
| | import random
|
| | import argparse
|
| | import os
|
| | import torch
|
| | import torch.nn as nn
|
| | from torch.utils.data import Dataset
|
| | from tqdm import tqdm
|
| | from collections import defaultdict
|
| | import torch.distributed as dist
|
| | import logging
|
| | import re
|
| | import pdb
|
| | import json
|
| | from prompt_finetune import sft_prompt, all_prompt
|
| | import numpy as np
|
| |
|
| |
|
| | class BaseDataset(Dataset):
|
| |
|
| | def __init__(self, args):
|
| | super().__init__()
|
| |
|
| | self.args = args
|
| | self.dataset = args.dataset
|
| | self.data_path = os.path.join(args.data_path, self.dataset)
|
| |
|
| | self.max_his_len = args.max_his_len
|
| | self.his_sep = args.his_sep
|
| | self.index_file = args.index_file
|
| | self.add_prefix = args.add_prefix
|
| |
|
| | self.new_tokens = None
|
| | self.allowed_tokens = None
|
| | self.all_items = None
|
| |
|
| |
|
| | def _load_data(self):
|
| |
|
| | with open(os.path.join(self.data_path, self.dataset + self.index_file), 'r') as f:
|
| | self.indices = json.load(f)
|
| |
|
| | def get_new_tokens(self):
|
| |
|
| | if self.new_tokens is not None:
|
| | return self.new_tokens
|
| |
|
| | self.new_tokens = set()
|
| | for index in self.indices.values():
|
| | for token in index:
|
| | self.new_tokens.add(token)
|
| | self.new_tokens = sorted(list(self.new_tokens))
|
| |
|
| | return self.new_tokens
|
| |
|
| | def get_all_items(self):
|
| |
|
| | if self.all_items is not None:
|
| | return self.all_items
|
| |
|
| | self.all_items = set()
|
| | for index in self.indices.values():
|
| | self.all_items.add("".join(index))
|
| |
|
| | return self.all_items
|
| |
|
| | def get_prefix_allowed_tokens_fn(self, tokenizer):
|
| |
|
| |
|
| | if self.allowed_tokens is None:
|
| | self.allowed_tokens = {}
|
| | for index in self.indices.values():
|
| | for i, token in enumerate(index):
|
| | token_id = tokenizer(token)["input_ids"][1]
|
| | if i not in self.allowed_tokens.keys():
|
| | self.allowed_tokens[i] = set()
|
| | self.allowed_tokens[i].add(token_id)
|
| | self.allowed_tokens[len(self.allowed_tokens.keys())] = set([tokenizer.eos_token_id])
|
| | sep = tokenizer("Response:")["input_ids"][1:]
|
| |
|
| | def prefix_allowed_tokens_fn(batch_id, sentence):
|
| | sentence = sentence.tolist()
|
| | reversed_sent = sentence[::-1]
|
| | for i in range(len(reversed_sent)):
|
| | if reversed_sent[i:i + len(sep)] == sep[::-1]:
|
| |
|
| | return list(self.allowed_tokens[i])
|
| |
|
| | return prefix_allowed_tokens_fn
|
| |
|
| | def _process_data(self):
|
| |
|
| | raise NotImplementedError
|
| |
|
| |
|
| |
|
| | class SeqRecFinetune(BaseDataset):
|
| |
|
| | def __init__(self, args, mode="train",
|
| | prompt_sample_num=1, prompt_id=0, sample_num=-1):
|
| | super().__init__(args)
|
| |
|
| | self.mode = mode
|
| | self.prompt_sample_num = prompt_sample_num
|
| | self.prompt_id = prompt_id
|
| | self.sample_num = sample_num
|
| |
|
| | self.prompts = all_prompt["seqrec"]
|
| |
|
| |
|
| |
|
| | self._load_data()
|
| | self._remap_items()
|
| |
|
| |
|
| | if self.mode == 'train':
|
| | self.inter_data = self._process_train_data()
|
| | elif self.mode == 'valid':
|
| | self.sample_valid = args.sample_valid
|
| | self.valid_prompt_id = args.valid_prompt_id
|
| | self.inter_data = self._process_valid_data()
|
| | self._construct_valid_text()
|
| | elif self.mode == 'test':
|
| | self.inter_data = self._process_test_data()
|
| | else:
|
| | raise NotImplementedError
|
| |
|
| |
|
| |
|
| | def _load_data(self):
|
| |
|
| | with open(os.path.join(self.data_path, self.dataset + ".inter.json"), 'r') as f:
|
| | self.inters = json.load(f)
|
| | with open(self.index_file, 'r') as f:
|
| | self.indices = json.load(f)
|
| |
|
| |
|
| | def _remap_items(self):
|
| |
|
| | self.remapped_inters = dict()
|
| | for uid, items in self.inters.items():
|
| | new_items = ["".join(self.indices[str(i)]) for i in items]
|
| | self.remapped_inters[uid] = new_items
|
| |
|
| |
|
| | def _process_train_data(self):
|
| |
|
| | inter_data = []
|
| | for uid in self.remapped_inters:
|
| | items = self.remapped_inters[uid][:-2]
|
| | for i in range(1, len(items)):
|
| | one_data = dict()
|
| |
|
| | one_data["item"] = items[i]
|
| | history = items[:i]
|
| | if self.max_his_len > 0:
|
| | history = history[-self.max_his_len:]
|
| | if self.add_prefix:
|
| | history = [str(k+1) + ". " + item_idx for k, item_idx in enumerate(history)]
|
| | one_data["inters"] = self.his_sep.join(history)
|
| | inter_data.append(one_data)
|
| |
|
| | return inter_data
|
| |
|
| | def _process_valid_data(self):
|
| |
|
| | inter_data = []
|
| | for uid in self.remapped_inters:
|
| | items = self.remapped_inters[uid]
|
| | one_data = dict()
|
| |
|
| | one_data["item"] = items[-2]
|
| | history = items[:-2]
|
| | if self.max_his_len > 0:
|
| | history = history[-self.max_his_len:]
|
| | if self.add_prefix:
|
| | history = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(history)]
|
| | one_data["inters"] = self.his_sep.join(history)
|
| | inter_data.append(one_data)
|
| |
|
| | return inter_data
|
| |
|
| | def _process_test_data(self):
|
| |
|
| | inter_data = []
|
| | for uid in self.remapped_inters:
|
| | items = self.remapped_inters[uid]
|
| | one_data = dict()
|
| |
|
| | one_data["item"] = items[-1]
|
| | history = items[:-1]
|
| | if self.max_his_len > 0:
|
| | history = history[-self.max_his_len:]
|
| | if self.add_prefix:
|
| | history = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(history)]
|
| | one_data["inters"] = self.his_sep.join(history)
|
| | inter_data.append(one_data)
|
| |
|
| | if self.sample_num > 0:
|
| | all_inter_idx = range(len(inter_data))
|
| | sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace=False)
|
| | inter_data = np.array(inter_data)[sample_idx].tolist()
|
| |
|
| | return inter_data
|
| |
|
| | def set_prompt(self, prompt_id):
|
| |
|
| | self.prompt_id = prompt_id
|
| |
|
| | def __len__(self):
|
| | if self.mode == 'train':
|
| | return len(self.inter_data) * self.prompt_sample_num
|
| | elif self.mode == 'valid':
|
| | return len(self.valid_text_data)
|
| | elif self.mode == 'test':
|
| | return len(self.inter_data)
|
| | else:
|
| | raise NotImplementedError
|
| |
|
| | def _construct_valid_text(self):
|
| | self.valid_text_data = []
|
| | if self.sample_valid:
|
| | all_prompt_ids = range(len(self.prompts))
|
| | for i in range(len(self.inter_data)):
|
| | d = self.inter_data[i]
|
| | prompt_ids = np.random.choice(all_prompt_ids, self.prompt_sample_num, replace=False)
|
| | for prompt_id in prompt_ids:
|
| | prompt = self.prompts[prompt_id]
|
| | input, output = self._get_text_data(d, prompt)
|
| | self.valid_text_data.append({"input_ids": input, "labels": output})
|
| | else:
|
| | self.prompt_sample_num = 1
|
| | prompt = self.prompts[self.valid_prompt_id]
|
| | for i in range(len(self.inter_data)):
|
| | d = self.inter_data[i]
|
| | input, output = self._get_text_data(d, prompt)
|
| | self.valid_text_data.append({"input_ids": input, "labels": output})
|
| |
|
| | def _get_text_data(self, data, prompt):
|
| |
|
| | instruction = prompt["instruction"].format(**data)
|
| | response = prompt["response"].format(**data)
|
| |
|
| | input = sft_prompt.format(instruction = instruction, response = "")
|
| | output = sft_prompt.format(instruction = instruction, response = response)
|
| |
|
| | if self.mode == 'test':
|
| | return input, response
|
| |
|
| | return input, output
|
| |
|
| | def __getitem__(self, index):
|
| |
|
| | if self.mode == 'valid':
|
| | return self.valid_text_data[index]
|
| |
|
| | idx = index // self.prompt_sample_num
|
| | d = self.inter_data[idx]
|
| |
|
| |
|
| | if self.mode == 'train':
|
| | prompt_id = random.randint(0, len(self.prompts) - 1)
|
| | elif self.mode == 'test':
|
| | prompt_id = self.prompt_id
|
| |
|
| | prompt = self.prompts[prompt_id]
|
| |
|
| | input, output = self._get_text_data(d, prompt)
|
| |
|
| |
|
| |
|
| | return dict(input_ids=input, labels=output)
|
| |
|
| |
|
| | class FusionSeqRecFinetune(BaseDataset):
|
| |
|
| | def __init__(self, args, mode="train",
|
| | prompt_sample_num=1, prompt_id=0, sample_num=-1):
|
| | super().__init__(args)
|
| |
|
| | self.mode = mode
|
| | self.prompt_sample_num = prompt_sample_num
|
| | self.prompt_id = prompt_id
|
| | self.sample_num = sample_num
|
| |
|
| | self.prompts = all_prompt["fusionseqrec"]
|
| |
|
| |
|
| | self._load_data()
|
| |
|
| |
|
| |
|
| | if self.mode == 'train':
|
| | self.inter_data = self._process_train_data()
|
| | elif self.mode == 'valid':
|
| | self.sample_valid = args.sample_valid
|
| | self.valid_prompt_id = args.valid_prompt_id
|
| | self.inter_data = self._process_valid_data()
|
| | self._construct_valid_text()
|
| | elif self.mode == 'test':
|
| | self.inter_data = self._process_test_data()
|
| | else:
|
| | raise NotImplementedError
|
| |
|
| |
|
| | def _load_data(self):
|
| |
|
| | with open(os.path.join(self.data_path, self.dataset + ".inter.json"), 'r') as f:
|
| | self.inters = json.load(f)
|
| | with open(self.index_file, 'r') as f:
|
| | self.indices = json.load(f)
|
| |
|
| |
|
| | with open(os.path.join(self.data_path, self.dataset + ".item.json"), 'r') as f:
|
| | self.item_feat = json.load(f)
|
| |
|
| | def _process_train_data(self):
|
| |
|
| | inter_data = []
|
| | for uid in self.inters:
|
| | items = self.inters[uid][:-2]
|
| | for i in range(1, len(items)):
|
| | one_data = dict()
|
| |
|
| | one_data["item"] = "".join(self.indices[str(items[i])])
|
| | one_data["title"] = self.item_feat[str(items[i])]["title"].strip().strip(".!?,;:`")
|
| | one_data["description"] = self.item_feat[str(items[i])]["description"]
|
| | history = items[:i]
|
| | if self.max_his_len > 0:
|
| | history = history[-self.max_his_len:]
|
| | inters = ["".join(self.indices[str(j)]) for j in history]
|
| | inter_titles = ["\"" + self.item_feat[str(j)]["title"].strip().strip(".!?,;:`") + "\"" for j in history]
|
| |
|
| |
|
| | if self.add_prefix:
|
| | inters = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(inters)]
|
| | inter_titles = [str(k + 1) + ". " + item_title for k, item_title in enumerate(inter_titles)]
|
| |
|
| | one_data["inters"] = self.his_sep.join(inters)
|
| | one_data["inter_titles"] = self.his_sep.join(inter_titles)
|
| | inter_data.append(one_data)
|
| |
|
| | if self.sample_num > 0:
|
| | all_inter_idx = range(len(inter_data))
|
| | sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace=False)
|
| | inter_data = np.array(inter_data)[sample_idx].tolist()
|
| |
|
| | return inter_data
|
| |
|
| | def _process_valid_data(self):
|
| |
|
| | inter_data = []
|
| | for uid in self.inters:
|
| | items = self.inters[uid]
|
| | one_data = dict()
|
| | one_data["item"] = "".join(self.indices[str(items[-2])])
|
| | one_data["title"] = self.item_feat[str(items[-2])]["title"].strip().strip(".!?,;:`")
|
| | one_data["description"] = self.item_feat[str(items[-2])]["description"]
|
| |
|
| |
|
| | history = items[:-2]
|
| | if self.max_his_len > 0:
|
| | history = history[-self.max_his_len:]
|
| | inters = ["".join(self.indices[str(j)]) for j in history]
|
| | inter_titles = ["\"" + self.item_feat[str(j)]["title"].strip().strip(".!?,;:`") + "\"" for j in history]
|
| |
|
| | if self.add_prefix:
|
| | inters = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(inters)]
|
| | inter_titles = [str(k + 1) + ". " + item_title for k, item_title in enumerate(inter_titles)]
|
| |
|
| | one_data["inters"] = self.his_sep.join(inters)
|
| | one_data["inter_titles"] = self.his_sep.join(inter_titles)
|
| | inter_data.append(one_data)
|
| |
|
| | if self.sample_num > 0:
|
| | all_inter_idx = range(len(inter_data))
|
| | sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace=False)
|
| | inter_data = np.array(inter_data)[sample_idx].tolist()
|
| |
|
| | return inter_data
|
| |
|
| | def _process_test_data(self):
|
| |
|
| | inter_data = []
|
| | for uid in self.inters:
|
| | items = self.inters[uid]
|
| | one_data = dict()
|
| | one_data["item"] = "".join(self.indices[str(items[-1])])
|
| | one_data["title"] = self.item_feat[str(items[-1])]["title"].strip().strip(".!?,;:`")
|
| | one_data["description"] = self.item_feat[str(items[-1])]["description"]
|
| |
|
| | history = items[:-1]
|
| | if self.max_his_len > 0:
|
| | history = history[-self.max_his_len:]
|
| | inters = ["".join(self.indices[str(j)]) for j in history]
|
| | inter_titles = ["\"" + self.item_feat[str(j)]["title"].strip().strip(".!?,;:`") + "\"" for j in history]
|
| |
|
| | if self.add_prefix:
|
| | inters = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(inters)]
|
| | inter_titles = [str(k + 1) + ". " + item_title for k, item_title in enumerate(inter_titles)]
|
| |
|
| | one_data["inters"] = self.his_sep.join(inters)
|
| | one_data["inter_titles"] = self.his_sep.join(inter_titles)
|
| | inter_data.append(one_data)
|
| |
|
| | if self.sample_num > 0:
|
| | all_inter_idx = range(len(inter_data))
|
| | sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace=False)
|
| | inter_data = np.array(inter_data)[sample_idx].tolist()
|
| |
|
| | return inter_data
|
| |
|
| | def set_prompt(self, prompt_id):
|
| |
|
| | self.prompt_id = prompt_id
|
| |
|
| | def __len__(self):
|
| | if self.mode == 'train':
|
| | return len(self.inter_data) * self.prompt_sample_num
|
| | elif self.mode == 'valid':
|
| | return len(self.valid_text_data)
|
| | elif self.mode == 'test':
|
| | return len(self.inter_data)
|
| | else:
|
| | raise NotImplementedError
|
| |
|
| | def _construct_valid_text(self):
|
| | self.valid_text_data = []
|
| | if self.sample_valid:
|
| | all_prompt_ids = range(len(self.prompts))
|
| | for i in range(len(self.inter_data)):
|
| | d = self.inter_data[i]
|
| | prompt_ids = np.random.choice(all_prompt_ids, self.prompt_sample_num, replace=False)
|
| | for prompt_id in prompt_ids:
|
| | prompt = self.prompts[prompt_id]
|
| | input, output = self._get_text_data(d, prompt)
|
| | self.valid_text_data.append({"input_ids": input, "labels": output})
|
| | else:
|
| | self.prompt_sample_num = 1
|
| | prompt = self.prompts[self.valid_prompt_id]
|
| | for i in range(len(self.inter_data)):
|
| | d = self.inter_data[i]
|
| | input, output = self._get_text_data(d, prompt)
|
| | self.valid_text_data.append({"input_ids": input, "labels": output})
|
| |
|
| | def _get_text_data(self, data, prompt):
|
| |
|
| | instruction = prompt["instruction"].format(**data)
|
| | response = prompt["response"].format(**data)
|
| |
|
| | input = sft_prompt.format(instruction=instruction, response="")
|
| | output = sft_prompt.format(instruction=instruction, response=response)
|
| |
|
| | if self.mode == 'test':
|
| | return input, response
|
| |
|
| | return input, output
|
| |
|
| | def __getitem__(self, index):
|
| |
|
| | if self.mode == 'valid':
|
| | return self.valid_text_data[index]
|
| |
|
| | idx = index // self.prompt_sample_num
|
| | d = self.inter_data[idx]
|
| |
|
| | if self.mode == 'train':
|
| | prompt_id = random.randint(0, len(self.prompts) - 1)
|
| | elif self.mode == 'test':
|
| | prompt_id = self.prompt_id
|
| |
|
| | prompt = self.prompts[prompt_id]
|
| |
|
| | input, output = self._get_text_data(d, prompt)
|
| |
|
| |
|
| | return dict(input_ids=input, labels=output)
|
| |
|
| |
|
| | class ItemFeatFinetune(BaseDataset):
|
| |
|
| | def __init__(self, args, task="item2index", prompt_sample_num=1, sample_num=-1):
|
| | super().__init__(args)
|
| |
|
| | self.task = task.lower()
|
| | self.prompt_sample_num = prompt_sample_num
|
| | self.sample_num = sample_num
|
| |
|
| | self.prompts = all_prompt[self.task]
|
| |
|
| |
|
| | self._load_data()
|
| | self.feat_data = self._process_data()
|
| |
|
| |
|
| |
|
| | def _load_data(self):
|
| |
|
| |
|
| |
|
| | with open(self.index_file, 'r') as f:
|
| | self.indices = json.load(f)
|
| | with open(os.path.join(self.data_path, self.dataset + ".item.json"), 'r') as f:
|
| | self.item_feat = json.load(f)
|
| |
|
| |
|
| | def _process_data(self):
|
| |
|
| | feat_data = []
|
| | for iid in self.item_feat:
|
| | feat = self.item_feat[iid]
|
| | index = "".join(self.indices[iid])
|
| | feat["item"] = index
|
| | feat["title"] = feat["title"].strip().strip(".!?,;:`")
|
| | feat_data.append(feat)
|
| |
|
| | if self.sample_num > 0:
|
| | all_idx = range(len(feat_data))
|
| | sample_idx = np.random.choice(all_idx, self.sample_num, replace=False)
|
| |
|
| | feat_data = np.array(feat_data)[sample_idx].tolist()
|
| |
|
| | return feat_data
|
| |
|
| |
|
| | def __len__(self):
|
| | return len(self.feat_data) * self.prompt_sample_num
|
| |
|
| | def _get_text_data(self, data, prompt):
|
| |
|
| | instruction = prompt["instruction"].format(**data)
|
| | response = prompt["response"].format(**data)
|
| |
|
| | input = sft_prompt.format(instruction = instruction, response = "")
|
| | output = sft_prompt.format(instruction = instruction, response = response)
|
| |
|
| | return input, output
|
| |
|
| | def __getitem__(self, index):
|
| |
|
| | idx = index // self.prompt_sample_num
|
| | d = self.feat_data[idx]
|
| |
|
| | prompt_id = random.randint(0, len(self.prompts) - 1)
|
| |
|
| | prompt = self.prompts[prompt_id]
|
| |
|
| | input, output = self._get_text_data(d, prompt)
|
| |
|
| | return dict(input_ids=input, labels=output)
|
| |
|
| |
|
| | class ItemSearchFinetune(BaseDataset):
|
| |
|
| | def __init__(self, args, mode="train",
|
| | prompt_sample_num=1, prompt_id=0, sample_num=-1):
|
| | super().__init__(args)
|
| |
|
| | self.mode = mode
|
| | self.prompt_sample_num = prompt_sample_num
|
| | self.prompt_id = prompt_id
|
| | self.sample_num = sample_num
|
| |
|
| | self.prompts = all_prompt["itemsearch"]
|
| |
|
| |
|
| | self._load_data()
|
| | self.search_data = self._process_data()
|
| |
|
| |
|
| |
|
| | def _load_data(self):
|
| |
|
| |
|
| |
|
| | with open(self.index_file, 'r') as f:
|
| | self.indices = json.load(f)
|
| | with open(os.path.join(self.data_path, self.dataset + ".user.json"), 'r') as f:
|
| | self.user_info = json.load(f)
|
| |
|
| |
|
| | def _process_data(self):
|
| |
|
| | search_data = []
|
| | user_explicit_preference = self.user_info["user_explicit_preference"]
|
| | user_vague_intention = self.user_info["user_vague_intention"]
|
| | if self.mode == 'train':
|
| | user_vague_intention = user_vague_intention["train"]
|
| | elif self.mode == 'test':
|
| | user_vague_intention = user_vague_intention["test"]
|
| | else:
|
| | raise NotImplementedError
|
| |
|
| | for uid in user_explicit_preference.keys():
|
| | one_data = {}
|
| | user_ep = user_explicit_preference[uid]
|
| | user_vi = user_vague_intention[uid]["querys"]
|
| | one_data["explicit_preferences"] = user_ep
|
| | one_data["user_related_intention"] = user_vi[0]
|
| | one_data["item_related_intention"] = user_vi[1]
|
| |
|
| | iid = user_vague_intention[uid]["item"]
|
| | inters = user_vague_intention[uid]["inters"]
|
| |
|
| | index = "".join(self.indices[str(iid)])
|
| | one_data["item"] = index
|
| |
|
| | if self.max_his_len > 0:
|
| | inters = inters[-self.max_his_len:]
|
| | inters = ["".join(self.indices[str(i)]) for i in inters]
|
| | if self.add_prefix:
|
| | inters = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(inters)]
|
| |
|
| | one_data["inters"] = self.his_sep.join(inters)
|
| |
|
| | search_data.append(one_data)
|
| |
|
| | if self.sample_num > 0:
|
| | all_idx = range(len(search_data))
|
| | sample_idx = np.random.choice(all_idx, self.sample_num, replace=False)
|
| |
|
| | search_data = np.array(search_data)[sample_idx].tolist()
|
| |
|
| | return search_data
|
| |
|
| | def set_prompt(self, prompt_id):
|
| | self.prompt_id = prompt_id
|
| |
|
| | def __len__(self):
|
| | if self.mode == 'train':
|
| | return len(self.search_data) * self.prompt_sample_num
|
| | elif self.mode == 'test':
|
| | return len(self.search_data)
|
| | else:
|
| | return len(self.search_data)
|
| |
|
| |
|
| | def _get_text_data(self, data, prompt):
|
| |
|
| | instruction = prompt["instruction"].format(**data)
|
| | response = prompt["response"].format(**data)
|
| |
|
| | input = sft_prompt.format(instruction = instruction, response = "")
|
| | output = sft_prompt.format(instruction = instruction, response = response)
|
| |
|
| | if self.mode == 'test':
|
| | return input, response
|
| |
|
| | return input, output
|
| |
|
| | def __getitem__(self, index):
|
| |
|
| | idx = index // self.prompt_sample_num
|
| |
|
| | d = self.search_data[idx]
|
| | if self.mode == 'train':
|
| | prompt_id = random.randint(0, len(self.prompts) - 1)
|
| | elif self.mode == 'test':
|
| | prompt_id = self.prompt_id
|
| |
|
| | prompt = self.prompts[prompt_id]
|
| |
|
| | d["explicit_preference"] = copy.deepcopy(random.choice(d["explicit_preferences"]))
|
| | all_querys = [d["user_related_intention"], d["item_related_intention"]]
|
| | d["query"] = random.choice(all_querys)
|
| |
|
| | input, output = self._get_text_data(d, prompt)
|
| |
|
| | return dict(input_ids=input, labels=output)
|
| |
|
| |
|
| |
|
| | class PreferenceObtainFinetune(BaseDataset):
|
| |
|
| | def __init__(self, args, prompt_sample_num=1, sample_num=-1):
|
| | super().__init__(args)
|
| |
|
| | self.prompt_sample_num = prompt_sample_num
|
| | self.sample_num = sample_num
|
| |
|
| | self.prompts = all_prompt["preferenceobtain"]
|
| |
|
| |
|
| | self._load_data()
|
| | self._remap_items()
|
| |
|
| | self.preference_data = self._process_data()
|
| |
|
| |
|
| |
|
| | def _load_data(self):
|
| |
|
| | with open(os.path.join(self.data_path, self.dataset + ".user.json"), 'r') as f:
|
| | self.user_info = json.load(f)
|
| | with open(os.path.join(self.data_path, self.dataset + ".inter.json"), 'r') as f:
|
| | self.inters = json.load(f)
|
| |
|
| |
|
| | with open(self.index_file, 'r') as f:
|
| | self.indices = json.load(f)
|
| |
|
| |
|
| | def _remap_items(self):
|
| |
|
| | self.remapped_inters = dict()
|
| | for uid, items in self.inters.items():
|
| | new_items = ["".join(self.indices[str(i)]) for i in items]
|
| | self.remapped_inters[uid] = new_items
|
| |
|
| | def _process_data(self):
|
| |
|
| | preference_data = []
|
| | user_explicit_preference = self.user_info["user_explicit_preference"]
|
| |
|
| | for uid in user_explicit_preference.keys():
|
| | one_data = {}
|
| | inters = self.remapped_inters[uid][:-3]
|
| | user_ep = user_explicit_preference[uid]
|
| |
|
| | if self.max_his_len > 0:
|
| | inters = inters[-self.max_his_len:]
|
| | if self.add_prefix:
|
| | inters = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(inters)]
|
| |
|
| | one_data["explicit_preferences"] = user_ep
|
| | one_data["inters"] = self.his_sep.join(inters)
|
| |
|
| | preference_data.append(one_data)
|
| |
|
| | if self.sample_num > 0:
|
| | all_idx = range(len(preference_data))
|
| | sample_idx = np.random.choice(all_idx, self.sample_num, replace=False)
|
| |
|
| | preference_data = np.array(preference_data)[sample_idx].tolist()
|
| |
|
| | return preference_data
|
| |
|
| | def set_prompt(self, prompt_id):
|
| | self.prompt_id = prompt_id
|
| |
|
| | def __len__(self):
|
| | return len(self.preference_data) * self.prompt_sample_num
|
| |
|
| |
|
| | def _get_text_data(self, data, prompt):
|
| |
|
| | instruction = prompt["instruction"].format(**data)
|
| | response = prompt["response"].format(**data)
|
| |
|
| | input = sft_prompt.format(instruction = instruction, response = "")
|
| | output = sft_prompt.format(instruction = instruction, response = response)
|
| |
|
| | return input, output
|
| |
|
| | def __getitem__(self, index):
|
| |
|
| | idx = index // self.prompt_sample_num
|
| |
|
| | d = self.preference_data[idx]
|
| | prompt_id = random.randint(0, len(self.prompts) - 1)
|
| |
|
| | prompt = self.prompts[prompt_id]
|
| |
|
| | d["explicit_preference"] = copy.deepcopy(random.choice(d["explicit_preferences"]))
|
| |
|
| | input, output = self._get_text_data(d, prompt)
|
| |
|
| | return dict(input_ids=input, labels=output)
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | class SeqRecTestDataset(BaseDataset):
|
| |
|
| | def __init__(self, args, prompt_id=0, sample_num=-1):
|
| | super().__init__(args)
|
| |
|
| | self.prompt_id = prompt_id
|
| | self.sample_num = sample_num
|
| |
|
| | self.prompt = all_prompt["seqrec"][self.prompt_id]
|
| |
|
| |
|
| | self._load_data()
|
| | self._remap_items()
|
| |
|
| | self.inter_data = self._process_test_data()
|
| |
|
| | def _load_data(self):
|
| |
|
| | with open(os.path.join(self.data_path, self.dataset + ".inter.json"), 'r') as f:
|
| | self.inters = json.load(f)
|
| | with open(os.path.join(self.data_path, self.dataset + self.index_file), 'r') as f:
|
| | self.indices = json.load(f)
|
| |
|
| |
|
| | def _remap_items(self):
|
| |
|
| | self.remapped_inters = dict()
|
| | for uid, items in self.inters.items():
|
| | new_items = ["".join(self.indices[str(i)]) for i in items]
|
| | self.remapped_inters[uid] = new_items
|
| |
|
| | def _process_test_data(self):
|
| |
|
| | inter_data = []
|
| | for uid in self.remapped_inters:
|
| | items = self.remapped_inters[uid]
|
| | one_data = dict()
|
| |
|
| | one_data["item"] = items[-1]
|
| | history = items[:-1]
|
| | if self.max_his_len > 0:
|
| | history = history[-self.max_his_len:]
|
| | if self.add_prefix:
|
| | history = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(history)]
|
| | one_data["inters"] = self.his_sep.join(history)
|
| | inter_data.append(one_data)
|
| |
|
| | if self.sample_num > 0:
|
| | all_inter_idx = range(len(inter_data))
|
| | sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace=False)
|
| |
|
| | inter_data = np.array(inter_data)[sample_idx].tolist()
|
| |
|
| | return inter_data
|
| |
|
| | def set_prompt(self, prompt_id):
|
| | self.prompt_id = prompt_id
|
| |
|
| | self.prompt = all_prompt["seqrec"][self.prompt_id]
|
| |
|
| | def __len__(self):
|
| |
|
| | return len(self.inter_data)
|
| |
|
| | def _get_text_data(self, data, prompt):
|
| |
|
| | instruction = prompt["instruction"].format(**data)
|
| | response = prompt["response"].format(**data)
|
| |
|
| | input = sft_prompt.format(instruction=instruction, response="")
|
| |
|
| | return input, response
|
| |
|
| | def __getitem__(self, index):
|
| |
|
| | d = self.inter_data[index]
|
| | input, target = self._get_text_data(d, self.prompt)
|
| |
|
| | return dict(input_ids=input, labels=target) |