File size: 5,441 Bytes
6a4546d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import copy
import os
from pathlib import Path

import numpy as np
from tokenizers import Tokenizer

import modules.shared as shared
from modules.callbacks import Iteratorize

np.set_printoptions(precision=4, suppress=True, linewidth=200)

os.environ['RWKV_JIT_ON'] = '1'
os.environ["RWKV_CUDA_ON"] = '1' if shared.args.rwkv_cuda_on else '0'  # use CUDA kernel for seq mode (much faster)

from rwkv.model import RWKV
from rwkv.utils import PIPELINE, PIPELINE_ARGS


class RWKVModel:
    def __init__(self):
        pass

    @classmethod
    def from_pretrained(self, path, dtype="fp16", device="cuda"):
        tokenizer_path = Path(f"{path.parent}/20B_tokenizer.json")
        if shared.args.rwkv_strategy is None:
            model = RWKV(model=str(path), strategy=f'{device} {dtype}')
        else:
            model = RWKV(model=str(path), strategy=shared.args.rwkv_strategy)

        pipeline = PIPELINE(model, str(tokenizer_path))
        result = self()
        result.pipeline = pipeline
        result.model = model
        result.cached_context = ""
        result.cached_model_state = None
        result.cached_output_logits = None
        return result

    def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=None, alpha_frequency=0.1, alpha_presence=0.1, token_ban=None, token_stop=None, callback=None):
        args = PIPELINE_ARGS(
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            alpha_frequency=alpha_frequency,  # Frequency Penalty (as in GPT-3)
            alpha_presence=alpha_presence,  # Presence Penalty (as in GPT-3)
            token_ban=token_ban or [0],  # ban the generation of some tokens
            token_stop=token_stop or []
        )

        if self.cached_context != "":
            if context.startswith(self.cached_context):
                context = context[len(self.cached_context):]
            else:
                self.cached_context = ""
                self.cached_model_state = None
                self.cached_output_logits = None

        # out = self.pipeline.generate(context, token_count=token_count, args=args, callback=callback)
        out = self.generate_from_cached_state(context, token_count=token_count, args=args, callback=callback)
        return out

    def generate_with_streaming(self, **kwargs):
        with Iteratorize(self.generate, kwargs, callback=None) as generator:
            reply = ''
            for token in generator:
                reply += token
                yield reply

    # Similar to the PIPELINE.generate, but lets us maintain the cached_model_state
    def generate_from_cached_state(self, ctx="", token_count=20, args=None, callback=None):
        all_tokens = []
        out_str = ''
        occurrence = {}
        state = copy.deepcopy(self.cached_model_state) if self.cached_model_state is not None else None

        # if we ended up with an empty context, just reuse the cached logits
        # this can happen if a user undoes a message and then sends the exact message again
        # in that case the full context ends up being the same as the cached_context, so the remaining context is empty.
        if ctx == "":
            out = self.cached_output_logits

        for i in range(token_count):
            # forward
            tokens = self.pipeline.encode(ctx) if i == 0 else [token]
            while len(tokens) > 0:
                out, state = self.model.forward(tokens[:args.chunk_len], state)
                tokens = tokens[args.chunk_len:]

            # cache the model state after scanning the context
            # we don't cache the state after processing our own generated tokens because
            # the output string might be post-processed arbitrarily. Therefore, what's fed into the model
            # on the next round of chat might be slightly different what what it output on the previous round
            if i == 0:
                self.cached_context += ctx
                self.cached_model_state = copy.deepcopy(state)
                self.cached_output_logits = copy.deepcopy(out)

            # adjust probabilities
            for n in args.token_ban:
                out[n] = -float('inf')

            for n in occurrence:
                out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)

            # sampler
            token = self.pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k)
            if token in args.token_stop:
                break

            all_tokens += [token]
            if token not in occurrence:
                occurrence[token] = 1
            else:
                occurrence[token] += 1

            # output
            tmp = self.pipeline.decode([token])
            if '\ufffd' not in tmp:  # is valid utf-8 string?
                if callback:
                    callback(tmp)

                out_str += tmp

        return out_str


class RWKVTokenizer:
    def __init__(self):
        pass

    @classmethod
    def from_pretrained(self, path):
        tokenizer_path = path / "20B_tokenizer.json"
        tokenizer = Tokenizer.from_file(str(tokenizer_path))
        result = self()
        result.tokenizer = tokenizer
        return result

    def encode(self, prompt):
        return self.tokenizer.encode(prompt).ids

    def decode(self, ids):
        return self.tokenizer.decode(ids)