Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os, gc, copy, torch | |
from huggingface_hub import hf_hub_download | |
from pynvml import * | |
# Flag to check if GPU is present | |
HAS_GPU = False | |
# Model title and context size limit | |
ctx_limit = 2000 | |
title = "RWKV-5-World-1B5-v2-Translator" | |
model_file = "RWKV-5-World-1B5-v2-20231025-ctx4096" | |
# Get the GPU count | |
try: | |
nvmlInit() | |
GPU_COUNT = nvmlDeviceGetCount() | |
if GPU_COUNT > 0: | |
HAS_GPU = True | |
gpu_h = nvmlDeviceGetHandleByIndex(0) | |
except NVMLError as error: | |
print(error) | |
os.environ["RWKV_JIT_ON"] = '1' | |
# Model strategy to use | |
MODEL_STRAT = "cpu bf16" | |
os.environ["RWKV_CUDA_ON"] = '0' # if '1' then use CUDA kernel for seq mode (much faster) | |
# Switch to GPU mode | |
if HAS_GPU: | |
os.environ["RWKV_CUDA_ON"] = '1' | |
MODEL_STRAT = "cuda bf16" | |
# Load the model | |
from rwkv.model import RWKV | |
model_path = hf_hub_download(repo_id="BlinkDL/rwkv-5-world", filename=f"{model_file}.pth") | |
model = RWKV(model=model_path, strategy=MODEL_STRAT) | |
from rwkv.utils import PIPELINE | |
pipeline = PIPELINE(model, "rwkv_vocab_v20230424") | |
# State copy | |
def universal_deepcopy(obj): | |
if hasattr(obj, 'clone'): # Assuming it's a tensor if it has a clone method | |
return obj.clone() | |
elif isinstance(obj, list): | |
return [universal_deepcopy(item) for item in obj] | |
else: | |
return copy.deepcopy(obj) | |
# For debgging mostly | |
def inspect_structure(obj, depth=0): | |
indent = " " * depth | |
obj_type = type(obj).__name__ | |
if isinstance(obj, list): | |
print(f"{indent}List (length {len(obj)}):") | |
for item in obj: | |
inspect_structure(item, depth + 1) | |
elif isinstance(obj, dict): | |
print(f"{indent}Dict (length {len(obj)}):") | |
for key, value in obj.items(): | |
print(f"{indent} Key: {key}") | |
inspect_structure(value, depth + 1) | |
else: | |
print(f"{indent}{obj_type}") | |
# Precomputation of the state | |
def precompute_state(text): | |
state = None | |
text_encoded = pipeline.encode(text) | |
_, state = model.forward(text_encoded, state) | |
return state | |
# Precomputing the base instruction set | |
INSTRUCT_PREFIX = f''' | |
You are a translator bot that can translate text to any language. | |
And will respond only with the translated text, without additional comments. | |
## From English: | |
It is not enough to know, we must also apply; it is not enough to will, we must also do. | |
## To Polish: | |
Nie wystarczy wiedzieć, trzeba także zastosować; nie wystarczy chcieć, trzeba też działać. | |
## From Spanish: | |
La muerte no nos concierne, porque mientras existamos, la muerte no está aquí. Y cuando llega, ya no existimos. | |
## To English: | |
Death does not concern us, because as long as we exist, death is not here. And when it does come, we no longer exist. | |
''' | |
# Get the prefix state | |
PREFIX_STATE = precompute_state(INSTRUCT_PREFIX) | |
# Translation logic | |
def translate( | |
text, source_language, target_language, | |
inState=PREFIX_STATE, | |
temperature=0.2, | |
top_p=0.5, | |
presencePenalty = 0.1, | |
countPenalty = 0.1, | |
): | |
prompt = f"## From {source_language}:\n{text}\n\n## To {target_language}:\n" | |
ctx = prompt.strip() | |
all_tokens = [] | |
out_last = 0 | |
out_str = '' | |
occurrence = {} | |
alpha_frequency = countPenalty | |
alpha_presence = presencePenalty | |
state = None | |
if inState != None: | |
state = universal_deepcopy(inState) | |
# Clear GC | |
gc.collect() | |
if HAS_GPU == True : | |
torch.cuda.empty_cache() | |
# Generate things token by token | |
for i in range(ctx_limit): | |
out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token], state) | |
for n in occurrence: | |
out[n] -= (alpha_presence + occurrence[n] * alpha_frequency) | |
token = pipeline.sample_logits(out, temperature=temperature, top_p=top_p) | |
if token in [0]: # EOS token | |
break | |
all_tokens += [token] | |
for xxx in occurrence: | |
occurrence[xxx] *= 0.996 | |
if token not in occurrence: | |
occurrence[token] = 1 | |
else: | |
occurrence[token] += 1 | |
tmp = pipeline.decode(all_tokens[out_last:]) | |
if '\ufffd' not in tmp: | |
out_str += tmp | |
out_last = i + 1 | |
else: | |
return out_str.strip() | |
if "\n:" in out_str : | |
out_str = out_str.split("\n\nHuman:")[0].split("\nHuman:")[0] | |
return out_str.strip() | |
if "{source_language}:" in out_str : | |
out_str = out_str.split("{source_language}:")[0] | |
return out_str.strip() | |
if "{target_language}:" in out_str : | |
out_str = out_str.split("{target_language}:")[0] | |
return out_str.strip() | |
if "\nHuman:" in out_str : | |
out_str = out_str.split("\n\nHuman:")[0].split("\nHuman:")[0] | |
return out_str.strip() | |
if "\nAssistant:" in out_str : | |
out_str = out_str.split("\n\nAssistant:")[0].split("\nAssistant:")[0] | |
return out_str.strip() | |
if "\n#" in out_str : | |
out_str = out_str.split("\n\n#")[0].split("\n#")[0] | |
return out_str.strip() | |
# Yield for streaming | |
yield out_str.strip() | |
del out | |
del state | |
# # Clear GC | |
# gc.collect() | |
# if HAS_GPU == True : | |
# torch.cuda.empty_cache() | |
# yield out_str.strip() | |
return out_str.strip() | |
# Languages | |
LANGUAGES = [ | |
"English", | |
"Chinese", | |
"Spanish", | |
"Bengali", | |
"Hindi", | |
"Portuguese", | |
"Russian", | |
"Japanese", | |
"German", | |
"Chinese (Wu)", | |
"Javanese", | |
"Korean", | |
"French", | |
"Vietnamese", | |
"Telugu", | |
"Chinese (Yue)", | |
"Marathi", | |
"Tamil", | |
"Turkish", | |
"Urdu", | |
"Chinese (Min Nan)", | |
"Chinese (Jin Yu)", | |
"Gujarati", | |
"Polish", | |
"Arabic (Egyptian Spoken)", | |
"Ukrainian", | |
"Italian", | |
"Chinese (Xiang)", | |
"Malayalam", | |
"Chinese (Hakka)", | |
"Kannada", | |
"Oriya", | |
"Panjabi (Western)", | |
"Panjabi (Eastern)", | |
"Sunda", | |
"Romanian", | |
"Bhojpuri", | |
"Azerbaijani (South)", | |
"Farsi (Western)", | |
"Maithili", | |
"Hausa", | |
"Arabic (Algerian Spoken)", | |
"Burmese", | |
"Serbo-Croatian", | |
"Chinese (Gan)", | |
"Awadhi", | |
"Thai", | |
"Dutch", | |
"Yoruba", | |
"Sindhi", | |
"Arabic (Moroccan Spoken)", | |
"Arabic (Saidi Spoken)", | |
"Uzbek, Northern", | |
"Malay", | |
"Amharic", | |
"Indonesian", | |
"Igbo", | |
"Tagalog", | |
"Nepali", | |
"Arabic (Sudanese Spoken)", | |
"Saraiki", | |
"Cebuano", | |
"Arabic (North Levantine Spoken)", | |
"Thai (Northeastern)", | |
"Assamese", | |
"Hungarian", | |
"Chittagonian", | |
"Arabic (Mesopotamian Spoken)", | |
"Madura", | |
"Sinhala", | |
"Haryanvi", | |
"Marwari", | |
"Czech", | |
"Greek", | |
"Magahi", | |
"Chhattisgarhi", | |
"Deccan", | |
"Chinese (Min Bei)", | |
"Belarusan", | |
"Zhuang (Northern)", | |
"Arabic (Najdi Spoken)", | |
"Pashto (Northern)", | |
"Somali", | |
"Malagasy", | |
"Arabic (Tunisian Spoken)", | |
"Rwanda", | |
"Zulu", | |
"Latin", | |
"Bulgarian", | |
"Swedish", | |
"Lombard", | |
"Oromo (West-central)", | |
"Pashto (Southern)", | |
"Kazakh", | |
"Ilocano", | |
"Tatar", | |
"Fulfulde (Nigerian)", | |
"Arabic (Sanaani Spoken)", | |
"Uyghur", | |
"Haitian Creole French", | |
"Azerbaijani, North", | |
"Napoletano-calabrese", | |
"Khmer (Central)", | |
"Farsi (Eastern)", | |
"Akan", | |
"Hiligaynon", | |
"Kurmanji", | |
"Shona" | |
] | |
# Example data | |
EXAMPLES = [ | |
# More people would learn from their mistakes if they weren't so busy denying them. | |
["Többen tanulnának a hibáikból, ha nem lennének annyira elfoglalva, hogy tagadják azokat.", "Hungarian", "English"], | |
["La mejor venganza es el éxito masivo.", "Spanish", "English"], | |
["Tout est bien qui finit bien.", "French", "English"], | |
["Lasciate ogne speranza, voi ch'intrate.", "Italian", "English"], | |
["Errare humanum est.", "Latin", "English"], | |
] | |
# Gradio interface | |
with gr.Blocks(title=title) as demo: | |
gr.HTML(f"<div style=\"text-align: center;\"><h1>RWKV-5 World v2 - {title}</h1></div>") | |
gr.Markdown("This is the RWKV-5 World v2 1B5 model tailored for translation tasks") | |
# Input and output components | |
text = gr.Textbox(lines=5, label="Source Text", placeholder="Enter the text you want to translate...", value=EXAMPLES[0][0]) | |
source_language = gr.Dropdown(choices=LANGUAGES, label="Source Language", value=EXAMPLES[0][1]) | |
target_language = gr.Dropdown(choices=LANGUAGES, label="Target Language", value=EXAMPLES[0][2]) | |
output = gr.Textbox(lines=5, label="Translated Text") | |
# Submission | |
submit = gr.Button("Translate", variant="primary") | |
# Example data | |
data = gr.Dataset(components=[text, source_language, target_language], samples=EXAMPLES, label="Example Translations", headers=["Source Text", "Target Language"]) | |
# Button action | |
submit.click(translate, [text, source_language, target_language], [output]) | |
data.click(lambda x: x, [data], [text, source_language, target_language]) | |
# Gradio launch | |
demo.queue(concurrency_count=1, max_size=10) | |
demo.launch(share=False, debug=True) |