File size: 2,222 Bytes
6dc0c9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import gc
import sys
from typing import Dict

import torch


def generate_stream_exllama(
    model,
    tokenizer,
    params: Dict,
    device: str,
    context_len: int,
    stream_interval: int = 2,
    judge_sent_end: bool = False,
):
    try:
        from exllamav2.generator import ExLlamaV2StreamingGenerator, ExLlamaV2Sampler
    except ImportError as e:
        print(f"Error: Failed to load Exllamav2. {e}")
        sys.exit(-1)

    prompt = params["prompt"]

    generator = ExLlamaV2StreamingGenerator(model.model, model.cache, tokenizer)
    settings = ExLlamaV2Sampler.Settings()

    settings.temperature = float(params.get("temperature", 0.85))
    settings.top_k = int(params.get("top_k", 50))
    settings.top_p = float(params.get("top_p", 0.8))
    settings.token_repetition_penalty = float(params.get("repetition_penalty", 1.15))
    settings.disallow_tokens(generator.tokenizer, [generator.tokenizer.eos_token_id])

    max_new_tokens = int(params.get("max_new_tokens", 256))

    generator.set_stop_conditions(params.get("stop_token_ids", None) or [])
    echo = bool(params.get("echo", True))

    input_ids = generator.tokenizer.encode(prompt)
    prompt_tokens = input_ids.shape[-1]
    generator.begin_stream(input_ids, settings)

    generated_tokens = 0
    if echo:
        output = prompt
    else:
        output = ""
    while True:
        chunk, eos, _ = generator.stream()
        output += chunk
        generated_tokens += 1
        if generated_tokens == max_new_tokens:
            finish_reason = "length"
            break
        elif eos:
            finish_reason = "length"
            break
        yield {
            "text": output,
            "usage": {
                "prompt_tokens": prompt_tokens,
                "completion_tokens": generated_tokens,
                "total_tokens": prompt_tokens + generated_tokens,
            },
            "finish_reason": None,
        }

    yield {
        "text": output,
        "usage": {
            "prompt_tokens": prompt_tokens,
            "completion_tokens": generated_tokens,
            "total_tokens": prompt_tokens + generated_tokens,
        },
        "finish_reason": finish_reason,
    }
    gc.collect()