Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os, gc | |
from huggingface_hub import hf_hub_download | |
from pynvml import * | |
# Flag to check if GPU is present | |
HAS_GPU = False | |
# Model title and context size limit | |
ctx_limit = 2000 | |
title = "RWKV-5-World-1B5-v2-20231025-ctx4096" | |
model_file = "RWKV-5-World-1B5-v2-20231025-ctx4096" | |
# Get the GPU count | |
try: | |
nvmlInit() | |
GPU_COUNT = nvmlDeviceGetCount() | |
if GPU_COUNT > 0: | |
HAS_GPU = True | |
gpu_h = nvmlDeviceGetHandleByIndex(0) | |
except NVMLError as error: | |
print(error) | |
os.environ["RWKV_JIT_ON"] = '1' | |
# Model strategy to use | |
MODEL_STRAT = "cpu bf16" | |
os.environ["RWKV_CUDA_ON"] = '0' # if '1' then use CUDA kernel for seq mode (much faster) | |
# Switch to GPU mode | |
if HAS_GPU: | |
os.environ["RWKV_CUDA_ON"] = '1' | |
MODEL_STRAT = "cuda bf16" | |
# Load the model | |
from rwkv.model import RWKV | |
model_path = hf_hub_download(repo_id="BlinkDL/rwkv-5-world", filename=f"{model_file}.pth") | |
model = RWKV(model=model_path, strategy=MODEL_STRAT) | |
from rwkv.utils import PIPELINE | |
pipeline = PIPELINE(model, "rwkv_vocab_v20230424") | |
# Translation logic | |
def translate(text, target_language): | |
prompt = f"Translate the following English text to {target_language}: '{text}'" | |
ctx = prompt.strip() | |
all_tokens = [] | |
out_last = 0 | |
out_str = '' | |
occurrence = {} | |
state = None | |
for i in range(ctx_limit): | |
out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token], state) | |
token = pipeline.sample_logits(out) | |
if token in [0]: # EOS token | |
break | |
all_tokens += [token] | |
tmp = pipeline.decode(all_tokens[out_last:]) | |
if '\ufffd' not in tmp: | |
out_str += tmp | |
yield out_str.strip() | |
out_last = i + 1 | |
del out | |
del state | |
# Clear GC | |
gc.collect() | |
if HAS_GPU == True : | |
torch.cuda.empty_cache() | |
yield out_str.strip() | |
# Example data | |
EXAMPLES = [ | |
["Hello, how are you?", "French"], | |
["Hello, how are you?", "Spanish"], | |
["Hello, how are you?", "Chinese"], | |
["Bonjour, comment ça va?", "English"], | |
["Hola, ¿cómo estás?", "English"], | |
["你好吗?", "English"], | |
["Guten Tag, wie geht es Ihnen?", "English"], | |
["Привет, как ты?", "English"], | |
["مرحبًا ، كيف حالك؟", "English"], | |
] | |
# Gradio interface | |
with gr.Blocks(title=title) as demo: | |
gr.HTML(f"<div style=\"text-align: center;\"><h1>RWKV-5 World v2 - {title}</h1></div>") | |
gr.Markdown("This is the RWKV-5 World v2 1B5 model tailored for translation. Please provide the text and select the target language for translation.") | |
# Input and output components | |
text = gr.Textbox(lines=5, label="English Text", placeholder="Enter the text you want to translate...") | |
target_language = gr.Dropdown(choices=["French", "Spanish", "German", "Chinese", "Japanese", "Russian", "Arabic"], label="Target Language") | |
output = gr.Textbox(lines=5, label="Translated Text") | |
submit = gr.Button("Translate", variant="primary") | |
# Example data | |
data = gr.Dataset(components=[text, target_language], samples=EXAMPLES, label="Example Translations", headers=["Text", "Target Language"]) | |
# Button action | |
submit.click(translate, [text, target_language], [output]) | |
data.click(lambda x: x, [data], [text, target_language]) | |
# Gradio launch | |
demo.queue(concurrency_count=1, max_size=10) | |
demo.launch(share=False) |