chatlawv1 / utils.py
teachyourselfcoding's picture
Upload 245 files
fa6856c
import logging
import sys
import os
import torch
import json
from typing import Optional, Tuple, Union, List, Callable
from transformers import LlamaForCausalLM
from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.beam_search import BeamSearchScorer
from transformers.deepspeed import is_deepspeed_zero3_enabled
from transformers.generation.utils import (
LogitsProcessorList,
StoppingCriteriaList,
GenerationConfig,
GenerationMixin,
)
import warnings
from peft import PeftModel, PeftModelForCausalLM, LoraConfig
import peft
import torch.distributed as dist
from torch import nn
import copy
from accelerate.hooks import (
AlignDevicesHook,
add_hook_to_module,
remove_hook_from_submodules,
)
from accelerate.utils import get_balanced_memory
from huggingface_hub import hf_hub_download
from accelerate import dispatch_model, infer_auto_device_map
from peft.utils import PeftType, set_peft_model_state_dict
def printf(*args,**kargs):
if os.environ.get('DEBUG',False):
end = '\n'
if 'end' in kargs:
end = kargs['end']
print(*args, end=end, flush=True)
class ColorFormatter(logging.Formatter):
grey = "\x1b[38;20m"
blue = "\x1b[34;20m"
yellow = "\x1b[33;20m"
red = "\x1b[31;20m"
bold_red = "\x1b[31;1m"
reset = "\x1b[0m"
def __init__(self, fmt):
super().__init__(fmt)
self.FORMATS = {
logging.DEBUG: self.grey + fmt + self.reset,
logging.INFO: self.blue + fmt + self.reset,
logging.WARNING: self.yellow + fmt + self.reset,
logging.ERROR: self.red + fmt + self.reset,
logging.CRITICAL: self.bold_red + fmt + self.reset
}
def format(self, record):
log_fmt = self.FORMATS.get(record.levelno)
formatter = logging.Formatter(log_fmt)
return formatter.format(record)
def set_console_logger(name):
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
consoleHandler = logging.StreamHandler(sys.stdout)
consoleHandler.setLevel(logging.INFO)
consoleHandler.setFormatter(ColorFormatter("%(asctime)s | %(levelname)s %(message)s"))
logger.addHandler(consoleHandler)
return logger
def set_file_logger(name, dir, use_console=False):
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
os.makedirs(dir, exist_ok=True)
if use_console:
logger.propagate = False # disable default handler
consoleHandler = logging.StreamHandler(sys.stdout)
consoleHandler.setLevel(logging.INFO)
consoleHandler.setFormatter(ColorFormatter("%(asctime)s | %(levelname)s %(message)s"))
logger.addHandler(consoleHandler)
fileHandler = logging.FileHandler(os.path.join(dir,'session.log'), mode='a')
fileHandler.setLevel(logging.INFO)
fileHandler.setFormatter(logging.Formatter("%(asctime)s | %(levelname)s %(message)s"))
logger.addHandler(fileHandler)
return logger
def to_jsonl(data, path):
with open(path, 'a') as f:
for line in data:
f.write(json.dumps(line,ensure_ascii=False)+'\n')
def from_json(path):
return json.load(open(path))
def from_jsonl(path):
return [json.loads(line) for line in open(path, 'r') ]
def to_json(data, path):
json.dump(data, open(path, 'w'), ensure_ascii=False)
class StreamGenerationMixin(GenerationMixin):
# support for streamly generation
# TODO: group_beam_search
@torch.no_grad()
def stream_generate(
self,
input_ids: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[
Callable[[int, torch.Tensor], List[int]]
] = None,
**kwargs,
):
if is_deepspeed_zero3_enabled() and dist.world_size() > 1:
synced_gpus = True
else:
synced_gpus = False
if kwargs.get("attention_mask", None) is not None:
# concat prompt attention mask
prefix_attention_mask = torch.ones(
kwargs["input_ids"].shape[0], self.peft_config.num_virtual_tokens
).to(kwargs["input_ids"].device)
kwargs["attention_mask"] = torch.cat(
(prefix_attention_mask, kwargs["attention_mask"]), dim=1
)
if kwargs.get("position_ids", None) is not None:
warnings.warn(
"Position ids are not supported for parameter efficient tuning. Ignoring position ids."
)
kwargs["position_ids"] = None
if kwargs.get("token_type_ids", None) is not None:
warnings.warn(
"Token type ids are not supported for parameter efficient tuning. Ignoring token type ids"
)
kwargs["token_type_ids"] = None
batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
if generation_config is None:
generation_config = self.generation_config
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs)
bos_token_id, eos_token_id, pad_token_id = (
generation_config.bos_token_id,
generation_config.eos_token_id,
generation_config.pad_token_id,
)
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
has_default_max_length = (
kwargs.get("max_length") is None
and generation_config.max_length is not None
)
if has_default_max_length and generation_config.max_new_tokens is None:
warnings.warn(
f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
" recommend using `max_new_tokens` to control the maximum length of the generation.",
UserWarning,
)
elif generation_config.max_new_tokens is not None:
generation_config.max_length = (
generation_config.max_new_tokens + input_ids_seq_length
)
if generation_config.min_new_tokens is not None:
generation_config.min_length = (
generation_config.min_new_tokens + input_ids_seq_length
)
if input_ids_seq_length >= generation_config.max_length:
input_ids_string = (
"decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
)
# 2. Set generation parameters if not already defined
logits_processor = (
logits_processor if logits_processor is not None else LogitsProcessorList()
)
stopping_criteria = (
stopping_criteria
if stopping_criteria is not None
else StoppingCriteriaList()
)
# 7. determine generation mode
is_constraint_gen_mode = (
generation_config.constraints is not None or generation_config.force_words_ids is not None
)
is_contrastive_search_gen_mode = (
generation_config.top_k is not None
and generation_config.top_k > 1
and generation_config.do_sample is False
and generation_config.penalty_alpha is not None
and generation_config.penalty_alpha > 0
)
is_greedy_gen_mode = (
(generation_config.num_beams == 1)
and (generation_config.num_beam_groups == 1)
and generation_config.do_sample is False
and not is_constraint_gen_mode
and not is_contrastive_search_gen_mode
)
# beam=1 and do_sample=True
is_sample_gen_mode = (
(generation_config.num_beams == 1)
and (generation_config.num_beam_groups == 1)
and generation_config.do_sample is True
and not is_constraint_gen_mode
and not is_contrastive_search_gen_mode
)
is_beam_gen_mode = (
(generation_config.num_beams > 1)
and (generation_config.num_beam_groups == 1)
and generation_config.do_sample is False
and not is_constraint_gen_mode
and not is_contrastive_search_gen_mode
)
is_beam_sample_gen_mode = (
(generation_config.num_beams > 1)
and (generation_config.num_beam_groups == 1)
and generation_config.do_sample is True
and not is_constraint_gen_mode
and not is_contrastive_search_gen_mode
)
is_group_beam_gen_mode = (
(generation_config.num_beams > 1)
and (generation_config.num_beam_groups > 1)
and not is_constraint_gen_mode
and not is_contrastive_search_gen_mode
)
# 8. prepare distribution pre_processing samplers
logits_processor = self._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
encoder_input_ids=input_ids,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
logits_processor=logits_processor,
)
# 9. prepare stopping criteria
stopping_criteria = self._get_stopping_criteria(
generation_config=generation_config, stopping_criteria=stopping_criteria
)
logits_warper = self._get_logits_warper(generation_config)
if is_greedy_gen_mode:
# 11. run greedy search
return self.stream_greedy_search(
input_ids,
logits_processor,
stopping_criteria,
generation_config,
synced_gpus,
**model_kwargs,
)
elif is_sample_gen_mode:
# 12. expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=generation_config.num_return_sequences,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)
return self.stream_sample(
generation_config,
input_ids,
logits_processor,
logits_warper,
stopping_criteria,
synced_gpus,
**model_kwargs,
)
elif is_beam_gen_mode:
return self.stream_beam_search(
generation_config,
input_ids,
logits_processor,
stopping_criteria,
synced_gpus,
**model_kwargs,
)
elif is_beam_sample_gen_mode:
# interleave input_ids with `num_beams` additional sequences per batch
return self.stream_beam_sample(
input_ids,
logits_processor,
logits_warper,
stopping_criteria,
generation_config,
synced_gpus,
**model_kwargs,
)
else:
raise Exception('not implement')
def stream_sample(
self,
generation_config,
input_ids,
logits_processor,
logits_warper,
stopping_criteria,
synced_gpus,
**model_kwargs,
):
bos_token_id, eos_token_id, pad_token_id = (
generation_config.bos_token_id,
generation_config.eos_token_id,
generation_config.pad_token_id,
)
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
# keep track of which sequences are already finished
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
this_peer_finished = False # used by synced_gpus only
scores=()
# auto-regressive generation
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
break
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# forward pass to get next token
outputs = self(
**model_inputs,
return_dict=True,
)
if synced_gpus and this_peer_finished:
continue # don't waste resources running the code we don't need
next_token_logits = outputs.logits[:, -1, :]
# pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits)
next_token_scores = logits_warper(input_ids, next_token_scores)
# sample
probs = nn.functional.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
# finished sentences should have their next token be a padding token
if eos_token_id is not None:
if pad_token_id is None:
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
yield input_ids
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id_tensor is not None:
unfinished_sequences = unfinished_sequences.mul(
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
)
# stop when each sentence is finished, or if we exceed the maximum length
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
if not synced_gpus:
break
else:
this_peer_finished = True
yield input_ids
def stream_beam_sample(
self,
input_ids,
logits_processor,
logits_warper,
stopping_criteria,
generation_config,
synced_gpus,
**model_kwargs,
):
bos_token_id, eos_token_id, pad_token_id = (
generation_config.bos_token_id,
generation_config.eos_token_id,
generation_config.pad_token_id,
)
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
num_beams = generation_config.num_beams
batch_size, cur_len = input_ids.shape[0], input_ids.shape[-1]
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=generation_config.num_beams,
device=input_ids.device,
length_penalty=generation_config.length_penalty,
do_early_stopping=generation_config.early_stopping,
num_beam_hyps_to_keep=generation_config.num_return_sequences,
max_length=generation_config.max_length,
)
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=generation_config.num_beams * generation_config.num_return_sequences,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)
scores = ()
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
beam_scores = beam_scores.view((batch_size * num_beams,))
this_peer_finished = False # used by synced_gpus only
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
break
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
outputs = self(
**model_inputs,
return_dict=True,
)
if synced_gpus and this_peer_finished:
cur_len = cur_len + 1
continue # don't waste resources running the code we don't need
next_token_logits = outputs.logits[:, -1, :]
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `nn.functional.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
next_token_scores = nn.functional.log_softmax(
next_token_logits, dim=-1
) # (batch_size * num_beams, vocab_size)
next_token_scores_processed = logits_processor(input_ids, next_token_scores)
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)
# Note: logits warpers are intentionally applied after adding running beam scores. On some logits warpers
# (like top_p) this is indiferent, but on others (like temperature) it is not. For reference, see
# https://github.com/huggingface/transformers/pull/5420#discussion_r449779867
next_token_scores = logits_warper(input_ids, next_token_scores)
# reshape for beam search
vocab_size = next_token_scores.shape[-1]
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
probs = nn.functional.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=2 * num_beams)
next_token_scores = torch.gather(next_token_scores, -1, next_tokens)
next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
next_tokens = torch.gather(next_tokens, -1, _indices)
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
next_tokens = next_tokens % vocab_size
# stateless
beam_outputs = beam_scorer.process(
input_ids,
next_token_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
beam_indices=None,
)
beam_scores = beam_outputs["next_beam_scores"]
beam_next_tokens = beam_outputs["next_beam_tokens"]
beam_idx = beam_outputs["next_beam_indices"]
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
yield input_ids
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
if model_kwargs["past_key_values"] is not None:
model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)
# increase cur_len
cur_len = cur_len + 1
if beam_scorer.is_done or stopping_criteria(input_ids, scores):
if not synced_gpus:
break
else:
this_peer_finished = True
sequence_outputs = beam_scorer.finalize(
input_ids,
beam_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length,
beam_indices=None,
)
yield sequence_outputs["sequences"]
def stream_greedy_search(
self,
input_ids,
logits_processor,
stopping_criteria,
generation_config,
synced_gpus,
**model_kwargs,
):
# init values
bos_token_id, eos_token_id, pad_token_id = (
generation_config.bos_token_id,
generation_config.eos_token_id,
generation_config.pad_token_id,
)
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
# init attention / hidden states / scores tuples
scores = ()
# keep track of which sequences are already finished
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
this_peer_finished = False # used by synced_gpus only
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
break
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# forward pass to get next token
outputs = self(
**model_inputs,
return_dict=True,
)
if synced_gpus and this_peer_finished:
continue # don't waste resources running the code we don't need
next_token_logits = outputs.logits[:, -1, :]
# pre-process distribution
next_tokens_scores = logits_processor(input_ids, next_token_logits)
# argmax
next_tokens = torch.argmax(next_tokens_scores, dim=-1)
# finished sentences should have their next token be a padding token
if eos_token_id is not None:
if pad_token_id is None:
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
yield input_ids
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id_tensor is not None:
unfinished_sequences = unfinished_sequences.mul(
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
)
# stop when each sentence is finished, or if we exceed the maximum length
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
if not synced_gpus:
break
else:
this_peer_finished = True
yield input_ids
def stream_beam_search(
self,
generation_config,
input_ids,
logits_processor,
stopping_criteria,
synced_gpus,
**model_kwargs,
):
# 10. go into beam search generation modes
# 11. prepare beam search scorer
bos_token_id, eos_token_id, pad_token_id = (
generation_config.bos_token_id,
generation_config.eos_token_id,
generation_config.pad_token_id,
)
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
num_beams = generation_config.num_beams
batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=generation_config.num_beams,
device=input_ids.device,
length_penalty=generation_config.length_penalty,
do_early_stopping=generation_config.early_stopping,
num_beam_hyps_to_keep=generation_config.num_return_sequences,
max_length=generation_config.max_length,
)
# 12. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=generation_config.num_beams,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)
# beam_search logits
batch_beam_size, cur_len = input_ids.shape
if num_beams * batch_size != batch_beam_size:
raise ValueError(
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
)
beam_scores = torch.zeros(
(batch_size, num_beams), dtype=torch.float, device=input_ids.device
)
beam_scores[:, 1:] = -1e9
beam_scores = beam_scores.view((batch_size * num_beams,))
this_peer_finished = False # used by synced_gpus only
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(
0.0 if this_peer_finished else 1.0
).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
break
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
if synced_gpus and this_peer_finished:
cur_len = cur_len + 1
continue # don't waste resources running the code we don't need
next_token_logits = outputs.logits[:, -1, :]
# next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) hack: adjust tokens for Marian.
next_token_scores = nn.functional.log_softmax(
next_token_logits, dim=-1
) # (batch_size * num_beams, vocab_size)
next_token_scores_processed = logits_processor(input_ids, next_token_scores)
next_token_scores = next_token_scores_processed + beam_scores[
:, None
].expand_as(next_token_scores)
# reshape for beam search
vocab_size = next_token_scores.shape[-1]
next_token_scores = next_token_scores.view(
batch_size, num_beams * vocab_size
)
# Sample 2 next tokens for each beam (so we have some spare tokens and match output of beam search)
next_token_scores, next_tokens = torch.topk(
next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True
)
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
next_tokens = next_tokens % vocab_size
# stateless
beam_outputs = beam_scorer.process(
input_ids,
next_token_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
beam_indices=None,
)
beam_scores = beam_outputs["next_beam_scores"]
beam_next_tokens = beam_outputs["next_beam_tokens"]
beam_idx = beam_outputs["next_beam_indices"]
input_ids = torch.cat(
[input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1
)
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
if model_kwargs["past_key_values"] is not None:
model_kwargs["past_key_values"] = self._reorder_cache(
model_kwargs["past_key_values"], beam_idx
)
# increase cur_len
cur_len = cur_len + 1
yield input_ids
if beam_scorer.is_done or stopping_criteria(input_ids, None):
if not synced_gpus:
break
else:
this_peer_finished = True
final_result = beam_scorer.finalize(
input_ids,
beam_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length,
beam_indices=None,
)
yield final_result["sequences"]
class StreamLlamaForCausalLM(LlamaForCausalLM, StreamGenerationMixin):
pass
class StreamPeftGenerationMixin(PeftModelForCausalLM, StreamGenerationMixin):
# default it call `model = MODEL_TYPE_TO_PEFT_MODEL_MAPPING[config.task_type](model, config)`, not cls!! so inherent PeftModelForCausalLM is non sense
@classmethod
def from_pretrained(cls, model, model_id, adapter_name="default", is_trainable=False, **kwargs):
# work in peft==0.3.0
if peft.__version__ >= '0.3.0' and peft.__version__ != '0.3.0.dev0':
# load the config
from peft.utils import PromptLearningConfig
config = LoraConfig.from_pretrained(model_id)
if (getattr(model, "hf_device_map", None) is not None) and len(
set(model.hf_device_map.values()).intersection({"cpu", "disk"})
) > 0:
remove_hook_from_submodules(model)
if isinstance(config, PromptLearningConfig) and is_trainable:
raise ValueError("Cannot set a prompt learning adapter to trainable when loading pretrained adapter.")
else:
config.inference_mode = not is_trainable
# here is the hack
model = cls(model, config, adapter_name)
model.load_adapter(model_id, adapter_name, **kwargs)
# NOTICE
model.base_model_prepare_inputs_for_generation = model.base_model.prepare_inputs_for_generation
model._reorder_cache = model.base_model._reorder_cache
return model
else:
return cls.from_pretrained_old_peft_version(model, model_id, **kwargs)
@classmethod
def from_pretrained_old_peft_version(cls, model, model_id, **kwargs):
# work well in peft@e536616888d51b453ed354a6f1e243fecb02ea08
# load the config
config = LoraConfig.from_pretrained(model_id)
if getattr(model, "hf_device_map", None) is not None:
remove_hook_from_submodules(model)
# here is the hack
model = cls(model, config)
model._reorder_cache = model.base_model._reorder_cache
# load weights if any
if os.path.exists(os.path.join(model_id, "adapter_model.bin")):
filename = os.path.join(model_id, "adapter_model.bin")
else:
try:
filename = hf_hub_download(model_id, "adapter_model.bin")
except: # noqa
raise ValueError(
f"Can't find weights for {model_id} in {model_id} or in the Hugging Face Hub. "
f"Please check that the file {'adapter_model.bin'} is present at {model_id}."
)
adapters_weights = torch.load(
filename,
map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
)
# load the weights into the model
model = set_peft_model_state_dict(model, adapters_weights)
if getattr(model, "hf_device_map", None) is not None:
device_map = kwargs.get("device_map", "auto")
max_memory = kwargs.get("max_memory", None)
no_split_module_classes = model._no_split_modules
if device_map != "sequential":
max_memory = get_balanced_memory(
model,
max_memory=max_memory,
no_split_module_classes=no_split_module_classes,
low_zero=(device_map == "balanced_low_0"),
)
if isinstance(device_map, str):
device_map = infer_auto_device_map(
model,
max_memory=max_memory,
no_split_module_classes=no_split_module_classes,
)
model = dispatch_model(model, device_map=device_map)
hook = AlignDevicesHook(io_same_device=True)
if model.peft_config.peft_type == PeftType.LORA:
add_hook_to_module(model.base_model.model, hook)
else:
remove_hook_from_submodules(model.prompt_encoder)
add_hook_to_module(model.base_model, hook)
return model