|
import copy |
|
import os |
|
from datetime import timedelta |
|
import sys |
|
from time import time |
|
from pathlib import Path |
|
from typing import List, Literal, Optional, Tuple, Union |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
import transformers |
|
from accelerate import ( |
|
Accelerator, |
|
DistributedType, |
|
InitProcessGroupKwargs, |
|
find_executable_batch_size, |
|
) |
|
from packaging import version |
|
from peft import PeftModel |
|
from peft import __version__ as PEFT_VERSION |
|
from tqdm import tqdm |
|
from transformers.models.auto.modeling_auto import ( |
|
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, |
|
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, |
|
) |
|
from transformers import TextStreamer |
|
|
|
from lm_eval import utils |
|
from lm_eval.api.instance import Instance |
|
from lm_eval.api.model import TemplateLM |
|
from lm_eval.api.registry import register_model |
|
from lm_eval.models.utils import ( |
|
Collator, |
|
clear_torch_cache, |
|
get_dtype, |
|
pad_and_concat, |
|
stop_sequences_criteria, |
|
) |
|
from lm_eval.models.huggingface import HFLM |
|
|
|
|
|
class StopWatch(TextStreamer): |
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.start_prefilling = None |
|
self.prefilling_time = None |
|
self.start_decoding = None |
|
self.decoding_time = None |
|
self.decoding_iterations = 0 |
|
|
|
def put(self, value): |
|
if self.start_prefilling is None: |
|
self.start_prefilling = time() |
|
return |
|
elif self.prefilling_time is None: |
|
self.prefilling_time = time() - self.start_prefilling |
|
self.start_decoding = time() |
|
self.decoding_iterations += 1 |
|
return |
|
|
|
def end(self): |
|
if self.decoding_time is None and self.start_decoding is not None: |
|
self.decoding_time = time() - self.start_decoding |
|
return |
|
|
|
|
|
class HFLMWithMeasurement(HFLM): |
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
|
|
def _loglikelihood_tokens( |
|
self, |
|
requests: List[Tuple[Tuple[str, str], List[int], List[int]]], |
|
disable_tqdm: bool = False, |
|
override_bs: int = None, |
|
) -> List[Tuple[float, bool]]: |
|
|
|
res = [] |
|
|
|
def _collate(req: Tuple[Tuple[str, str], List[int], List[int]]): |
|
"""Defines the key for the sorted method""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
toks = req[1] + req[2] |
|
return -len(toks), tuple(toks) |
|
|
|
def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]): |
|
"""Defines the key to group and lookup one-token continuations""" |
|
|
|
|
|
|
|
|
|
return req[-2] + req[-1][:-1] |
|
|
|
re_ord = Collator( |
|
requests, |
|
sort_fn=_collate, |
|
group_by="contexts" |
|
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM |
|
and self.logits_cache |
|
else None, |
|
group_fn=_lookup_one_token_cont, |
|
) |
|
|
|
|
|
|
|
n_reordered_requests = len(re_ord) |
|
batch_size = ( |
|
self.batch_size |
|
if self.batch_size != "auto" |
|
else override_bs |
|
if override_bs is not None |
|
else 0 |
|
) |
|
batch_fn = ( |
|
self._batch_scheduler |
|
if self.batch_size == "auto" |
|
and n_reordered_requests > 0 |
|
and not override_bs |
|
else None |
|
) |
|
|
|
chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn) |
|
pbar = tqdm( |
|
total=len(requests), |
|
disable=(disable_tqdm or (self.rank != 0)), |
|
desc="Running loglikelihood requests", |
|
) |
|
for chunk in chunks: |
|
inps = [] |
|
cont_toks_list = [] |
|
inplens = [] |
|
|
|
conts = [] |
|
encoder_attns = [] |
|
|
|
padding_len_inp = None |
|
padding_len_cont = None |
|
|
|
|
|
|
|
|
|
for _, context_enc, continuation_enc in chunk: |
|
|
|
assert len(context_enc) > 0 |
|
assert len(continuation_enc) > 0 |
|
assert len(continuation_enc) <= self.max_length |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: |
|
inp = torch.tensor( |
|
(context_enc + continuation_enc)[-(self.max_length + 1) :][:-1], |
|
dtype=torch.long, |
|
device=self.device, |
|
) |
|
(inplen,) = inp.shape |
|
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: |
|
inp = torch.tensor( |
|
(context_enc)[-self.max_length :], |
|
dtype=torch.long, |
|
device=self.device, |
|
) |
|
(inplen,) = inp.shape |
|
|
|
|
|
encoder_attns.append(torch.ones_like(inp)) |
|
|
|
cont = torch.tensor( |
|
(continuation_enc)[-self.max_length :], |
|
|
|
|
|
dtype=torch.long, |
|
device=self.device, |
|
) |
|
(contlen,) = cont.shape |
|
|
|
conts.append(cont) |
|
|
|
padding_len_cont = ( |
|
max(padding_len_cont, contlen) |
|
if padding_len_cont is not None |
|
else contlen |
|
) |
|
|
|
padding_len_inp = ( |
|
max(padding_len_inp, inplen) |
|
if padding_len_inp is not None |
|
else inplen |
|
) |
|
|
|
inps.append(inp) |
|
cont_toks_list.append(continuation_enc) |
|
inplens.append(inplen) |
|
|
|
|
|
call_kwargs = {} |
|
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: |
|
batched_inps = pad_and_concat( |
|
padding_len_inp, inps, padding_side="right" |
|
) |
|
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: |
|
|
|
batched_inps = pad_and_concat( |
|
padding_len_inp, inps |
|
) |
|
batched_conts = pad_and_concat( |
|
padding_len_cont, conts |
|
) |
|
batched_encoder_mask = pad_and_concat( |
|
padding_len_inp, encoder_attns |
|
) |
|
call_kwargs = { |
|
"attn_mask": batched_encoder_mask, |
|
"labels": batched_conts, |
|
} |
|
|
|
start = time() |
|
intermediate_res = self._model_call(batched_inps, **call_kwargs) |
|
end = time() |
|
multi_logits = F.log_softmax( |
|
intermediate_res , dim=-1 |
|
) |
|
per_sample_time = (end - start) / len(multi_logits) |
|
|
|
for (request_str, ctx_tokens, _), logits, inplen, cont_toks in zip( |
|
chunk, multi_logits, inplens, cont_toks_list |
|
): |
|
|
|
contlen = len(cont_toks) |
|
|
|
|
|
|
|
|
|
ctx_len = ( |
|
inplen + (logits.shape[0] - padding_len_inp) |
|
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM |
|
else None |
|
) |
|
logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len) |
|
logits = logits.unsqueeze(0) |
|
|
|
|
|
greedy_tokens = logits.argmax(dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
for request_str, cont_toks, logits in re_ord.get_cache( |
|
req_str=request_str, |
|
cxt_toks=ctx_tokens, |
|
cont_toks=cont_toks, |
|
logits=logits, |
|
): |
|
cont_toks = torch.tensor( |
|
cont_toks, dtype=torch.long, device=self.device |
|
).unsqueeze(0) |
|
max_equal = (greedy_tokens == cont_toks).all() |
|
|
|
|
|
|
|
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze( |
|
-1 |
|
) |
|
|
|
|
|
answer = (float(logits.sum()), bool(max_equal)) |
|
|
|
res.append((answer, per_sample_time, 0, 0)) |
|
|
|
self.cache_hook.add_partial("loglikelihood", request_str, answer) |
|
pbar.update(1) |
|
|
|
pbar.close() |
|
|
|
return re_ord.get_original(res) |
|
|
|
def _model_generate(self, context, max_length, stop, **generation_kwargs): |
|
|
|
|
|
|
|
|
|
generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0) |
|
do_sample = generation_kwargs.get("do_sample", None) |
|
|
|
|
|
if generation_kwargs.get("temperature") == 0.0 and do_sample is None: |
|
generation_kwargs["do_sample"] = do_sample = False |
|
|
|
if do_sample is False and generation_kwargs.get("temperature") == 0.0: |
|
generation_kwargs.pop("temperature") |
|
|
|
stopping_criteria = stop_sequences_criteria( |
|
self.tokenizer, stop, context.shape[1], context.shape[0] |
|
) |
|
stop_watch = StopWatch(self.tokenizer) |
|
start = time() |
|
res = self.model.generate( |
|
input_ids=context, |
|
max_length=max_length, |
|
stopping_criteria=stopping_criteria, |
|
pad_token_id=self.tokenizer.pad_token_id, |
|
use_cache=True, |
|
streamer=stop_watch, |
|
**generation_kwargs, |
|
) |
|
end = time() |
|
|
|
batch_size = context.shape[0] |
|
output_length = stop_watch.decoding_iterations |
|
|
|
end_to_end_time = (end - start) / batch_size |
|
prefilling_time = stop_watch.prefilling_time / batch_size |
|
decoding_time = stop_watch.decoding_time / batch_size |
|
token_per_sec = output_length / decoding_time |
|
return res, end_to_end_time, prefilling_time, token_per_sec |
|
|
|
def generate_until( |
|
self, requests: List[Instance], disable_tqdm: bool = False |
|
) -> List[str]: |
|
res = [] |
|
|
|
def _collate(req: Tuple[str, dict]): |
|
"""Defines the key for the sorted method""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
toks = self.tok_encode(req[0]) |
|
return -len(toks), req[0] |
|
|
|
pbar = tqdm( |
|
total=len(requests), |
|
disable=(disable_tqdm or (self.rank != 0)), |
|
desc="Running generate_until requests", |
|
) |
|
adaptive_batch_size = None |
|
if self.batch_size == "auto": |
|
|
|
print("Passed argument batch_size = auto. Detecting largest batch size") |
|
batch_size = self._detect_batch_size() |
|
print(f"Determined Largest batch size: {batch_size}") |
|
adaptive_batch_size = batch_size |
|
|
|
batch_size = ( |
|
self.batch_size |
|
if self.batch_size != "auto" |
|
else adaptive_batch_size |
|
if adaptive_batch_size is not None |
|
else 0 |
|
) |
|
batch_fn = ( |
|
self._batch_scheduler |
|
if self.batch_size == "auto" and not adaptive_batch_size |
|
else None |
|
) |
|
|
|
|
|
|
|
|
|
|
|
re_ords = Collator( |
|
[reg.args for reg in requests], |
|
sort_fn=_collate, |
|
group_by="gen_kwargs", |
|
group_fn=lambda x: x[1], |
|
) |
|
chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn) |
|
for chunk in chunks: |
|
contexts, all_gen_kwargs = zip(*chunk) |
|
|
|
|
|
gen_kwargs = all_gen_kwargs[0] |
|
|
|
until = None |
|
if isinstance(gen_kwargs, dict): |
|
kwargs = copy.deepcopy(gen_kwargs) |
|
if "until" in kwargs.keys(): |
|
until = kwargs.pop("until") |
|
if isinstance(until, str): |
|
until = [kwargs] |
|
elif not isinstance(until, list): |
|
raise ValueError( |
|
f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}" |
|
) |
|
else: |
|
raise ValueError( |
|
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}" |
|
) |
|
|
|
eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False) |
|
if not until: |
|
until = [eos] |
|
else: |
|
until.append(eos) |
|
if "max_gen_toks" in kwargs.keys(): |
|
max_gen_toks = kwargs.pop("max_gen_toks") |
|
else: |
|
max_gen_toks = self.max_gen_toks |
|
|
|
|
|
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: |
|
|
|
max_ctx_len = self.max_length - max_gen_toks |
|
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: |
|
|
|
max_ctx_len = self.max_length |
|
|
|
|
|
context_enc, attn_masks = self.tok_batch_encode( |
|
contexts, |
|
left_truncate_len=max_ctx_len, |
|
truncation=self.truncation, |
|
) |
|
context_enc = context_enc.to(self.device) |
|
attn_masks = attn_masks.to(self.device) |
|
|
|
if "max_length" not in kwargs: |
|
kwargs["max_length"] = context_enc.shape[1] + max_gen_toks |
|
|
|
|
|
cont, end_to_end_time, prefilling_time, token_per_sec = self._model_generate( |
|
context=context_enc, |
|
attention_mask=attn_masks, |
|
stop=until, |
|
**kwargs, |
|
) |
|
|
|
cont_toks_list = cont.tolist() |
|
for cont_toks, context in zip(cont_toks_list, contexts): |
|
|
|
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: |
|
cont_toks = cont_toks[context_enc.shape[1] :] |
|
|
|
s = self.tok_decode(cont_toks) |
|
|
|
|
|
for term in until: |
|
if len(term) > 0: |
|
|
|
|
|
s = s.split(term)[0] |
|
|
|
res.append((s, end_to_end_time, prefilling_time, token_per_sec)) |
|
|
|
self.cache_hook.add_partial("generate_until", (context, gen_kwargs), s) |
|
pbar.update(1) |
|
|
|
res = re_ords.get_original(res) |
|
|
|
pbar.close() |
|
|
|
return res |
|
|