File size: 3,197 Bytes
6dc0c9c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import gc
from threading import Thread
import torch
import transformers
from transformers import (
GenerationConfig,
StoppingCriteria,
StoppingCriteriaList,
TextIteratorStreamer,
)
@torch.inference_mode()
def generate_stream_codet5p(
model,
tokenizer,
params,
device,
context_len=2048,
stream_interval=2,
judge_sent_end=False,
):
prompt = params["prompt"]
temperature = float(params.get("temperature", 1.0))
repetition_penalty = float(params.get("repetition_penalty", 1.0))
top_p = float(params.get("top_p", 1.0))
top_k = int(params.get("top_k", 50)) # -1 means disable
max_new_tokens = int(params.get("max_new_tokens", 1024))
stop_token_ids = params.get("stop_token_ids", None) or []
stop_token_ids.append(tokenizer.eos_token_id)
decode_config = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True)
streamer = TextIteratorStreamer(tokenizer, **decode_config)
encoding = tokenizer(prompt, return_tensors="pt").to(device)
input_ids = encoding.input_ids
encoding["decoder_input_ids"] = encoding["input_ids"].clone()
input_echo_len = len(input_ids)
generation_config = GenerationConfig(
max_new_tokens=max_new_tokens,
do_sample=temperature >= 1e-5,
temperature=temperature,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=10,
top_p=top_p,
top_k=top_k,
eos_token_id=stop_token_ids,
)
class CodeBlockStopper(StoppingCriteria):
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> bool:
# Code-completion is open-end generation.
# We check \n\n to stop at end of a code block.
if list(input_ids[0][-2:]) == [628, 198]:
return True
return False
gen_kwargs = dict(
**encoding,
streamer=streamer,
generation_config=generation_config,
stopping_criteria=StoppingCriteriaList([CodeBlockStopper()]),
)
thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
i = 0
output = ""
for new_text in streamer:
i += 1
output += new_text
if i % stream_interval == 0 or i == max_new_tokens - 1:
yield {
"text": output,
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": i,
"total_tokens": input_echo_len + i,
},
"finish_reason": None,
}
if i >= max_new_tokens:
break
if i >= max_new_tokens:
finish_reason = "length"
else:
finish_reason = "stop"
yield {
"text": output,
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": i,
"total_tokens": input_echo_len + i,
},
"finish_reason": finish_reason,
}
thread.join()
# clean
gc.collect()
torch.cuda.empty_cache()
if device == "xpu":
torch.xpu.empty_cache()
if device == "npu":
torch.npu.empty_cache()
|