picocreator commited on
Commit
59618c7
1 Parent(s): f092f6f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -4
app.py CHANGED
@@ -3,20 +3,42 @@ import os, gc, copy, torch
3
  from datetime import datetime
4
  from huggingface_hub import hf_hub_download
5
  from pynvml import *
6
- nvmlInit()
7
- gpu_h = nvmlDeviceGetHandleByIndex(0)
 
 
 
8
  ctx_limit = 2000
9
  title = "RWKV-5-World-1B5-v2-20231025-ctx4096"
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  os.environ["RWKV_JIT_ON"] = '1'
12
  os.environ["RWKV_CUDA_ON"] = '1' # if '1' then use CUDA kernel for seq mode (much faster)
13
 
 
 
 
 
 
 
14
  from rwkv.model import RWKV
15
- model_path = hf_hub_download(repo_id="BlinkDL/rwkv-5-world", filename=f"{title}.pth")
16
- model = RWKV(model=model_path, strategy='cuda fp16')
17
  from rwkv.utils import PIPELINE, PIPELINE_ARGS
18
  pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
19
 
 
20
  def generate_prompt(instruction, input=""):
21
  instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
22
  input = input.strip().replace('\r\n','\n').replace('\n\n','\n')
@@ -35,6 +57,7 @@ User: {instruction}
35
 
36
  Assistant:"""
37
 
 
38
  def evaluate(
39
  ctx,
40
  token_count=200,
@@ -84,6 +107,7 @@ def evaluate(
84
  torch.cuda.empty_cache()
85
  yield out_str.strip()
86
 
 
87
  examples = [
88
  ["Assistant: Sure! Here is a very detailed plan to create flying pigs:", 333, 1, 0.3, 0, 1],
89
  ["Assistant: Sure! Here are some ideas for FTL drive:", 333, 1, 0.3, 0, 1],
@@ -108,6 +132,7 @@ Edward:''', 333, 1, 0.3, 0, 1],
108
 
109
  ##########################################################################
110
 
 
111
  with gr.Blocks(title=title) as demo:
112
  gr.HTML(f"<div style=\"text-align: center;\">\n<h1>RWKV-5 World v2 - {title}</h1>\n</div>")
113
  with gr.Tab("Raw Generation"):
@@ -130,5 +155,6 @@ with gr.Blocks(title=title) as demo:
130
  clear.click(lambda: None, [], [output])
131
  data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty])
132
 
 
133
  demo.queue(concurrency_count=1, max_size=10)
134
  demo.launch(share=False)
 
3
  from datetime import datetime
4
  from huggingface_hub import hf_hub_download
5
  from pynvml import *
6
+
7
+ # Flag to check if GPU is present
8
+ HAS_GPU = False
9
+
10
+ # Model title and context size limit
11
  ctx_limit = 2000
12
  title = "RWKV-5-World-1B5-v2-20231025-ctx4096"
13
+ model_file = "RWKV-5-World-1B5-v2-20231025-ctx4096"
14
+
15
+ # Get the GPU count
16
+ try:
17
+ nvmlInit()
18
+ GPU_COUNT = nvmlDeviceGetCount()
19
+ if GPU_COUNT > 0:
20
+ HAS_GPU = True
21
+ gpu_h = nvmlDeviceGetHandleByIndex(0)
22
+ except NVMLError as error:
23
+ print(error)
24
+
25
 
26
  os.environ["RWKV_JIT_ON"] = '1'
27
  os.environ["RWKV_CUDA_ON"] = '1' # if '1' then use CUDA kernel for seq mode (much faster)
28
 
29
+ # Model strat to use
30
+ MODEL_STRAT="cuda fp16"
31
+ if HAS_GPU == False :
32
+ "cpu bf16"
33
+
34
+ # Load the model accordingly
35
  from rwkv.model import RWKV
36
+ model_path = hf_hub_download(repo_id="BlinkDL/rwkv-5-world", filename=f"{model_file}.pth")
37
+ model = RWKV(model=model_path, strategy=MODEL_STRAT)
38
  from rwkv.utils import PIPELINE, PIPELINE_ARGS
39
  pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
40
 
41
+ # Prompt generation
42
  def generate_prompt(instruction, input=""):
43
  instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
44
  input = input.strip().replace('\r\n','\n').replace('\n\n','\n')
 
57
 
58
  Assistant:"""
59
 
60
+ # Evaluation logic
61
  def evaluate(
62
  ctx,
63
  token_count=200,
 
107
  torch.cuda.empty_cache()
108
  yield out_str.strip()
109
 
110
+ # Examples and gradio blocks
111
  examples = [
112
  ["Assistant: Sure! Here is a very detailed plan to create flying pigs:", 333, 1, 0.3, 0, 1],
113
  ["Assistant: Sure! Here are some ideas for FTL drive:", 333, 1, 0.3, 0, 1],
 
132
 
133
  ##########################################################################
134
 
135
+ # Gradio blocks
136
  with gr.Blocks(title=title) as demo:
137
  gr.HTML(f"<div style=\"text-align: center;\">\n<h1>RWKV-5 World v2 - {title}</h1>\n</div>")
138
  with gr.Tab("Raw Generation"):
 
155
  clear.click(lambda: None, [], [output])
156
  data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty])
157
 
158
+ # Gradio launch
159
  demo.queue(concurrency_count=1, max_size=10)
160
  demo.launch(share=False)