pingnie commited on
Commit
444730a
1 Parent(s): 07f212a

fix generation bugs

Browse files
src/backend/huggingface_generate_until.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Literal, Optional, Tuple, Union
2
+ import torch
3
+ import transformers
4
+
5
+ from lm_eval.models.huggingface import HFLM
6
+ from lm_eval.api.registry import register_model
7
+
8
+ @register_model('hf-chat')
9
+ class HFLMwithChatTemplate(HFLM):
10
+ def __init__(self, use_chat_template=True, **kwargs):
11
+ super().__init__(**kwargs)
12
+ self.use_chat_template = use_chat_template
13
+
14
+ def tok_batch_encode(
15
+ self,
16
+ strings: List[str],
17
+ padding_side: str = "left",
18
+ left_truncate_len: int = None,
19
+ truncation: bool = False,
20
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
21
+
22
+ if self.use_chat_template:
23
+ try:
24
+ updated_strings = []
25
+ for input_string in strings:
26
+ messages = [
27
+ {"role": "user", "content": f"{input_string}"},
28
+ ]
29
+ updated_string = self.tokenizer.apply_chat_template(messages, tokenize=False)
30
+ updated_strings.append(updated_string)
31
+ strings = updated_strings[:]
32
+ except:
33
+ print(f"failed to update input string with chat template: {self._model}")
34
+ # encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
35
+ old_padding_side = self.tokenizer.padding_side
36
+ self.tokenizer.padding_side = padding_side
37
+
38
+ if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
39
+ add_special_tokens = False
40
+ elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
41
+ add_special_tokens = True
42
+
43
+ encoding = self.tokenizer(
44
+ strings,
45
+ truncation=truncation,
46
+ padding="longest",
47
+ return_tensors="pt",
48
+ add_special_tokens=add_special_tokens,
49
+ )
50
+ if left_truncate_len:
51
+ encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:]
52
+ encoding["attention_mask"] = encoding["attention_mask"][
53
+ :, -left_truncate_len:
54
+ ]
55
+ self.tokenizer.padding_side = old_padding_side
56
+
57
+ return encoding["input_ids"], encoding["attention_mask"]