Spaces:
Running
Running
import gc | |
import sys | |
from typing import Dict | |
import torch | |
def generate_stream_exllama( | |
model, | |
tokenizer, | |
params: Dict, | |
device: str, | |
context_len: int, | |
stream_interval: int = 2, | |
judge_sent_end: bool = False, | |
): | |
try: | |
from exllamav2.generator import ExLlamaV2StreamingGenerator, ExLlamaV2Sampler | |
except ImportError as e: | |
print(f"Error: Failed to load Exllamav2. {e}") | |
sys.exit(-1) | |
prompt = params["prompt"] | |
generator = ExLlamaV2StreamingGenerator(model.model, model.cache, tokenizer) | |
settings = ExLlamaV2Sampler.Settings() | |
settings.temperature = float(params.get("temperature", 0.85)) | |
settings.top_k = int(params.get("top_k", 50)) | |
settings.top_p = float(params.get("top_p", 0.8)) | |
settings.token_repetition_penalty = float(params.get("repetition_penalty", 1.15)) | |
settings.disallow_tokens(generator.tokenizer, [generator.tokenizer.eos_token_id]) | |
max_new_tokens = int(params.get("max_new_tokens", 256)) | |
generator.set_stop_conditions(params.get("stop_token_ids", None) or []) | |
echo = bool(params.get("echo", True)) | |
input_ids = generator.tokenizer.encode(prompt) | |
prompt_tokens = input_ids.shape[-1] | |
generator.begin_stream(input_ids, settings) | |
generated_tokens = 0 | |
if echo: | |
output = prompt | |
else: | |
output = "" | |
while True: | |
chunk, eos, _ = generator.stream() | |
output += chunk | |
generated_tokens += 1 | |
if generated_tokens == max_new_tokens: | |
finish_reason = "length" | |
break | |
elif eos: | |
finish_reason = "length" | |
break | |
yield { | |
"text": output, | |
"usage": { | |
"prompt_tokens": prompt_tokens, | |
"completion_tokens": generated_tokens, | |
"total_tokens": prompt_tokens + generated_tokens, | |
}, | |
"finish_reason": None, | |
} | |
yield { | |
"text": output, | |
"usage": { | |
"prompt_tokens": prompt_tokens, | |
"completion_tokens": generated_tokens, | |
"total_tokens": prompt_tokens + generated_tokens, | |
}, | |
"finish_reason": finish_reason, | |
} | |
gc.collect() | |