import logging import sys import os import torch import json from typing import Optional, Tuple, Union, List, Callable from transformers import LlamaForCausalLM from transformers.generation.logits_process import LogitsProcessor from transformers.generation.beam_search import BeamSearchScorer from transformers.deepspeed import is_deepspeed_zero3_enabled from transformers.generation.utils import ( LogitsProcessorList, StoppingCriteriaList, GenerationConfig, GenerationMixin, ) import warnings from peft import PeftModel, PeftModelForCausalLM, LoraConfig import peft import torch.distributed as dist from torch import nn import copy from accelerate.hooks import ( AlignDevicesHook, add_hook_to_module, remove_hook_from_submodules, ) from accelerate.utils import get_balanced_memory from huggingface_hub import hf_hub_download from accelerate import dispatch_model, infer_auto_device_map from peft.utils import PeftType, set_peft_model_state_dict def printf(*args,**kargs): if os.environ.get('DEBUG',False): end = '\n' if 'end' in kargs: end = kargs['end'] print(*args, end=end, flush=True) class ColorFormatter(logging.Formatter): grey = "\x1b[38;20m" blue = "\x1b[34;20m" yellow = "\x1b[33;20m" red = "\x1b[31;20m" bold_red = "\x1b[31;1m" reset = "\x1b[0m" def __init__(self, fmt): super().__init__(fmt) self.FORMATS = { logging.DEBUG: self.grey + fmt + self.reset, logging.INFO: self.blue + fmt + self.reset, logging.WARNING: self.yellow + fmt + self.reset, logging.ERROR: self.red + fmt + self.reset, logging.CRITICAL: self.bold_red + fmt + self.reset } def format(self, record): log_fmt = self.FORMATS.get(record.levelno) formatter = logging.Formatter(log_fmt) return formatter.format(record) def set_console_logger(name): logger = logging.getLogger(name) logger.setLevel(logging.DEBUG) consoleHandler = logging.StreamHandler(sys.stdout) consoleHandler.setLevel(logging.INFO) consoleHandler.setFormatter(ColorFormatter("%(asctime)s | %(levelname)s %(message)s")) logger.addHandler(consoleHandler) return logger def set_file_logger(name, dir, use_console=False): logger = logging.getLogger(name) logger.setLevel(logging.DEBUG) os.makedirs(dir, exist_ok=True) if use_console: logger.propagate = False # disable default handler consoleHandler = logging.StreamHandler(sys.stdout) consoleHandler.setLevel(logging.INFO) consoleHandler.setFormatter(ColorFormatter("%(asctime)s | %(levelname)s %(message)s")) logger.addHandler(consoleHandler) fileHandler = logging.FileHandler(os.path.join(dir,'session.log'), mode='a') fileHandler.setLevel(logging.INFO) fileHandler.setFormatter(logging.Formatter("%(asctime)s | %(levelname)s %(message)s")) logger.addHandler(fileHandler) return logger def to_jsonl(data, path): with open(path, 'a') as f: for line in data: f.write(json.dumps(line,ensure_ascii=False)+'\n') def from_json(path): return json.load(open(path)) def from_jsonl(path): return [json.loads(line) for line in open(path, 'r') ] def to_json(data, path): json.dump(data, open(path, 'w'), ensure_ascii=False) class StreamGenerationMixin(GenerationMixin): # support for streamly generation # TODO: group_beam_search @torch.no_grad() def stream_generate( self, input_ids: Optional[torch.Tensor] = None, generation_config: Optional[GenerationConfig] = None, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, prefix_allowed_tokens_fn: Optional[ Callable[[int, torch.Tensor], List[int]] ] = None, **kwargs, ): if is_deepspeed_zero3_enabled() and dist.world_size() > 1: synced_gpus = True else: synced_gpus = False if kwargs.get("attention_mask", None) is not None: # concat prompt attention mask prefix_attention_mask = torch.ones( kwargs["input_ids"].shape[0], self.peft_config.num_virtual_tokens ).to(kwargs["input_ids"].device) kwargs["attention_mask"] = torch.cat( (prefix_attention_mask, kwargs["attention_mask"]), dim=1 ) if kwargs.get("position_ids", None) is not None: warnings.warn( "Position ids are not supported for parameter efficient tuning. Ignoring position ids." ) kwargs["position_ids"] = None if kwargs.get("token_type_ids", None) is not None: warnings.warn( "Token type ids are not supported for parameter efficient tuning. Ignoring token type ids" ) kwargs["token_type_ids"] = None batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] if generation_config is None: generation_config = self.generation_config generation_config = copy.deepcopy(generation_config) model_kwargs = generation_config.update(**kwargs) bos_token_id, eos_token_id, pad_token_id = ( generation_config.bos_token_id, generation_config.eos_token_id, generation_config.pad_token_id, ) if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] has_default_max_length = ( kwargs.get("max_length") is None and generation_config.max_length is not None ) if has_default_max_length and generation_config.max_new_tokens is None: warnings.warn( f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" " recommend using `max_new_tokens` to control the maximum length of the generation.", UserWarning, ) elif generation_config.max_new_tokens is not None: generation_config.max_length = ( generation_config.max_new_tokens + input_ids_seq_length ) if generation_config.min_new_tokens is not None: generation_config.min_length = ( generation_config.min_new_tokens + input_ids_seq_length ) if input_ids_seq_length >= generation_config.max_length: input_ids_string = ( "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" ) # 2. Set generation parameters if not already defined logits_processor = ( logits_processor if logits_processor is not None else LogitsProcessorList() ) stopping_criteria = ( stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() ) # 7. determine generation mode is_constraint_gen_mode = ( generation_config.constraints is not None or generation_config.force_words_ids is not None ) is_contrastive_search_gen_mode = ( generation_config.top_k is not None and generation_config.top_k > 1 and generation_config.do_sample is False and generation_config.penalty_alpha is not None and generation_config.penalty_alpha > 0 ) is_greedy_gen_mode = ( (generation_config.num_beams == 1) and (generation_config.num_beam_groups == 1) and generation_config.do_sample is False and not is_constraint_gen_mode and not is_contrastive_search_gen_mode ) # beam=1 and do_sample=True is_sample_gen_mode = ( (generation_config.num_beams == 1) and (generation_config.num_beam_groups == 1) and generation_config.do_sample is True and not is_constraint_gen_mode and not is_contrastive_search_gen_mode ) is_beam_gen_mode = ( (generation_config.num_beams > 1) and (generation_config.num_beam_groups == 1) and generation_config.do_sample is False and not is_constraint_gen_mode and not is_contrastive_search_gen_mode ) is_beam_sample_gen_mode = ( (generation_config.num_beams > 1) and (generation_config.num_beam_groups == 1) and generation_config.do_sample is True and not is_constraint_gen_mode and not is_contrastive_search_gen_mode ) is_group_beam_gen_mode = ( (generation_config.num_beams > 1) and (generation_config.num_beam_groups > 1) and not is_constraint_gen_mode and not is_contrastive_search_gen_mode ) # 8. prepare distribution pre_processing samplers logits_processor = self._get_logits_processor( generation_config=generation_config, input_ids_seq_length=input_ids_seq_length, encoder_input_ids=input_ids, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, logits_processor=logits_processor, ) # 9. prepare stopping criteria stopping_criteria = self._get_stopping_criteria( generation_config=generation_config, stopping_criteria=stopping_criteria ) logits_warper = self._get_logits_warper(generation_config) if is_greedy_gen_mode: # 11. run greedy search return self.stream_greedy_search( input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, **model_kwargs, ) elif is_sample_gen_mode: # 12. expand input_ids with `num_return_sequences` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, expand_size=generation_config.num_return_sequences, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs, ) return self.stream_sample( generation_config, input_ids, logits_processor, logits_warper, stopping_criteria, synced_gpus, **model_kwargs, ) elif is_beam_gen_mode: return self.stream_beam_search( generation_config, input_ids, logits_processor, stopping_criteria, synced_gpus, **model_kwargs, ) elif is_beam_sample_gen_mode: # interleave input_ids with `num_beams` additional sequences per batch return self.stream_beam_sample( input_ids, logits_processor, logits_warper, stopping_criteria, generation_config, synced_gpus, **model_kwargs, ) else: raise Exception('not implement') def stream_sample( self, generation_config, input_ids, logits_processor, logits_warper, stopping_criteria, synced_gpus, **model_kwargs, ): bos_token_id, eos_token_id, pad_token_id = ( generation_config.bos_token_id, generation_config.eos_token_id, generation_config.pad_token_id, ) if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None # keep track of which sequences are already finished unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) this_peer_finished = False # used by synced_gpus only scores=() # auto-regressive generation while True: if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # The following logic allows an early break if all peers finished generating their sequence this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) # send 0.0 if we finished, 1.0 otherwise dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) # did all peers finish? the reduced sum will be 0.0 then if this_peer_finished_flag.item() == 0.0: break # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) # forward pass to get next token outputs = self( **model_inputs, return_dict=True, ) if synced_gpus and this_peer_finished: continue # don't waste resources running the code we don't need next_token_logits = outputs.logits[:, -1, :] # pre-process distribution next_token_scores = logits_processor(input_ids, next_token_logits) next_token_scores = logits_warper(input_ids, next_token_scores) # sample probs = nn.functional.softmax(next_token_scores, dim=-1) next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) # finished sentences should have their next token be a padding token if eos_token_id is not None: if pad_token_id is None: raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) yield input_ids # if eos_token was found in one sentence, set sentence to finished if eos_token_id_tensor is not None: unfinished_sequences = unfinished_sequences.mul( next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) ) # stop when each sentence is finished, or if we exceed the maximum length if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): if not synced_gpus: break else: this_peer_finished = True yield input_ids def stream_beam_sample( self, input_ids, logits_processor, logits_warper, stopping_criteria, generation_config, synced_gpus, **model_kwargs, ): bos_token_id, eos_token_id, pad_token_id = ( generation_config.bos_token_id, generation_config.eos_token_id, generation_config.pad_token_id, ) if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None num_beams = generation_config.num_beams batch_size, cur_len = input_ids.shape[0], input_ids.shape[-1] beam_scorer = BeamSearchScorer( batch_size=batch_size, num_beams=generation_config.num_beams, device=input_ids.device, length_penalty=generation_config.length_penalty, do_early_stopping=generation_config.early_stopping, num_beam_hyps_to_keep=generation_config.num_return_sequences, max_length=generation_config.max_length, ) input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, expand_size=generation_config.num_beams * generation_config.num_return_sequences, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs, ) scores = () beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) beam_scores = beam_scores.view((batch_size * num_beams,)) this_peer_finished = False # used by synced_gpus only while True: if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # The following logic allows an early break if all peers finished generating their sequence this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) # send 0.0 if we finished, 1.0 otherwise dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) # did all peers finish? the reduced sum will be 0.0 then if this_peer_finished_flag.item() == 0.0: break model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) outputs = self( **model_inputs, return_dict=True, ) if synced_gpus and this_peer_finished: cur_len = cur_len + 1 continue # don't waste resources running the code we don't need next_token_logits = outputs.logits[:, -1, :] # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` # cannot be generated both before and after the `nn.functional.log_softmax` operation. next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) next_token_scores = nn.functional.log_softmax( next_token_logits, dim=-1 ) # (batch_size * num_beams, vocab_size) next_token_scores_processed = logits_processor(input_ids, next_token_scores) next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores) # Note: logits warpers are intentionally applied after adding running beam scores. On some logits warpers # (like top_p) this is indiferent, but on others (like temperature) it is not. For reference, see # https://github.com/huggingface/transformers/pull/5420#discussion_r449779867 next_token_scores = logits_warper(input_ids, next_token_scores) # reshape for beam search vocab_size = next_token_scores.shape[-1] next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) probs = nn.functional.softmax(next_token_scores, dim=-1) next_tokens = torch.multinomial(probs, num_samples=2 * num_beams) next_token_scores = torch.gather(next_token_scores, -1, next_tokens) next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1) next_tokens = torch.gather(next_tokens, -1, _indices) next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") next_tokens = next_tokens % vocab_size # stateless beam_outputs = beam_scorer.process( input_ids, next_token_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id, beam_indices=None, ) beam_scores = beam_outputs["next_beam_scores"] beam_next_tokens = beam_outputs["next_beam_tokens"] beam_idx = beam_outputs["next_beam_indices"] input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) yield input_ids model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) if model_kwargs["past_key_values"] is not None: model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx) # increase cur_len cur_len = cur_len + 1 if beam_scorer.is_done or stopping_criteria(input_ids, scores): if not synced_gpus: break else: this_peer_finished = True sequence_outputs = beam_scorer.finalize( input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id, max_length=stopping_criteria.max_length, beam_indices=None, ) yield sequence_outputs["sequences"] def stream_greedy_search( self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, **model_kwargs, ): # init values bos_token_id, eos_token_id, pad_token_id = ( generation_config.bos_token_id, generation_config.eos_token_id, generation_config.pad_token_id, ) if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None # init attention / hidden states / scores tuples scores = () # keep track of which sequences are already finished unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) this_peer_finished = False # used by synced_gpus only while True: if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # The following logic allows an early break if all peers finished generating their sequence this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) # send 0.0 if we finished, 1.0 otherwise dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) # did all peers finish? the reduced sum will be 0.0 then if this_peer_finished_flag.item() == 0.0: break # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) # forward pass to get next token outputs = self( **model_inputs, return_dict=True, ) if synced_gpus and this_peer_finished: continue # don't waste resources running the code we don't need next_token_logits = outputs.logits[:, -1, :] # pre-process distribution next_tokens_scores = logits_processor(input_ids, next_token_logits) # argmax next_tokens = torch.argmax(next_tokens_scores, dim=-1) # finished sentences should have their next token be a padding token if eos_token_id is not None: if pad_token_id is None: raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) yield input_ids # if eos_token was found in one sentence, set sentence to finished if eos_token_id_tensor is not None: unfinished_sequences = unfinished_sequences.mul( next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) ) # stop when each sentence is finished, or if we exceed the maximum length if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): if not synced_gpus: break else: this_peer_finished = True yield input_ids def stream_beam_search( self, generation_config, input_ids, logits_processor, stopping_criteria, synced_gpus, **model_kwargs, ): # 10. go into beam search generation modes # 11. prepare beam search scorer bos_token_id, eos_token_id, pad_token_id = ( generation_config.bos_token_id, generation_config.eos_token_id, generation_config.pad_token_id, ) if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] num_beams = generation_config.num_beams batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] beam_scorer = BeamSearchScorer( batch_size=batch_size, num_beams=generation_config.num_beams, device=input_ids.device, length_penalty=generation_config.length_penalty, do_early_stopping=generation_config.early_stopping, num_beam_hyps_to_keep=generation_config.num_return_sequences, max_length=generation_config.max_length, ) # 12. interleave input_ids with `num_beams` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, expand_size=generation_config.num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs, ) # beam_search logits batch_beam_size, cur_len = input_ids.shape if num_beams * batch_size != batch_beam_size: raise ValueError( f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." ) beam_scores = torch.zeros( (batch_size, num_beams), dtype=torch.float, device=input_ids.device ) beam_scores[:, 1:] = -1e9 beam_scores = beam_scores.view((batch_size * num_beams,)) this_peer_finished = False # used by synced_gpus only while True: if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # The following logic allows an early break if all peers finished generating their sequence this_peer_finished_flag = torch.tensor( 0.0 if this_peer_finished else 1.0 ).to(input_ids.device) # send 0.0 if we finished, 1.0 otherwise dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) # did all peers finish? the reduced sum will be 0.0 then if this_peer_finished_flag.item() == 0.0: break model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) outputs = self( **model_inputs, return_dict=True, output_attentions=False, output_hidden_states=False, ) if synced_gpus and this_peer_finished: cur_len = cur_len + 1 continue # don't waste resources running the code we don't need next_token_logits = outputs.logits[:, -1, :] # next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) hack: adjust tokens for Marian. next_token_scores = nn.functional.log_softmax( next_token_logits, dim=-1 ) # (batch_size * num_beams, vocab_size) next_token_scores_processed = logits_processor(input_ids, next_token_scores) next_token_scores = next_token_scores_processed + beam_scores[ :, None ].expand_as(next_token_scores) # reshape for beam search vocab_size = next_token_scores.shape[-1] next_token_scores = next_token_scores.view( batch_size, num_beams * vocab_size ) # Sample 2 next tokens for each beam (so we have some spare tokens and match output of beam search) next_token_scores, next_tokens = torch.topk( next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True ) next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") next_tokens = next_tokens % vocab_size # stateless beam_outputs = beam_scorer.process( input_ids, next_token_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id, beam_indices=None, ) beam_scores = beam_outputs["next_beam_scores"] beam_next_tokens = beam_outputs["next_beam_tokens"] beam_idx = beam_outputs["next_beam_indices"] input_ids = torch.cat( [input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1 ) model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) if model_kwargs["past_key_values"] is not None: model_kwargs["past_key_values"] = self._reorder_cache( model_kwargs["past_key_values"], beam_idx ) # increase cur_len cur_len = cur_len + 1 yield input_ids if beam_scorer.is_done or stopping_criteria(input_ids, None): if not synced_gpus: break else: this_peer_finished = True final_result = beam_scorer.finalize( input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id, max_length=stopping_criteria.max_length, beam_indices=None, ) yield final_result["sequences"] class StreamLlamaForCausalLM(LlamaForCausalLM, StreamGenerationMixin): pass class StreamPeftGenerationMixin(PeftModelForCausalLM, StreamGenerationMixin): # default it call `model = MODEL_TYPE_TO_PEFT_MODEL_MAPPING[config.task_type](model, config)`, not cls!! so inherent PeftModelForCausalLM is non sense @classmethod def from_pretrained(cls, model, model_id, adapter_name="default", is_trainable=False, **kwargs): # work in peft==0.3.0 if peft.__version__ >= '0.3.0' and peft.__version__ != '0.3.0.dev0': # load the config from peft.utils import PromptLearningConfig config = LoraConfig.from_pretrained(model_id) if (getattr(model, "hf_device_map", None) is not None) and len( set(model.hf_device_map.values()).intersection({"cpu", "disk"}) ) > 0: remove_hook_from_submodules(model) if isinstance(config, PromptLearningConfig) and is_trainable: raise ValueError("Cannot set a prompt learning adapter to trainable when loading pretrained adapter.") else: config.inference_mode = not is_trainable # here is the hack model = cls(model, config, adapter_name) model.load_adapter(model_id, adapter_name, **kwargs) # NOTICE model.base_model_prepare_inputs_for_generation = model.base_model.prepare_inputs_for_generation model._reorder_cache = model.base_model._reorder_cache return model else: return cls.from_pretrained_old_peft_version(model, model_id, **kwargs) @classmethod def from_pretrained_old_peft_version(cls, model, model_id, **kwargs): # work well in peft@e536616888d51b453ed354a6f1e243fecb02ea08 # load the config config = LoraConfig.from_pretrained(model_id) if getattr(model, "hf_device_map", None) is not None: remove_hook_from_submodules(model) # here is the hack model = cls(model, config) model._reorder_cache = model.base_model._reorder_cache # load weights if any if os.path.exists(os.path.join(model_id, "adapter_model.bin")): filename = os.path.join(model_id, "adapter_model.bin") else: try: filename = hf_hub_download(model_id, "adapter_model.bin") except: # noqa raise ValueError( f"Can't find weights for {model_id} in {model_id} or in the Hugging Face Hub. " f"Please check that the file {'adapter_model.bin'} is present at {model_id}." ) adapters_weights = torch.load( filename, map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"), ) # load the weights into the model model = set_peft_model_state_dict(model, adapters_weights) if getattr(model, "hf_device_map", None) is not None: device_map = kwargs.get("device_map", "auto") max_memory = kwargs.get("max_memory", None) no_split_module_classes = model._no_split_modules if device_map != "sequential": max_memory = get_balanced_memory( model, max_memory=max_memory, no_split_module_classes=no_split_module_classes, low_zero=(device_map == "balanced_low_0"), ) if isinstance(device_map, str): device_map = infer_auto_device_map( model, max_memory=max_memory, no_split_module_classes=no_split_module_classes, ) model = dispatch_model(model, device_map=device_map) hook = AlignDevicesHook(io_same_device=True) if model.peft_config.peft_type == PeftType.LORA: add_hook_to_module(model.base_model.model, hook) else: remove_hook_from_submodules(model.prompt_encoder) add_hook_to_module(model.base_model, hook) return model