File size: 6,872 Bytes
bfcf71e
 
44b21e0
bfcf71e
 
 
2b37d53
bfcf71e
 
 
 
44b21e0
 
 
 
 
 
 
 
 
 
 
 
 
 
66fa394
4f16778
66fa394
4f16778
66fa394
bfcf71e
 
 
fad10f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bfcf71e
fad10f4
bfcf71e
fad10f4
 
bfcf71e
 
 
 
 
 
 
 
 
 
 
 
fad10f4
bfcf71e
 
 
 
 
 
 
 
 
2b37d53
bfcf71e
 
 
 
 
 
 
66fa394
 
 
bfcf71e
 
 
 
 
 
66fa394
bfcf71e
 
 
 
 
 
 
66fa394
bfcf71e
 
 
 
 
 
 
 
 
66fa394
 
 
bfcf71e
 
 
66fa394
bfcf71e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44b21e0
bfcf71e
66fa394
 
 
44b21e0
 
 
bfcf71e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66fa394
 
 
bfcf71e
 
 
 
 
 
 
 
 
 
 
fad10f4
bfcf71e
 
 
 
 
36113cf
ea06d0c
a7b08fa
 
ea06d0c
36113cf
 
39e4f73
bfcf71e
a7b08fa
 
 
bfcf71e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
# Gradio demo of streaming generation of multiple LLM response pairs.

import spaces
import logging
import time
import html
import os
import numpy as np
import gradio as gr
import util

import huggingface_hub
import torch
import transformers
import accelerate

# For setting `requirements.txt`.
print('Dependency versions:')
print(f'huggingface_hub=={huggingface_hub.__version__}')
print(f'numpy=={np.__version__}')
print(f'torch=={torch.__version__}')
print(f'transformers=={transformers.__version__}')
print(f'accelerate=={accelerate.__version__}')
print()

# Initialize logging.
logging.basicConfig(format='%(levelname)s:%(name)s: %(message)s', level=logging.WARNING)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

# gr.DataFrame is currently bugged for updating values,
# so we must use raw HTML.
# https://github.com/gradio-app/gradio/issues/8160
css = '''
.response-table {
    width: 100%;
    table-layout: fixed;
}
.response-table th, .response-table td {
    width: 50%;
}
.response-table td {
    font-family: monospace;
    white-space: pre-wrap;
    text-align: left;
    vertical-align: top;
}
.highlight {
    background-color: #90FF90;
}
'''

def make_html_table(headers, data):
    rows = ['<tr>' + ''.join(f'<th>{h}</th>' for h in headers) + '</tr>\n']
    for row in data:
        rows.append('<tr>' + ''.join(f'<td>{v}</td>' for v in row) + '</tr>\n')
    return '<table class="response-table">\n' + ''.join(rows) + '</table>\n'

def highlight_prefix(tokens, prefix_len):
    prefix_tokens = tokens[:prefix_len]

    s = tokenizer.decode(tokens, skip_special_tokens=True)
    prefix_s = tokenizer.decode(prefix_tokens, skip_special_tokens=True)

    s_lcp_len = util.longest_common_prefix(np.array(list(s)), np.array(list(prefix_s)))

    prefix_html = html.escape(s[:s_lcp_len])
    suffix_html = html.escape(s[s_lcp_len:])

    return f'<span class="highlight">{prefix_html}</span>{suffix_html}'

def format_response_pair(tokens_a, tokens_b):
    # This is slightly convoluted, so as to properly handle grapheme clusters that span token boundaries.
    token_lcp_len = util.longest_common_prefix(tokens_a, tokens_b)
    return highlight_prefix(tokens_a, token_lcp_len), highlight_prefix(tokens_b, token_lcp_len)

HEADERS = ['Response (Left)', 'Response (Right)']
repo_id = "Qwen/Qwen2-0.5B-Instruct"

DRY_RUN = os.environ.get('DRY_RUN') == '1'

if DRY_RUN:
    from load import load_tokenizer

    tokenizer = load_tokenizer(repo_id)

    def fn(max_tokens, num_responses, prompt_x, prompt_y):
        logger.info('Starting generation...')
        generation_start = time.perf_counter()

        rows = [['']*2 for i in range(num_responses)]
        
        yield make_html_table(HEADERS, rows)

        for j in range(num_responses):
            response_raw_a = f'Sure!\n\n1 2 3 4 & 5.'
            response_raw_b = f'Sure!\n\n1 2 3 4 5 &\n\n\n\n6.'

            response_tok_a = tokenizer.encode(response_raw_a, add_special_tokens=False, return_tensors='np')[0]
            response_tok_b = tokenizer.encode(response_raw_b, add_special_tokens=False, return_tensors='np')[0]

            steps = 1 + max(len(response_tok_a), len(response_tok_b))

            for i in range(steps):
                time.sleep(0.01)
                prefix_tok_a = response_tok_a[:i]
                prefix_tok_b = response_tok_b[:i]

                content_a, content_b = format_response_pair(prefix_tok_a, prefix_tok_b)

                rows[j][0] = content_a
                rows[j][1] = content_b

                yield make_html_table(HEADERS, rows)

        generation_end = time.perf_counter()
        logger.info(f'Generation took {(generation_end - generation_start):.3f} s')
else:
    from load import load_model
    import algorithms
    #algorithms.logger.setLevel(logging.DEBUG)

    model, tokenizer = load_model(repo_id)

    def make_chat(system_msg, prompt):
        chat = [
                {
                    'role': 'system',
                    'content': system_msg,
                },
                {
                    'role': 'user',
                    'content': prompt,
                },
        ]
        return chat

    @spaces.GPU
    def fn(max_tokens, num_responses, prompt_x, prompt_y):
        logger.info('Starting generation...')
        generation_start = time.perf_counter()

        # Is this necessary with ZeroGPU?
        torch.use_deterministic_algorithms(True)

        rows = [['']*2 for i in range(num_responses)]
        yield make_html_table(HEADERS, rows)

        for j in range(num_responses):
            system_msg = "You are a helpful assistant."

            chat_x = make_chat(system_msg, prompt_x)
            chat_y = make_chat(system_msg, prompt_y)

            gen = algorithms.apoc_streaming(
                model,
                model,
                tokenizer,
                chat_x,
                chat_y,
                max_tokens=max_tokens,
            )
            response_a_L = []
            response_b_L = []
            for token_a, token_b in gen:
                dirty = False
                if token_a is not None:
                    response_a_L.append(token_a)
                    dirty = True
                if token_b is not None:
                    response_b_L.append(token_b)
                    dirty = True
                
                if dirty:
                    content_a, content_b = format_response_pair(np.array(response_a_L), np.array(response_b_L))

                    rows[j][0] = content_a
                    rows[j][1] = content_b
                
                yield make_html_table(HEADERS, rows)

        generation_end = time.perf_counter()
        logger.info(f'Generation took {(generation_end - generation_start):.3f} s')

demo = gr.Interface(
    fn=fn,
    inputs=[
        gr.Slider(1, 512, label='Max Tokens', value=48),
        gr.Slider(1, 16, step=1, label='Num Responses', value=8),
        gr.Textbox(label='Prompt (Left)'),
        gr.Textbox(label='Prompt (Right)'),
        ],
    outputs=[
        gr.HTML(),
        ],
    css=css,
    title='All-Prefix-Optimal Coupling',
    description='Try similar prompts to see the effect of the difference between them. '
        f'Model: `{repo_id}`.'
        ,
    examples=[
        [48, 8, 'Count from 1 to 5.', 'Count from 1 to 6.'],

        # This would be a good example, but Qwen2-0.5B occasionally goes off-color.
        #[48, 8, 'Tell me a joke.', 'Tell me a funny joke.'],

        [48, 8, 'Calculate 3 + 4', 'Calculate 3 + 5'],
        [48, 8, "What's the capital of Canada?", "What's the capital of France?"],
        [48, 8, "1 3 5. What number is next?", "4 5 6. What number is next?"],
    ],
    # In HuggingFace Spaces, this defaults to true, which makes startup
    # take a very long time.
    cache_examples=False,
    )

demo.launch()