|
|
|
import torch |
|
from transformers.generation.logits_process import LogitsProcessor |
|
from transformers.models.llama.modeling_llama import LlamaForCausalLM |
|
from transformers import AutoTokenizer |
|
import re |
|
|
|
class FrequencyPenaltyLogitsProcessor(LogitsProcessor): |
|
def __init__(self, penalty: float, penalty_dialog: torch.LongTensor, input_length: int): |
|
if not isinstance(penalty, float) or not (penalty > 0): |
|
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}") |
|
|
|
self.penalty = penalty |
|
self.input_length = input_length |
|
self.penalty_dialog = penalty_dialog |
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: |
|
new_scores = [] |
|
if self.penalty == 0.0: |
|
return scores |
|
for input_, score in zip(input_ids, scores): |
|
generated_tokens = torch.cat((self.penalty_dialog, input_[self.input_length:]), dim=-1) |
|
token_frequency = torch.bincount(generated_tokens, minlength=scores.size(-1)).to(scores.device) |
|
new_scores.append(score - self.penalty * token_frequency) |
|
|
|
return torch.stack(new_scores).float() |
|
|
|
|
|
|
|
class LlamaForConditionalGeneration(LlamaForCausalLM): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
def generate(self, **kwargs): |
|
history_penalty = kwargs.pop("history_penalty", 0.0) |
|
penalty_turns = kwargs.pop("penalty_turns", 0) |
|
messages = kwargs.pop("messages", []) |
|
|
|
if history_penalty != 0.0 and penalty_turns >= 0: |
|
input_ids = kwargs.get("input_ids", torch.tensor([[]])) |
|
input_length = input_ids.size(-1) |
|
|
|
dialogs = [] |
|
for i in range(len(messages)): |
|
message = messages[i] |
|
if message['role'] == 'assistant': |
|
dialogs.append(message['content']) |
|
|
|
penalty_dialog = [] |
|
for i in range(penalty_turns, 0, -1): |
|
if i <= len(dialogs): |
|
dialog = dialogs[-i].replace("("," ").replace(")"," ").replace("("," ").replace(")"," ") |
|
penalty_dialog.append(dialog) |
|
|
|
model_id = "Collective-Ai/collective-v0.1-chinese-roleplay-8b" |
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
penalty_token = torch.LongTensor(tokenizer.encode(' '.join(penalty_dialog))).to(input_ids.device) |
|
|
|
logits_processor = [] |
|
logits_processor.append(FrequencyPenaltyLogitsProcessor(penalty=history_penalty, penalty_dialog=penalty_token, input_length=input_length)) |
|
result = super().generate(logits_processor = logits_processor, **kwargs) |
|
else: |
|
result = super().generate(**kwargs) |
|
|
|
return result |
|
|
|
|