Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
fix generation bugs
Browse files
src/backend/huggingface_generate_until.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Literal, Optional, Tuple, Union
|
2 |
+
import torch
|
3 |
+
import transformers
|
4 |
+
|
5 |
+
from lm_eval.models.huggingface import HFLM
|
6 |
+
from lm_eval.api.registry import register_model
|
7 |
+
|
8 |
+
@register_model('hf-chat')
|
9 |
+
class HFLMwithChatTemplate(HFLM):
|
10 |
+
def __init__(self, use_chat_template=True, **kwargs):
|
11 |
+
super().__init__(**kwargs)
|
12 |
+
self.use_chat_template = use_chat_template
|
13 |
+
|
14 |
+
def tok_batch_encode(
|
15 |
+
self,
|
16 |
+
strings: List[str],
|
17 |
+
padding_side: str = "left",
|
18 |
+
left_truncate_len: int = None,
|
19 |
+
truncation: bool = False,
|
20 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
21 |
+
|
22 |
+
if self.use_chat_template:
|
23 |
+
try:
|
24 |
+
updated_strings = []
|
25 |
+
for input_string in strings:
|
26 |
+
messages = [
|
27 |
+
{"role": "user", "content": f"{input_string}"},
|
28 |
+
]
|
29 |
+
updated_string = self.tokenizer.apply_chat_template(messages, tokenize=False)
|
30 |
+
updated_strings.append(updated_string)
|
31 |
+
strings = updated_strings[:]
|
32 |
+
except:
|
33 |
+
print(f"failed to update input string with chat template: {self._model}")
|
34 |
+
# encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
|
35 |
+
old_padding_side = self.tokenizer.padding_side
|
36 |
+
self.tokenizer.padding_side = padding_side
|
37 |
+
|
38 |
+
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
|
39 |
+
add_special_tokens = False
|
40 |
+
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
|
41 |
+
add_special_tokens = True
|
42 |
+
|
43 |
+
encoding = self.tokenizer(
|
44 |
+
strings,
|
45 |
+
truncation=truncation,
|
46 |
+
padding="longest",
|
47 |
+
return_tensors="pt",
|
48 |
+
add_special_tokens=add_special_tokens,
|
49 |
+
)
|
50 |
+
if left_truncate_len:
|
51 |
+
encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:]
|
52 |
+
encoding["attention_mask"] = encoding["attention_mask"][
|
53 |
+
:, -left_truncate_len:
|
54 |
+
]
|
55 |
+
self.tokenizer.padding_side = old_padding_side
|
56 |
+
|
57 |
+
return encoding["input_ids"], encoding["attention_mask"]
|