File size: 7,764 Bytes
af9251e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
from abc import ABC
from langchain.llms.base import LLM
import random
import torch
import transformers
from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
from typing import Optional, List, Dict, Any
from models.loader import LoaderCheckPoint
from models.base import (BaseAnswer,
AnswerResult)
class InvalidScoreLogitsProcessor(LogitsProcessor):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if torch.isnan(scores).any() or torch.isinf(scores).any():
scores.zero_()
scores[..., 5] = 5e4
return scores
class LLamaLLM(BaseAnswer, LLM, ABC):
checkPoint: LoaderCheckPoint = None
# history = []
history_len: int = 3
max_new_tokens: int = 500
num_beams: int = 1
temperature: float = 0.5
top_p: float = 0.4
top_k: int = 10
repetition_penalty: float = 1.2
encoder_repetition_penalty: int = 1
min_length: int = 0
logits_processor: LogitsProcessorList = None
stopping_criteria: Optional[StoppingCriteriaList] = None
eos_token_id: Optional[int] = [2]
state: object = {'max_new_tokens': 50,
'seed': 1,
'temperature': 0, 'top_p': 0.1,
'top_k': 40, 'typical_p': 1,
'repetition_penalty': 1.2,
'encoder_repetition_penalty': 1,
'no_repeat_ngram_size': 0,
'min_length': 0,
'penalty_alpha': 0,
'num_beams': 1,
'length_penalty': 1,
'early_stopping': False, 'add_bos_token': True, 'ban_eos_token': False,
'truncation_length': 2048, 'custom_stopping_strings': '',
'cpu_memory': 0, 'auto_devices': False, 'disk': False, 'cpu': False, 'bf16': False,
'load_in_8bit': False, 'wbits': 'None', 'groupsize': 'None', 'model_type': 'None',
'pre_layer': 0, 'gpu_memory_0': 0}
def __init__(self, checkPoint: LoaderCheckPoint = None):
super().__init__()
self.checkPoint = checkPoint
@property
def _llm_type(self) -> str:
return "LLamaLLM"
@property
def _check_point(self) -> LoaderCheckPoint:
return self.checkPoint
def encode(self, prompt, add_special_tokens=True, add_bos_token=True, truncation_length=None):
input_ids = self.checkPoint.tokenizer.encode(str(prompt), return_tensors='pt',
add_special_tokens=add_special_tokens)
# This is a hack for making replies more creative.
if not add_bos_token and input_ids[0][0] == self.checkPoint.tokenizer.bos_token_id:
input_ids = input_ids[:, 1:]
# Llama adds this extra token when the first character is '\n', and this
# compromises the stopping criteria, so we just remove it
if type(self.checkPoint.tokenizer) is transformers.LlamaTokenizer and input_ids[0][0] == 29871:
input_ids = input_ids[:, 1:]
# Handling truncation
if truncation_length is not None:
input_ids = input_ids[:, -truncation_length:]
return input_ids.cuda()
def decode(self, output_ids):
reply = self.checkPoint.tokenizer.decode(output_ids, skip_special_tokens=True)
return reply
# 将历史对话数组转换为文本格式
def history_to_text(self, query, history):
"""
历史对话软提示
这段代码首先定义了一个名为 history_to_text 的函数,用于将 self.history
数组转换为所需的文本格式。然后,我们将格式化后的历史文本
再用 self.encode 将其转换为向量表示。最后,将历史对话向量与当前输入的对话向量拼接在一起。
:return:
"""
formatted_history = ''
history = history[-self.history_len:] if self.history_len > 0 else []
if len(history) > 0:
for i, (old_query, response) in enumerate(history):
formatted_history += "### Human:{}\n### Assistant:{}\n".format(old_query, response)
formatted_history += "### Human:{}\n### Assistant:".format(query)
return formatted_history
def prepare_inputs_for_generation(self,
input_ids: torch.LongTensor):
"""
预生成注意力掩码和 输入序列中每个位置的索引的张量
# TODO 没有思路
:return:
"""
mask_positions = torch.zeros((1, input_ids.shape[1]), dtype=input_ids.dtype).to(self.checkPoint.model.device)
attention_mask = self.get_masks(input_ids, input_ids.device)
position_ids = self.get_position_ids(
input_ids,
device=input_ids.device,
mask_positions=mask_positions
)
return input_ids, position_ids, attention_mask
@property
def _history_len(self) -> int:
return self.history_len
def set_history_len(self, history_len: int = 10) -> None:
self.history_len = history_len
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
print(f"__call:{prompt}")
if self.logits_processor is None:
self.logits_processor = LogitsProcessorList()
self.logits_processor.append(InvalidScoreLogitsProcessor())
gen_kwargs = {
"max_new_tokens": self.max_new_tokens,
"num_beams": self.num_beams,
"top_p": self.top_p,
"do_sample": True,
"top_k": self.top_k,
"repetition_penalty": self.repetition_penalty,
"encoder_repetition_penalty": self.encoder_repetition_penalty,
"min_length": self.min_length,
"temperature": self.temperature,
"eos_token_id": self.checkPoint.tokenizer.eos_token_id,
"logits_processor": self.logits_processor}
# 向量转换
input_ids = self.encode(prompt, add_bos_token=self.state['add_bos_token'], truncation_length=self.max_new_tokens)
# input_ids, position_ids, attention_mask = self.prepare_inputs_for_generation(input_ids=filler_input_ids)
gen_kwargs.update({'inputs': input_ids})
# 注意力掩码
# gen_kwargs.update({'attention_mask': attention_mask})
# gen_kwargs.update({'position_ids': position_ids})
if self.stopping_criteria is None:
self.stopping_criteria = transformers.StoppingCriteriaList()
# 观测输出
gen_kwargs.update({'stopping_criteria': self.stopping_criteria})
output_ids = self.checkPoint.model.generate(**gen_kwargs)
new_tokens = len(output_ids[0]) - len(input_ids[0])
reply = self.decode(output_ids[0][-new_tokens:])
print(f"response:{reply}")
print(f"+++++++++++++++++++++++++++++++++++")
return reply
def generatorAnswer(self, prompt: str,
history: List[List[str]] = [],
streaming: bool = False):
# TODO 需要实现chat对话模块和注意力模型,目前_call为langchain的LLM拓展的api,默认为无提示词模式,如果需要操作注意力模型,可以参考chat_glm的实现
softprompt = self.history_to_text(prompt,history=history)
response = self._call(prompt=softprompt, stop=['\n###'])
answer_result = AnswerResult()
answer_result.history = history + [[prompt, response]]
answer_result.llm_output = {"answer": response}
yield answer_result
|