import gc from threading import Thread from typing import Iterable import torch import transformers from transformers import TextIteratorStreamer, GenerationConfig from fastchat.utils import is_partial_stop @torch.inference_mode() def generate_stream_falcon( model, tokenizer, params, device, context_len=2048, stream_interval=2, judge_sent_end=False, ): prompt = params["prompt"] len_prompt = len(prompt) temperature = float(params.get("temperature", 1.0)) repetition_penalty = float(params.get("repetition_penalty", 1.0)) top_p = float(params.get("top_p", 1.0)) top_k = int(params.get("top_k", 50)) # -1 means disable max_new_tokens = int(params.get("max_new_tokens", 256)) stop_str = params.get("stop", None) echo = bool(params.get("echo", True)) stop_token_ids = params.get("stop_token_ids", None) or [] stop_token_ids.append(tokenizer.eos_token_id) inputs = tokenizer(prompt, return_tensors="pt").to(model.device) input_ids = inputs["input_ids"] attention_mask = inputs["attention_mask"] max_src_len = context_len - max_new_tokens - 8 input_ids = input_ids[-max_src_len:] # truncate from the left attention_mask = attention_mask[-max_src_len:] # truncate from the left input_echo_len = len(input_ids) decode_config = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True) streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, **decode_config) generation_config = GenerationConfig( max_new_tokens=max_new_tokens, do_sample=temperature >= 1e-5, temperature=temperature, repetition_penalty=repetition_penalty, no_repeat_ngram_size=10, top_p=top_p, top_k=top_k, eos_token_id=stop_token_ids, ) generation_kwargs = dict( inputs=input_ids, attention_mask=attention_mask, streamer=streamer, generation_config=generation_config, ) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() if echo: # means keep the prompt output = prompt else: output = "" for i, new_text in enumerate(streamer): output += new_text if i % stream_interval == 0: if echo: rfind_start = len_prompt else: rfind_start = 0 partially_stopped = False if stop_str: if isinstance(stop_str, str): pos = output.rfind(stop_str, rfind_start) if pos != -1: output = output[:pos] else: partially_stopped = is_partial_stop(output, stop_str) elif isinstance(stop_str, Iterable): for each_stop in stop_str: pos = output.rfind(each_stop, rfind_start) if pos != -1: output = output[:pos] break else: partially_stopped = is_partial_stop(output, each_stop) if partially_stopped: break else: raise ValueError("Invalid stop field type.") # prevent yielding partial stop sequence if not partially_stopped: yield { "text": output, "usage": { "prompt_tokens": input_echo_len, "completion_tokens": i, "total_tokens": input_echo_len + i, }, "finish_reason": None, } output = output.strip() # finish stream event, which contains finish reason if i == max_new_tokens - 1: finish_reason = "length" elif partially_stopped: finish_reason = None else: finish_reason = "stop" yield { "text": output, "usage": { "prompt_tokens": input_echo_len, "completion_tokens": i, "total_tokens": input_echo_len + i, }, "finish_reason": finish_reason, } # clean gc.collect() torch.cuda.empty_cache() if device == "xpu": torch.xpu.empty_cache() if device == "npu": torch.npu.empty_cache()