ieiei commited on
Commit
632029f
1 Parent(s): e5c3ba4

Create modeling_llama3.py

Browse files
Files changed (1) hide show
  1. modeling_llama3.py +66 -0
modeling_llama3.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ import torch
3
+ from transformers.generation.logits_process import LogitsProcessor
4
+ from transformers.models.llama.modeling_llama import LlamaForCausalLM
5
+ from transformers import AutoTokenizer
6
+ import re
7
+
8
+ class FrequencyPenaltyLogitsProcessor(LogitsProcessor):
9
+ def __init__(self, penalty: float, penalty_dialog: torch.LongTensor, input_length: int):
10
+ if not isinstance(penalty, float) or not (penalty > 0):
11
+ raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
12
+
13
+ self.penalty = penalty
14
+ self.input_length = input_length
15
+ self.penalty_dialog = penalty_dialog
16
+
17
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
18
+ new_scores = []
19
+ if self.penalty == 0.0:
20
+ return scores
21
+ for input_, score in zip(input_ids, scores):
22
+ generated_tokens = torch.cat((self.penalty_dialog, input_[self.input_length:]), dim=-1)
23
+ token_frequency = torch.bincount(generated_tokens, minlength=scores.size(-1)).to(scores.device)
24
+ new_scores.append(score - self.penalty * token_frequency)
25
+
26
+ return torch.stack(new_scores).float()
27
+
28
+
29
+
30
+ class LlamaForConditionalGeneration(LlamaForCausalLM):
31
+ def __init__(self, config):
32
+ super().__init__(config)
33
+
34
+ def generate(self, **kwargs):
35
+ history_penalty = kwargs.pop("history_penalty", 0.0)
36
+ penalty_turns = kwargs.pop("penalty_turns", 0)
37
+ messages = kwargs.pop("messages", [])
38
+
39
+ if history_penalty != 0.0 and penalty_turns >= 0:
40
+ input_ids = kwargs.get("input_ids", torch.tensor([[]]))
41
+ input_length = input_ids.size(-1)
42
+
43
+ dialogs = []
44
+ for i in range(len(messages)):
45
+ message = messages[i]
46
+ if message['role'] == 'assistant':
47
+ dialogs.append(message['content'])
48
+
49
+ penalty_dialog = []
50
+ for i in range(penalty_turns, 0, -1):
51
+ if i <= len(dialogs):
52
+ dialog = dialogs[-i].replace("("," ").replace(")"," ").replace("("," ").replace(")"," ")
53
+ penalty_dialog.append(dialog)
54
+
55
+ model_id = "Collective-Ai/collective-v0.1-chinese-roleplay-8b"
56
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
57
+ penalty_token = torch.LongTensor(tokenizer.encode(' '.join(penalty_dialog))).to(input_ids.device)
58
+
59
+ logits_processor = []
60
+ logits_processor.append(FrequencyPenaltyLogitsProcessor(penalty=history_penalty, penalty_dialog=penalty_token, input_length=input_length))
61
+ result = super().generate(logits_processor = logits_processor, **kwargs)
62
+ else:
63
+ result = super().generate(**kwargs)
64
+
65
+ return result
66
+