import gc import time import uuid from threading import Thread from types import MethodType from typing import Iterable, Dict, Any import torch from transformers import ( TextIteratorStreamer, PreTrainedModel, PreTrainedTokenizer, ) from api.generation.qwen import check_is_qwen from api.generation.utils import ( prepare_logits_processor, is_partial_stop, apply_stopping_strings, ) @torch.inference_mode() def generate_stream( model: PreTrainedModel, tokenizer: PreTrainedTokenizer, params: Dict[str, Any], ): # Read parameters input_ids = params.get("inputs") prompt = params.get("prompt") model_name = params.get("model", "llm") temperature = float(params.get("temperature", 1.0)) repetition_penalty = float(params.get("repetition_penalty", 1.0)) top_p = float(params.get("top_p", 1.0)) top_k = int(params.get("top_k", -1)) # -1 means disable max_new_tokens = int(params.get("max_tokens", 256)) logprobs = params.get("logprobs") echo = bool(params.get("echo", True)) stop_str = params.get("stop") stop_token_ids = params.get("stop_token_ids") or [] if tokenizer.eos_token_id not in stop_token_ids: stop_token_ids.append(tokenizer.eos_token_id) logits_processor = prepare_logits_processor( temperature, repetition_penalty, top_p, top_k ) output_ids = list(input_ids) input_echo_len = len(input_ids) device = model.device if model.config.is_encoder_decoder: encoder_output = model.encoder( input_ids=torch.as_tensor([input_ids], device=device) )[0] start_ids = torch.as_tensor( [[model.generation_config.decoder_start_token_id]], dtype=torch.int64, device=device, ) else: start_ids = torch.as_tensor([input_ids], device=device) past_key_values, sent_interrupt = None, False token_logprobs = [None] # The first token has no logprobs. completion_id: str = f"cmpl-{str(uuid.uuid4())}" created: int = int(time.time()) previous_text = "" for i in range(max_new_tokens): if i == 0: # prefill if model.config.is_encoder_decoder: out = model.decoder( input_ids=start_ids, encoder_hidden_states=encoder_output, use_cache=True, ) logits = model.lm_head(out[0]) else: out = model(torch.as_tensor([input_ids], device=device), use_cache=True) logits = out.logits past_key_values = out.past_key_values if logprobs is not None: # Prefull logprobs for the prompt. shift_input_ids = start_ids[..., 1:].contiguous() shift_logits = logits[..., :-1, :].contiguous() shift_logits = torch.log_softmax(shift_logits, dim=-1).tolist() for label_id, logit in zip( shift_input_ids[0].tolist(), shift_logits[0] ): token_logprobs.append(logit[label_id]) else: # decoding if model.config.is_encoder_decoder: out = model.decoder( input_ids=torch.as_tensor( [output_ids if sent_interrupt else [token]], device=device ), encoder_hidden_states=encoder_output, use_cache=True, past_key_values=None if sent_interrupt else past_key_values, ) sent_interrupt = False logits = model.lm_head(out[0]) else: out = model( input_ids=torch.as_tensor( [output_ids if sent_interrupt else [token]], device=device ), use_cache=True, past_key_values=None if sent_interrupt else past_key_values, ) sent_interrupt = False logits = out.logits past_key_values = out.past_key_values if logits_processor: if repetition_penalty > 1.0: tmp_output_ids = torch.as_tensor([output_ids], device=logits.device) else: tmp_output_ids = None last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0] else: last_token_logits = logits[0, -1, :] if device == "mps": # Switch to CPU by avoiding some bugs in mps backend. last_token_logits = last_token_logits.float().to("cpu") if temperature < 1e-5 or top_p < 1e-8: # greedy _, indices = torch.topk(last_token_logits, 2) tokens = [int(index) for index in indices.tolist()] else: probs = torch.softmax(last_token_logits, dim=-1) indices = torch.multinomial(probs, num_samples=2) tokens = [int(token) for token in indices.tolist()] token = tokens[0] output_ids.append(token) if logprobs is not None: # Cannot use last_token_logits because logprobs is based on raw logits. token_logprobs.append( torch.log_softmax(logits[0, -1, :], dim=-1)[token].tolist() ) if token in stop_token_ids: stopped = True else: stopped = False # Yield the output tokens if i % 2 == 0 or i == max_new_tokens - 1 or stopped: if echo: tmp_output_ids = output_ids rfind_start = len(prompt) else: tmp_output_ids = output_ids[input_echo_len:] rfind_start = 0 output = tokenizer.decode( tmp_output_ids, skip_special_tokens=False if check_is_qwen(model) else True, # fix for qwen react spaces_between_special_tokens=False, clean_up_tokenization_spaces=True, ) ret_logprobs = None if logprobs is not None: ret_logprobs = { "text_offset": [], "tokens": [ tokenizer.decode(token) for token in ( output_ids if echo else output_ids[input_echo_len:] ) ], "token_logprobs": token_logprobs if echo else token_logprobs[input_echo_len:], "top_logprobs": [{}] * len(token_logprobs if echo else token_logprobs[input_echo_len:]), } # Compute text_offset curr_pos = 0 for text in ret_logprobs["tokens"]: ret_logprobs["text_offset"].append(curr_pos) curr_pos += len(text) partially_stopped, finish_reason = False, None if stop_str: if isinstance(stop_str, str): pos = output.rfind(stop_str, rfind_start) if pos != -1: output = output[:pos] stopped = True else: partially_stopped = is_partial_stop(output, stop_str) elif isinstance(stop_str, Iterable): for each_stop in stop_str: pos = output.rfind(each_stop, rfind_start) if pos != -1: output = output[:pos] stopped = True if each_stop == "Observation:": finish_reason = "function_call" break else: partially_stopped = is_partial_stop(output, each_stop) if partially_stopped: break else: raise ValueError("Invalid stop field type.") # Prevent yielding partial stop sequence if (not partially_stopped) and output and output[-1] != "�": delta_text = output[len(previous_text):] previous_text = output yield { "id": completion_id, "object": "text_completion", "created": created, "model": model_name, "delta": delta_text, "text": output, "logprobs": ret_logprobs, "finish_reason": finish_reason, "usage": { "prompt_tokens": input_echo_len, "completion_tokens": i, "total_tokens": input_echo_len + i, }, } if stopped: break yield { "id": completion_id, "object": "text_completion", "created": created, "model": model_name, "delta": "", "text": output, "logprobs": ret_logprobs, "finish_reason": "stop", "usage": { "prompt_tokens": input_echo_len, "completion_tokens": i, "total_tokens": input_echo_len + i, }, } # Clean del past_key_values, out gc.collect() torch.cuda.empty_cache() @torch.inference_mode() def generate_stream_v2( model: PreTrainedModel, tokenizer: PreTrainedTokenizer, params: Dict[str, Any], ): input_ids = params.get("inputs") functions = params.get("functions") model_name = params.get("model", "llm") temperature = float(params.get("temperature", 1.0)) repetition_penalty = float(params.get("repetition_penalty", 1.0)) top_p = float(params.get("top_p", 1.0)) top_k = int(params.get("top_k", 40)) max_new_tokens = int(params.get("max_tokens", 256)) stop_token_ids = params.get("stop_token_ids") or [] if tokenizer.eos_token_id not in stop_token_ids: stop_token_ids.append(tokenizer.eos_token_id) stop_strings = params.get("stop", []) input_echo_len = len(input_ids) device = model.device generation_kwargs = dict( input_ids=torch.tensor([input_ids], device=device), do_sample=True, temperature=temperature, top_p=top_p, top_k=top_k, max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty, pad_token_id=tokenizer.pad_token_id, ) if temperature <= 1e-5: generation_kwargs["do_sample"] = False generation_kwargs.pop("top_k") streamer = TextIteratorStreamer( tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True ) generation_kwargs["streamer"] = streamer if "GenerationMixin" not in str(model.generate.__func__): model.generate = MethodType(PreTrainedModel.generate, model) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() generated_text, func_call_found = "", False completion_id: str = f"cmpl-{str(uuid.uuid4())}" created: int = int(time.time()) previous_text = "" for i, new_text in enumerate(streamer): generated_text += new_text if functions: _, func_call_found = apply_stopping_strings(generated_text, ["Observation:"]) generated_text, stop_found = apply_stopping_strings(generated_text, stop_strings) if generated_text and generated_text[-1] != "�": delta_text = generated_text[len(previous_text):] previous_text = generated_text yield { "id": completion_id, "object": "text_completion", "created": created, "model": model_name, "delta": delta_text, "text": generated_text, "logprobs": None, "finish_reason": "function_call" if func_call_found else None, "usage": { "prompt_tokens": input_echo_len, "completion_tokens": i, "total_tokens": input_echo_len + i, }, } if stop_found: break yield { "id": completion_id, "object": "text_completion", "created": created, "model": model_name, "delta": "", "text": generated_text, "logprobs": None, "finish_reason": "stop", "usage": { "prompt_tokens": input_echo_len, "completion_tokens": i, "total_tokens": input_echo_len + i, }, }