|
import os |
|
|
|
import rwkv_rs |
|
import numpy as np |
|
import huggingface_hub |
|
import tokenizers |
|
|
|
import gradio as gr |
|
|
|
model_path = "./rnn.safetensors" |
|
if not os.path.exists(model_path): |
|
model_path = huggingface_hub.hf_hub_download(repo_id="mrsteyk/RWKV-LM-safetensors", filename="RWKV-4-Pile-7B-Instruct-test1-20230124.rnn.safetensors") |
|
assert model_path is not None |
|
|
|
model = rwkv_rs.Rwkv(model_path) |
|
tokenizer = tokenizers.Tokenizer.from_pretrained("EleutherAI/gpt-neox-20b") |
|
|
|
GT = [ |
|
gr.Button.update(visible=False), |
|
gr.Button.update(visible=True), |
|
] |
|
GF = [ |
|
gr.Button.update(visible=True), |
|
gr.Button.update(visible=False), |
|
] |
|
|
|
def complete_fn(inpt, max_tokens, min_tokens, alpha_f, alpha_p): |
|
try: |
|
state = rwkv_rs.State(model) |
|
text = inpt |
|
counts = [0]*tokenizer.get_vocab_size() |
|
tokens = tokenizer.encode(inpt).ids |
|
|
|
|
|
for i in range(len(tokens) - 1): |
|
model.forward_token_preproc(tokens[i], state) |
|
yield (tokenizer.decode(tokens[:i + 1]), None) |
|
logits = model.forward_token(tokens, state) |
|
yield (text, None) |
|
max_tokens = int(max_tokens) |
|
for i in range(max_tokens): |
|
if i < min_tokens: |
|
logits[0] = -100 |
|
for i in range(len(counts)): |
|
logits[i] -= (counts[i]* alpha_f) + (float(counts[i] > 0) * alpha_p) |
|
token = np.argmax(logits) |
|
counts[token] += 1 |
|
if token == 0: |
|
break |
|
if i == max_tokens - 1: |
|
break |
|
tokens += [token] |
|
text = tokenizer.decode(tokens) |
|
yield (text, None) |
|
logits = model.forward_token(token, state) |
|
yield (text, None) |
|
except Exception as e: |
|
print(e) |
|
yield ("Error...", gr.Text.update(value=str(e), visible=True)) |
|
|
|
|
|
|
|
def generator_wrap(l, fn): |
|
def wrap(*args): |
|
last_i = list([None] * l) |
|
try: |
|
for i in fn(*args): |
|
last_i = list(i) |
|
yield last_i + GT |
|
finally: |
|
yield last_i + GF |
|
return wrap |
|
|
|
|
|
with gr.Blocks() as app: |
|
gr.Markdown(f"Running on `{model_path}`") |
|
error_box = gr.Text(label="Error", visible=False) |
|
|
|
with gr.Tab("Complete"): |
|
with gr.Row(): |
|
inpt = gr.TextArea(label="Input") |
|
out = gr.TextArea(label="Output") |
|
complete = gr.Button("Complete", variant="primary") |
|
c_stop = gr.Button("Stop", variant="stop", visible=False) |
|
with gr.Tab("Insert (WIP)"): |
|
gr.Markdown("WIP, use `<|INSERT|>` to indicate a place to replace") |
|
with gr.Row(): |
|
inpt_i = gr.TextArea(label="Input") |
|
out_i = gr.TextArea(label="Output") |
|
insert = gr.Button("Insert") |
|
|
|
with gr.Column(): |
|
max_tokens = gr.Slider(label="Max Tokens", minimum=1, maximum=4096, step=1, value=767) |
|
min_tokens = gr.Slider(label="Min Tokens", minimum=0, maximum=4096, step=1) |
|
alpha_f = gr.Slider(label="Alpha Frequency", minimum=0, maximum=100, step=0.01) |
|
alpha_p = gr.Slider(label="Alpha Presence", minimum=0, maximum=100, step=0.01) |
|
|
|
G = [complete, c_stop] |
|
|
|
c = complete.click(generator_wrap(2, complete_fn), [inpt, max_tokens, min_tokens, alpha_f, alpha_p], [out, error_box] + G) |
|
c_stop.click(lambda: (complete.update(visible=True), c_stop.update(visible=False)), inputs=None, outputs=[complete, c_stop], cancels=[c], queue=False) |
|
|
|
app.queue() |
|
app.launch() |