ytyeung's picture
Revert translation
71fab9c verified
raw
history blame
13.2 kB
# %%
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
from peft import AutoPeftModelForCausalLM
import gradio as gr
import hanzidentifier
import re
import chinese_converter
import pathlib
current_path=str(pathlib.Path(__file__).parent.resolve())
# %%
#Load the LLM model and pipeline directly
llm_model_name="Qwen/Qwen1.5-0.5B-Chat"
#pipe = pipeline("text2text-generation", model=model)
#model = AutoModelForCausalLM.from_pretrained(
# llm_model_name
#)
model = AutoModelForCausalLM.from_pretrained(
"ytyeung/Qwen1.5-0.5B-Chat-SFT-riddles",
)
tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
# %%
# %%
# loading the vector encoder
vec_model_name = "shibing624/text2vec-base-chinese"
encode_kwargs = {'normalize_embeddings': False}
model_kwargs = {'device': 'cpu'}
huggingface_embeddings= HuggingFaceEmbeddings(
model_name=vec_model_name,
model_kwargs=model_kwargs,
encode_kwargs = encode_kwargs
)
# %%
persist_directory = 'chroma/'
vectordb = Chroma(embedding_function=huggingface_embeddings,persist_directory=persist_directory)
print(vectordb._collection.count())
# %%
text_input_label=["谜面","謎面","Riddle"]
text_output_label=["谜底","謎底","Answer"]
clear_label = ["清除","清除","Clear"]
submit_label = ["提交","提交","Submit"]
threshold = 0.6
# %%
# helper functions for prompt processing for this LLM
# def preprocess(text):
# text = text.replace("\n", "\\n").replace("\t", "\\t")
# return text
# def postprocess(text):
# return text.replace("\\n", "\n").replace("\\t", "\t").replace('%20',' ')
# get answer from LLM with prompt input
def answer(input_text,context=None):
if context:
tips = "提示:\n"
tips += "\n".join([x[0] for x in context])
print (f"====\n{input_text}\n{context[0][0]} {context[0][1]}")
if context[0][1] >=0.9:
ans = re.search(r"谜底:(\w+)", context[0][0])
if ans:
return f"谜底是:{ans.group(1)}"
else:
tips=""
prompt = f"{input_text}\n{tips}\n谜底是什么?"
prompt = prompt.strip()
print(f"===\n{prompt}")
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(device="cpu")
generated_ids = model.generate(
model_inputs.input_ids,
max_new_tokens=128,
do_sample=False,
top_p=0.0
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
#return out_text[0]["generated_text"]
return response
#return postprocess(out_text[0]["generated_text"])
# helper function for RAG
def helper_rag(text):
docs_out = vectordb.similarity_search_with_relevance_scores(text,k=1)
#docs_out = vectordb.max_marginal_relevance_search(text,k=5,fetch_k = 20, lambda_mult = 0.5)
context = []
for doc in docs_out:
if doc[1] > threshold:
context.append((f"{doc[0].page_content}{doc[0].metadata['answer']}", doc[1]))
return context
# helper function for prompt
def helper_text(text_input,radio=None):
chinese_type = "simplified"
if hanzidentifier.is_traditional(text_input):
text_input = chinese_converter.to_simplified(text_input)
chinese_type = "traditional"
text_input = re.sub(r'hint',"猜",text_input,flags=re.I)
#if not any(c in text_input for c in ["猜", "打"]):
# warning = "请给一个提示,提示格式,例子:猜一水果,打一字。"
# if chinese_type == "traditional" or radio == "繁體中文":
# warning = chinese_converter.to_traditional(warning)
# return warning
text=f"""猜谜语:\n谜面:{text_input}
"""
context = helper_rag(text)
output = answer(text,context=context)
print(output)
if chinese_type == "traditional":
output = chinese_converter.to_traditional(output)
#output = re.split(r'\s+',output)
return output
#return output[0]
# get answer from LLM with prompt input
def translate(input_text):
'''Use LLM for translation'''
prompt = f"""翻译以下內容成英语:
{input_text}
"""
print(prompt)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(device="cpu")
generated_ids = model.generate(
model_inputs.input_ids,
max_new_tokens=128,
do_sample=False,
top_p=0.0
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
#return out_text[0]["generated_text"]
return response
#return postprocess(out_text[0]["generated_text"])
# Gradio function for configure the language of UI
def change_language(radio,text_input,text_output,markdown,
markdown_msg1, markdown_msg2,translate_btn):
if radio == "简体中文":
index = 0
text_input=gr.Textbox(value = chinese_converter.to_simplified(text_input), label = text_input_label[index])
text_output=gr.Textbox(value = chinese_converter.to_simplified(text_output),label = text_output_label[index])
markdown=chinese_converter.to_simplified(markdown)
markdown_msg1=chinese_converter.to_simplified(markdown_msg1)
markdown_msg2=chinese_converter.to_simplified(markdown_msg2)
translate_btn=gr.Button(visible=False)
elif radio == "繁體中文":
index = 1
text_input=gr.Textbox(value = chinese_converter.to_traditional(text_input),label = text_input_label[index])
text_output=gr.Textbox(value = chinese_converter.to_traditional(text_output),label = text_output_label[index])
markdown=chinese_converter.to_traditional(markdown)
markdown_msg1=chinese_converter.to_traditional(markdown_msg1)
markdown_msg2=chinese_converter.to_traditional(markdown_msg2)
translate_btn=gr.Button(visible=False)
elif radio == "English":
index = 2
text_input=gr.Textbox(label = text_input_label[index])
text_output=gr.Textbox(label = text_output_label[index])
translate_btn=gr.Button(visible=True)
else:
index = 0
text_input=gr.Textbox(label = text_input_label[index])
text_output=gr.Textbox(label = text_output_label[index])
markdown=chinese_converter.to_simplified(markdown)
markdown_msg1=chinese_converter.to_simplified(markdown_msg1)
markdown_msg2=chinese_converter.to_simplified(markdown_msg2)
translate_btn=gr.Button(visible=False)
clear_btn = clear_label[index]
submit_btn = submit_label[index]
return [text_input,text_output,clear_btn,submit_btn,markdown,
markdown_msg1 ,markdown_msg2,translate_btn]
def clear_text():
text_input_update=""
text_output_update=""
return [text_input_update,text_output_update]
def translate_text(text_input,text_output):
text_input = translate(f"{text_input}")
text_output = translate(f"{text_output}")
return text_input,text_output
# %%
# css = """
# #markdown { background-image: url("file/data/DSC_0105.jpg");
# background-size: cover;
# }
# """
with gr.Blocks() as demo:
index = 0
example_list = [
["小家伙穿黄袍,花丛中把房造。飞到西来飞到东,人人夸他爱劳动。(猜一动物)"],
["一物生来身穿三百多件衣,每天脱一件,年底剩张皮。(猜一物品)"],
["A thousand threads, a million strands. Reaching the water, vanishing all at once. (Hint: natural phenomenon)"],
["无底洞"],
]
radio = gr.Radio(
["简体中文","繁體中文", "English"],show_label=False,value="简体中文"
)
markdown = gr.Markdown(
"""
# Chinese Lantern Riddles Solver with LLM
## 用大语言模型来猜灯谜
""",elem_id="markdown")
with gr.Row():
with gr.Column():
text_input = gr.Textbox(label=text_input_label[index],
value="小家伙穿黄袍,花丛中把房造。飞到西来飞到东,人人夸他爱劳动。(猜一动物)", lines = 2)
with gr.Row():
clear_btn = gr.ClearButton(value=clear_label[index],components=[text_input])
submit_btn = gr.Button(value=submit_label[index], variant = "primary")
text_output = gr.Textbox(label=text_output_label[index])
translate_btn = gr.Button(value="Translate", variant = "primary", scale=0, visible=False)
examples = gr.Examples(
examples=example_list,
inputs=text_input,
outputs=text_output,
fn=helper_text,
cache_examples=True,
)
markdown_msg1 = gr.Markdown(
"""
灯谜是中华文化特色文娱活动,自北宋盛行。每年逢正月十五元宵节,将谜语贴在花灯上,让大家可一起猜谜。
Lantern riddle is a traditional Chinese cultural activity. Being popular since the Song Dynasty (960-1276), it \
is held in the Lantern Festival (15th day of the first lunar month). \
When people are viewing the flower lanterns, they can guess the riddles on the lanterns together.
"""
)
with gr.Column():
markdown_msg2 = gr.Markdown(
"""
![lantern](file/data/DSC_0105.jpg)
---
# 声明 Disclaimer
本应用输出的文本为机器基于模型生成的结果,不代表任何人观点,请谨慎辨别和参考。请在法律允许的范围内使用。
本应用调用了 [Qwen1.5-0.5B-Chat](https://huggingface.co/Qwen/Qwen1.5-0.5B-Chat) 对话语言大模型,\
使用本应用前请务必阅读和同意遵守其[使用授权许可证](https://huggingface.co/Qwen/Qwen1.5-0.5B-Chat/blob/main/LICENSE)。
本应用仅供非商业用途。
The outputs of this application are machine-generated with a statistical model. \
The outputs do not reflect any opinions of any human subjects. You must identify the outputs in caution. \
It is your responsbility to decide whether to accept the outputs. You must use the applicaiton in obedience to the Law.
This application utilizes [Qwen1.5-0.5B-Chat](https://huggingface.co/Qwen/Qwen1.5-0.5B-Chat) \
Conversational Large Language Model. Before using this application, you must read and accept to follow \
the [LICENSE](https://huggingface.co/Qwen/Qwen1.5-0.5B-Chat/blob/main/LICENSE).
This application is for non-commercial use only.
---
# 感谢 Acknowledgement
本应用调用了 [text2vec-base-chinese](https://huggingface.co/shibing624/text2vec-base-chinese) 生成 text vector embeddings.
该模型是以 [apache-2.0](https://www.apache.org/licenses/LICENSE-2.0) 发行。
This application utilizes [text2vec-base-chinese](https://huggingface.co/shibing624/text2vec-base-chinese) to generate text vector embeddings.
The model is released under [apache-2.0](https://www.apache.org/licenses/LICENSE-2.0)。
""")
submit_btn.click(fn=helper_text, inputs=[text_input,radio], outputs=text_output)
translate_btn.click(fn=translate_text, inputs=[text_input,text_output], outputs=[text_input,text_output])
clear_btn.click(fn=clear_text,outputs=[text_input,text_output])
radio.change(fn=change_language,inputs=[radio,text_input,text_output,
markdown, markdown_msg1,markdown_msg2,translate_btn],
outputs=[text_input,text_output,clear_btn,submit_btn,
markdown, markdown_msg1,markdown_msg2,translate_btn])
#demo = gr.Interface(fn=helper_text, inputs=text_input, outputs=text_output,
# flagging_options=["Inappropriate"],allow_flagging="never",
# title="aaa",description="aaa",article="aaa")
#demo.queue(api_open=False)
demo.launch(show_api=False,allowed_paths=[current_path+"/data/"])
# %%