|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
|
from langchain.vectorstores import Chroma |
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
from peft import AutoPeftModelForCausalLM |
|
|
|
import gradio as gr |
|
import hanzidentifier |
|
import re |
|
|
|
import chinese_converter |
|
|
|
import pathlib |
|
current_path=str(pathlib.Path(__file__).parent.resolve()) |
|
|
|
|
|
|
|
llm_model_name="Qwen/Qwen1.5-0.5B-Chat" |
|
|
|
|
|
|
|
|
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
"ytyeung/Qwen1.5-0.5B-Chat-SFT-riddles", |
|
) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(llm_model_name) |
|
|
|
|
|
|
|
|
|
vec_model_name = "shibing624/text2vec-base-chinese" |
|
|
|
encode_kwargs = {'normalize_embeddings': False} |
|
model_kwargs = {'device': 'cpu'} |
|
|
|
huggingface_embeddings= HuggingFaceEmbeddings( |
|
model_name=vec_model_name, |
|
model_kwargs=model_kwargs, |
|
encode_kwargs = encode_kwargs |
|
) |
|
|
|
|
|
|
|
persist_directory = 'chroma/' |
|
vectordb = Chroma(embedding_function=huggingface_embeddings,persist_directory=persist_directory) |
|
print(vectordb._collection.count()) |
|
|
|
|
|
text_input_label=["谜面","謎面","Riddle"] |
|
text_output_label=["谜底","謎底","Answer"] |
|
|
|
clear_label = ["清除","清除","Clear"] |
|
submit_label = ["提交","提交","Submit"] |
|
|
|
threshold = 0.6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def answer(input_text,context=None): |
|
if context: |
|
tips = "提示:\n" |
|
tips += "\n".join([x[0] for x in context]) |
|
print (f"====\n{input_text}\n{context[0][0]} {context[0][1]}") |
|
if context[0][1] >=0.9: |
|
ans = re.search(r"谜底:(\w+)", context[0][0]) |
|
if ans: |
|
return f"谜底是:{ans.group(1)}" |
|
else: |
|
tips="" |
|
|
|
prompt = f"{input_text}\n{tips}\n谜底是什么?" |
|
prompt = prompt.strip() |
|
|
|
print(f"===\n{prompt}") |
|
|
|
messages = [ |
|
{"role": "system", "content": "You are a helpful assistant."}, |
|
{"role": "user", "content": prompt} |
|
] |
|
text = tokenizer.apply_chat_template( |
|
messages, |
|
tokenize=False, |
|
add_generation_prompt=True |
|
) |
|
model_inputs = tokenizer([text], return_tensors="pt").to(device="cpu") |
|
|
|
generated_ids = model.generate( |
|
model_inputs.input_ids, |
|
max_new_tokens=128, |
|
do_sample=False, |
|
top_p=0.0 |
|
) |
|
generated_ids = [ |
|
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) |
|
] |
|
|
|
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
|
|
|
|
return response |
|
|
|
|
|
|
|
def helper_rag(text): |
|
docs_out = vectordb.similarity_search_with_relevance_scores(text,k=1) |
|
|
|
context = [] |
|
for doc in docs_out: |
|
if doc[1] > threshold: |
|
context.append((f"{doc[0].page_content}{doc[0].metadata['answer']}", doc[1])) |
|
|
|
return context |
|
|
|
|
|
def helper_text(text_input,radio=None): |
|
chinese_type = "simplified" |
|
|
|
if hanzidentifier.is_traditional(text_input): |
|
text_input = chinese_converter.to_simplified(text_input) |
|
chinese_type = "traditional" |
|
|
|
text_input = re.sub(r'hint',"猜",text_input,flags=re.I) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
text=f"""猜谜语:\n谜面:{text_input} |
|
""" |
|
|
|
context = helper_rag(text) |
|
|
|
output = answer(text,context=context) |
|
|
|
print(output) |
|
|
|
if chinese_type == "traditional": |
|
output = chinese_converter.to_traditional(output) |
|
|
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
|
def translate(input_text): |
|
'''Use LLM for translation''' |
|
|
|
prompt = f"""翻译以下內容成英语: |
|
|
|
{input_text} |
|
""" |
|
print(prompt) |
|
|
|
messages = [ |
|
{"role": "system", "content": "You are a helpful assistant."}, |
|
{"role": "user", "content": prompt} |
|
] |
|
text = tokenizer.apply_chat_template( |
|
messages, |
|
tokenize=False, |
|
add_generation_prompt=True |
|
) |
|
model_inputs = tokenizer([text], return_tensors="pt").to(device="cpu") |
|
|
|
generated_ids = model.generate( |
|
model_inputs.input_ids, |
|
max_new_tokens=128, |
|
do_sample=False, |
|
top_p=0.0 |
|
) |
|
generated_ids = [ |
|
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) |
|
] |
|
|
|
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
|
|
|
|
return response |
|
|
|
|
|
|
|
|
|
|
|
def change_language(radio,text_input,text_output,markdown, |
|
markdown_msg1, markdown_msg2,translate_btn): |
|
if radio == "简体中文": |
|
index = 0 |
|
text_input=gr.Textbox(value = chinese_converter.to_simplified(text_input), label = text_input_label[index]) |
|
text_output=gr.Textbox(value = chinese_converter.to_simplified(text_output),label = text_output_label[index]) |
|
markdown=chinese_converter.to_simplified(markdown) |
|
markdown_msg1=chinese_converter.to_simplified(markdown_msg1) |
|
markdown_msg2=chinese_converter.to_simplified(markdown_msg2) |
|
translate_btn=gr.Button(visible=False) |
|
elif radio == "繁體中文": |
|
index = 1 |
|
text_input=gr.Textbox(value = chinese_converter.to_traditional(text_input),label = text_input_label[index]) |
|
text_output=gr.Textbox(value = chinese_converter.to_traditional(text_output),label = text_output_label[index]) |
|
markdown=chinese_converter.to_traditional(markdown) |
|
markdown_msg1=chinese_converter.to_traditional(markdown_msg1) |
|
markdown_msg2=chinese_converter.to_traditional(markdown_msg2) |
|
translate_btn=gr.Button(visible=False) |
|
elif radio == "English": |
|
index = 2 |
|
text_input=gr.Textbox(label = text_input_label[index]) |
|
text_output=gr.Textbox(label = text_output_label[index]) |
|
translate_btn=gr.Button(visible=True) |
|
|
|
else: |
|
index = 0 |
|
text_input=gr.Textbox(label = text_input_label[index]) |
|
text_output=gr.Textbox(label = text_output_label[index]) |
|
markdown=chinese_converter.to_simplified(markdown) |
|
markdown_msg1=chinese_converter.to_simplified(markdown_msg1) |
|
markdown_msg2=chinese_converter.to_simplified(markdown_msg2) |
|
translate_btn=gr.Button(visible=False) |
|
|
|
clear_btn = clear_label[index] |
|
submit_btn = submit_label[index] |
|
|
|
return [text_input,text_output,clear_btn,submit_btn,markdown, |
|
markdown_msg1 ,markdown_msg2,translate_btn] |
|
|
|
|
|
def clear_text(): |
|
text_input_update="" |
|
text_output_update="" |
|
|
|
return [text_input_update,text_output_update] |
|
|
|
def translate_text(text_input,text_output): |
|
|
|
text_input = translate(f"{text_input}") |
|
text_output = translate(f"{text_output}") |
|
return text_input,text_output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
index = 0 |
|
example_list = [ |
|
["小家伙穿黄袍,花丛中把房造。飞到西来飞到东,人人夸他爱劳动。(猜一动物)"], |
|
["一物生来身穿三百多件衣,每天脱一件,年底剩张皮。(猜一物品)"], |
|
["A thousand threads, a million strands. Reaching the water, vanishing all at once. (Hint: natural phenomenon)"], |
|
["无底洞"], |
|
] |
|
radio = gr.Radio( |
|
["简体中文","繁體中文", "English"],show_label=False,value="简体中文" |
|
) |
|
markdown = gr.Markdown( |
|
""" |
|
# Chinese Lantern Riddles Solver with LLM |
|
## 用大语言模型来猜灯谜 |
|
""",elem_id="markdown") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
text_input = gr.Textbox(label=text_input_label[index], |
|
value="小家伙穿黄袍,花丛中把房造。飞到西来飞到东,人人夸他爱劳动。(猜一动物)", lines = 2) |
|
with gr.Row(): |
|
clear_btn = gr.ClearButton(value=clear_label[index],components=[text_input]) |
|
submit_btn = gr.Button(value=submit_label[index], variant = "primary") |
|
|
|
text_output = gr.Textbox(label=text_output_label[index]) |
|
|
|
translate_btn = gr.Button(value="Translate", variant = "primary", scale=0, visible=False) |
|
|
|
|
|
examples = gr.Examples( |
|
examples=example_list, |
|
inputs=text_input, |
|
outputs=text_output, |
|
fn=helper_text, |
|
cache_examples=True, |
|
) |
|
markdown_msg1 = gr.Markdown( |
|
""" |
|
灯谜是中华文化特色文娱活动,自北宋盛行。每年逢正月十五元宵节,将谜语贴在花灯上,让大家可一起猜谜。 |
|
|
|
Lantern riddle is a traditional Chinese cultural activity. Being popular since the Song Dynasty (960-1276), it \ |
|
is held in the Lantern Festival (15th day of the first lunar month). \ |
|
When people are viewing the flower lanterns, they can guess the riddles on the lanterns together. |
|
|
|
|
|
""" |
|
) |
|
|
|
with gr.Column(): |
|
markdown_msg2 = gr.Markdown( |
|
""" |
|
![lantern](file/data/DSC_0105.jpg) |
|
|
|
--- |
|
# 声明 Disclaimer |
|
|
|
本应用输出的文本为机器基于模型生成的结果,不代表任何人观点,请谨慎辨别和参考。请在法律允许的范围内使用。 |
|
|
|
本应用调用了 [Qwen1.5-0.5B-Chat](https://huggingface.co/Qwen/Qwen1.5-0.5B-Chat) 对话语言大模型,\ |
|
使用本应用前请务必阅读和同意遵守其[使用授权许可证](https://huggingface.co/Qwen/Qwen1.5-0.5B-Chat/blob/main/LICENSE)。 |
|
|
|
本应用仅供非商业用途。 |
|
|
|
The outputs of this application are machine-generated with a statistical model. \ |
|
The outputs do not reflect any opinions of any human subjects. You must identify the outputs in caution. \ |
|
It is your responsbility to decide whether to accept the outputs. You must use the applicaiton in obedience to the Law. |
|
|
|
This application utilizes [Qwen1.5-0.5B-Chat](https://huggingface.co/Qwen/Qwen1.5-0.5B-Chat) \ |
|
Conversational Large Language Model. Before using this application, you must read and accept to follow \ |
|
the [LICENSE](https://huggingface.co/Qwen/Qwen1.5-0.5B-Chat/blob/main/LICENSE). |
|
|
|
This application is for non-commercial use only. |
|
|
|
--- |
|
|
|
# 感谢 Acknowledgement |
|
|
|
本应用调用了 [text2vec-base-chinese](https://huggingface.co/shibing624/text2vec-base-chinese) 生成 text vector embeddings. |
|
该模型是以 [apache-2.0](https://www.apache.org/licenses/LICENSE-2.0) 发行。 |
|
|
|
This application utilizes [text2vec-base-chinese](https://huggingface.co/shibing624/text2vec-base-chinese) to generate text vector embeddings. |
|
The model is released under [apache-2.0](https://www.apache.org/licenses/LICENSE-2.0)。 |
|
""") |
|
|
|
|
|
|
|
|
|
|
|
|
|
submit_btn.click(fn=helper_text, inputs=[text_input,radio], outputs=text_output) |
|
|
|
translate_btn.click(fn=translate_text, inputs=[text_input,text_output], outputs=[text_input,text_output]) |
|
|
|
clear_btn.click(fn=clear_text,outputs=[text_input,text_output]) |
|
radio.change(fn=change_language,inputs=[radio,text_input,text_output, |
|
markdown, markdown_msg1,markdown_msg2,translate_btn], |
|
outputs=[text_input,text_output,clear_btn,submit_btn, |
|
markdown, markdown_msg1,markdown_msg2,translate_btn]) |
|
|
|
|
|
|
|
|
|
|
|
demo.launch(show_api=False,allowed_paths=[current_path+"/data/"]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|