Spaces:
Runtime error
Runtime error
# %% | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
from langchain.vectorstores import Chroma | |
from langchain.embeddings import HuggingFaceEmbeddings | |
import gradio as gr | |
import hanzidentifier | |
import re | |
import chinese_converter | |
# %% | |
#Load the LLM model and pipeline directly | |
model="ClueAI/ChatYuan-large-v2" | |
pipe = pipeline("text2text-generation", model=model) | |
# %% | |
# %% | |
# loading the vector encoder | |
model_name = "shibing624/text2vec-base-chinese" | |
encode_kwargs = {'normalize_embeddings': False} | |
model_kwargs = {'device': 'cpu'} | |
huggingface_embeddings= HuggingFaceEmbeddings( | |
model_name=model_name, | |
model_kwargs=model_kwargs, | |
encode_kwargs = encode_kwargs | |
) | |
# %% | |
persist_directory = 'chroma/' | |
vectordb = Chroma(embedding_function=huggingface_embeddings,persist_directory=persist_directory) | |
print(f"Vector count: {vectordb._collection.count()}") | |
# %% | |
text_input_label=["谜面","謎面","Riddle"] | |
text_output_label=["谜底","謎底","Answer"] | |
clear_label = ["清除","清除","Clear"] | |
submit_label = ["提交","提交","Submit"] | |
# helper functions for prompt processing for this LLM | |
def preprocess(text): | |
text = text.replace("\n", "\\n").replace("\t", "\\t") | |
return text | |
def postprocess(text): | |
return text.replace("\\n", "\n").replace("\\t", "\t").replace('%20',' ') | |
# get answer from LLM with prompt input | |
def answer(text,context=""): | |
text = f"{context}\n{text}\n谜底:" | |
text = text.strip() | |
print(text) | |
text = preprocess(text) | |
out_text = pipe(text) | |
return postprocess(out_text[0]["generated_text"]) | |
# helper function for RAG | |
def helper_rag(text): | |
docs_out = vectordb.similarity_search_with_relevance_scores(text,k=5) | |
#docs_out = vectordb.max_marginal_relevance_search(text,k=5,fetch_k = 20, lambda_mult = 0.5) | |
context = "" | |
for doc in docs_out: | |
if doc[1] > 0.7: | |
context += doc[0].page_content + "\n" | |
return context | |
# helper function for prompt | |
def helper_text(text_input,lang=None): | |
chinese_type = "simplified" | |
if lang == "繁體中文" or lang == "traditional": | |
chinese_type = "traditional" | |
if hanzidentifier.is_traditional(text_input): | |
text_input = chinese_converter.to_simplified(text_input) | |
chinese_type = "traditional" | |
text_input = re.sub(r'hint',"猜",text_input,flags=re.I) | |
if not any(c in text_input for c in ["猜", "打"]): | |
warning = "请给一个提示,提示格式,例子:猜一水果,打一字。" | |
if chinese_type == "traditional": | |
warning = chinese_converter.to_traditional(warning) | |
return warning | |
text=f"""谜面:{text_input} | |
""" | |
context = helper_rag(text) | |
output = answer(text,context=context) | |
if chinese_type == "traditional": | |
output = chinese_converter.to_traditional(output) | |
output = re.split(r'\s+',output) | |
return output[0] | |
# Gradio function for configure the language of UI | |
def change_language(radio,text_input,text_output,markdown, | |
markdown_msg1, markdown_msg2,language): | |
if radio == "简体中文": | |
index = 0 | |
text_input_update=gr.Textbox.update(value = chinese_converter.to_simplified(text_input), label = text_input_label[index]) | |
text_output_update=gr.Textbox.update(value = chinese_converter.to_simplified(text_output),label = text_output_label[index]) | |
markdown_update=gr.Markdown.update(value = chinese_converter.to_simplified(markdown)) | |
markdown_msg1_update=gr.Markdown.update(value = chinese_converter.to_simplified(markdown_msg1)) | |
markdown_msg2_update=gr.Markdown.update(value = chinese_converter.to_simplified(markdown_msg2)) | |
elif radio == "繁體中文": | |
index = 1 | |
text_input_update=gr.Textbox.update(value = chinese_converter.to_traditional(text_input),label = text_input_label[index]) | |
text_output_update=gr.Textbox.update(value = chinese_converter.to_traditional(text_output),label = text_output_label[index]) | |
markdown_update=gr.Markdown.update(value = chinese_converter.to_traditional(markdown)) | |
markdown_msg1_update=gr.Markdown.update(value = chinese_converter.to_traditional(markdown_msg1)) | |
markdown_msg2_update=gr.Markdown.update(value = chinese_converter.to_traditional(markdown_msg2)) | |
elif radio == "English": | |
index = 2 | |
text_input_update=gr.Textbox.update(label = text_input_label[index]) | |
text_output_update=gr.Textbox.update(label = text_output_label[index]) | |
markdown_update=gr.Markdown.update(value = markdown) | |
markdown_msg1_update=gr.Markdown.update(value = markdown_msg1) | |
markdown_msg2_update=gr.Markdown.update(value = markdown_msg2) | |
else: | |
index = 0 | |
text_input_update=gr.Textbox.update(label = text_input_label[index]) | |
text_output_update=gr.Textbox.update(label = text_output_label[index]) | |
markdown_update=gr.Markdown.update(value = chinese_converter.to_simplified(markdown)) | |
markdown_msg1_update=gr.Markdown.update(value = chinese_converter.to_simplified(markdown_msg1)) | |
markdown_msg2_update=gr.Markdown.update(value = chinese_converter.to_simplified(markdown_msg2)) | |
clear_btn_update = gr.ClearButton.update(value = clear_label[index]) | |
submit_btn_update = gr.Button.update(value = submit_label[index]) | |
language = radio | |
return [text_input_update,text_output_update,clear_btn_update,submit_btn_update,markdown_update, | |
markdown_msg1_update, markdown_msg2_update,language] | |
def clear_text(): | |
text_input_update=gr.Textbox.update(value=None) | |
text_output_update=gr.Textbox.update(value=None) | |
return [text_input_update,text_output_update] | |
# %% | |
# css = """ | |
# #markdown { background-image: url("file/data/DSC_0105.jpg"); | |
# background-size: cover; | |
# } | |
# """ | |
with gr.Blocks() as demo: | |
index = 0 | |
language = gr.State() | |
example_list = [ | |
["小家伙穿黄袍,花丛中把房造。飞到西来飞到东,人人夸他爱劳动。(猜一动物)"], | |
["一物生来身穿三百多件衣,每天脱一件,年底剩张皮。(猜一物品)"], | |
["A thousand threads, a million strands. Reaching the water, vanishing all at once. (Hint: natural phenomenon)"], | |
["无底洞(猜成语)"], | |
] | |
radio = gr.Radio( | |
["简体中文","繁體中文", "English"],show_label=False,value="简体中文" | |
) | |
markdown = gr.Markdown( | |
""" | |
# Chinese Lantern Riddles Solver with LLM | |
## 用语言大模型来猜灯谜 | |
""",elem_id="markdown") | |
with gr.Row(): | |
with gr.Column(): | |
text_input = gr.Textbox(label=text_input_label[index], | |
value="小家伙穿黄袍,花丛中把房造。飞到西来飞到东,人人夸他爱劳动。(猜一动物)", lines = 2) | |
with gr.Row(): | |
clear_btn = gr.ClearButton(value=clear_label[index],components=[text_input]) | |
submit_btn = gr.Button(value=submit_label[index], variant = "primary") | |
text_output = gr.Textbox(label=text_output_label[index]) | |
examples = gr.Examples( | |
examples=example_list, | |
inputs=text_input, | |
outputs=text_output, | |
fn=helper_text, | |
cache_examples=True, | |
) | |
markdown_msg1 = gr.Markdown( | |
""" | |
灯谜是中华文化特色文娱活动,自北宋盛行。每年逢正月十五元宵节,将谜语贴在花灯上,让大家可一起猜谜。 | |
Lantern riddle is a traditional Chinese cultural activity. Being popular since the Song Dynasty (960-1276), it \ | |
is held in the Lantern Festival (15th day of the first lunar month). \ | |
When people are viewing the flower lanterns, they can guess the riddles on the lanterns together. | |
""" | |
) | |
with gr.Column(): | |
markdown_msg2 = gr.Markdown( | |
""" | |
![lantern](file/data/DSC_0105.jpg) | |
--- | |
# 声明 Disclaimer | |
本应用输出的文本为机器基于模型生成的结果,不代表任何人观点,请谨慎辨别和参考。请在法律允许的范围内使用。 | |
本应用调用了 [ChatYuan-large-v2](https://huggingface.co/ClueAI/ChatYuan-large-v2) 对话语言大模型,\ | |
使用本应用前请务必阅读和同意遵守其[使用授权许可证](https://huggingface.co/ClueAI/ChatYuan-large-v2/blob/main/LICENSE)。 | |
本应用仅供非商业用途。 | |
The outputs of this application are machine-generated with a statistical model. \ | |
The outputs do not reflect any opinions of any human subjects. You must identify the outputs in caution. \ | |
It is your responsbility to decide whether to accept the outputs. You must use the applicaiton in obedience to the Law. | |
This application utilizes [ChatYuan-large-v2](https://huggingface.co/ClueAI/ChatYuan-large-v2) \ | |
Conversational Large Language Model. Before using this application, you must read and accept to follow \ | |
the [LICENSE](https://huggingface.co/ClueAI/ChatYuan-large-v2/blob/main/LICENSE). | |
This application is for non-commercial use only. | |
--- | |
# 感谢 Acknowledgement | |
本应用调用了 [text2vec-base-chinese](https://huggingface.co/shibing624/text2vec-base-chinese) 生成 text vector embeddings. | |
该模型是以 [apache-2.0](https://www.apache.org/licenses/LICENSE-2.0) 发行。 | |
This application utilizes [text2vec-base-chinese](https://huggingface.co/shibing624/text2vec-base-chinese) to generate text vector embeddings. | |
The model is released under [apache-2.0](https://www.apache.org/licenses/LICENSE-2.0)。 | |
""") | |
submit_btn.click(fn=helper_text, inputs=[text_input,radio], outputs=text_output, api_name="answer-the-riddle") | |
clear_btn.click(fn=clear_text,outputs=[text_input,text_output]) | |
radio.change(fn=change_language,inputs=[radio,text_input,text_output, | |
markdown, markdown_msg1,markdown_msg2], | |
outputs=[text_input,text_output,clear_btn,submit_btn, | |
markdown, markdown_msg1,markdown_msg2,language]) | |
#demo = gr.Interface(fn=helper_text, inputs=text_input, outputs=text_output, | |
# flagging_options=["Inappropriate"],allow_flagging="never", | |
# title="aaa",description="aaa",article="aaa") | |
demo.queue(api_open=False) | |
demo.launch(show_api=False) | |
# %% | |