import gc import sys from typing import Dict import torch def generate_stream_exllama( model, tokenizer, params: Dict, device: str, context_len: int, stream_interval: int = 2, judge_sent_end: bool = False, ): try: from exllamav2.generator import ExLlamaV2StreamingGenerator, ExLlamaV2Sampler except ImportError as e: print(f"Error: Failed to load Exllamav2. {e}") sys.exit(-1) prompt = params["prompt"] generator = ExLlamaV2StreamingGenerator(model.model, model.cache, tokenizer) settings = ExLlamaV2Sampler.Settings() settings.temperature = float(params.get("temperature", 0.85)) settings.top_k = int(params.get("top_k", 50)) settings.top_p = float(params.get("top_p", 0.8)) settings.token_repetition_penalty = float(params.get("repetition_penalty", 1.15)) settings.disallow_tokens(generator.tokenizer, [generator.tokenizer.eos_token_id]) max_new_tokens = int(params.get("max_new_tokens", 256)) generator.set_stop_conditions(params.get("stop_token_ids", None) or []) echo = bool(params.get("echo", True)) input_ids = generator.tokenizer.encode(prompt) prompt_tokens = input_ids.shape[-1] generator.begin_stream(input_ids, settings) generated_tokens = 0 if echo: output = prompt else: output = "" while True: chunk, eos, _ = generator.stream() output += chunk generated_tokens += 1 if generated_tokens == max_new_tokens: finish_reason = "length" break elif eos: finish_reason = "length" break yield { "text": output, "usage": { "prompt_tokens": prompt_tokens, "completion_tokens": generated_tokens, "total_tokens": prompt_tokens + generated_tokens, }, "finish_reason": None, } yield { "text": output, "usage": { "prompt_tokens": prompt_tokens, "completion_tokens": generated_tokens, "total_tokens": prompt_tokens + generated_tokens, }, "finish_reason": finish_reason, } gc.collect()