File size: 5,693 Bytes
2a566c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3cd76cf
2a566c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import gradio as gr
from modeling import global_config, ToyTransformer, AttentionBackend
import torch
from tokenizers import TRIETokenizer
from threading import Thread
import bisect

if torch.cuda.is_available():
    g_device = torch.device('cpu')
else:
    g_device = torch.device('cpu')
global_config['attn_backend'] = AttentionBackend.Naive

g_SEQ_LEN = 1024
g_HIDDEN_SIZE = 768
g_NUM_HEADS = 12
g_NUM_LAYERS = 12
g_DTYPE = torch.float32

g_tokenizer = TRIETokenizer('llama_vocab_pruned_32k.json')
g_model = ToyTransformer(g_tokenizer.get_vocab_size(), g_NUM_LAYERS, g_NUM_HEADS, g_HIDDEN_SIZE, g_SEQ_LEN, g_device, g_DTYPE)

g_model.load_state_dict(torch.load('model.pt', map_location='cpu'))


def generate(model, tokenizer, prompt, temperature, top_p, rep_penalty,
             max_new_tokens=20, total_tokens=None,
             end_tokens=None,
             enable_kv_cache=True):
    model.eval()

    feed_tokens = tokenizer.encode(prompt) if isinstance(prompt, str) else prompt

    all_tokens = feed_tokens.copy()
    if total_tokens is not None:
        max_new_tokens = max(0, total_tokens - len(feed_tokens))

    with torch.no_grad():
        kv_cache = None
        for _ in range(max_new_tokens):
            logits, kv_cache = model.forward(
                torch.tensor([feed_tokens if enable_kv_cache else all_tokens]).to(model.device),
                kv_cache=kv_cache)
            logits = logits[0][-1].cpu()
            if not enable_kv_cache:
                kv_cache = None

            # apply repetition penalty
            logits_rep = torch.gather(logits, 0, torch.tensor(all_tokens))
            logits_rep = torch.where(logits_rep < 0, logits_rep * rep_penalty, logits_rep / rep_penalty)
            logits.scatter_(0, torch.tensor(all_tokens), logits_rep)

            # apply temperature
            logits /= max(temperature, 1e-6)

            probs = torch.softmax(logits, dim=0)

            # apply top-p
            ordered_probs, ordered_indices = torch.sort(probs, descending=True)
            cum_probs = torch.cumsum(ordered_probs, dim=0).tolist()
            top_p_index = bisect.bisect_right(cum_probs, top_p) + 1
            ordered_probs, ordered_indices = ordered_probs[:top_p_index], ordered_indices[:top_p_index]
            sampled_index = ordered_indices[torch.multinomial(ordered_probs, num_samples=1).item()].item()

            all_tokens.append(sampled_index)
            feed_tokens = [sampled_index]

            if end_tokens is not None and sampled_index in end_tokens:
                break

            yield feed_tokens
    return


def predict(user_input, history, max_length, top_p, temperature, rep_penalty, retry):
    if retry and len(history) == 0:
        yield []
        return
    elif retry:
        user_input = history[-1][0]
        history = history[:-1]

    history.append((user_input, ""))

    encoded_inputs = [(g_tokenizer.encode('User:' + h[0]), g_tokenizer.encode('Assistant:' + h[1])) for h in history]
    taken_rounds, taken_rounds_length = [], 0
    while len(taken_rounds) < len(encoded_inputs):
        round_pair = encoded_inputs[len(encoded_inputs) - 1 - len(taken_rounds)]
        if len(round_pair[0]) + len(round_pair[1]) + taken_rounds_length >= g_SEQ_LEN - max_length:
            break
        taken_rounds.append(round_pair)
        taken_rounds_length += len(round_pair[0]) + len(round_pair[1])
    taken_rounds = taken_rounds[::-1]

    input_tokens = g_tokenizer.encode('<s>A chat between User and Assistant.')
    for round_pair in taken_rounds:
        input_tokens += g_tokenizer.encode('\n') + round_pair[0] + g_tokenizer.encode('\n') + round_pair[1]
    # print(taken_rounds, g_tokenizer.decode(input_tokens))
    for response in generate(g_model, g_tokenizer, input_tokens, temperature, top_p, rep_penalty, max_length, end_tokens=g_tokenizer.encode('</s>')):
        history[-1] = (history[-1][0], history[-1][1] + g_tokenizer.decode(response))
        yield history


def main():
    css = '''
        .contain {max-width:50}

        #chatbot {min-height:500px}
    '''

    with gr.Blocks(css=css) as demo:
        gr.HTML('<h1 align="center">ToyTransformer</h1>')

        chatbot = gr.Chatbot(elem_id='chatbot')
        with gr.Column():
            user_input = gr.Textbox(show_label=False, placeholder="Input", lines=1, container=False)
            with gr.Row():
                submitBtn = gr.Button("Send", variant="primary")
                retryBtn = gr.Button("Retry")
                cancelBtn = gr.Button('Undo')
                emptyBtn = gr.Button("Clear")
            with gr.Row():
                max_length = gr.Slider(0, 512, value=200, step=1, label="Max Response Tokens", interactive=True)
                top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top-P", interactive=True)
                temperature = gr.Slider(0, 1, value=0.5, step=0.01, label="Temperature", interactive=True)
                rep_penalty = gr.Slider(1.0, 1.5, value=1.1, step=0.01, label='Repetition Penalty', interactive=True)

        submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, rep_penalty, gr.State(False)],
                        [chatbot], show_progress=False)
        submitBtn.click(lambda: '', [], [user_input], show_progress=False)

        retryBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, rep_penalty, gr.State(True)],
                       [chatbot], show_progress=False)

        cancelBtn.click(lambda m: m[:-1], [chatbot], [chatbot], show_progress=False)

        emptyBtn.click(lambda: [], outputs=[chatbot], show_progress=False)

    demo.queue().launch(share=False, inbrowser=True)


main()