picocreator's picture
Update app.py
78c7c53
import gradio as gr
import os, gc, copy, torch
from huggingface_hub import hf_hub_download
from pynvml import *
# Flag to check if GPU is present
HAS_GPU = False
# Model title and context size limit
ctx_limit = 2000
title = "RWKV-5-World-1B5-v2-Translator"
model_file = "RWKV-5-World-1B5-v2-20231025-ctx4096"
# Get the GPU count
try:
nvmlInit()
GPU_COUNT = nvmlDeviceGetCount()
if GPU_COUNT > 0:
HAS_GPU = True
gpu_h = nvmlDeviceGetHandleByIndex(0)
except NVMLError as error:
print(error)
os.environ["RWKV_JIT_ON"] = '1'
# Model strategy to use
MODEL_STRAT = "cpu bf16"
os.environ["RWKV_CUDA_ON"] = '0' # if '1' then use CUDA kernel for seq mode (much faster)
# Switch to GPU mode
if HAS_GPU:
os.environ["RWKV_CUDA_ON"] = '1'
MODEL_STRAT = "cuda bf16"
# Load the model
from rwkv.model import RWKV
model_path = hf_hub_download(repo_id="BlinkDL/rwkv-5-world", filename=f"{model_file}.pth")
model = RWKV(model=model_path, strategy=MODEL_STRAT)
from rwkv.utils import PIPELINE
pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
# State copy
def universal_deepcopy(obj):
if hasattr(obj, 'clone'): # Assuming it's a tensor if it has a clone method
return obj.clone()
elif isinstance(obj, list):
return [universal_deepcopy(item) for item in obj]
else:
return copy.deepcopy(obj)
# For debgging mostly
def inspect_structure(obj, depth=0):
indent = " " * depth
obj_type = type(obj).__name__
if isinstance(obj, list):
print(f"{indent}List (length {len(obj)}):")
for item in obj:
inspect_structure(item, depth + 1)
elif isinstance(obj, dict):
print(f"{indent}Dict (length {len(obj)}):")
for key, value in obj.items():
print(f"{indent} Key: {key}")
inspect_structure(value, depth + 1)
else:
print(f"{indent}{obj_type}")
# Precomputation of the state
def precompute_state(text):
state = None
text_encoded = pipeline.encode(text)
_, state = model.forward(text_encoded, state)
return state
# Precomputing the base instruction set
INSTRUCT_PREFIX = f'''
You are a translator bot that can translate text to any language.
And will respond only with the translated text, without additional comments.
## From English:
It is not enough to know, we must also apply; it is not enough to will, we must also do.
## To Polish:
Nie wystarczy wiedzieć, trzeba także zastosować; nie wystarczy chcieć, trzeba też działać.
## From Spanish:
La muerte no nos concierne, porque mientras existamos, la muerte no está aquí. Y cuando llega, ya no existimos.
## To English:
Death does not concern us, because as long as we exist, death is not here. And when it does come, we no longer exist.
'''
# Get the prefix state
PREFIX_STATE = precompute_state(INSTRUCT_PREFIX)
# Translation logic
def translate(
text, source_language, target_language,
inState=PREFIX_STATE,
temperature=0.2,
top_p=0.5,
presencePenalty = 0.1,
countPenalty = 0.1,
):
prompt = f"## From {source_language}:\n{text}\n\n## To {target_language}:\n"
ctx = prompt.strip()
all_tokens = []
out_last = 0
out_str = ''
occurrence = {}
alpha_frequency = countPenalty
alpha_presence = presencePenalty
state = None
if inState != None:
state = universal_deepcopy(inState)
# Clear GC
gc.collect()
if HAS_GPU == True :
torch.cuda.empty_cache()
# Generate things token by token
for i in range(ctx_limit):
out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token], state)
for n in occurrence:
out[n] -= (alpha_presence + occurrence[n] * alpha_frequency)
token = pipeline.sample_logits(out, temperature=temperature, top_p=top_p)
if token in [0]: # EOS token
break
all_tokens += [token]
for xxx in occurrence:
occurrence[xxx] *= 0.996
if token not in occurrence:
occurrence[token] = 1
else:
occurrence[token] += 1
tmp = pipeline.decode(all_tokens[out_last:])
if '\ufffd' not in tmp:
out_str += tmp
out_last = i + 1
else:
return out_str.strip()
if "\n:" in out_str :
out_str = out_str.split("\n\nHuman:")[0].split("\nHuman:")[0]
return out_str.strip()
if "{source_language}:" in out_str :
out_str = out_str.split("{source_language}:")[0]
return out_str.strip()
if "{target_language}:" in out_str :
out_str = out_str.split("{target_language}:")[0]
return out_str.strip()
if "\nHuman:" in out_str :
out_str = out_str.split("\n\nHuman:")[0].split("\nHuman:")[0]
return out_str.strip()
if "\nAssistant:" in out_str :
out_str = out_str.split("\n\nAssistant:")[0].split("\nAssistant:")[0]
return out_str.strip()
if "\n#" in out_str :
out_str = out_str.split("\n\n#")[0].split("\n#")[0]
return out_str.strip()
# Yield for streaming
yield out_str.strip()
del out
del state
# # Clear GC
# gc.collect()
# if HAS_GPU == True :
# torch.cuda.empty_cache()
# yield out_str.strip()
return out_str.strip()
# Languages
LANGUAGES = [
"English",
"Chinese",
"Spanish",
"Bengali",
"Hindi",
"Portuguese",
"Russian",
"Japanese",
"German",
"Chinese (Wu)",
"Javanese",
"Korean",
"French",
"Vietnamese",
"Telugu",
"Chinese (Yue)",
"Marathi",
"Tamil",
"Turkish",
"Urdu",
"Chinese (Min Nan)",
"Chinese (Jin Yu)",
"Gujarati",
"Polish",
"Arabic (Egyptian Spoken)",
"Ukrainian",
"Italian",
"Chinese (Xiang)",
"Malayalam",
"Chinese (Hakka)",
"Kannada",
"Oriya",
"Panjabi (Western)",
"Panjabi (Eastern)",
"Sunda",
"Romanian",
"Bhojpuri",
"Azerbaijani (South)",
"Farsi (Western)",
"Maithili",
"Hausa",
"Arabic (Algerian Spoken)",
"Burmese",
"Serbo-Croatian",
"Chinese (Gan)",
"Awadhi",
"Thai",
"Dutch",
"Yoruba",
"Sindhi",
"Arabic (Moroccan Spoken)",
"Arabic (Saidi Spoken)",
"Uzbek, Northern",
"Malay",
"Amharic",
"Indonesian",
"Igbo",
"Tagalog",
"Nepali",
"Arabic (Sudanese Spoken)",
"Saraiki",
"Cebuano",
"Arabic (North Levantine Spoken)",
"Thai (Northeastern)",
"Assamese",
"Hungarian",
"Chittagonian",
"Arabic (Mesopotamian Spoken)",
"Madura",
"Sinhala",
"Haryanvi",
"Marwari",
"Czech",
"Greek",
"Magahi",
"Chhattisgarhi",
"Deccan",
"Chinese (Min Bei)",
"Belarusan",
"Zhuang (Northern)",
"Arabic (Najdi Spoken)",
"Pashto (Northern)",
"Somali",
"Malagasy",
"Arabic (Tunisian Spoken)",
"Rwanda",
"Zulu",
"Latin",
"Bulgarian",
"Swedish",
"Lombard",
"Oromo (West-central)",
"Pashto (Southern)",
"Kazakh",
"Ilocano",
"Tatar",
"Fulfulde (Nigerian)",
"Arabic (Sanaani Spoken)",
"Uyghur",
"Haitian Creole French",
"Azerbaijani, North",
"Napoletano-calabrese",
"Khmer (Central)",
"Farsi (Eastern)",
"Akan",
"Hiligaynon",
"Kurmanji",
"Shona"
]
# Example data
EXAMPLES = [
# More people would learn from their mistakes if they weren't so busy denying them.
["Többen tanulnának a hibáikból, ha nem lennének annyira elfoglalva, hogy tagadják azokat.", "Hungarian", "English"],
["La mejor venganza es el éxito masivo.", "Spanish", "English"],
["Tout est bien qui finit bien.", "French", "English"],
["Lasciate ogne speranza, voi ch'intrate.", "Italian", "English"],
["Errare humanum est.", "Latin", "English"],
# ["Brargh-ains argh-uh foo-duh", "English"],
# ["I Want to eat your brains", "Zombie Speak"],
# ["Bonjour, comment ça va?", "English"],
# ["Hola, ¿cómo estás?", "English"],
# ["你好吗?", "English"],
# ["Guten Tag, wie geht es Ihnen?", "English"],
# ["Привет, как ты?", "English"],
# ["مرحبًا ، كيف حالك؟", "English"],
]
# Gradio interface
with gr.Blocks(title=title) as demo:
gr.HTML(f"<div style=\"text-align: center;\"><h1>RWKV-5 World v2 - {title}</h1></div>")
gr.Markdown("This is the RWKV-5 World v2 1B5 model tailored for translation tasks. All on 8 vCPUs")
# Input and output components
text = gr.Textbox(lines=5, label="Source Text", placeholder="Enter the text you want to translate...", value=EXAMPLES[0][0])
source_language = gr.Dropdown(choices=LANGUAGES, label="Source Language", value=EXAMPLES[0][1])
target_language = gr.Dropdown(choices=LANGUAGES, label="Target Language", value=EXAMPLES[0][2])
output = gr.Textbox(lines=5, label="Translated Text")
# Submission
submit = gr.Button("Translate", variant="primary")
# Example data
data = gr.Dataset(components=[text, source_language, target_language], samples=EXAMPLES, label="Example Translations", headers=["Source Text", "Source Language", "Target Language"])
# Button action
submit.click(translate, [text, source_language, target_language], [output])
data.click(lambda x: x, [data], [text, source_language, target_language])
# Gradio launch
demo.queue(concurrency_count=1, max_size=10)
demo.launch(share=False, debug=True)