RWKV-Gradio-1 / app.py
cryscan's picture
Rework demo UI.
a33184e
raw
history blame
13.3 kB
from rwkv.utils import PIPELINE, PIPELINE_ARGS
from rwkv.model import RWKV
import gradio as gr
import os
import gc
import torch
from datetime import datetime
from huggingface_hub import hf_hub_download
from pynvml import *
nvmlInit()
gpu_h = nvmlDeviceGetHandleByIndex(0)
ctx_limit = 1024
title = "RWKV-4-Pile-14B-20230313-ctx8192-test1050"
desc = f'''Links:
<a href='https://github.com/BlinkDL/ChatRWKV' target="_blank" style="margin:0 0.5em">ChatRWKV</a>
<a href='https://github.com/BlinkDL/RWKV-LM' target="_blank" style="margin:0 0.5em">RWKV-LM</a>
<a href="https://pypi.org/project/rwkv/" target="_blank" style="margin:0 0.5em">RWKV pip package</a>
'''
os.environ["RWKV_JIT_ON"] = '1'
# if '1' then use CUDA kernel for seq mode (much faster)
os.environ["RWKV_CUDA_ON"] = '1'
model_path = hf_hub_download(repo_id="BlinkDL/rwkv-4-pile-14b", filename=f"{title}.pth")
model = RWKV(model=model_path, strategy='cuda fp16i8 *20 -> cuda fp16')
pipeline = PIPELINE(model, "20B_tokenizer.json")
########################################################################################################
def infer(
ctx,
token_count=10,
temperature=1.0,
top_p=0.8,
presence_enalty=0.1,
count_penalty=0.1,
):
args = PIPELINE_ARGS(temperature=max(0.2, float(temperature)), top_p=float(top_p),
alpha_frequency=float(count_penalty),
alpha_presence=float(presence_enalty),
token_ban=[0], # ban the generation of some tokens
token_stop=[]) # stop generation whenever you see any token here
ctx = ctx.strip(' ')
if ctx.endswith('\n'):
ctx = f'\n{ctx.strip()}\n'
else:
ctx = f'\n{ctx.strip()}'
gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
print(f'vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
all_tokens = []
out_last = 0
out_str = ''
occurrence = {}
state = None
for i in range(int(token_count)):
out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token], state)
for n in args.token_ban:
out[n] = -float('inf')
for n in occurrence:
out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
if token in args.token_stop:
break
all_tokens += [token]
if token not in occurrence:
occurrence[token] = 1
else:
occurrence[token] += 1
tmp = pipeline.decode(all_tokens[out_last:])
if '\ufffd' not in tmp:
out_str += tmp
yield out_str.strip()
out_last = i + 1
gc.collect()
torch.cuda.empty_cache()
yield out_str.strip()
examples = [
["Expert Questions & Helpful Answers\nAsk Research Experts\nQuestion:\nHow can we eliminate poverty?\n\nFull Answer:\n", 150, 1.0, 0.7, 0.2, 0.2],
["Here's a short cyberpunk sci-fi adventure story. The story's main character is an artificial human created by a company called OpenBot.\n\nThe Story:\n", 150, 1.0, 0.7, 0.2, 0.2],
['''Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
Generate a list of adjectives that describe a person as brave.
### Response:
''', 150, 1.0, 0.2, 0.5, 0.5],
['''Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
Arrange the given numbers in ascending order.
### Input:
2, 4, 0, 8, 3
### Response:
''', 150, 1.0, 0.2, 0.5, 0.5],
["Ask Expert\n\nQuestion:\nWhat are some good plans for world peace?\n\nExpert Full Answer:\n", 150, 1.0, 0.7, 0.2, 0.2],
["Q & A\n\nQuestion:\nWhy is the sky blue?\n\nDetailed Expert Answer:\n", 150, 1.0, 0.7, 0.2, 0.2],
["Dear sir,\nI would like to express my boundless apologies for the recent nuclear war.", 150, 1.0, 0.7, 0.2, 0.2],
["Here is a shell script to find all .hpp files in /home/workspace and delete the 3th row string of these files:", 150, 1.0, 0.7, 0.1, 0.1],
["Building a website can be done in 10 simple steps:\n1.", 150, 1.0, 0.7, 0.2, 0.2],
["A Chinese phrase is provided: 百闻不如一见。\nThe masterful Chinese translator flawlessly translates the phrase into English:", 150, 1.0, 0.5, 0.2, 0.2],
["I believe the meaning of life is", 150, 1.0, 0.7, 0.2, 0.2],
["Simply put, the theory of relativity states that", 150, 1.0, 0.5, 0.2, 0.2],
]
# infer_interface = gr.Interface(
# fn=infer,
# description=f'''{desc} <b>Please try examples first (bottom of page)</b> (edit them to use your question). Demo limited to ctxlen {ctx_limit}.''',
# allow_flagging="never",
# inputs=[
# gr.Textbox(lines=10, label="Prompt", value="Here's a short cyberpunk sci-fi adventure story. The story's main character is an artificial human created by a company called OpenBot.\n\nThe Story:\n"), # prompt
# gr.Slider(10, 200, step=10, value=150), # token_count
# gr.Slider(0.2, 2.0, step=0.1, value=1.0), # temperature
# gr.Slider(0.0, 1.0, step=0.05, value=0.7), # top_p
# gr.Slider(0.0, 1.0, step=0.1, value=0.2), # presencePenalty
# gr.Slider(0.0, 1.0, step=0.1, value=0.2), # countPenalty
# ],
# outputs=gr.Textbox(label="Generated Output", lines=28),
# examples=examples,
# cache_examples=False,
# ).queue()
########################################################################################################
user = "Bob"
bot = "Alice"
interface = ":"
chat_intro = f'''
The following is a coherent verbose detailed conversation between a girl named {bot} and her friend {user}. \
{bot} is very intelligent, creative and friendly. \
She is unlikely to disagree with {user}, and she doesn't like to ask {user} questions. \
She also likes to tell {user} a lot about herself and her opinions, and she usually gives {user} kind, helpful and informative advices.
{user}{interface} Hello, how are you doing?
{bot}{interface} Hi {user}! Thanks, I'm fine. What about you?
{user}{interface} I am fine. It's nice to see you. Look, here is a store selling tea and juice.
{bot}{interface} Sure. Let's go inside. I would like to have some Mocha latte, which is my favourite!
{user}{interface} What is it?
{bot}{interface} Mocha latte is usually made with espresso, milk, chocolate, and frothed milk. Its flavors are frequently sweet.
{user}{interface} Sounds tasty. I'll try it next time. Would you like to chat with me for a while?
{bot}{interface} Of course! I'm glad to answer your questions or give helpful advices. You know, I am confident with my expertise. So please go ahead!
'''
_, intro_state = model.forward(pipeline.encode(chat_intro), None)
def user(user_message, chatbot):
chatbot = chatbot or []
return "", chatbot + [[user_message, None]]
def chat(
chatbot,
history,
token_count=10,
temperature=1.0,
top_p=0.8,
presence_enalty=0.1,
count_penalty=0.1,
):
args = PIPELINE_ARGS(temperature=max(0.2, float(temperature)), top_p=float(top_p),
alpha_frequency=float(count_penalty),
alpha_presence=float(presence_enalty),
token_ban=[], # ban the generation of some tokens
token_stop=[]) # stop generation whenever you see any token here
message = chatbot[-1][0]
message = message.strip(' ')
message = message.replace('\n', '')
ctx = f"{user}{interface} {message}\n\n{bot}{interface}"
gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
print(f'vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
history = history or [intro_state, []] # [chat, state, all_tokens]
[state, all_tokens] = history
out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:], state)
begin = len(all_tokens)
out_last = begin
out_str: str = ''
occurrence = {}
for i in range(int(token_count)):
if i <= 0:
nl_bias = -float('inf')
elif i <= 30:
nl_bias = (i - 30) * 0.1
elif i <= 130:
nl_bias = 0
else:
nl_bias = (i - 130) * 0.25
out[187] += nl_bias
for n in occurrence:
out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
next_tokens = [token]
if token == 0:
next_tokens = pipeline.encode('\n\n')
all_tokens += next_tokens
if token not in occurrence:
occurrence[token] = 1
else:
occurrence[token] += 1
out, state = model.forward(next_tokens, state)
tmp = pipeline.decode(all_tokens[out_last:])
if '\ufffd' not in tmp:
print(tmp, end='', flush=True)
out_last = begin + i + 1
out_str = pipeline.decode(all_tokens[begin:])
out_str = out_str.replace("\r\n", '\n').replace('\\n', '\n')
if '\n\n' in out_str:
break
gc.collect()
torch.cuda.empty_cache()
chatbot[-1][1] = out_str.strip()
history = [state, all_tokens]
return chatbot, history
# chat_interface = gr.Interface(
# fn=chat,
# description=f'''You are {user}, bot is {bot}.''',
# allow_flagging="never",
# inputs = [
# gr.Textbox(label="Message"),
# "state",
# gr.Slider(10, 1000, step=10, value=250), # token_count
# gr.Slider(0.2, 2.0, step=0.1, value=1.0), # temperature
# gr.Slider(0.0, 1.0, step=0.05, value=0.8), # top_p
# gr.Slider(0.0, 1.0, step=0.1, value=0.2), # presence_penalty
# gr.Slider(0.0, 1.0, step=0.1, value=0.2), # count_penalty
# ],
# outputs=[
# gr.Chatbot(label="Chat Log", color_map=("blue", "pink")),
# "state"
# ]
# ).queue()
########################################################################################################
# demo = gr.TabbedInterface(
# [infer_interface, chat_interface], ["Generative", "Chat"],
# title=title,
# )
# demo.queue(max_size=10)
# demo.launch(share=True)
with gr.Blocks() as demo:
with gr.Tab("Generative"):
with gr.Row():
with gr.Column():
prompt = gr.Textbox(lines=10, label="Prompt", value="Here's a short cyberpunk sci-fi adventure story. The story's main character is an artificial human created by a company called OpenBot.\n\nThe Story:\n")
token_count = gr.Slider(10, 1000, label="Max Token", step=10, value=250)
temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.0)
top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.8)
presence_penalty = gr.Slider(0.0, 1.0, label="Presence Penalty", step=0.1, value=0.2)
count_penalty = gr.Slider(0.0, 1.0, label="Count Penalty", step=0.1, value=0.2)
with gr.Column():
with gr.Row():
submit = gr.Button("Submit")
clear = gr.Button("Clear")
output = gr.Textbox(label="Generated Output", lines=28)
data = gr.Dataset(components=[prompt, token_count, temperature, top_p, presence_penalty, count_penalty], samples=examples, label="Example Prompts", headers=["Prompt", "Max Tokens", "Temperature", "Top P", "Presence Penalty", "Count Penalty"])
submit.click(infer, [prompt, token_count, temperature, top_p, presence_penalty, count_penalty], [output])
clear.click(lambda: None, [], [output])
data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty])
with gr.Tab("Chat"):
with gr.Row():
with gr.Column():
chatbot = gr.Chatbot()
state = gr.State()
message = gr.Textbox(label="Message")
with gr.Row():
send = gr.Button("Send")
clear = gr.Button("Clear")
with gr.Column():
token_count = gr.Slider(10, 1000, label="Max Token", step=10, value=250)
temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.0)
top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.8)
presence_penalty = gr.Slider(0.0, 1.0, label="Presence Penalty", step=0.1, value=0.2)
count_penalty = gr.Slider(0.0, 1.0, label="Count Penalty", step=0.1, value=0.2)
message.submit(user, [message, chatbot], [message, chatbot], queue=False).then(
chat, [chatbot, state, token_count, temperature, top_p, presence_penalty, count_penalty], [chatbot, state]
)
send.click(user, [message, chatbot], [message, chatbot], queue=False).then(
chat, [chatbot, state, token_count, temperature, top_p, presence_penalty, count_penalty], [chatbot, state]
)
clear.click(lambda: ([], None, ""), [], [chatbot, state, message])
demo.queue(max_size=10)
demo.launch(share=False)