File size: 2,222 Bytes
6dc0c9c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import gc
import sys
from typing import Dict
import torch
def generate_stream_exllama(
model,
tokenizer,
params: Dict,
device: str,
context_len: int,
stream_interval: int = 2,
judge_sent_end: bool = False,
):
try:
from exllamav2.generator import ExLlamaV2StreamingGenerator, ExLlamaV2Sampler
except ImportError as e:
print(f"Error: Failed to load Exllamav2. {e}")
sys.exit(-1)
prompt = params["prompt"]
generator = ExLlamaV2StreamingGenerator(model.model, model.cache, tokenizer)
settings = ExLlamaV2Sampler.Settings()
settings.temperature = float(params.get("temperature", 0.85))
settings.top_k = int(params.get("top_k", 50))
settings.top_p = float(params.get("top_p", 0.8))
settings.token_repetition_penalty = float(params.get("repetition_penalty", 1.15))
settings.disallow_tokens(generator.tokenizer, [generator.tokenizer.eos_token_id])
max_new_tokens = int(params.get("max_new_tokens", 256))
generator.set_stop_conditions(params.get("stop_token_ids", None) or [])
echo = bool(params.get("echo", True))
input_ids = generator.tokenizer.encode(prompt)
prompt_tokens = input_ids.shape[-1]
generator.begin_stream(input_ids, settings)
generated_tokens = 0
if echo:
output = prompt
else:
output = ""
while True:
chunk, eos, _ = generator.stream()
output += chunk
generated_tokens += 1
if generated_tokens == max_new_tokens:
finish_reason = "length"
break
elif eos:
finish_reason = "length"
break
yield {
"text": output,
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": generated_tokens,
"total_tokens": prompt_tokens + generated_tokens,
},
"finish_reason": None,
}
yield {
"text": output,
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": generated_tokens,
"total_tokens": prompt_tokens + generated_tokens,
},
"finish_reason": finish_reason,
}
gc.collect()
|