import copy import os from datetime import timedelta import sys from time import time from pathlib import Path from typing import List, Literal, Optional, Tuple, Union import torch import torch.nn.functional as F import transformers from accelerate import ( Accelerator, DistributedType, InitProcessGroupKwargs, find_executable_batch_size, ) from packaging import version from peft import PeftModel from peft import __version__ as PEFT_VERSION from tqdm import tqdm from transformers.models.auto.modeling_auto import ( MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, ) from transformers import TextStreamer from lm_eval import utils from lm_eval.api.instance import Instance from lm_eval.api.model import TemplateLM from lm_eval.api.registry import register_model from lm_eval.models.utils import ( Collator, clear_torch_cache, get_dtype, pad_and_concat, stop_sequences_criteria, ) from lm_eval.models.huggingface import HFLM class StopWatch(TextStreamer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.start_prefilling = None self.prefilling_time = None self.start_decoding = None self.decoding_time = None self.decoding_iterations = 0 def put(self, value): if self.start_prefilling is None: self.start_prefilling = time() return elif self.prefilling_time is None: self.prefilling_time = time() - self.start_prefilling self.start_decoding = time() self.decoding_iterations += 1 return def end(self): if self.decoding_time is None and self.start_decoding is not None: self.decoding_time = time() - self.start_decoding return class HFLMWithMeasurement(HFLM): def __init__(self, **kwargs): super().__init__(**kwargs) def _loglikelihood_tokens( self, requests: List[Tuple[Tuple[str, str], List[int], List[int]]], disable_tqdm: bool = False, override_bs: int = None, ) -> List[Tuple[float, bool]]: # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context res = [] def _collate(req: Tuple[Tuple[str, str], List[int], List[int]]): """Defines the key for the sorted method""" # the negative sign on len(toks) sorts descending - this has a few advantages: # - time estimates will always be over not underestimates, which is more useful for planning # - to know the size of a batch when going through the list, you know the first one is always the batch # padded context length. this is useful to simplify the batching logic and more importantly to make # automatic adaptive batches much much easier to implement # - any OOMs will happen right away rather than near the end toks = req[1] + req[2] return -len(toks), tuple(toks) def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]): """Defines the key to group and lookup one-token continuations""" # Use with group_by="contexts" (optional)" # allows for the creation of a lookup, so we can reuse logits in case of one-token continuations. # speeds up some multiple-choice tasks proportionally to the number of choices. # groups requests by context+continuation[:-1] and infer on one request/group. return req[-2] + req[-1][:-1] re_ord = Collator( requests, sort_fn=_collate, group_by="contexts" if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM and self.logits_cache else None, group_fn=_lookup_one_token_cont, ) # automatic (variable) batch size detection for vectorization # pull longest context sample from request n_reordered_requests = len(re_ord) batch_size = ( self.batch_size if self.batch_size != "auto" else override_bs if override_bs is not None else 0 ) batch_fn = ( self._batch_scheduler if self.batch_size == "auto" and n_reordered_requests > 0 and not override_bs else None ) chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn) pbar = tqdm( total=len(requests), disable=(disable_tqdm or (self.rank != 0)), desc="Running loglikelihood requests", ) for chunk in chunks: inps = [] cont_toks_list = [] inplens = [] conts = [] encoder_attns = [] padding_len_inp = None padding_len_cont = None # because vectorizing is annoying, we first convert each (context, continuation) pair to padded # tensors, then we pack them together into a batch, call the model, and then pick it all apart # again because vectorizing is annoying for _, context_enc, continuation_enc in chunk: # sanity check assert len(context_enc) > 0 assert len(continuation_enc) > 0 assert len(continuation_enc) <= self.max_length # how this all works (illustrated on a causal decoder-only setup): # CTX CONT # inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1] # model \ \ # logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the # cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice # when too long to fit in context, truncate from the left if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: inp = torch.tensor( (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1], dtype=torch.long, device=self.device, ) (inplen,) = inp.shape elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: inp = torch.tensor( (context_enc)[-self.max_length :], dtype=torch.long, device=self.device, ) (inplen,) = inp.shape # build encoder attn masks encoder_attns.append(torch.ones_like(inp)) cont = torch.tensor( (continuation_enc)[-self.max_length :], # TODO: left-shift these? # TODO: our code assumes we never end up truncating conts for either model type dtype=torch.long, device=self.device, ) (contlen,) = cont.shape conts.append(cont) padding_len_cont = ( max(padding_len_cont, contlen) if padding_len_cont is not None else contlen ) padding_len_inp = ( max(padding_len_inp, inplen) if padding_len_inp is not None else inplen ) inps.append(inp) # [1, inp_length] cont_toks_list.append(continuation_enc) inplens.append(inplen) # create encoder attn mask and batched conts, if seq2seq call_kwargs = {} if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: batched_inps = pad_and_concat( padding_len_inp, inps, padding_side="right" ) # [batch, padding_len_inp] elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: # TODO: left-pad encoder inps and mask? batched_inps = pad_and_concat( padding_len_inp, inps ) # [batch, padding_len_inp] batched_conts = pad_and_concat( padding_len_cont, conts ) # [batch, padding_len_cont] batched_encoder_mask = pad_and_concat( padding_len_inp, encoder_attns ) # [batch, padding_len_inp] call_kwargs = { "attn_mask": batched_encoder_mask, "labels": batched_conts, } start = time() intermediate_res = self._model_call(batched_inps, **call_kwargs) end = time() multi_logits = F.log_softmax( intermediate_res , dim=-1 ) # [batch, padding_length (inp or cont), vocab] per_sample_time = (end - start) / len(multi_logits) for (request_str, ctx_tokens, _), logits, inplen, cont_toks in zip( chunk, multi_logits, inplens, cont_toks_list ): # Slice to original seq length contlen = len(cont_toks) # take only logits in the continuation # (discard context toks if decoder-only ; discard right-padding) # also discards + checks for "virtual tokens" in the causal LM's input window # from prompt/prefix tuning tokens, if applicable ctx_len = ( inplen + (logits.shape[0] - padding_len_inp) if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM else None ) logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len) logits = logits.unsqueeze(0) # [1, seq, vocab] # Check if per-token argmax is exactly equal to continuation greedy_tokens = logits.argmax(dim=-1) # check for one-token continuation cache hits. # noop in case group_by != "contexts" or no cache hit and returns the # original args. Otherwise, expands the logits batch dimension and yields each # batch along with matching continuation tokens and prompt strings. # logits -> [1, seq, vocab] for request_str, cont_toks, logits in re_ord.get_cache( req_str=request_str, cxt_toks=ctx_tokens, cont_toks=cont_toks, logits=logits, ): cont_toks = torch.tensor( cont_toks, dtype=torch.long, device=self.device ).unsqueeze(0) # [1, seq] max_equal = (greedy_tokens == cont_toks).all() # Obtain log-probs at the corresponding continuation token indices # last_token_slice = logits[:, -1, :].squeeze(0).tolist() logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze( -1 ) # [1, seq] # Answer: (log prob, is-exact-match) answer = (float(logits.sum()), bool(max_equal)) res.append((answer, per_sample_time, 0, 0)) self.cache_hook.add_partial("loglikelihood", request_str, answer) pbar.update(1) pbar.close() return re_ord.get_original(res) def _model_generate(self, context, max_length, stop, **generation_kwargs): # temperature = 0.0 if not set # if do_sample is false and temp==0.0: # remove temperature, as do_sample=False takes care of this # and we don't want a warning from HF generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0) do_sample = generation_kwargs.get("do_sample", None) # The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies if generation_kwargs.get("temperature") == 0.0 and do_sample is None: generation_kwargs["do_sample"] = do_sample = False if do_sample is False and generation_kwargs.get("temperature") == 0.0: generation_kwargs.pop("temperature") # build stopping criteria stopping_criteria = stop_sequences_criteria( self.tokenizer, stop, context.shape[1], context.shape[0] ) stop_watch = StopWatch(self.tokenizer) start = time() res = self.model.generate( input_ids=context, max_length=max_length, stopping_criteria=stopping_criteria, pad_token_id=self.tokenizer.pad_token_id, use_cache=True, streamer=stop_watch, **generation_kwargs, ) end = time() batch_size = context.shape[0] output_length = stop_watch.decoding_iterations end_to_end_time = (end - start) / batch_size prefilling_time = stop_watch.prefilling_time / batch_size decoding_time = stop_watch.decoding_time / batch_size token_per_sec = output_length / decoding_time return res, end_to_end_time, prefilling_time, token_per_sec def generate_until( self, requests: List[Instance], disable_tqdm: bool = False ) -> List[str]: res = [] def _collate(req: Tuple[str, dict]): """Defines the key for the sorted method""" # the negative sign on len(toks) sorts descending - this has a few advantages: # - time estimates will always be over not underestimates, which is more useful for planning # - to know the size of a batch when going through the list, you know the first one is always the batch # padded context length. this is useful to simplify the batching logic and more importantly to make # automatic adaptive batches much much easier to implement # - any OOMs will happen right away rather than near the end toks = self.tok_encode(req[0]) return -len(toks), req[0] pbar = tqdm( total=len(requests), disable=(disable_tqdm or (self.rank != 0)), desc="Running generate_until requests", ) adaptive_batch_size = None if self.batch_size == "auto": # using rolling window with maximum context print("Passed argument batch_size = auto. Detecting largest batch size") batch_size = self._detect_batch_size() print(f"Determined Largest batch size: {batch_size}") adaptive_batch_size = batch_size # for each different set of kwargs, we execute all requests, by batch. batch_size = ( self.batch_size if self.batch_size != "auto" else adaptive_batch_size if adaptive_batch_size is not None else 0 ) batch_fn = ( self._batch_scheduler if self.batch_size == "auto" and not adaptive_batch_size else None ) # we group requests by their generation_kwargs, # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling # in the same batch. # group_fn=lambda x: x[1] -> x=(context, gen_kwargs) re_ords = Collator( [reg.args for reg in requests], sort_fn=_collate, group_by="gen_kwargs", group_fn=lambda x: x[1], ) chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn) for chunk in chunks: contexts, all_gen_kwargs = zip(*chunk) # we assume all gen kwargs in the batch are the same # this is safe to assume because the `grouper` object ensures it. gen_kwargs = all_gen_kwargs[0] # unpack our keyword arguments. until = None if isinstance(gen_kwargs, dict): kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1 if "until" in kwargs.keys(): until = kwargs.pop("until") if isinstance(until, str): until = [kwargs] elif not isinstance(until, list): raise ValueError( f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}" ) else: raise ValueError( f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}" ) # add EOS token to stop sequences eos = self.tok_decode(self.eot_token_id) if not until: until = [eos] else: until.append(eos) if "max_gen_toks" in kwargs.keys(): max_gen_toks = kwargs.pop("max_gen_toks") else: max_gen_toks = self.max_gen_toks # set the max length in tokens of inputs ("context_enc") if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: # max len for inputs = max length, minus room to generate the max new tokens max_ctx_len = self.max_length - max_gen_toks elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: # max len for inputs = encoder's whole max_length max_ctx_len = self.max_length # encode, pad, and truncate contexts for this batch context_enc, attn_masks = self.tok_batch_encode( contexts, left_truncate_len=max_ctx_len, truncation=self.truncation, ) context_enc = context_enc.to(self.device) attn_masks = attn_masks.to(self.device) if "max_length" not in kwargs: kwargs["max_length"] = context_enc.shape[1] + max_gen_toks # perform batched generation cont, end_to_end_time, prefilling_time, token_per_sec = self._model_generate( context=context_enc, attention_mask=attn_masks, stop=until, **kwargs, ) cont_toks_list = cont.tolist() for cont_toks, context in zip(cont_toks_list, contexts): # discard context + left-padding toks if using causal decoder-only LM if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: cont_toks = cont_toks[context_enc.shape[1] :] s = self.tok_decode(cont_toks) # use secondary stop seqs to cut off should-have-been-stopped content post-hoc for term in until: if len(term) > 0: # ignore '' separator, # for seq2seq case where self.tok_decode(self.eot_token_id) = '' s = s.split(term)[0] res.append((s, end_to_end_time, prefilling_time, token_per_sec)) self.cache_hook.add_partial("generate_until", (context, gen_kwargs), s) pbar.update(1) # reorder this group of results back to original unsorted form res = re_ords.get_original(res) pbar.close() return res