import json import torch import random import transformers import networkx as nx from tqdm import tqdm from peft import (LoraConfig, get_peft_model, prepare_model_for_kbit_training) from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig class QloraTrainer_CS: def __init__(self, config: dict, use_predefined_graph=False): self.config = config self.use_predefined_graph = use_predefined_graph self.tokenizer = None self.base_model = None self.adapter_model = None self.merged_model = None self.transformer_trainer = None self.test_data = None template_file_path = 'configs/alpaca.json' with open(template_file_path) as fp: self.template = json.load(fp) def load_base_model(self): model_id = self.config['inference']["base_model"] print(model_id) bnb_config = BitsAndBytesConfig( load_in_8bit=True, bnb_8bit_use_double_quant=True, bnb_8bit_quant_type="nf8", bnb_8bit_compute_dtype=torch.bfloat16 ) tokenizer = AutoTokenizer.from_pretrained(model_id) tokenizer.model_max_length = self.config['training']['tokenizer']["max_length"] if not tokenizer.pad_token: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, torch_dtype=torch.bfloat16) if model.device.type != 'cuda': model.to('cuda') model.gradient_checkpointing_enable() model = prepare_model_for_kbit_training(model) self.tokenizer = tokenizer self.base_model = model def train(self): # Set up lora config or load pre-trained adapter lora_config = LoraConfig( r=self.config['training']['qlora']['rank'], lora_alpha=self.config['training']['qlora']['lora_alpha'], target_modules=self.config['training']['qlora']['target_modules'], lora_dropout=self.config['training']['qlora']['lora_dropout'], bias="none", task_type="CAUSAL_LM", ) model = get_peft_model(self.base_model, lora_config) self._print_trainable_parameters(model) print("Start data preprocessing") train_data = self._process_data_instruction() print('Length of dataset: ', len(train_data)) print("Start training") self.transformer_trainer = transformers.Trainer( model=model, train_dataset=train_data, args=transformers.TrainingArguments( per_device_train_batch_size=self.config["training"]['trainer_args']["per_device_train_batch_size"], gradient_accumulation_steps=self.config['model_saving']['index'], warmup_steps=self.config["training"]['trainer_args']["warmup_steps"], num_train_epochs=self.config["training"]['trainer_args']["num_train_epochs"], learning_rate=self.config["training"]['trainer_args']["learning_rate"], lr_scheduler_type=self.config["training"]['trainer_args']["lr_scheduler_type"], fp16=self.config["training"]['trainer_args']["fp16"], logging_steps=self.config["training"]['trainer_args']["logging_steps"], output_dir=self.config["training"]['trainer_args']["trainer_output_dir"], report_to="wandb", save_steps=self.config["training"]['trainer_args']["save_steps"], ), data_collator=transformers.DataCollatorForLanguageModeling(self.tokenizer, mlm=False), ) model.config.use_cache = False self.transformer_trainer.train() model_save_path = f"{self.config['model_saving']['model_output_dir']}/{self.config['model_saving']['model_name']}_{self.config['model_saving']['index']}_adapter_test_graph" self.transformer_trainer.save_model(model_save_path) self.adapter_model = model print(f"Training complete, adapter model saved in {model_save_path}") def _print_trainable_parameters(self, model): """ Prints the number of trainable parameters in the model. """ trainable_params = 0 all_param = 0 for _, param in model.named_parameters(): all_param += param.numel() if param.requires_grad: trainable_params += param.numel() print( f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}" ) def _process_data_instruction(self): context_window = self.tokenizer.model_max_length if self.use_predefined_graph: graph_data = nx.read_gexf('datasets/' + self.config["training"]["predefined_graph_path"], node_type=None, relabel=False, version='1.2draft') else: graph_path = self.config['data_downloading']['download_directory'] + 'description/' + self.config['data_downloading']['gexf_file'] graph_data = nx.read_gexf(graph_path, node_type=None, relabel=False, version='1.2draft') raw_graph = graph_data test_set_size = len(graph_data.nodes()) // 10 all_test_nodes = set(list(graph_data.nodes())[:test_set_size]) all_train_nodes = set(list(graph_data.nodes())[test_set_size:]) raw_id_2_title_abs = dict() for paper_id in list(graph_data.nodes())[test_set_size:]: title = graph_data.nodes()[paper_id]['title'] abstract = graph_data.nodes()[paper_id]['abstract'] raw_id_2_title_abs[paper_id] = [title, abstract] raw_id_2_intro = dict() for paper_id in list(graph_data.nodes())[test_set_size:]: if graph_data.nodes[paper_id]['introduction'] != '': intro = graph_data.nodes[paper_id]['introduction'] raw_id_2_intro[paper_id] = intro raw_id_pair_2_sentence = dict() for edge in list(graph_data.edges()): sentence = graph_data.edges()[edge]['sentence'] raw_id_pair_2_sentence[edge] = sentence test_data = [] edge_list = [] for edge in list(raw_graph.edges()): src, tar = edge if src not in all_test_nodes and tar not in all_test_nodes: edge_list.append(edge) else: test_data.append(edge) train_num = int(len(edge_list)) data_LP = [] data_abstract_2_title = [] data_paper_retrieval = [] data_citation_sentence = [] data_abs_completion = [] data_title_2_abs = [] data_intro_2_abs = [] for sample in tqdm(random.sample(edge_list, train_num)): source, target = sample[0], sample[1] source_title, source_abs = raw_id_2_title_abs[source] target_title, target_abs = raw_id_2_title_abs[target] # LP prompt rand_ind = random.choice(list(raw_id_2_title_abs.keys())) neg_title, neg_abs = raw_id_2_title_abs[rand_ind] data_LP.append({'s_title':source_title, 's_abs':source_abs, 't_title':target_title, 't_abs':target_abs, 'label':'yes'}) data_LP.append({'s_title':source_title, 's_abs':source_abs, 't_title':neg_title, 't_abs':neg_abs, 'label':'no'}) for sample in tqdm(random.sample(edge_list, train_num)): source, target = sample[0], sample[1] source_title, source_abs = raw_id_2_title_abs[source] target_title, target_abs = raw_id_2_title_abs[target] # abs_2_title prompt data_abstract_2_title.append({'title':source_title, 'abs':source_abs}) data_abstract_2_title.append({'title':target_title, 'abs':target_abs}) for sample in tqdm(random.sample(edge_list, train_num)): source, target = sample[0], sample[1] source_title, source_abs = raw_id_2_title_abs[source] target_title, target_abs = raw_id_2_title_abs[target] # paper_retrieval prompt neighbors = list(nx.all_neighbors(raw_graph, source)) sample_node_list = list(all_train_nodes - set(neighbors) - set([source]) - set([target])) sampled_neg_nodes = random.sample(sample_node_list, 5) + [target] random.shuffle(sampled_neg_nodes) data_paper_retrieval.append({'title':source_title, 'abs':source_abs, 'sample_title': [raw_id_2_title_abs[node][0] for node in sampled_neg_nodes], 'right_title':target_title}) for sample in tqdm(random.sample(edge_list, train_num)): source, target = sample[0], sample[1] source_title, source_abs = raw_id_2_title_abs[source] target_title, target_abs = raw_id_2_title_abs[target] # citation_sentence prompt citation_sentence = raw_id_pair_2_sentence[(source, target)] if (source, target) in raw_id_pair_2_sentence.keys() else raw_id_pair_2_sentence[(target, source)] data_citation_sentence.append({'s_title':source_title, 's_abs':source_abs, 't_title':target_title, 't_abs':target_abs, 'sentence': citation_sentence}) for sample in tqdm(random.sample(edge_list, train_num)): source, target = sample[0], sample[1] source_title, source_abs = raw_id_2_title_abs[source] target_title, target_abs = raw_id_2_title_abs[target] # abs_complete prompt data_abs_completion.append({'title':source_title, 'abs':source_abs}) data_abs_completion.append({'title':target_title, 'abs':target_abs}) for sample in tqdm(random.sample(edge_list, train_num)): source, target = sample[0], sample[1] source_title, source_abs = raw_id_2_title_abs[source] target_title, target_abs = raw_id_2_title_abs[target] # title_2_abs prompt data_title_2_abs.append({'title':source_title, 'right_abs':source_abs}) data_title_2_abs.append({'title':target_title, 'right_abs':target_abs}) for sample in tqdm(random.sample(edge_list, train_num)): source, target = sample[0], sample[1] if source in raw_id_2_intro: source_intro = raw_id_2_intro[source] _, source_abs = raw_id_2_title_abs[source] data_intro_2_abs.append({'intro':source_intro, 'abs':source_abs}) if target in raw_id_2_intro: target_intro = raw_id_2_intro[target] _, target_abs = raw_id_2_title_abs[target] data_intro_2_abs.append({'intro':target_intro, 'abs':target_abs}) data_prompt = [] data_prompt += [self._generate_paper_retrieval_prompt(data_point) for data_point in data_paper_retrieval] data_prompt += [self._generate_LP_prompt(data_point) for data_point in data_LP] data_prompt += [self._generate_abstract_2_title_prompt(data_point) for data_point in data_abstract_2_title] data_prompt += [self._generate_citation_sentence_prompt(data_point) for data_point in data_citation_sentence] data_prompt += [self._generate_abstract_completion_prompt(data_point) for data_point in data_abs_completion] data_prompt += [self._generate_title_2_abstract_prompt(data_point) for data_point in data_title_2_abs] data_prompt += [self._generate_intro_2_abstract_prompt(data_point, context_window) for data_point in data_intro_2_abs] print("Total prompts:", len(data_prompt)) random.shuffle(data_prompt) if self.tokenizer.chat_template is None: data_tokenized = [self.tokenizer(sample, max_length=context_window, truncation=True) for sample in tqdm(data_prompt)] else: data_tokenized = [self.tokenizer.apply_chat_template(sample, max_length=context_window, truncation=True, tokenize=False) for sample in tqdm(data_prompt)] return data_tokenized def _generate_LP_prompt(self, data_point: dict): instruction = "Determine if paper A will cite paper B." prompt_input = "" prompt_input = prompt_input + "Title of Paper A: " + (data_point['s_title'] if data_point['s_title'] != None else 'Unknown') + "\n" prompt_input = prompt_input + "Abstract of Paper A: " + (data_point['s_abs'] if data_point['s_abs'] != None else 'Unknown') + "\n" prompt_input = prompt_input + "Title of Paper B: " + (data_point['t_title'] if data_point['t_title'] != None else 'Unknown') + "\n" prompt_input = prompt_input + "Abstract of Paper B: " + (data_point['t_abs'] if data_point['t_abs'] != None else 'Unknown') + "\n" if self.tokenizer.chat_template is None: res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input) res = f"{res}{data_point['label']}" else: res = [ {"role": "user", "content": self.template["prompt_input"].format(instruction=instruction, input=prompt_input)}, {"role": "assistant", "content": data_point['label']} ] return res def _generate_abstract_2_title_prompt(self, data_point: dict): instruction = "Please generate the title of paper based on its abstract." prompt_input = "" prompt_input = prompt_input + "Abstract: " + data_point['abs'] + "\n" if self.tokenizer.chat_template is None: res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input) res = f"{res}{data_point['title']}" else: res = [ {"role": "user", "content": self.template["prompt_input"].format(instruction=instruction, input=prompt_input)}, {"role": "assistant", "content": data_point['title']} ] return res def _generate_paper_retrieval_prompt(self, data_point: dict): instruction = "Please select the paper that is more likely to be cited by paper A from candidate papers." prompt_input = "" prompt_input = prompt_input + "Title of the Paper A: " + data_point['title'] + "\n" prompt_input = prompt_input + "Abstract of the Paper A: " + data_point['abs'] + "\n" prompt_input = prompt_input + "candidate papers: " + "\n" for i in range(len(data_point['sample_title'])): prompt_input = prompt_input + str(i) + '. ' + data_point['sample_title'][i] + "\n" if self.tokenizer.chat_template is None: res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input) res = f"{res}{data_point['right_title']}" else: res = [ {"role": "user", "content": self.template["prompt_input"].format(instruction=instruction, input=prompt_input)}, {"role": "assistant", "content": data_point['right_title']} ] return res def _generate_citation_sentence_prompt(self, data_point: dict): instruction = "Please generate the citation sentence of how Paper A cites paper B in its related work section." prompt_input = "" prompt_input = prompt_input + "Title of Paper A: " + (data_point['s_title'] if data_point['s_title'] != None else 'Unknown') + "\n" prompt_input = prompt_input + "Abstract of Paper A: " + (data_point['s_abs'] if data_point['s_abs'] != None else 'Unknown') + "\n" prompt_input = prompt_input + "Title of Paper B: " + (data_point['t_title'] if data_point['t_title'] != None else 'Unknown') + "\n" prompt_input = prompt_input + "Abstract of Paper B: " + (data_point['t_abs'] if data_point['t_abs'] != None else 'Unknown') + "\n" if self.tokenizer.chat_template is None: res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input) res = f"{res}{data_point['sentence']}" else: res = [ {"role": "user", "content": self.template["prompt_input"].format(instruction=instruction, input=prompt_input)}, {"role": "assistant", "content": data_point['sentence']} ] return res def _generate_abstract_completion_prompt(self, data_point: dict): instruction = "Please complete the abstract of a paper." prompt_input = "" prompt_input = prompt_input + "Title: " + data_point['title'] if data_point['title'] != None else 'Unknown' + "\n" split_abs = data_point['abs'][: int(0.3*len(data_point['abs']))] prompt_input = prompt_input + "Part of abstract: " + split_abs + "\n" if self.tokenizer.chat_template is None: res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input) res = f"{res}{data_point['abs']}" else: res = [ {"role": "user", "content": self.template["prompt_input"].format(instruction=instruction, input=prompt_input)}, {"role": "assistant", "content": data_point['abs']} ] return res def _generate_title_2_abstract_prompt(self, data_point: dict): instruction = "Please generate the abstract of paper based on its title." prompt_input = "" prompt_input = prompt_input + "Title: " + data_point['title'] + "\n" if self.tokenizer.chat_template is None: res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input) res = f"{res}{data_point['right_abs']}" else: res = [ {"role": "user", "content": self.template["prompt_input"].format(instruction=instruction, input=prompt_input)}, {"role": "assistant", "content": data_point['right_abs']} ] return res def _generate_intro_2_abstract_prompt(self, data_point: dict, context_window): instruction = "Please generate the abstract of paper based on its introduction section." prompt_input = "" prompt_input = prompt_input + "Introduction: " + data_point['intro'] + "\n" # Reduce it to make it fit prompt_input = prompt_input[:int(context_window*2)] if self.tokenizer.chat_template is None: res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input) res = f"{res}{data_point['abs']}" else: res = [ {"role": "user", "content": self.template["prompt_input"].format(instruction=instruction, input=prompt_input)}, {"role": "assistant", "content": data_point['abs']} ] return res