import copy import json import time import torch import torch.nn as nn from transformers import PreTrainedModel, PretrainedConfig,AutoConfig from .modeling_llama_kv import LlamaForCausalLM as KVLlamaForCausalLM from .modeling_mixtral_kv import MixtralForCausalLM as KVMixtralForCausalLM from .utils import * from .kv_cache import initialize_past_key_values from transformers import AutoTokenizer import os from huggingface_hub import hf_hub_download from .cnets import Model from .configs import EConfig from huggingface_hub import hf_hub_download class EaModel(nn.Module): def __init__( self, base_model, base_model_name_or_path, ea_model_path, total_token, depth, top_k, threshold, ea_layer_state_dict ): super().__init__() self.base_model = base_model self.config = base_model.config self.hidden_size = base_model.lm_head.weight.shape[-1] self.vocab_size = base_model.lm_head.weight.shape[0] self.base_model_name_or_path = base_model_name_or_path self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name_or_path,use_fast=False) config = EConfig.from_pretrained(ea_model_path) with open(ea_model_path,"r") as f: con=json.loads(f.read()) try: bias=con["bias"] except: bias=True print("draft init") self.ea_layer = Model(config,bias=bias,total_tokens=total_token,depth=depth,top_k=top_k,threshold=threshold) print("draft init end") low_memory=False device = base_model.model.layers[-1].self_attn.q_proj.weight.device if device!=base_model.lm_head.weight.device: self.ea_layer.diff_device = True if not low_memory: # self.ea_layer.head=nn.Linear(base_model.lm_head.in_features,base_model.lm_head.out_features,bias=False) # self.ea_layer.head.weight=copy.deepcopy(base_model.lm_head.weight) # self.ea_layer.head.to(device) self.ea_layer.headweight = base_model.lm_head.weight.clone().to(device) else: self.ea_layer.layer_device = device else: self.ea_layer.diff_device = False self.ea_layer.load_state_dict(ea_layer_state_dict, strict=True) self.ea_layer.to(self.base_model.dtype).to(device) self.ea_layer.init_tree() def get_tokenizer(self): """Get the tokenizer of the base model. Returns: Tokenizer: The tokenizer of the base model. """ return self.tokenizer @classmethod def from_pretrained( cls, Type="LLaMA", base_model_path=None, ea_model_path=None, total_token=59, depth=5, top_k=10, threshold=1.0, **kwargs, ): #assert Type=="LLaMA" or "Mixtral" Type=AutoConfig.from_pretrained(base_model_path).architectures[0] if Type=='LlamaForCausalLM': base_model = KVLlamaForCausalLM.from_pretrained( base_model_path, **kwargs ) else: base_model = KVMixtralForCausalLM.from_pretrained( base_model_path, **kwargs ) base_model.cuda() configpath=os.path.join(ea_model_path,"config.json") if not os.path.exists(configpath): configpath = hf_hub_download(ea_model_path, "config.json") load_model_path=os.path.join(ea_model_path, "pytorch_model.bin") if not os.path.exists(load_model_path): load_model_path=hf_hub_download(ea_model_path, "pytorch_model.bin") ea_layer_state_dict = torch.load(load_model_path, map_location="cpu") model = cls( base_model, base_model_path, configpath, total_token, depth, top_k, threshold, ea_layer_state_dict ) if total_token==-1: device = model.base_model.model.layers[0].self_attn.q_proj.weight.device cans=[40,48,50,56,60] x=[1,1.05,1.07,1.1,1.13] times=[] for i in range(len(cans)): length = cans[i] input_ids = torch.randint(0, model.config.vocab_size - 200, (1, length)).to(device) torch.cuda.synchronize() start_time = time.time() for _ in range(20): torch.cuda.synchronize() with torch.no_grad(): outputs = model.base_model(input_ids) torch.cuda.synchronize() torch.cuda.synchronize() end_time = time.time() times.append((end_time - start_time) / x[i]) total_token=cans[times.index(min(times))] model.ea_layer.total_tokens=total_token-1 return model def forward( self, input_ids=None, attention_mask=None, past_key_values=None, output_orig=False, position_ids=None, ): with torch.inference_mode(): # Pass input through the base model print("Weight is on device:", self.base_model.base_model.embed_tokens.weight.device) print("Input is on device:", input_ids.device) outputs = self.base_model.model( input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, position_ids=position_ids, ) if output_orig: orig = self.base_model.lm_head(outputs[0]) hidden_states = outputs[0] # if init: # if logits_processor is not None: # logits = orig[:, -1] # logits = logits_processor(None, logits) # probabilities = torch.nn.functional.softmax(logits, dim=1) # token = torch.multinomial(probabilities, 1) # else: # token = torch.argmax(orig[:, -1]) # token = token[None, None] # input_ids = torch.cat((input_ids, token.to(input_ids.device)), dim=1) # # Clone the output hidden states # # draft_tokens, retrieve_indices,tree_mask,tree_position_ids = self.ea_layer.topK_genrate(hidden_states, input_ids, self.base_model.lm_head) # if output_orig: # return draft_tokens, retrieve_indices,tree_mask,tree_position_ids, outputs, orig, hidden_states, token # return draft_tokens, retrieve_indices,tree_mask,tree_position_ids, hidden_states, token # else: if output_orig: return outputs, orig, hidden_states else: return outputs, hidden_states @torch.no_grad() def eagenerate( self, input_ids, temperature=0.0, top_p=0.0, top_k=0.0, max_new_tokens=512, max_length=2048, log=False, is_llama3=False, ): if is_llama3: stop_token_id = self.tokenizer.convert_tokens_to_ids("<|eot_id|>") max_length=max_length-self.ea_layer.total_tokens-10 if temperature > 1e-5: logits_processor = prepare_logits_processor(temperature=temperature, top_p=top_p, top_k=top_k) else: logits_processor = None #assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!" # Avoid modifying the input_ids in-place padding=(torch.zeros(1,1,dtype=torch.long)-1).to(input_ids.device) input_ids = input_ids.clone() self.ea_layer.reset_kv() # Initialize the past key and value states if hasattr(self, "past_key_values"): past_key_values = self.past_key_values past_key_values_data = self.past_key_values_data current_length_data = self.current_length_data # Reset the past key and value states current_length_data.zero_() else: ( past_key_values, past_key_values_data, current_length_data, ) = initialize_past_key_values(self.base_model) self.past_key_values = past_key_values self.past_key_values_data = past_key_values_data self.current_length_data = current_length_data input_len = input_ids.shape[1] reset_tree_mode(self) draft_tokens, retrieve_indices,tree_mask,tree_position_ids, logits, hidden_state, sample_token = initialize_tree( input_ids, self, past_key_values, logits_processor ) new_token = 0 for idx in range(max_length): #with Timer("all"): self.base_model.model.tree_mask = tree_mask draft_tokens=draft_tokens.to(input_ids.device) #with Timer("tree_decoding"): logits, hidden_state_new, outputs = tree_decoding( self, draft_tokens, past_key_values, tree_position_ids, input_ids, retrieve_indices, ) #retrieve_indices=tree_buffers["retrieve_indices"] #logits = logits[0, retrieve_indices] draft_tokens=torch.cat((draft_tokens,padding),dim=1) candidates=draft_tokens[0,retrieve_indices] best_candidate, accept_length, sample_p = evaluate_posterior( logits, candidates, logits_processor ) # print(accept_length) #with Timer("update_inference_inputs"): input_ids, draft_tokens, retrieve_indices,tree_mask,tree_position_ids, new_token, hidden_state, sample_token = update_inference_inputs( input_ids, candidates, best_candidate, accept_length, retrieve_indices, logits_processor, new_token, past_key_values_data, current_length_data, self, hidden_state_new, sample_p ) if is_llama3: if stop_token_id in input_ids[0, input_len:].tolist(): break if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist(): break if new_token > max_new_tokens: break if input_ids.shape[1] > max_length: break if not log: return input_ids else: return input_ids, new_token, idx @torch.no_grad() def naivegenerate( self, input_ids, temperature=0.0, top_p=0.0, top_k=0.0, max_new_tokens=512, max_length=2048, log=False, is_llama3=False, ): if is_llama3: stop_token_id = self.tokenizer.convert_tokens_to_ids("<|eot_id|>") max_length = max_length - self.ea_layer.total_tokens - 10 if temperature > 1e-5: logits_processor = prepare_logits_processor(temperature=temperature, top_p=top_p, top_k=top_k) else: logits_processor = None # assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!" # Avoid modifying the input_ids in-place padding = (torch.zeros(1, 1, dtype=torch.long) - 1).to(input_ids.device) input_ids = input_ids.clone() self.ea_layer.reset_kv() # Initialize the past key and value states if hasattr(self, "past_key_values"): past_key_values = self.past_key_values past_key_values_data = self.past_key_values_data current_length_data = self.current_length_data # Reset the past key and value states current_length_data.zero_() else: ( past_key_values, past_key_values_data, current_length_data, ) = initialize_past_key_values(self.base_model) self.past_key_values = past_key_values self.past_key_values_data = past_key_values_data self.current_length_data = current_length_data input_len = input_ids.shape[1] reset_tree_mode(self) outputs = self.base_model(input_ids, past_key_values=past_key_values, use_cache=True) new_token = 0 for idx in range(max_length): if logits_processor is not None: logits = outputs.logits[:, -1] logits = logits_processor(None, logits) probabilities = torch.nn.functional.softmax(logits, dim=-1) input_id = torch.multinomial(probabilities, 1) else: input_id = outputs.logits[:, -1:].argmax(dim=-1) outputs = self.base_model(input_id, use_cache=True, past_key_values=past_key_values) input_ids = torch.cat([input_ids, input_id], dim=-1) new_token+=1 if is_llama3: if stop_token_id in input_ids[0, input_len:].tolist(): break if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist(): break if new_token > max_new_tokens: break if input_ids.shape[1] > max_length: break if not log: return input_ids else: return input_ids, new_token, idx @torch.no_grad() def ea_generate( self, input_ids, temperature=0.0, top_p=0.0, top_k=0.0, max_new_tokens=512, max_length=2048, log=False, is_llama3=False, ): if is_llama3: stop_token_id = self.tokenizer.convert_tokens_to_ids("<|eot_id|>") max_length=max_length-self.ea_layer.total_tokens-10 if temperature > 1e-5: logits_processor = prepare_logits_processor(temperature=temperature, top_p=top_p, top_k=top_k) else: logits_processor = None #assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!" # Avoid modifying the input_ids in-place padding=(torch.zeros(1,1,dtype=torch.long)-1).to(input_ids.device) input_ids = input_ids.clone() self.ea_layer.reset_kv() # Initialize the past key and value states if hasattr(self, "past_key_values"): past_key_values = self.past_key_values past_key_values_data = self.past_key_values_data current_length_data = self.current_length_data # Reset the past key and value states current_length_data.zero_() else: ( past_key_values, past_key_values_data, current_length_data, ) = initialize_past_key_values(self.base_model) self.past_key_values = past_key_values self.past_key_values_data = past_key_values_data self.current_length_data = current_length_data input_len = input_ids.shape[1] reset_tree_mode(self) draft_tokens, retrieve_indices,tree_mask,tree_position_ids, logits, hidden_state, sample_token = initialize_tree( input_ids, self, past_key_values, logits_processor ) new_token = 0 for idx in range(max_length): #with Timer("all"): self.base_model.model.tree_mask = tree_mask draft_tokens=draft_tokens.to(input_ids.device) #with Timer("tree_decoding"): logits, hidden_state_new, outputs = tree_decoding( self, draft_tokens, past_key_values, tree_position_ids, input_ids, retrieve_indices, ) #retrieve_indices=tree_buffers["retrieve_indices"] #logits = logits[0, retrieve_indices] draft_tokens=torch.cat((draft_tokens,padding),dim=1) candidates=draft_tokens[0,retrieve_indices] best_candidate, accept_length, sample_p = evaluate_posterior( logits, candidates, logits_processor ) # print(accept_length) #with Timer("update_inference_inputs"): input_ids, draft_tokens, retrieve_indices,tree_mask,tree_position_ids, new_token, hidden_state, sample_token = update_inference_inputs( input_ids, candidates, best_candidate, accept_length, retrieve_indices, logits_processor, new_token, past_key_values_data, current_length_data, self, hidden_state_new, sample_p ) yield input_ids if is_llama3: if stop_token_id in input_ids[0, input_len:].tolist(): break if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist(): break if new_token > max_new_tokens: break if input_ids.shape[1] > max_length: break @torch.no_grad() def naive_generate( self, input_ids, temperature=0.0, top_p=0.0, top_k=0.0, max_new_tokens=512, max_length=2048, log=False, is_llama3=False, ): if is_llama3: stop_token_id = self.tokenizer.convert_tokens_to_ids("<|eot_id|>") max_length = max_length - self.ea_layer.total_tokens - 10 if temperature > 1e-5: logits_processor = prepare_logits_processor(temperature=temperature, top_p=top_p, top_k=top_k) else: logits_processor = None # assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!" # Avoid modifying the input_ids in-place padding = (torch.zeros(1, 1, dtype=torch.long) - 1).to(input_ids.device) input_ids = input_ids.clone() self.ea_layer.reset_kv() # Initialize the past key and value states if hasattr(self, "past_key_values"): past_key_values = self.past_key_values past_key_values_data = self.past_key_values_data current_length_data = self.current_length_data # Reset the past key and value states current_length_data.zero_() else: ( past_key_values, past_key_values_data, current_length_data, ) = initialize_past_key_values(self.base_model) self.past_key_values = past_key_values self.past_key_values_data = past_key_values_data self.current_length_data = current_length_data input_len = input_ids.shape[1] reset_tree_mode(self) outputs = self.base_model(input_ids, past_key_values=past_key_values, use_cache=True) new_token = 0 for idx in range(max_length): if logits_processor is not None: logits = outputs.logits[:, -1] logits = logits_processor(None, logits) probabilities = torch.nn.functional.softmax(logits, dim=-1) input_id = torch.multinomial(probabilities, 1) else: input_id = outputs.logits[:, -1:].argmax(dim=-1) outputs = self.base_model(input_id, use_cache=True, past_key_values=past_key_values) input_ids = torch.cat([input_ids, input_id], dim=-1) new_token += 1 yield input_ids if is_llama3: if stop_token_id in input_ids[0, input_len:].tolist(): break if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist(): break if new_token > max_new_tokens: break if input_ids.shape[1] > max_length: break