Spaces:
Runtime error
Runtime error
import logging | |
import sys | |
import os | |
import torch | |
import json | |
from typing import Optional, Tuple, Union, List, Callable | |
from transformers import LlamaForCausalLM | |
from transformers.generation.logits_process import LogitsProcessor | |
from transformers.generation.beam_search import BeamSearchScorer | |
from transformers.deepspeed import is_deepspeed_zero3_enabled | |
from transformers.generation.utils import ( | |
LogitsProcessorList, | |
StoppingCriteriaList, | |
GenerationConfig, | |
GenerationMixin, | |
) | |
import warnings | |
from peft import PeftModel, PeftModelForCausalLM, LoraConfig | |
import peft | |
import torch.distributed as dist | |
from torch import nn | |
import copy | |
from accelerate.hooks import ( | |
AlignDevicesHook, | |
add_hook_to_module, | |
remove_hook_from_submodules, | |
) | |
from accelerate.utils import get_balanced_memory | |
from huggingface_hub import hf_hub_download | |
from accelerate import dispatch_model, infer_auto_device_map | |
from peft.utils import PeftType, set_peft_model_state_dict | |
def printf(*args,**kargs): | |
if os.environ.get('DEBUG',False): | |
end = '\n' | |
if 'end' in kargs: | |
end = kargs['end'] | |
print(*args, end=end, flush=True) | |
class ColorFormatter(logging.Formatter): | |
grey = "\x1b[38;20m" | |
blue = "\x1b[34;20m" | |
yellow = "\x1b[33;20m" | |
red = "\x1b[31;20m" | |
bold_red = "\x1b[31;1m" | |
reset = "\x1b[0m" | |
def __init__(self, fmt): | |
super().__init__(fmt) | |
self.FORMATS = { | |
logging.DEBUG: self.grey + fmt + self.reset, | |
logging.INFO: self.blue + fmt + self.reset, | |
logging.WARNING: self.yellow + fmt + self.reset, | |
logging.ERROR: self.red + fmt + self.reset, | |
logging.CRITICAL: self.bold_red + fmt + self.reset | |
} | |
def format(self, record): | |
log_fmt = self.FORMATS.get(record.levelno) | |
formatter = logging.Formatter(log_fmt) | |
return formatter.format(record) | |
def set_console_logger(name): | |
logger = logging.getLogger(name) | |
logger.setLevel(logging.DEBUG) | |
consoleHandler = logging.StreamHandler(sys.stdout) | |
consoleHandler.setLevel(logging.INFO) | |
consoleHandler.setFormatter(ColorFormatter("%(asctime)s | %(levelname)s %(message)s")) | |
logger.addHandler(consoleHandler) | |
return logger | |
def set_file_logger(name, dir, use_console=False): | |
logger = logging.getLogger(name) | |
logger.setLevel(logging.DEBUG) | |
os.makedirs(dir, exist_ok=True) | |
if use_console: | |
logger.propagate = False # disable default handler | |
consoleHandler = logging.StreamHandler(sys.stdout) | |
consoleHandler.setLevel(logging.INFO) | |
consoleHandler.setFormatter(ColorFormatter("%(asctime)s | %(levelname)s %(message)s")) | |
logger.addHandler(consoleHandler) | |
fileHandler = logging.FileHandler(os.path.join(dir,'session.log'), mode='a') | |
fileHandler.setLevel(logging.INFO) | |
fileHandler.setFormatter(logging.Formatter("%(asctime)s | %(levelname)s %(message)s")) | |
logger.addHandler(fileHandler) | |
return logger | |
def to_jsonl(data, path): | |
with open(path, 'a') as f: | |
for line in data: | |
f.write(json.dumps(line,ensure_ascii=False)+'\n') | |
def from_json(path): | |
return json.load(open(path)) | |
def from_jsonl(path): | |
return [json.loads(line) for line in open(path, 'r') ] | |
def to_json(data, path): | |
json.dump(data, open(path, 'w'), ensure_ascii=False) | |
class StreamGenerationMixin(GenerationMixin): | |
# support for streamly generation | |
# TODO: group_beam_search | |
def stream_generate( | |
self, | |
input_ids: Optional[torch.Tensor] = None, | |
generation_config: Optional[GenerationConfig] = None, | |
logits_processor: Optional[LogitsProcessorList] = None, | |
stopping_criteria: Optional[StoppingCriteriaList] = None, | |
prefix_allowed_tokens_fn: Optional[ | |
Callable[[int, torch.Tensor], List[int]] | |
] = None, | |
**kwargs, | |
): | |
if is_deepspeed_zero3_enabled() and dist.world_size() > 1: | |
synced_gpus = True | |
else: | |
synced_gpus = False | |
if kwargs.get("attention_mask", None) is not None: | |
# concat prompt attention mask | |
prefix_attention_mask = torch.ones( | |
kwargs["input_ids"].shape[0], self.peft_config.num_virtual_tokens | |
).to(kwargs["input_ids"].device) | |
kwargs["attention_mask"] = torch.cat( | |
(prefix_attention_mask, kwargs["attention_mask"]), dim=1 | |
) | |
if kwargs.get("position_ids", None) is not None: | |
warnings.warn( | |
"Position ids are not supported for parameter efficient tuning. Ignoring position ids." | |
) | |
kwargs["position_ids"] = None | |
if kwargs.get("token_type_ids", None) is not None: | |
warnings.warn( | |
"Token type ids are not supported for parameter efficient tuning. Ignoring token type ids" | |
) | |
kwargs["token_type_ids"] = None | |
batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] | |
if generation_config is None: | |
generation_config = self.generation_config | |
generation_config = copy.deepcopy(generation_config) | |
model_kwargs = generation_config.update(**kwargs) | |
bos_token_id, eos_token_id, pad_token_id = ( | |
generation_config.bos_token_id, | |
generation_config.eos_token_id, | |
generation_config.pad_token_id, | |
) | |
if isinstance(eos_token_id, int): | |
eos_token_id = [eos_token_id] | |
has_default_max_length = ( | |
kwargs.get("max_length") is None | |
and generation_config.max_length is not None | |
) | |
if has_default_max_length and generation_config.max_new_tokens is None: | |
warnings.warn( | |
f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " | |
"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" | |
" recommend using `max_new_tokens` to control the maximum length of the generation.", | |
UserWarning, | |
) | |
elif generation_config.max_new_tokens is not None: | |
generation_config.max_length = ( | |
generation_config.max_new_tokens + input_ids_seq_length | |
) | |
if generation_config.min_new_tokens is not None: | |
generation_config.min_length = ( | |
generation_config.min_new_tokens + input_ids_seq_length | |
) | |
if input_ids_seq_length >= generation_config.max_length: | |
input_ids_string = ( | |
"decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" | |
) | |
# 2. Set generation parameters if not already defined | |
logits_processor = ( | |
logits_processor if logits_processor is not None else LogitsProcessorList() | |
) | |
stopping_criteria = ( | |
stopping_criteria | |
if stopping_criteria is not None | |
else StoppingCriteriaList() | |
) | |
# 7. determine generation mode | |
is_constraint_gen_mode = ( | |
generation_config.constraints is not None or generation_config.force_words_ids is not None | |
) | |
is_contrastive_search_gen_mode = ( | |
generation_config.top_k is not None | |
and generation_config.top_k > 1 | |
and generation_config.do_sample is False | |
and generation_config.penalty_alpha is not None | |
and generation_config.penalty_alpha > 0 | |
) | |
is_greedy_gen_mode = ( | |
(generation_config.num_beams == 1) | |
and (generation_config.num_beam_groups == 1) | |
and generation_config.do_sample is False | |
and not is_constraint_gen_mode | |
and not is_contrastive_search_gen_mode | |
) | |
# beam=1 and do_sample=True | |
is_sample_gen_mode = ( | |
(generation_config.num_beams == 1) | |
and (generation_config.num_beam_groups == 1) | |
and generation_config.do_sample is True | |
and not is_constraint_gen_mode | |
and not is_contrastive_search_gen_mode | |
) | |
is_beam_gen_mode = ( | |
(generation_config.num_beams > 1) | |
and (generation_config.num_beam_groups == 1) | |
and generation_config.do_sample is False | |
and not is_constraint_gen_mode | |
and not is_contrastive_search_gen_mode | |
) | |
is_beam_sample_gen_mode = ( | |
(generation_config.num_beams > 1) | |
and (generation_config.num_beam_groups == 1) | |
and generation_config.do_sample is True | |
and not is_constraint_gen_mode | |
and not is_contrastive_search_gen_mode | |
) | |
is_group_beam_gen_mode = ( | |
(generation_config.num_beams > 1) | |
and (generation_config.num_beam_groups > 1) | |
and not is_constraint_gen_mode | |
and not is_contrastive_search_gen_mode | |
) | |
# 8. prepare distribution pre_processing samplers | |
logits_processor = self._get_logits_processor( | |
generation_config=generation_config, | |
input_ids_seq_length=input_ids_seq_length, | |
encoder_input_ids=input_ids, | |
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, | |
logits_processor=logits_processor, | |
) | |
# 9. prepare stopping criteria | |
stopping_criteria = self._get_stopping_criteria( | |
generation_config=generation_config, stopping_criteria=stopping_criteria | |
) | |
logits_warper = self._get_logits_warper(generation_config) | |
if is_greedy_gen_mode: | |
# 11. run greedy search | |
return self.stream_greedy_search( | |
input_ids, | |
logits_processor, | |
stopping_criteria, | |
generation_config, | |
synced_gpus, | |
**model_kwargs, | |
) | |
elif is_sample_gen_mode: | |
# 12. expand input_ids with `num_return_sequences` additional sequences per batch | |
input_ids, model_kwargs = self._expand_inputs_for_generation( | |
input_ids=input_ids, | |
expand_size=generation_config.num_return_sequences, | |
is_encoder_decoder=self.config.is_encoder_decoder, | |
**model_kwargs, | |
) | |
return self.stream_sample( | |
generation_config, | |
input_ids, | |
logits_processor, | |
logits_warper, | |
stopping_criteria, | |
synced_gpus, | |
**model_kwargs, | |
) | |
elif is_beam_gen_mode: | |
return self.stream_beam_search( | |
generation_config, | |
input_ids, | |
logits_processor, | |
stopping_criteria, | |
synced_gpus, | |
**model_kwargs, | |
) | |
elif is_beam_sample_gen_mode: | |
# interleave input_ids with `num_beams` additional sequences per batch | |
return self.stream_beam_sample( | |
input_ids, | |
logits_processor, | |
logits_warper, | |
stopping_criteria, | |
generation_config, | |
synced_gpus, | |
**model_kwargs, | |
) | |
else: | |
raise Exception('not implement') | |
def stream_sample( | |
self, | |
generation_config, | |
input_ids, | |
logits_processor, | |
logits_warper, | |
stopping_criteria, | |
synced_gpus, | |
**model_kwargs, | |
): | |
bos_token_id, eos_token_id, pad_token_id = ( | |
generation_config.bos_token_id, | |
generation_config.eos_token_id, | |
generation_config.pad_token_id, | |
) | |
if isinstance(eos_token_id, int): | |
eos_token_id = [eos_token_id] | |
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None | |
# keep track of which sequences are already finished | |
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) | |
this_peer_finished = False # used by synced_gpus only | |
scores=() | |
# auto-regressive generation | |
while True: | |
if synced_gpus: | |
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence. | |
# The following logic allows an early break if all peers finished generating their sequence | |
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) | |
# send 0.0 if we finished, 1.0 otherwise | |
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) | |
# did all peers finish? the reduced sum will be 0.0 then | |
if this_peer_finished_flag.item() == 0.0: | |
break | |
# prepare model inputs | |
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) | |
# forward pass to get next token | |
outputs = self( | |
**model_inputs, | |
return_dict=True, | |
) | |
if synced_gpus and this_peer_finished: | |
continue # don't waste resources running the code we don't need | |
next_token_logits = outputs.logits[:, -1, :] | |
# pre-process distribution | |
next_token_scores = logits_processor(input_ids, next_token_logits) | |
next_token_scores = logits_warper(input_ids, next_token_scores) | |
# sample | |
probs = nn.functional.softmax(next_token_scores, dim=-1) | |
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) | |
# finished sentences should have their next token be a padding token | |
if eos_token_id is not None: | |
if pad_token_id is None: | |
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") | |
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) | |
# update generated ids, model inputs, and length for next step | |
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) | |
model_kwargs = self._update_model_kwargs_for_generation( | |
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder | |
) | |
yield input_ids | |
# if eos_token was found in one sentence, set sentence to finished | |
if eos_token_id_tensor is not None: | |
unfinished_sequences = unfinished_sequences.mul( | |
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) | |
) | |
# stop when each sentence is finished, or if we exceed the maximum length | |
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): | |
if not synced_gpus: | |
break | |
else: | |
this_peer_finished = True | |
yield input_ids | |
def stream_beam_sample( | |
self, | |
input_ids, | |
logits_processor, | |
logits_warper, | |
stopping_criteria, | |
generation_config, | |
synced_gpus, | |
**model_kwargs, | |
): | |
bos_token_id, eos_token_id, pad_token_id = ( | |
generation_config.bos_token_id, | |
generation_config.eos_token_id, | |
generation_config.pad_token_id, | |
) | |
if isinstance(eos_token_id, int): | |
eos_token_id = [eos_token_id] | |
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None | |
num_beams = generation_config.num_beams | |
batch_size, cur_len = input_ids.shape[0], input_ids.shape[-1] | |
beam_scorer = BeamSearchScorer( | |
batch_size=batch_size, | |
num_beams=generation_config.num_beams, | |
device=input_ids.device, | |
length_penalty=generation_config.length_penalty, | |
do_early_stopping=generation_config.early_stopping, | |
num_beam_hyps_to_keep=generation_config.num_return_sequences, | |
max_length=generation_config.max_length, | |
) | |
input_ids, model_kwargs = self._expand_inputs_for_generation( | |
input_ids=input_ids, | |
expand_size=generation_config.num_beams * generation_config.num_return_sequences, | |
is_encoder_decoder=self.config.is_encoder_decoder, | |
**model_kwargs, | |
) | |
scores = () | |
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) | |
beam_scores = beam_scores.view((batch_size * num_beams,)) | |
this_peer_finished = False # used by synced_gpus only | |
while True: | |
if synced_gpus: | |
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence. | |
# The following logic allows an early break if all peers finished generating their sequence | |
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) | |
# send 0.0 if we finished, 1.0 otherwise | |
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) | |
# did all peers finish? the reduced sum will be 0.0 then | |
if this_peer_finished_flag.item() == 0.0: | |
break | |
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) | |
outputs = self( | |
**model_inputs, | |
return_dict=True, | |
) | |
if synced_gpus and this_peer_finished: | |
cur_len = cur_len + 1 | |
continue # don't waste resources running the code we don't need | |
next_token_logits = outputs.logits[:, -1, :] | |
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` | |
# cannot be generated both before and after the `nn.functional.log_softmax` operation. | |
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) | |
next_token_scores = nn.functional.log_softmax( | |
next_token_logits, dim=-1 | |
) # (batch_size * num_beams, vocab_size) | |
next_token_scores_processed = logits_processor(input_ids, next_token_scores) | |
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores) | |
# Note: logits warpers are intentionally applied after adding running beam scores. On some logits warpers | |
# (like top_p) this is indiferent, but on others (like temperature) it is not. For reference, see | |
# https://github.com/huggingface/transformers/pull/5420#discussion_r449779867 | |
next_token_scores = logits_warper(input_ids, next_token_scores) | |
# reshape for beam search | |
vocab_size = next_token_scores.shape[-1] | |
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) | |
probs = nn.functional.softmax(next_token_scores, dim=-1) | |
next_tokens = torch.multinomial(probs, num_samples=2 * num_beams) | |
next_token_scores = torch.gather(next_token_scores, -1, next_tokens) | |
next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1) | |
next_tokens = torch.gather(next_tokens, -1, _indices) | |
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") | |
next_tokens = next_tokens % vocab_size | |
# stateless | |
beam_outputs = beam_scorer.process( | |
input_ids, | |
next_token_scores, | |
next_tokens, | |
next_indices, | |
pad_token_id=pad_token_id, | |
eos_token_id=eos_token_id, | |
beam_indices=None, | |
) | |
beam_scores = beam_outputs["next_beam_scores"] | |
beam_next_tokens = beam_outputs["next_beam_tokens"] | |
beam_idx = beam_outputs["next_beam_indices"] | |
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) | |
yield input_ids | |
model_kwargs = self._update_model_kwargs_for_generation( | |
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder | |
) | |
if model_kwargs["past_key_values"] is not None: | |
model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx) | |
# increase cur_len | |
cur_len = cur_len + 1 | |
if beam_scorer.is_done or stopping_criteria(input_ids, scores): | |
if not synced_gpus: | |
break | |
else: | |
this_peer_finished = True | |
sequence_outputs = beam_scorer.finalize( | |
input_ids, | |
beam_scores, | |
next_tokens, | |
next_indices, | |
pad_token_id=pad_token_id, | |
eos_token_id=eos_token_id, | |
max_length=stopping_criteria.max_length, | |
beam_indices=None, | |
) | |
yield sequence_outputs["sequences"] | |
def stream_greedy_search( | |
self, | |
input_ids, | |
logits_processor, | |
stopping_criteria, | |
generation_config, | |
synced_gpus, | |
**model_kwargs, | |
): | |
# init values | |
bos_token_id, eos_token_id, pad_token_id = ( | |
generation_config.bos_token_id, | |
generation_config.eos_token_id, | |
generation_config.pad_token_id, | |
) | |
if isinstance(eos_token_id, int): | |
eos_token_id = [eos_token_id] | |
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None | |
# init attention / hidden states / scores tuples | |
scores = () | |
# keep track of which sequences are already finished | |
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) | |
this_peer_finished = False # used by synced_gpus only | |
while True: | |
if synced_gpus: | |
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence. | |
# The following logic allows an early break if all peers finished generating their sequence | |
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) | |
# send 0.0 if we finished, 1.0 otherwise | |
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) | |
# did all peers finish? the reduced sum will be 0.0 then | |
if this_peer_finished_flag.item() == 0.0: | |
break | |
# prepare model inputs | |
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) | |
# forward pass to get next token | |
outputs = self( | |
**model_inputs, | |
return_dict=True, | |
) | |
if synced_gpus and this_peer_finished: | |
continue # don't waste resources running the code we don't need | |
next_token_logits = outputs.logits[:, -1, :] | |
# pre-process distribution | |
next_tokens_scores = logits_processor(input_ids, next_token_logits) | |
# argmax | |
next_tokens = torch.argmax(next_tokens_scores, dim=-1) | |
# finished sentences should have their next token be a padding token | |
if eos_token_id is not None: | |
if pad_token_id is None: | |
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") | |
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) | |
# update generated ids, model inputs, and length for next step | |
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) | |
model_kwargs = self._update_model_kwargs_for_generation( | |
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder | |
) | |
yield input_ids | |
# if eos_token was found in one sentence, set sentence to finished | |
if eos_token_id_tensor is not None: | |
unfinished_sequences = unfinished_sequences.mul( | |
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) | |
) | |
# stop when each sentence is finished, or if we exceed the maximum length | |
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): | |
if not synced_gpus: | |
break | |
else: | |
this_peer_finished = True | |
yield input_ids | |
def stream_beam_search( | |
self, | |
generation_config, | |
input_ids, | |
logits_processor, | |
stopping_criteria, | |
synced_gpus, | |
**model_kwargs, | |
): | |
# 10. go into beam search generation modes | |
# 11. prepare beam search scorer | |
bos_token_id, eos_token_id, pad_token_id = ( | |
generation_config.bos_token_id, | |
generation_config.eos_token_id, | |
generation_config.pad_token_id, | |
) | |
if isinstance(eos_token_id, int): | |
eos_token_id = [eos_token_id] | |
num_beams = generation_config.num_beams | |
batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] | |
beam_scorer = BeamSearchScorer( | |
batch_size=batch_size, | |
num_beams=generation_config.num_beams, | |
device=input_ids.device, | |
length_penalty=generation_config.length_penalty, | |
do_early_stopping=generation_config.early_stopping, | |
num_beam_hyps_to_keep=generation_config.num_return_sequences, | |
max_length=generation_config.max_length, | |
) | |
# 12. interleave input_ids with `num_beams` additional sequences per batch | |
input_ids, model_kwargs = self._expand_inputs_for_generation( | |
input_ids=input_ids, | |
expand_size=generation_config.num_beams, | |
is_encoder_decoder=self.config.is_encoder_decoder, | |
**model_kwargs, | |
) | |
# beam_search logits | |
batch_beam_size, cur_len = input_ids.shape | |
if num_beams * batch_size != batch_beam_size: | |
raise ValueError( | |
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." | |
) | |
beam_scores = torch.zeros( | |
(batch_size, num_beams), dtype=torch.float, device=input_ids.device | |
) | |
beam_scores[:, 1:] = -1e9 | |
beam_scores = beam_scores.view((batch_size * num_beams,)) | |
this_peer_finished = False # used by synced_gpus only | |
while True: | |
if synced_gpus: | |
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence. | |
# The following logic allows an early break if all peers finished generating their sequence | |
this_peer_finished_flag = torch.tensor( | |
0.0 if this_peer_finished else 1.0 | |
).to(input_ids.device) | |
# send 0.0 if we finished, 1.0 otherwise | |
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) | |
# did all peers finish? the reduced sum will be 0.0 then | |
if this_peer_finished_flag.item() == 0.0: | |
break | |
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) | |
outputs = self( | |
**model_inputs, | |
return_dict=True, | |
output_attentions=False, | |
output_hidden_states=False, | |
) | |
if synced_gpus and this_peer_finished: | |
cur_len = cur_len + 1 | |
continue # don't waste resources running the code we don't need | |
next_token_logits = outputs.logits[:, -1, :] | |
# next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) hack: adjust tokens for Marian. | |
next_token_scores = nn.functional.log_softmax( | |
next_token_logits, dim=-1 | |
) # (batch_size * num_beams, vocab_size) | |
next_token_scores_processed = logits_processor(input_ids, next_token_scores) | |
next_token_scores = next_token_scores_processed + beam_scores[ | |
:, None | |
].expand_as(next_token_scores) | |
# reshape for beam search | |
vocab_size = next_token_scores.shape[-1] | |
next_token_scores = next_token_scores.view( | |
batch_size, num_beams * vocab_size | |
) | |
# Sample 2 next tokens for each beam (so we have some spare tokens and match output of beam search) | |
next_token_scores, next_tokens = torch.topk( | |
next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True | |
) | |
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") | |
next_tokens = next_tokens % vocab_size | |
# stateless | |
beam_outputs = beam_scorer.process( | |
input_ids, | |
next_token_scores, | |
next_tokens, | |
next_indices, | |
pad_token_id=pad_token_id, | |
eos_token_id=eos_token_id, | |
beam_indices=None, | |
) | |
beam_scores = beam_outputs["next_beam_scores"] | |
beam_next_tokens = beam_outputs["next_beam_tokens"] | |
beam_idx = beam_outputs["next_beam_indices"] | |
input_ids = torch.cat( | |
[input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1 | |
) | |
model_kwargs = self._update_model_kwargs_for_generation( | |
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder | |
) | |
if model_kwargs["past_key_values"] is not None: | |
model_kwargs["past_key_values"] = self._reorder_cache( | |
model_kwargs["past_key_values"], beam_idx | |
) | |
# increase cur_len | |
cur_len = cur_len + 1 | |
yield input_ids | |
if beam_scorer.is_done or stopping_criteria(input_ids, None): | |
if not synced_gpus: | |
break | |
else: | |
this_peer_finished = True | |
final_result = beam_scorer.finalize( | |
input_ids, | |
beam_scores, | |
next_tokens, | |
next_indices, | |
pad_token_id=pad_token_id, | |
eos_token_id=eos_token_id, | |
max_length=stopping_criteria.max_length, | |
beam_indices=None, | |
) | |
yield final_result["sequences"] | |
class StreamLlamaForCausalLM(LlamaForCausalLM, StreamGenerationMixin): | |
pass | |
class StreamPeftGenerationMixin(PeftModelForCausalLM, StreamGenerationMixin): | |
# default it call `model = MODEL_TYPE_TO_PEFT_MODEL_MAPPING[config.task_type](model, config)`, not cls!! so inherent PeftModelForCausalLM is non sense | |
def from_pretrained(cls, model, model_id, adapter_name="default", is_trainable=False, **kwargs): | |
# work in peft==0.3.0 | |
if peft.__version__ >= '0.3.0' and peft.__version__ != '0.3.0.dev0': | |
# load the config | |
from peft.utils import PromptLearningConfig | |
config = LoraConfig.from_pretrained(model_id) | |
if (getattr(model, "hf_device_map", None) is not None) and len( | |
set(model.hf_device_map.values()).intersection({"cpu", "disk"}) | |
) > 0: | |
remove_hook_from_submodules(model) | |
if isinstance(config, PromptLearningConfig) and is_trainable: | |
raise ValueError("Cannot set a prompt learning adapter to trainable when loading pretrained adapter.") | |
else: | |
config.inference_mode = not is_trainable | |
# here is the hack | |
model = cls(model, config, adapter_name) | |
model.load_adapter(model_id, adapter_name, **kwargs) | |
# NOTICE | |
model.base_model_prepare_inputs_for_generation = model.base_model.prepare_inputs_for_generation | |
model._reorder_cache = model.base_model._reorder_cache | |
return model | |
else: | |
return cls.from_pretrained_old_peft_version(model, model_id, **kwargs) | |
def from_pretrained_old_peft_version(cls, model, model_id, **kwargs): | |
# work well in peft@e536616888d51b453ed354a6f1e243fecb02ea08 | |
# load the config | |
config = LoraConfig.from_pretrained(model_id) | |
if getattr(model, "hf_device_map", None) is not None: | |
remove_hook_from_submodules(model) | |
# here is the hack | |
model = cls(model, config) | |
model._reorder_cache = model.base_model._reorder_cache | |
# load weights if any | |
if os.path.exists(os.path.join(model_id, "adapter_model.bin")): | |
filename = os.path.join(model_id, "adapter_model.bin") | |
else: | |
try: | |
filename = hf_hub_download(model_id, "adapter_model.bin") | |
except: # noqa | |
raise ValueError( | |
f"Can't find weights for {model_id} in {model_id} or in the Hugging Face Hub. " | |
f"Please check that the file {'adapter_model.bin'} is present at {model_id}." | |
) | |
adapters_weights = torch.load( | |
filename, | |
map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"), | |
) | |
# load the weights into the model | |
model = set_peft_model_state_dict(model, adapters_weights) | |
if getattr(model, "hf_device_map", None) is not None: | |
device_map = kwargs.get("device_map", "auto") | |
max_memory = kwargs.get("max_memory", None) | |
no_split_module_classes = model._no_split_modules | |
if device_map != "sequential": | |
max_memory = get_balanced_memory( | |
model, | |
max_memory=max_memory, | |
no_split_module_classes=no_split_module_classes, | |
low_zero=(device_map == "balanced_low_0"), | |
) | |
if isinstance(device_map, str): | |
device_map = infer_auto_device_map( | |
model, | |
max_memory=max_memory, | |
no_split_module_classes=no_split_module_classes, | |
) | |
model = dispatch_model(model, device_map=device_map) | |
hook = AlignDevicesHook(io_same_device=True) | |
if model.peft_config.peft_type == PeftType.LORA: | |
add_hook_to_module(model.base_model.model, hook) | |
else: | |
remove_hook_from_submodules(model.prompt_encoder) | |
add_hook_to_module(model.base_model, hook) | |
return model | |