|
import copy |
|
import os |
|
from datetime import timedelta |
|
import sys |
|
from time import time |
|
from pathlib import Path |
|
from typing import List, Literal, Optional, Tuple, Union |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
import transformers |
|
from accelerate import ( |
|
Accelerator, |
|
DistributedType, |
|
InitProcessGroupKwargs, |
|
find_executable_batch_size, |
|
) |
|
from packaging import version |
|
from peft import PeftModel |
|
from peft import __version__ as PEFT_VERSION |
|
from tqdm import tqdm |
|
from transformers.models.auto.modeling_auto import ( |
|
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, |
|
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, |
|
) |
|
from transformers import TextStreamer |
|
|
|
from lm_eval import utils |
|
from lm_eval.api.instance import Instance |
|
from lm_eval.api.model import TemplateLM |
|
from lm_eval.api.registry import register_model |
|
from lm_eval.models.utils import ( |
|
Collator, |
|
clear_torch_cache, |
|
get_dtype, |
|
pad_and_concat, |
|
stop_sequences_criteria, |
|
) |
|
from lm_eval.models.huggingface import HFLM |
|
|
|
|
|
class StopWatch(TextStreamer): |
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.start_prefilling = None |
|
self.prefilling_time = None |
|
self.start_decoding = None |
|
self.decoding_time = None |
|
self.decoding_iterations = 0 |
|
|
|
def put(self, value): |
|
if self.start_prefilling is None: |
|
self.start_prefilling = time() |
|
return |
|
elif self.prefilling_time is None: |
|
self.prefilling_time = time() - self.start_prefilling |
|
self.start_decoding = time() |
|
self.decoding_iterations += 1 |
|
return |
|
|
|
def end(self): |
|
if self.decoding_time is None and self.start_decoding is not None: |
|
self.decoding_time = time() - self.start_decoding |
|
return |
|
|
|
|
|
class HFLMWithMeasurement(HFLM): |
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
|
|
def _model_generate(self, context, max_length, stop, **generation_kwargs): |
|
|
|
|
|
|
|
|
|
generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0) |
|
do_sample = generation_kwargs.get("do_sample", None) |
|
|
|
|
|
if generation_kwargs.get("temperature") == 0.0 and do_sample is None: |
|
generation_kwargs["do_sample"] = do_sample = False |
|
|
|
if do_sample is False and generation_kwargs.get("temperature") == 0.0: |
|
generation_kwargs.pop("temperature") |
|
|
|
stopping_criteria = stop_sequences_criteria( |
|
self.tokenizer, stop, context.shape[1], context.shape[0] |
|
) |
|
stop_watch = StopWatch(self.tokenizer) |
|
start = time() |
|
res = self.model.generate( |
|
input_ids=context, |
|
max_length=max_length, |
|
stopping_criteria=stopping_criteria, |
|
pad_token_id=self.tokenizer.pad_token_id, |
|
use_cache=True, |
|
streamer=stop_watch, |
|
**generation_kwargs, |
|
) |
|
end = time() |
|
|
|
batch_size = context.shape[0] |
|
output_length = stop_watch.decoding_iterations |
|
|
|
end_to_end_time = (end - start) / batch_size |
|
prefilling_time = stop_watch.prefilling_time / batch_size |
|
decoding_time = stop_watch.decoding_time / batch_size |
|
token_per_sec = output_length / decoding_time |
|
return res, end_to_end_time, prefilling_time, token_per_sec |
|
|
|
def generate_until( |
|
self, requests: List[Instance], disable_tqdm: bool = False |
|
) -> List[str]: |
|
res = [] |
|
|
|
def _collate(req: Tuple[str, dict]): |
|
"""Defines the key for the sorted method""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
toks = self.tok_encode(req[0]) |
|
return -len(toks), req[0] |
|
|
|
pbar = tqdm( |
|
total=len(requests), |
|
disable=(disable_tqdm or (self.rank != 0)), |
|
desc="Running generate_until requests", |
|
) |
|
adaptive_batch_size = None |
|
if self.batch_size == "auto": |
|
|
|
print("Passed argument batch_size = auto. Detecting largest batch size") |
|
batch_size = self._detect_batch_size() |
|
print(f"Determined Largest batch size: {batch_size}") |
|
adaptive_batch_size = batch_size |
|
|
|
batch_size = ( |
|
self.batch_size |
|
if self.batch_size != "auto" |
|
else adaptive_batch_size |
|
if adaptive_batch_size is not None |
|
else 0 |
|
) |
|
batch_fn = ( |
|
self._batch_scheduler |
|
if self.batch_size == "auto" and not adaptive_batch_size |
|
else None |
|
) |
|
|
|
|
|
|
|
|
|
|
|
re_ords = Collator( |
|
[reg.args for reg in requests], |
|
sort_fn=_collate, |
|
group_by="gen_kwargs", |
|
group_fn=lambda x: x[1], |
|
) |
|
chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn) |
|
for chunk in chunks: |
|
contexts, all_gen_kwargs = zip(*chunk) |
|
|
|
|
|
gen_kwargs = all_gen_kwargs[0] |
|
|
|
until = None |
|
if isinstance(gen_kwargs, dict): |
|
kwargs = copy.deepcopy(gen_kwargs) |
|
if "until" in kwargs.keys(): |
|
until = kwargs.pop("until") |
|
if isinstance(until, str): |
|
until = [kwargs] |
|
elif not isinstance(until, list): |
|
raise ValueError( |
|
f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}" |
|
) |
|
else: |
|
raise ValueError( |
|
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}" |
|
) |
|
|
|
eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False) |
|
if not until: |
|
until = [eos] |
|
else: |
|
until.append(eos) |
|
if "max_gen_toks" in kwargs.keys(): |
|
max_gen_toks = kwargs.pop("max_gen_toks") |
|
else: |
|
max_gen_toks = self.max_gen_toks |
|
|
|
|
|
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: |
|
|
|
max_ctx_len = self.max_length - max_gen_toks |
|
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: |
|
|
|
max_ctx_len = self.max_length |
|
|
|
|
|
context_enc, attn_masks = self.tok_batch_encode( |
|
contexts, |
|
left_truncate_len=max_ctx_len, |
|
truncation=self.truncation, |
|
) |
|
context_enc = context_enc.to(self.device) |
|
attn_masks = attn_masks.to(self.device) |
|
|
|
if "max_length" not in kwargs: |
|
kwargs["max_length"] = context_enc.shape[1] + max_gen_toks |
|
|
|
|
|
cont, end_to_end_time, prefilling_time, token_per_sec = self._model_generate( |
|
context=context_enc, |
|
attention_mask=attn_masks, |
|
stop=until, |
|
**kwargs, |
|
) |
|
|
|
cont_toks_list = cont.tolist() |
|
for cont_toks, context in zip(cont_toks_list, contexts): |
|
|
|
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: |
|
cont_toks = cont_toks[context_enc.shape[1] :] |
|
|
|
s = self.tok_decode(cont_toks) |
|
|
|
|
|
for term in until: |
|
if len(term) > 0: |
|
|
|
|
|
s = s.split(term)[0] |
|
|
|
res.append((s, end_to_end_time, prefilling_time, token_per_sec)) |
|
|
|
self.cache_hook.add_partial("generate_until", (context, gen_kwargs), s) |
|
pbar.update(1) |
|
|
|
res = re_ords.get_original(res) |
|
|
|
pbar.close() |
|
|
|
return res |
|
|