|
import gc |
|
from threading import Thread |
|
from typing import Iterable |
|
|
|
import torch |
|
import transformers |
|
from transformers import TextIteratorStreamer, GenerationConfig |
|
|
|
from src.utils import is_partial_stop |
|
|
|
|
|
@torch.inference_mode() |
|
def generate_stream_falcon( |
|
model, |
|
tokenizer, |
|
params, |
|
device, |
|
context_len=2048, |
|
stream_interval=2, |
|
judge_sent_end=False, |
|
): |
|
prompt = params["prompt"] |
|
len_prompt = len(prompt) |
|
temperature = float(params.get("temperature", 1.0)) |
|
repetition_penalty = float(params.get("repetition_penalty", 1.0)) |
|
top_p = float(params.get("top_p", 1.0)) |
|
top_k = int(params.get("top_k", 50)) |
|
max_new_tokens = int(params.get("max_new_tokens", 256)) |
|
stop_str = params.get("stop", None) |
|
echo = bool(params.get("echo", True)) |
|
stop_token_ids = params.get("stop_token_ids", None) or [] |
|
stop_token_ids.append(tokenizer.eos_token_id) |
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
input_ids = inputs["input_ids"] |
|
attention_mask = inputs["attention_mask"] |
|
|
|
max_src_len = context_len - max_new_tokens - 8 |
|
|
|
input_ids = input_ids[-max_src_len:] |
|
attention_mask = attention_mask[-max_src_len:] |
|
input_echo_len = len(input_ids) |
|
|
|
decode_config = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True) |
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, **decode_config) |
|
|
|
generation_config = GenerationConfig( |
|
max_new_tokens=max_new_tokens, |
|
do_sample=temperature >= 1e-5, |
|
temperature=temperature, |
|
repetition_penalty=repetition_penalty, |
|
no_repeat_ngram_size=10, |
|
top_p=top_p, |
|
top_k=top_k, |
|
eos_token_id=stop_token_ids, |
|
) |
|
|
|
generation_kwargs = dict( |
|
inputs=input_ids, |
|
attention_mask=attention_mask, |
|
streamer=streamer, |
|
generation_config=generation_config, |
|
) |
|
|
|
thread = Thread(target=model.generate, kwargs=generation_kwargs) |
|
thread.start() |
|
|
|
if echo: |
|
|
|
output = prompt |
|
else: |
|
output = "" |
|
|
|
for i, new_text in enumerate(streamer): |
|
output += new_text |
|
if i % stream_interval == 0: |
|
if echo: |
|
rfind_start = len_prompt |
|
else: |
|
rfind_start = 0 |
|
|
|
partially_stopped = False |
|
if stop_str: |
|
if isinstance(stop_str, str): |
|
pos = output.rfind(stop_str, rfind_start) |
|
if pos != -1: |
|
output = output[:pos] |
|
else: |
|
partially_stopped = is_partial_stop(output, stop_str) |
|
elif isinstance(stop_str, Iterable): |
|
for each_stop in stop_str: |
|
pos = output.rfind(each_stop, rfind_start) |
|
if pos != -1: |
|
output = output[:pos] |
|
break |
|
else: |
|
partially_stopped = is_partial_stop(output, each_stop) |
|
if partially_stopped: |
|
break |
|
else: |
|
raise ValueError("Invalid stop field type.") |
|
|
|
|
|
if not partially_stopped: |
|
yield { |
|
"text": output, |
|
"usage": { |
|
"prompt_tokens": input_echo_len, |
|
"completion_tokens": i, |
|
"total_tokens": input_echo_len + i, |
|
}, |
|
"finish_reason": None, |
|
} |
|
output = output.strip() |
|
|
|
|
|
if i == max_new_tokens - 1: |
|
finish_reason = "length" |
|
elif partially_stopped: |
|
finish_reason = None |
|
else: |
|
finish_reason = "stop" |
|
|
|
yield { |
|
"text": output, |
|
"usage": { |
|
"prompt_tokens": input_echo_len, |
|
"completion_tokens": i, |
|
"total_tokens": input_echo_len + i, |
|
}, |
|
"finish_reason": finish_reason, |
|
} |
|
|
|
|
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
if device == "xpu": |
|
torch.xpu.empty_cache() |
|
if device == "npu": |
|
torch.npu.empty_cache() |
|
|