File size: 8,074 Bytes
d6f2919 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
import torch
import numpy as np
from queue import Queue
from typing import Tuple, List, Union, Iterable
from transformers.utils import logging, add_start_docstrings
from transformers.generation.logits_process import LogitsProcessor, LOGITS_PROCESSOR_INPUTS_DOCSTRING, LogitsProcessorList
def make_context(model, tokenizer,
messages: List[dict],
system: str = "You are a helpful assistant.",
max_new_tokens: int=0,
):
max_new_tokens = max_new_tokens or model.generation_config.max_new_tokens
max_input_length = model.config.model_max_length - max_new_tokens
im_start_id = [tokenizer.im_start_id]
im_end_id = [tokenizer.im_end_id]
nl_tokens = tokenizer.encode("\n")
def _tokenize_str(role, content):
return tokenizer.encode(role, allowed_special=set()) + nl_tokens + tokenizer.encode(content, allowed_special=set())
def _parse_messages(messages):
system, query, history = "", "", []
## system
if messages[0]["role"] == "system":
system = messages[0]["content"]
messages = messages[1:]
## query
assert messages[-1]["role"] == "user"
query = messages[-1]["content"]
messages = messages[:-1]
## history
assert len(messages) % 2 == 0
for i in range(0, len(messages), 2):
assert messages[i]["role"] == "user" and messages[i+1]["role"] == "assistant"
history.append([messages[i]["content"], messages[i+1]["content"]])
return system, query, history
_system, query, history = _parse_messages(messages)
## system
system_text = _system if _system != "" else system
system_tokens = []
if system_text:
system_tokens = im_start_id + _tokenize_str("system", system_text) + im_end_id + nl_tokens
## query
query_tokens = im_start_id + _tokenize_str("user", query) + im_end_id + nl_tokens
## final assistant
final_tokens = im_start_id + tokenizer.encode("assistant", allowed_special=set()) + nl_tokens
## max_history_tokens
max_history_length = max_input_length - len(system_tokens) - len(query_tokens) - len(final_tokens)
## history
context_tokens = []
for turn_query, turn_response in reversed(history):
## query tokens
history_query_tokens = im_start_id + _tokenize_str("user", turn_query) + im_end_id + nl_tokens
## answer tokens
histroy_response_tokens = im_start_id + _tokenize_str("assistant", turn_response) + im_end_id + nl_tokens
## this round tokens
next_context_tokens = history_query_tokens + histroy_response_tokens
## concat
current_context_size = len(next_context_tokens) + len(context_tokens)
if current_context_size < max_history_length:
context_tokens = next_context_tokens + context_tokens
else:
break
input_tokens = system_tokens + context_tokens + query_tokens + final_tokens
return torch.LongTensor([input_tokens]).to(model.device)
class TextIterStreamer:
def __init__(self, tokenizer, skip_prompt=False, skip_special_tokens=False):
self.tokenizer = tokenizer
self.skip_prompt = skip_prompt
self.skip_special_tokens = skip_special_tokens
self.tokens = []
self.text_queue = Queue()
self.next_tokens_are_prompt = True
def put(self, value):
if self.skip_prompt and self.next_tokens_are_prompt:
self.next_tokens_are_prompt = False
else:
if len(value.shape) > 1:
value = value[0]
self.tokens.extend(value.tolist())
self.text_queue.put(
self.tokenizer.decode(self.tokens, skip_special_tokens=self.skip_special_tokens, errors='ignore'))
def end(self):
self.text_queue.put(None)
def __iter__(self):
return self
def __next__(self):
value = self.text_queue.get()
if value is None:
raise StopIteration()
else:
return value
class OutputRepetitionPenaltyLogitsProcessor(LogitsProcessor):
r"""
[`OutputLogitsProcessor`] that prevents the repetition of previous tokens through a penalty. This penalty is applied at
most once per token. Note that, for decoder-only models like most LLMs, the considered tokens include the prompt.
In the original [paper](https://arxiv.org/pdf/1909.05858.pdf), the authors suggest the use of a penalty of around
1.2 to achieve a good balance between truthful generation and lack of repetition. To penalize and reduce
repetition, use `penalty` values above 1.0, where a higher value penalizes more strongly. To reward and encourage
repetition, use `penalty` values between 0.0 and 1.0, where a lower value rewards more strongly.
Args:
penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. Above 1.0 penalizes previously generated
tokens. Between 0.0 and 1.0 rewards previously generated tokens.
"""
def __init__(self, input_length: int,
presence_penalties: float = 1.0,
frequency_penalties: float = 0,
repetition_penalties: float = 0):
if not (repetition_penalties > 0):
raise ValueError(f"`repetition_penalties` has to be a strictly positive float, but is {repetition_penalties}")
if not ( (frequency_penalties >= -2) and (frequency_penalties <= 2) ):
raise ValueError(f"`frequency_penalties` has to be [-2, 2], but is {frequency_penalties}")
if not ( (presence_penalties >= -2) and (presence_penalties <= 2) ):
raise ValueError(f"`presence_penalties` has to be [-2, 2], but is {presence_penalties}")
self.repetition_penalties = repetition_penalties
self.frequency_penalties = frequency_penalties
self.presence_penalties = presence_penalties
self.input_length = input_length
def _get_bin_counts_and_mask(
self,
tokens: torch.Tensor,
vocab_size: int,
num_seqs: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Compute the bin counts for the tokens.
# vocab_size + 1 for padding.
bin_counts = torch.zeros((num_seqs, vocab_size + 1),
dtype=torch.long,
device=tokens.device)
bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
bin_counts = bin_counts[:, :vocab_size]
mask = bin_counts > 0
return bin_counts, mask
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor:
prompt_tokens_tensor = input_ids[:, :self.input_length+1]
output_tokens_tensor = input_ids[:, self.input_length+1:]
num_seqs, vocab_size = logits.shape
_, prompt_mask = self._get_bin_counts_and_mask(
prompt_tokens_tensor, vocab_size, num_seqs)
output_bin_counts, output_mask = self._get_bin_counts_and_mask(
output_tokens_tensor, vocab_size, num_seqs)
repetition_penalties = torch.Tensor([self.repetition_penalties]).to(logits.device)
frequency_penalties = torch.Tensor([self.frequency_penalties]).to(logits.device)
presence_penalties = torch.Tensor([self.presence_penalties]).to(logits.device)
repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
repetition_penalties[~(prompt_mask | output_mask)] = 1.0
logits = torch.where(logits > 0, logits / repetition_penalties,
logits * repetition_penalties)
# We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts
logits -= presence_penalties.unsqueeze_(dim=1) * output_mask
return logits |