|
from typing import List, Literal, Optional, Tuple, Union |
|
import torch |
|
import transformers |
|
|
|
from lm_eval.api.registry import register_model |
|
|
|
from src.backend.hflm_with_measurement import HFLMWithMeasurement |
|
|
|
|
|
@register_model("hf-chat") |
|
class HFLMwithChatTemplate(HFLMWithMeasurement): |
|
def __init__(self, use_chat_template=True, **kwargs): |
|
super().__init__(**kwargs) |
|
self.use_chat_template = use_chat_template |
|
|
|
def tok_batch_encode( |
|
self, |
|
strings: List[str], |
|
padding_side: str = "left", |
|
left_truncate_len: int = None, |
|
truncation: bool = False, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
if self.use_chat_template: |
|
try: |
|
updated_strings = [] |
|
for input_string in strings: |
|
messages = [ |
|
{"role": "user", "content": f"{input_string}"}, |
|
] |
|
if "dbrx" in self.model.name_or_path: |
|
updated_string = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
else: |
|
updated_string = self.tokenizer.apply_chat_template(messages, tokenize=False) |
|
updated_strings.append(updated_string) |
|
strings = updated_strings[:] |
|
except: |
|
print(f"failed to update input string with chat template: {self._model}") |
|
|
|
old_padding_side = self.tokenizer.padding_side |
|
self.tokenizer.padding_side = padding_side |
|
|
|
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: |
|
add_special_tokens = False |
|
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: |
|
add_special_tokens = True |
|
|
|
encoding = self.tokenizer( |
|
strings, |
|
truncation=truncation, |
|
padding="longest", |
|
return_tensors="pt", |
|
add_special_tokens=add_special_tokens, |
|
) |
|
if left_truncate_len: |
|
encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:] |
|
encoding["attention_mask"] = encoding["attention_mask"][:, -left_truncate_len:] |
|
self.tokenizer.padding_side = old_padding_side |
|
|
|
return encoding["input_ids"], encoding["attention_mask"] |
|
|