|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
|
from langchain.vectorstores import Chroma |
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
|
|
import gradio as gr |
|
import hanzidentifier |
|
import re |
|
|
|
import chinese_converter |
|
|
|
|
|
|
|
llm_model_name="Qwen/Qwen1.5-0.5B-Chat" |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
llm_model_name |
|
) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(llm_model_name) |
|
|
|
|
|
|
|
|
|
vec_model_name = "shibing624/text2vec-base-chinese" |
|
|
|
encode_kwargs = {'normalize_embeddings': False} |
|
model_kwargs = {'device': 'cpu'} |
|
|
|
huggingface_embeddings= HuggingFaceEmbeddings( |
|
model_name=vec_model_name, |
|
model_kwargs=model_kwargs, |
|
encode_kwargs = encode_kwargs |
|
) |
|
|
|
|
|
|
|
persist_directory = 'chroma/' |
|
vectordb = Chroma(embedding_function=huggingface_embeddings,persist_directory=persist_directory) |
|
print(vectordb._collection.count()) |
|
|
|
|
|
text_input_label=["谜面","謎面","Riddle"] |
|
text_output_label=["谜底","謎底","Answer"] |
|
|
|
clear_label = ["清除","清除","Clear"] |
|
submit_label = ["提交","提交","Submit"] |
|
|
|
|
|
|
|
|
|
def preprocess(text): |
|
text = text.replace("\n", "\\n").replace("\t", "\\t") |
|
return text |
|
|
|
def postprocess(text): |
|
return text.replace("\\n", "\n").replace("\\t", "\t").replace('%20',' ') |
|
|
|
|
|
|
|
def answer(input_text,context=""): |
|
prompt = f"{input_text}\n提示:\n{context}\n谜底是什么?请解释。" |
|
prompt = prompt.strip() |
|
|
|
print(prompt) |
|
|
|
|
|
messages = [ |
|
{"role": "system", "content": "You are a helpful assistant."}, |
|
{"role": "user", "content": prompt} |
|
] |
|
text = tokenizer.apply_chat_template( |
|
messages, |
|
tokenize=False, |
|
add_generation_prompt=True |
|
) |
|
model_inputs = tokenizer([text], return_tensors="pt").to(device="cpu") |
|
|
|
generated_ids = model.generate( |
|
model_inputs.input_ids, |
|
max_new_tokens=512, |
|
do_sample=False, |
|
top_p=0.0 |
|
) |
|
generated_ids = [ |
|
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) |
|
] |
|
|
|
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
|
|
|
|
return response |
|
|
|
|
|
|
|
def helper_rag(text): |
|
docs_out = vectordb.similarity_search_with_relevance_scores(text,k=1) |
|
|
|
context = "" |
|
for doc in docs_out: |
|
if doc[1] > 0.7: |
|
context += doc[0].page_content + "\n" |
|
|
|
return context |
|
|
|
|
|
def helper_text(text_input,radio=None): |
|
chinese_type = "simplified" |
|
|
|
if hanzidentifier.is_traditional(text_input): |
|
text_input = chinese_converter.to_simplified(text_input) |
|
chinese_type = "traditional" |
|
|
|
text_input = re.sub(r'hint',"猜",text_input,flags=re.I) |
|
|
|
if not any(c in text_input for c in ["猜", "打"]): |
|
warning = "请给一个提示,提示格式,例子:猜一水果,打一字。" |
|
if chinese_type == "traditional" or radio == "繁體中文": |
|
warning = chinese_converter.to_traditional(warning) |
|
return warning |
|
|
|
text=f"""猜谜语:\n谜面:{text_input} |
|
""" |
|
|
|
context = helper_rag(text) |
|
|
|
output = answer(text,context=context) |
|
|
|
print(output) |
|
|
|
if chinese_type == "traditional": |
|
output = chinese_converter.to_traditional(output) |
|
|
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
|
|
def change_language(radio,text_input,text_output,markdown, |
|
markdown_msg1, markdown_msg2): |
|
if radio == "简体中文": |
|
index = 0 |
|
text_input_update=gr.Textbox.update(value = chinese_converter.to_simplified(text_input), label = text_input_label[index]) |
|
text_output_update=gr.Textbox.update(value = chinese_converter.to_simplified(text_output),label = text_output_label[index]) |
|
markdown_update=gr.Markdown.update(value = chinese_converter.to_simplified(markdown)) |
|
markdown_msg1_update=gr.Markdown.update(value = chinese_converter.to_simplified(markdown_msg1)) |
|
markdown_msg2_update=gr.Markdown.update(value = chinese_converter.to_simplified(markdown_msg2)) |
|
elif radio == "繁體中文": |
|
index = 1 |
|
text_input_update=gr.Textbox.update(value = chinese_converter.to_traditional(text_input),label = text_input_label[index]) |
|
text_output_update=gr.Textbox.update(value = chinese_converter.to_traditional(text_output),label = text_output_label[index]) |
|
markdown_update=gr.Markdown.update(value = chinese_converter.to_traditional(markdown)) |
|
markdown_msg1_update=gr.Markdown.update(value = chinese_converter.to_traditional(markdown_msg1)) |
|
markdown_msg2_update=gr.Markdown.update(value = chinese_converter.to_traditional(markdown_msg2)) |
|
elif radio == "English": |
|
index = 2 |
|
text_input_update=gr.Textbox.update(label = text_input_label[index]) |
|
text_output_update=gr.Textbox.update(label = text_output_label[index]) |
|
markdown_update=gr.Markdown.update(value = markdown) |
|
markdown_msg1_update=gr.Markdown.update(value = markdown_msg1) |
|
markdown_msg2_update=gr.Markdown.update(value = markdown_msg2) |
|
|
|
else: |
|
index = 0 |
|
text_input_update=gr.Textbox.update(label = text_input_label[index]) |
|
text_output_update=gr.Textbox.update(label = text_output_label[index]) |
|
markdown_update=gr.Markdown.update(value = chinese_converter.to_simplified(markdown)) |
|
markdown_msg1_update=gr.Markdown.update(value = chinese_converter.to_simplified(markdown_msg1)) |
|
markdown_msg2_update=gr.Markdown.update(value = chinese_converter.to_simplified(markdown_msg2)) |
|
|
|
clear_btn_update = gr.ClearButton.update(value = clear_label[index]) |
|
submit_btn_update = gr.Button.update(value = submit_label[index]) |
|
|
|
return [text_input_update,text_output_update,clear_btn_update,submit_btn_update,markdown_update, |
|
markdown_msg1_update ,markdown_msg2_update] |
|
|
|
|
|
def clear_text(): |
|
text_input_update=gr.Textbox.update(value=None) |
|
text_output_update=gr.Textbox.update(value=None) |
|
|
|
return [text_input_update,text_output_update] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
index = 0 |
|
example_list = [ |
|
["小家伙穿黄袍,花丛中把房造。飞到西来飞到东,人人夸他爱劳动。(猜一动物)"], |
|
["一物生来身穿三百多件衣,每天脱一件,年底剩张皮。(猜一物品)"], |
|
["A thousand threads, a million strands. Reaching the water, vanishing all at once. (Hint: natural phenomenon)"], |
|
["无底洞(猜成语)"], |
|
] |
|
radio = gr.Radio( |
|
["简体中文","繁體中文", "English"],show_label=False,value="简体中文" |
|
) |
|
markdown = gr.Markdown( |
|
""" |
|
# Chinese Lantern Riddles Solver with LLM |
|
## 用大语言模型来猜灯谜 |
|
""",elem_id="markdown") |
|
with gr.Row(): |
|
with gr.Column(): |
|
text_input = gr.Textbox(label=text_input_label[index], |
|
value="小家伙穿黄袍,花丛中把房造。飞到西来飞到东,人人夸他爱劳动。(猜一动物)", lines = 2) |
|
with gr.Row(): |
|
clear_btn = gr.ClearButton(value=clear_label[index],components=[text_input]) |
|
submit_btn = gr.Button(value=submit_label[index], variant = "primary") |
|
|
|
text_output = gr.Textbox(label=text_output_label[index]) |
|
|
|
|
|
examples = gr.Examples( |
|
examples=example_list, |
|
inputs=text_input, |
|
outputs=text_output, |
|
fn=helper_text, |
|
cache_examples=True, |
|
) |
|
markdown_msg1 = gr.Markdown( |
|
""" |
|
灯谜是中华文化特色文娱活动,自北宋盛行。每年逢正月十五元宵节,将谜语贴在花灯上,让大家可一起猜谜。 |
|
|
|
Lantern riddle is a traditional Chinese cultural activity. Being popular since the Song Dynasty (960-1276), it \ |
|
is held in the Lantern Festival (15th day of the first lunar month). \ |
|
When people are viewing the flower lanterns, they can guess the riddles on the lanterns together. |
|
|
|
|
|
""" |
|
) |
|
|
|
with gr.Column(): |
|
markdown_msg2 = gr.Markdown( |
|
""" |
|
![lantern](file/data/DSC_0105.jpg) |
|
|
|
--- |
|
# 声明 Disclaimer |
|
|
|
本应用输出的文本为机器基于模型生成的结果,不代表任何人观点,请谨慎辨别和参考。请在法律允许的范围内使用。 |
|
|
|
本应用调用了 [ChatYuan-large-v2](https://huggingface.co/ClueAI/ChatYuan-large-v2) 对话语言大模型,\ |
|
使用本应用前请务必阅读和同意遵守其[使用授权许可证](https://huggingface.co/ClueAI/ChatYuan-large-v2/blob/main/LICENSE)。 |
|
|
|
本应用仅供非商业用途。 |
|
|
|
The outputs of this application are machine-generated with a statistical model. \ |
|
The outputs do not reflect any opinions of any human subjects. You must identify the outputs in caution. \ |
|
It is your responsbility to decide whether to accept the outputs. You must use the applicaiton in obedience to the Law. |
|
|
|
This application utilizes [ChatYuan-large-v2](https://huggingface.co/ClueAI/ChatYuan-large-v2) \ |
|
Conversational Large Language Model. Before using this application, you must read and accept to follow \ |
|
the [LICENSE](https://huggingface.co/ClueAI/ChatYuan-large-v2/blob/main/LICENSE). |
|
|
|
This application is for non-commercial use only. |
|
|
|
--- |
|
|
|
# 感谢 Acknowledgement |
|
|
|
本应用调用了 [text2vec-base-chinese](https://huggingface.co/shibing624/text2vec-base-chinese) 生成 text vector embeddings. |
|
该模型是以 [apache-2.0](https://www.apache.org/licenses/LICENSE-2.0) 发行。 |
|
|
|
This application utilizes [text2vec-base-chinese](https://huggingface.co/shibing624/text2vec-base-chinese) to generate text vector embeddings. |
|
The model is released under [apache-2.0](https://www.apache.org/licenses/LICENSE-2.0)。 |
|
""") |
|
|
|
|
|
|
|
|
|
|
|
|
|
submit_btn.click(fn=helper_text, inputs=[text_input,radio], outputs=text_output) |
|
|
|
clear_btn.click(fn=clear_text,outputs=[text_input,text_output]) |
|
radio.change(fn=change_language,inputs=[radio,text_input,text_output, |
|
markdown, markdown_msg1,markdown_msg2], |
|
outputs=[text_input,text_output,clear_btn,submit_btn, |
|
markdown, markdown_msg1,markdown_msg2]) |
|
|
|
|
|
|
|
|
|
demo.queue(api_open=False) |
|
demo.launch(show_api=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|