Spaces:
Running
on
T4
Running
on
T4
File size: 5,497 Bytes
f714b01 8a953d3 f714b01 8a953d3 e962dad 1d5e556 8a953d3 1d5e556 dc1d70c f714b01 8a953d3 f714b01 315ea19 8a953d3 f714b01 8a953d3 f714b01 3bcbfb1 8a953d3 3bcbfb1 eded8df 315ea19 3bcbfb1 c197986 234ee14 ab56f98 3bcbfb1 f714b01 c4f1727 5729400 0c50ef7 5729400 c4f1727 c197986 8a953d3 315ea19 37d48ee c197986 5729400 8a953d3 5729400 7a9514e c197986 8a953d3 6e2a0af 315ea19 c197986 315ea19 c197986 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import gradio as gr
import os, gc, copy, torch, re
from datetime import datetime
from huggingface_hub import hf_hub_download
from pynvml import *
nvmlInit()
gpu_h = nvmlDeviceGetHandleByIndex(0)
ctx_limit = 1024
title = "rwkv-x060-eng_single_round_qa-3B-20240430-ctx1024"
os.environ["RWKV_JIT_ON"] = '1'
os.environ["RWKV_CUDA_ON"] = '1' # if '1' then use CUDA kernel for seq mode (much faster)
from rwkv.model import RWKV
model_path = hf_hub_download(repo_id="BlinkDL/temp-latest-training-models", filename=f"{title}.pth")
model = RWKV(model=model_path, strategy='cuda fp16')
from rwkv.utils import PIPELINE, PIPELINE_ARGS
pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
def generate_prompt(instruction):
instruction = instruction.strip().replace('\r\n','\n')
instruction = re.sub(r'\n+', '\n', instruction)
return f"User: {instruction}\n\nAssistant:"""
def evaluate(
ctx,
token_count=500,
temperature=1.0,
top_p=0.3,
presencePenalty = 0.3,
countPenalty = 0.3,
):
args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
alpha_frequency = countPenalty,
alpha_presence = presencePenalty,
token_ban = [], # ban the generation of some tokens
token_stop = [0]) # stop generation whenever you see any token here
ctx = generate_prompt(ctx)
all_tokens = []
out_last = 0
out_str = ''
occurrence = {}
state = None
for i in range(int(token_count)):
out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token], state)
for n in occurrence:
out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
if token in args.token_stop:
break
all_tokens += [token]
for xxx in occurrence:
occurrence[xxx] *= 0.996
if token not in occurrence:
occurrence[token] = 1
else:
occurrence[token] += 1
tmp = pipeline.decode(all_tokens[out_last:])
if '\ufffd' not in tmp:
out_str += tmp
yield out_str.strip()
out_last = i + 1
gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
print(f'{timestamp} - vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
del out
del state
gc.collect()
torch.cuda.empty_cache()
yield out_str.strip()
examples = [
["How can I craft an engaging story featuring vampires on Mars?", 700, 2, 0.1, 0.3, 0.3],
["Compare the business models of Apple and Google.", 700, 2, 0.1, 0.3, 0.3],
["In JSON format, list the top 5 tourist attractions in Paris.", 700, 2, 0.1, 0.3, 0.3],
["Write an outline for a fantasy novel where dreams can alter reality.", 700, 2, 0.1, 0.3, 0.3],
["Can fish get thirsty?", 700, 2, 0.1, 0.3, 0.3],
["Write a Bash script to check disk usage and send alerts if it's too high.", 700, 2, 0.1, 0.3, 0.3],
["Write a simple website in HTML. When a user clicks the button, it shows a random joke from a list of 4 jokes.", 700, 2, 0.1, 0.3, 0.3],
]
##########################################################################
with gr.Blocks(title=title) as demo:
gr.HTML(f"<div style=\"text-align: center;\">\n<h1>{title}</h1>\n</div>")
with gr.Tab("Raw Generation"):
gr.Markdown(f"This is [RWKV-6](https://huggingface.co/BlinkDL/temp-latest-training-models) with 1.6B params [state-tuned](https://twitter.com/BlinkDL_AI/status/1784496793075744966) on single-round English Q & A. RWKV is a 100% attention-free RNN [RWKV-LM](https://github.com/BlinkDL/RWKV-LM), and we have [300+ Github RWKV projects](https://github.com/search?o=desc&p=1&q=rwkv&s=updated&type=Repositories). Demo limited to ctxlen {ctx_limit}.")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(lines=2, label="Prompt", value="How can I craft an engaging story featuring vampires on Mars?")
token_count = gr.Slider(10, 700, label="Max Tokens", step=10, value=700)
temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=2.0)
top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.1)
presence_penalty = gr.Slider(0.0, 1.0, label="Presence Penalty", step=0.1, value=0.3)
count_penalty = gr.Slider(0.0, 1.0, label="Count Penalty", step=0.1, value=0.3)
with gr.Column():
with gr.Row():
submit = gr.Button("Submit", variant="primary")
clear = gr.Button("Clear", variant="secondary")
output = gr.Textbox(label="Output", lines=50)
data = gr.Dataset(components=[prompt, token_count, temperature, top_p, presence_penalty, count_penalty], samples=examples, samples_per_page=50, label="Example Instructions", headers=["Prompt", "Max Tokens", "Temperature", "Top P", "Presence Penalty", "Count Penalty"])
submit.click(evaluate, [prompt, token_count, temperature, top_p, presence_penalty, count_penalty], [output])
clear.click(lambda: None, [], [output])
data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty])
demo.queue(concurrency_count=1, max_size=10)
demo.launch(share=False)
|