Spaces:
Sleeping
Sleeping
File size: 5,694 Bytes
2a566c9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import gradio as gr
from modeling import global_config, ToyTransformer, AttentionBackend
import torch
from tokenizers import TRIETokenizer
from threading import Thread
import bisect
if torch.cuda.is_available():
g_device = torch.device('cpu')
else:
g_device = torch.device('cpu')
global_config['attn_backend'] = AttentionBackend.Naive
g_SEQ_LEN = 1024
g_HIDDEN_SIZE = 768
g_NUM_HEADS = 12
g_NUM_LAYERS = 12
g_DTYPE = torch.float32
g_tokenizer = TRIETokenizer('llama_vocab_pruned_32k.json')
g_model = ToyTransformer(g_tokenizer.get_vocab_size(), g_NUM_LAYERS, g_NUM_HEADS, g_HIDDEN_SIZE, g_SEQ_LEN, g_device, g_DTYPE)
g_model.load_state_dict(torch.load('model.pt', map_location='cpu'))
def generate(model, tokenizer, prompt, temperature, top_p, rep_penalty,
max_new_tokens=20, total_tokens=None,
end_tokens=None,
enable_kv_cache=True):
model.eval()
feed_tokens = tokenizer.encode(prompt) if isinstance(prompt, str) else prompt
all_tokens = feed_tokens.copy()
if total_tokens is not None:
max_new_tokens = max(0, total_tokens - len(feed_tokens))
with torch.no_grad():
kv_cache = None
for _ in range(max_new_tokens):
logits, kv_cache = model.forward(
torch.tensor([feed_tokens if enable_kv_cache else all_tokens]).to(model.device),
kv_cache=kv_cache)
logits = logits[0][-1].cpu()
if not enable_kv_cache:
kv_cache = None
# apply repetition penalty
logits_rep = torch.gather(logits, 0, torch.tensor(all_tokens))
logits_rep = torch.where(logits_rep < 0, logits_rep * rep_penalty, logits_rep / rep_penalty)
logits.scatter_(0, torch.tensor(all_tokens), logits_rep)
# apply temperature
logits /= max(temperature, 1e-6)
probs = torch.softmax(logits, dim=0)
# apply top-p
ordered_probs, ordered_indices = torch.sort(probs, descending=True)
cum_probs = torch.cumsum(ordered_probs, dim=0).tolist()
top_p_index = bisect.bisect_right(cum_probs, top_p) + 1
ordered_probs, ordered_indices = ordered_probs[:top_p_index], ordered_indices[:top_p_index]
sampled_index = ordered_indices[torch.multinomial(ordered_probs, num_samples=1).item()].item()
all_tokens.append(sampled_index)
feed_tokens = [sampled_index]
if end_tokens is not None and sampled_index in end_tokens:
break
yield feed_tokens
return
def predict(user_input, history, max_length, top_p, temperature, rep_penalty, retry):
if retry and len(history) == 0:
yield []
return
elif retry:
user_input = history[-1][0]
history = history[:-1]
history.append((user_input, ""))
encoded_inputs = [(g_tokenizer.encode('User:' + h[0]), g_tokenizer.encode('Assistant:' + h[1])) for h in history]
taken_rounds, taken_rounds_length = [], 0
while len(taken_rounds) < len(encoded_inputs):
round_pair = encoded_inputs[len(encoded_inputs) - 1 - len(taken_rounds)]
if len(round_pair[0]) + len(round_pair[1]) + taken_rounds_length >= g_SEQ_LEN - max_length:
break
taken_rounds.append(round_pair)
taken_rounds_length += len(round_pair[0]) + len(round_pair[1])
taken_rounds = taken_rounds[::-1]
input_tokens = g_tokenizer.encode('<s>A chat between User and Assistant.')
for round_pair in taken_rounds:
input_tokens += g_tokenizer.encode('\n') + round_pair[0] + g_tokenizer.encode('\n') + round_pair[1]
# print(taken_rounds, g_tokenizer.decode(input_tokens))
for response in generate(g_model, g_tokenizer, input_tokens, temperature, top_p, rep_penalty, max_length, end_tokens=g_tokenizer.encode('</s>')):
history[-1] = (history[-1][0], history[-1][1] + g_tokenizer.decode(response))
yield history
def main():
css = '''
.contain {max-width:50}
#chatbot {min-height:500px}
'''
with gr.Blocks(css=css) as demo:
gr.HTML('<h1 align="center">ToyTransformer</h1>')
chatbot = gr.Chatbot(elem_id='chatbot')
with gr.Column():
user_input = gr.Textbox(show_label=False, placeholder="输入", lines=1, container=False)
with gr.Row():
submitBtn = gr.Button("Send", variant="primary")
retryBtn = gr.Button("Retry")
cancelBtn = gr.Button('Undo')
emptyBtn = gr.Button("Clear")
with gr.Row():
max_length = gr.Slider(0, 512, value=200, step=1, label="Max Response Tokens", interactive=True)
top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top-P", interactive=True)
temperature = gr.Slider(0, 1, value=0.5, step=0.01, label="Temperature", interactive=True)
rep_penalty = gr.Slider(1.0, 1.5, value=1.1, step=0.01, label='Repetition Penalty', interactive=True)
submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, rep_penalty, gr.State(False)],
[chatbot], show_progress=False)
submitBtn.click(lambda: '', [], [user_input], show_progress=False)
retryBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, rep_penalty, gr.State(True)],
[chatbot], show_progress=False)
cancelBtn.click(lambda m: m[:-1], [chatbot], [chatbot], show_progress=False)
emptyBtn.click(lambda: [], outputs=[chatbot], show_progress=False)
demo.queue().launch(share=False, inbrowser=True)
main()
|