picocreator's picture
Update app.py
5dbd048
raw
history blame
3.44 kB
import gradio as gr
import os, gc
from huggingface_hub import hf_hub_download
from pynvml import *
# Flag to check if GPU is present
HAS_GPU = False
# Model title and context size limit
ctx_limit = 2000
title = "RWKV-5-World-1B5-v2-20231025-ctx4096"
model_file = "RWKV-5-World-1B5-v2-20231025-ctx4096"
# Get the GPU count
try:
nvmlInit()
GPU_COUNT = nvmlDeviceGetCount()
if GPU_COUNT > 0:
HAS_GPU = True
gpu_h = nvmlDeviceGetHandleByIndex(0)
except NVMLError as error:
print(error)
os.environ["RWKV_JIT_ON"] = '1'
# Model strategy to use
MODEL_STRAT = "cpu bf16"
os.environ["RWKV_CUDA_ON"] = '0' # if '1' then use CUDA kernel for seq mode (much faster)
# Switch to GPU mode
if HAS_GPU:
os.environ["RWKV_CUDA_ON"] = '1'
MODEL_STRAT = "cuda bf16"
# Load the model
from rwkv.model import RWKV
model_path = hf_hub_download(repo_id="BlinkDL/rwkv-5-world", filename=f"{model_file}.pth")
model = RWKV(model=model_path, strategy=MODEL_STRAT)
from rwkv.utils import PIPELINE
pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
# Translation logic
def translate(text, target_language):
prompt = f"Translate the following English text to {target_language}: '{text}'"
ctx = prompt.strip()
all_tokens = []
out_last = 0
out_str = ''
occurrence = {}
state = None
for i in range(ctx_limit):
out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token], state)
token = pipeline.sample_logits(out)
if token in [0]: # EOS token
break
all_tokens += [token]
tmp = pipeline.decode(all_tokens[out_last:])
if '\ufffd' not in tmp:
out_str += tmp
yield out_str.strip()
out_last = i + 1
del out
del state
# Clear GC
gc.collect()
if HAS_GPU == True :
torch.cuda.empty_cache()
yield out_str.strip()
# Example data
EXAMPLES = [
["Hello, how are you?", "French"],
["Hello, how are you?", "Spanish"],
["Hello, how are you?", "Chinese"],
["Bonjour, comment ça va?", "English"],
["Hola, ¿cómo estás?", "English"],
["你好吗?", "English"],
["Guten Tag, wie geht es Ihnen?", "English"],
["Привет, как ты?", "English"],
["مرحبًا ، كيف حالك؟", "English"],
]
# Gradio interface
with gr.Blocks(title=title) as demo:
gr.HTML(f"<div style=\"text-align: center;\"><h1>RWKV-5 World v2 - {title}</h1></div>")
gr.Markdown("This is the RWKV-5 World v2 1B5 model tailored for translation. Please provide the text and select the target language for translation.")
# Input and output components
text = gr.Textbox(lines=5, label="English Text", placeholder="Enter the text you want to translate...")
target_language = gr.Dropdown(choices=["French", "Spanish", "German", "Chinese", "Japanese", "Russian", "Arabic"], label="Target Language")
output = gr.Textbox(lines=5, label="Translated Text")
submit = gr.Button("Translate", variant="primary")
# Example data
data = gr.Dataset(components=[text, target_language], samples=EXAMPLES, label="Example Translations", headers=["Text", "Target Language"])
# Button action
submit.click(translate, [text, target_language], [output])
data.click(lambda x: x, [data], [text, target_language])
# Gradio launch
demo.queue(concurrency_count=1, max_size=10)
demo.launch(share=False)