picocreator's picture
Update app.py
040cd58
raw
history blame
No virus
5.22 kB
import gradio as gr
import os, gc
from huggingface_hub import hf_hub_download
from pynvml import *
# Flag to check if GPU is present
HAS_GPU = False
# Model title and context size limit
ctx_limit = 2000
title = "RWKV-5-World-1B5-v2-20231025-ctx4096"
model_file = "RWKV-5-World-1B5-v2-20231025-ctx4096"
# Get the GPU count
try:
nvmlInit()
GPU_COUNT = nvmlDeviceGetCount()
if GPU_COUNT > 0:
HAS_GPU = True
gpu_h = nvmlDeviceGetHandleByIndex(0)
except NVMLError as error:
print(error)
os.environ["RWKV_JIT_ON"] = '1'
# Model strategy to use
MODEL_STRAT = "cpu bf16"
os.environ["RWKV_CUDA_ON"] = '0' # if '1' then use CUDA kernel for seq mode (much faster)
# Switch to GPU mode
if HAS_GPU:
os.environ["RWKV_CUDA_ON"] = '1'
MODEL_STRAT = "cuda bf16"
# Load the model
from rwkv.model import RWKV
model_path = hf_hub_download(repo_id="BlinkDL/rwkv-5-world", filename=f"{model_file}.pth")
model = RWKV(model=model_path, strategy=MODEL_STRAT)
from rwkv.utils import PIPELINE
pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
# Translation logic
def translate(text, target_language):
prompt = f"Translate the following text to {target_language}\n # Input Text:\n{text}\n\n# Output Text:\n"
ctx = prompt.strip()
all_tokens = []
out_last = 0
out_str = ''
occurrence = {}
state = None
for i in range(ctx_limit):
out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token], state)
token = pipeline.sample_logits(out)
if token in [0]: # EOS token
break
all_tokens += [token]
tmp = pipeline.decode(all_tokens[out_last:])
if '\ufffd' not in tmp:
out_str += tmp
yield out_str.strip()
out_last = i + 1
del out
del state
# Clear GC
gc.collect()
if HAS_GPU == True :
torch.cuda.empty_cache()
yield out_str.strip()
# Languages
LANGUAGES = [
"English",
"Chinese",
"Spanish",
"Bengali",
"Hindi",
"Portuguese",
"Russian",
"Japanese",
"German",
"Chinese (Wu)",
"Javanese",
"Korean",
"French",
"Vietnamese",
"Telugu",
"Chinese (Yue)",
"Marathi",
"Tamil",
"Turkish",
"Urdu",
"Chinese (Min Nan)",
"Chinese (Jin Yu)",
"Gujarati",
"Polish",
"Arabic (Egyptian Spoken)",
"Ukrainian",
"Italian",
"Chinese (Xiang)",
"Malayalam",
"Chinese (Hakka)",
"Kannada",
"Oriya",
"Panjabi (Western)",
"Panjabi (Eastern)",
"Sunda",
"Romanian",
"Bhojpuri",
"Azerbaijani (South)",
"Farsi (Western)",
"Maithili",
"Hausa",
"Arabic (Algerian Spoken)",
"Burmese",
"Serbo-Croatian",
"Chinese (Gan)",
"Awadhi",
"Thai",
"Dutch",
"Yoruba",
"Sindhi",
"Arabic (Moroccan Spoken)",
"Arabic (Saidi Spoken)",
"Uzbek, Northern",
"Malay",
"Amharic",
"Indonesian",
"Igbo",
"Tagalog",
"Nepali",
"Arabic (Sudanese Spoken)",
"Saraiki",
"Cebuano",
"Arabic (North Levantine Spoken)",
"Thai (Northeastern)",
"Assamese",
"Hungarian",
"Chittagonian",
"Arabic (Mesopotamian Spoken)",
"Madura",
"Sinhala",
"Haryanvi",
"Marwari",
"Czech",
"Greek",
"Magahi",
"Chhattisgarhi",
"Deccan",
"Chinese (Min Bei)",
"Belarusan",
"Zhuang (Northern)",
"Arabic (Najdi Spoken)",
"Pashto (Northern)",
"Somali",
"Malagasy",
"Arabic (Tunisian Spoken)",
"Rwanda",
"Zulu",
"Bulgarian",
"Swedish",
"Lombard",
"Oromo (West-central)",
"Pashto (Southern)",
"Kazakh",
"Ilocano",
"Tatar",
"Fulfulde (Nigerian)",
"Arabic (Sanaani Spoken)",
"Uyghur",
"Haitian Creole French",
"Azerbaijani, North",
"Napoletano-calabrese",
"Khmer (Central)",
"Farsi (Eastern)",
"Akan",
"Hiligaynon",
"Kurmanji",
"Shona"
]
# Example data
EXAMPLES = [
["Hello, how are you?", "French"],
["Hello, how are you?", "Spanish"],
["Hello, how are you?", "Chinese"],
["Bonjour, comment ça va?", "English"],
["Hola, ¿cómo estás?", "English"],
["你好吗?", "English"],
["Guten Tag, wie geht es Ihnen?", "English"],
["Привет, как ты?", "English"],
["مرحبًا ، كيف حالك؟", "English"],
]
# Gradio interface
with gr.Blocks(title=title) as demo:
gr.HTML(f"<div style=\"text-align: center;\"><h1>RWKV-5 World v2 - {title}</h1></div>")
gr.Markdown("This is the RWKV-5 World v2 1B5 model tailored for translation. Please provide the text and select the target language for translation.")
# Input and output components
text = gr.Textbox(lines=5, label="Source Text", placeholder="Enter the text you want to translate...")
target_language = gr.Dropdown(choices=LANGUAGES, label="Target Language")
output = gr.Textbox(lines=5, label="Translated Text")
submit = gr.Button("Translate", variant="primary")
# Example data
data = gr.Dataset(components=[text, target_language], samples=EXAMPLES, label="Example Translations", headers=["Text", "Target Language"])
# Button action
submit.click(translate, [text, target_language], [output])
data.click(lambda x: x, [data], [text, target_language])
# Gradio launch
demo.queue(concurrency_count=1, max_size=10)
demo.launch(share=False)