File size: 4,385 Bytes
6dc0c9c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import gc
from threading import Thread
from typing import Iterable
import torch
import transformers
from transformers import TextIteratorStreamer, GenerationConfig
from fastchat.utils import is_partial_stop
@torch.inference_mode()
def generate_stream_yuan2(
model,
tokenizer,
params,
device,
context_len=2048,
stream_interval=2,
judge_sent_end=False,
):
prompt = params["prompt"]
len_prompt = len(prompt)
temperature = float(params.get("temperature", 1))
repetition_penalty = float(params.get("repetition_penalty", 1.0))
top_p = float(params.get("top_p", 0))
top_k = int(params.get("top_k", 1)) # -1 means disable
max_new_tokens = int(params.get("max_new_tokens", 512))
stop_str = params.get("stop", "<eod>")
echo = bool(params.get("echo", True))
stop_token_ids = params.get("stop_token_ids", None) or []
stop_token_ids.append(tokenizer("<eod>")["input_ids"][0])
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
max_src_len = context_len - max_new_tokens - 8
input_ids = input_ids[-max_src_len:] # truncate from the left
attention_mask = attention_mask[-max_src_len:] # truncate from the left
input_echo_len = len(input_ids)
decode_config = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, **decode_config)
generation_config = GenerationConfig(
max_new_tokens=max_new_tokens,
do_sample=temperature >= 1.2,
temperature=temperature,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=10,
top_p=top_p,
top_k=top_k,
)
generation_kwargs = dict(
inputs=input_ids,
attention_mask=attention_mask,
streamer=streamer,
generation_config=generation_config,
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
if echo:
# means keep the prompt
output = prompt
else:
output = ""
for i, new_text in enumerate(streamer):
output += new_text
if i % stream_interval == 0:
if echo:
rfind_start = len_prompt
else:
rfind_start = 0
partially_stopped = False
if stop_str:
if isinstance(stop_str, str):
pos = output.rfind(stop_str, rfind_start)
if pos != -1:
output = output[:pos]
else:
partially_stopped = is_partial_stop(output, stop_str)
elif isinstance(stop_str, Iterable):
for each_stop in stop_str:
pos = output.rfind(each_stop, rfind_start)
if pos != -1:
output = output[:pos]
break
else:
partially_stopped = is_partial_stop(output, each_stop)
if partially_stopped:
break
else:
raise ValueError("Invalid stop field type.")
# prevent yielding partial stop sequence
if not partially_stopped:
yield {
"text": output,
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": i,
"total_tokens": input_echo_len + i,
},
"finish_reason": None,
}
output = output.strip()
# finish stream event, which contains finish reason
if i == max_new_tokens - 1:
finish_reason = "length"
elif partially_stopped:
finish_reason = None
else:
finish_reason = "stop"
yield {
"text": output,
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": i,
"total_tokens": input_echo_len + i,
},
"finish_reason": finish_reason,
}
# clean
gc.collect()
torch.cuda.empty_cache()
if device == "xpu":
torch.xpu.empty_cache()
if device == "npu":
torch.npu.empty_cache()
|