import copy import os from datetime import timedelta import sys from time import time from pathlib import Path from typing import List, Literal, Optional, Tuple, Union import torch import torch.nn.functional as F import transformers from accelerate import ( Accelerator, DistributedType, InitProcessGroupKwargs, find_executable_batch_size, ) from packaging import version from peft import PeftModel from peft import __version__ as PEFT_VERSION from tqdm import tqdm from transformers.models.auto.modeling_auto import ( MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, ) from transformers import TextStreamer from lm_eval import utils from lm_eval.api.instance import Instance from lm_eval.api.model import TemplateLM from lm_eval.api.registry import register_model from lm_eval.models.utils import ( Collator, clear_torch_cache, get_dtype, pad_and_concat, stop_sequences_criteria, ) from lm_eval.models.huggingface import HFLM class StopWatch(TextStreamer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.start_prefilling = None self.prefilling_time = None self.start_decoding = None self.decoding_time = None self.decoding_iterations = 0 def put(self, value): if self.start_prefilling is None: self.start_prefilling = time() return elif self.prefilling_time is None: self.prefilling_time = time() - self.start_prefilling self.start_decoding = time() self.decoding_iterations += 1 return def end(self): if self.decoding_time is None and self.start_decoding is not None: self.decoding_time = time() - self.start_decoding return class HFLMWithMeasurement(HFLM): def __init__(self, **kwargs): super().__init__(**kwargs) def _model_generate(self, context, max_length, stop, **generation_kwargs): # temperature = 0.0 if not set # if do_sample is false and temp==0.0: # remove temperature, as do_sample=False takes care of this # and we don't want a warning from HF generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0) do_sample = generation_kwargs.get("do_sample", None) # The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies if generation_kwargs.get("temperature") == 0.0 and do_sample is None: generation_kwargs["do_sample"] = do_sample = False if do_sample is False and generation_kwargs.get("temperature") == 0.0: generation_kwargs.pop("temperature") # build stopping criteria stopping_criteria = stop_sequences_criteria( self.tokenizer, stop, context.shape[1], context.shape[0] ) stop_watch = StopWatch(self.tokenizer) start = time() res = self.model.generate( input_ids=context, max_length=max_length, stopping_criteria=stopping_criteria, pad_token_id=self.tokenizer.pad_token_id, use_cache=True, streamer=stop_watch, **generation_kwargs, ) end = time() batch_size = context.shape[0] output_length = stop_watch.decoding_iterations end_to_end_time = (end - start) / batch_size prefilling_time = stop_watch.prefilling_time / batch_size decoding_time = stop_watch.decoding_time / batch_size token_per_sec = output_length / decoding_time return res, end_to_end_time, prefilling_time, token_per_sec def generate_until( self, requests: List[Instance], disable_tqdm: bool = False ) -> List[str]: res = [] def _collate(req: Tuple[str, dict]): """Defines the key for the sorted method""" # the negative sign on len(toks) sorts descending - this has a few advantages: # - time estimates will always be over not underestimates, which is more useful for planning # - to know the size of a batch when going through the list, you know the first one is always the batch # padded context length. this is useful to simplify the batching logic and more importantly to make # automatic adaptive batches much much easier to implement # - any OOMs will happen right away rather than near the end toks = self.tok_encode(req[0]) return -len(toks), req[0] pbar = tqdm( total=len(requests), disable=(disable_tqdm or (self.rank != 0)), desc="Running generate_until requests", ) adaptive_batch_size = None if self.batch_size == "auto": # using rolling window with maximum context print("Passed argument batch_size = auto. Detecting largest batch size") batch_size = self._detect_batch_size() print(f"Determined Largest batch size: {batch_size}") adaptive_batch_size = batch_size # for each different set of kwargs, we execute all requests, by batch. batch_size = ( self.batch_size if self.batch_size != "auto" else adaptive_batch_size if adaptive_batch_size is not None else 0 ) batch_fn = ( self._batch_scheduler if self.batch_size == "auto" and not adaptive_batch_size else None ) # we group requests by their generation_kwargs, # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling # in the same batch. # group_fn=lambda x: x[1] -> x=(context, gen_kwargs) re_ords = Collator( [reg.args for reg in requests], sort_fn=_collate, group_by="gen_kwargs", group_fn=lambda x: x[1], ) chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn) for chunk in chunks: contexts, all_gen_kwargs = zip(*chunk) # we assume all gen kwargs in the batch are the same # this is safe to assume because the `grouper` object ensures it. gen_kwargs = all_gen_kwargs[0] # unpack our keyword arguments. until = None if isinstance(gen_kwargs, dict): kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1 if "until" in kwargs.keys(): until = kwargs.pop("until") if isinstance(until, str): until = [kwargs] elif not isinstance(until, list): raise ValueError( f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}" ) else: raise ValueError( f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}" ) # add EOS token to stop sequences eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False) if not until: until = [eos] else: until.append(eos) if "max_gen_toks" in kwargs.keys(): max_gen_toks = kwargs.pop("max_gen_toks") else: max_gen_toks = self.max_gen_toks # set the max length in tokens of inputs ("context_enc") if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: # max len for inputs = max length, minus room to generate the max new tokens max_ctx_len = self.max_length - max_gen_toks elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: # max len for inputs = encoder's whole max_length max_ctx_len = self.max_length # encode, pad, and truncate contexts for this batch context_enc, attn_masks = self.tok_batch_encode( contexts, left_truncate_len=max_ctx_len, truncation=self.truncation, ) context_enc = context_enc.to(self.device) attn_masks = attn_masks.to(self.device) if "max_length" not in kwargs: kwargs["max_length"] = context_enc.shape[1] + max_gen_toks # perform batched generation cont, end_to_end_time, prefilling_time, token_per_sec = self._model_generate( context=context_enc, attention_mask=attn_masks, stop=until, **kwargs, ) cont_toks_list = cont.tolist() for cont_toks, context in zip(cont_toks_list, contexts): # discard context + left-padding toks if using causal decoder-only LM if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: cont_toks = cont_toks[context_enc.shape[1] :] s = self.tok_decode(cont_toks) # use secondary stop seqs to cut off should-have-been-stopped content post-hoc for term in until: if len(term) > 0: # ignore '' separator, # for seq2seq case where self.tok_decode(self.eot_token_id) = '' s = s.split(term)[0] res.append((s, end_to_end_time, prefilling_time, token_per_sec)) self.cache_hook.add_partial("generate_until", (context, gen_kwargs), s) pbar.update(1) # reorder this group of results back to original unsorted form res = re_ords.get_original(res) pbar.close() return res