David Yeung
first commit Qwen 0.5B
936a3fd
raw
history blame
No virus
11.7 kB
# %%
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
import gradio as gr
import hanzidentifier
import re
import chinese_converter
# %%
#Load the LLM model and pipeline directly
llm_model_name="Qwen/Qwen1.5-0.5B-Chat"
#pipe = pipeline("text2text-generation", model=model)
model = AutoModelForCausalLM.from_pretrained(
llm_model_name
)
tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
# %%
# %%
# loading the vector encoder
vec_model_name = "shibing624/text2vec-base-chinese"
encode_kwargs = {'normalize_embeddings': False}
model_kwargs = {'device': 'cpu'}
huggingface_embeddings= HuggingFaceEmbeddings(
model_name=vec_model_name,
model_kwargs=model_kwargs,
encode_kwargs = encode_kwargs
)
# %%
persist_directory = 'chroma/'
vectordb = Chroma(embedding_function=huggingface_embeddings,persist_directory=persist_directory)
print(vectordb._collection.count())
# %%
text_input_label=["谜面","謎面","Riddle"]
text_output_label=["谜底","謎底","Answer"]
clear_label = ["清除","清除","Clear"]
submit_label = ["提交","提交","Submit"]
# %%
# helper functions for prompt processing for this LLM
def preprocess(text):
text = text.replace("\n", "\\n").replace("\t", "\\t")
return text
def postprocess(text):
return text.replace("\\n", "\n").replace("\\t", "\t").replace('%20',' ')
# get answer from LLM with prompt input
def answer(input_text,context=""):
prompt = f"{input_text}\n提示:\n{context}\n谜底是什么?请解释。"
prompt = prompt.strip()
print(prompt)
#text = preprocess(text)
#out_text = pipe(text)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(device="cpu")
generated_ids = model.generate(
model_inputs.input_ids,
max_new_tokens=512,
do_sample=False,
top_p=0.0
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
#return out_text[0]["generated_text"]
return response
#return postprocess(out_text[0]["generated_text"])
# helper function for RAG
def helper_rag(text):
docs_out = vectordb.similarity_search_with_relevance_scores(text,k=1)
#docs_out = vectordb.max_marginal_relevance_search(text,k=5,fetch_k = 20, lambda_mult = 0.5)
context = ""
for doc in docs_out:
if doc[1] > 0.7:
context += doc[0].page_content + "\n"
return context
# helper function for prompt
def helper_text(text_input,radio=None):
chinese_type = "simplified"
if hanzidentifier.is_traditional(text_input):
text_input = chinese_converter.to_simplified(text_input)
chinese_type = "traditional"
text_input = re.sub(r'hint',"猜",text_input,flags=re.I)
if not any(c in text_input for c in ["猜", "打"]):
warning = "请给一个提示,提示格式,例子:猜一水果,打一字。"
if chinese_type == "traditional" or radio == "繁體中文":
warning = chinese_converter.to_traditional(warning)
return warning
text=f"""猜谜语:\n谜面:{text_input}
"""
context = helper_rag(text)
output = answer(text,context=context)
print(output)
if chinese_type == "traditional":
output = chinese_converter.to_traditional(output)
#output = re.split(r'\s+',output)
return output
#return output[0]
# Gradio function for configure the language of UI
def change_language(radio,text_input,text_output,markdown,
markdown_msg1, markdown_msg2):
if radio == "简体中文":
index = 0
text_input_update=gr.Textbox.update(value = chinese_converter.to_simplified(text_input), label = text_input_label[index])
text_output_update=gr.Textbox.update(value = chinese_converter.to_simplified(text_output),label = text_output_label[index])
markdown_update=gr.Markdown.update(value = chinese_converter.to_simplified(markdown))
markdown_msg1_update=gr.Markdown.update(value = chinese_converter.to_simplified(markdown_msg1))
markdown_msg2_update=gr.Markdown.update(value = chinese_converter.to_simplified(markdown_msg2))
elif radio == "繁體中文":
index = 1
text_input_update=gr.Textbox.update(value = chinese_converter.to_traditional(text_input),label = text_input_label[index])
text_output_update=gr.Textbox.update(value = chinese_converter.to_traditional(text_output),label = text_output_label[index])
markdown_update=gr.Markdown.update(value = chinese_converter.to_traditional(markdown))
markdown_msg1_update=gr.Markdown.update(value = chinese_converter.to_traditional(markdown_msg1))
markdown_msg2_update=gr.Markdown.update(value = chinese_converter.to_traditional(markdown_msg2))
elif radio == "English":
index = 2
text_input_update=gr.Textbox.update(label = text_input_label[index])
text_output_update=gr.Textbox.update(label = text_output_label[index])
markdown_update=gr.Markdown.update(value = markdown)
markdown_msg1_update=gr.Markdown.update(value = markdown_msg1)
markdown_msg2_update=gr.Markdown.update(value = markdown_msg2)
else:
index = 0
text_input_update=gr.Textbox.update(label = text_input_label[index])
text_output_update=gr.Textbox.update(label = text_output_label[index])
markdown_update=gr.Markdown.update(value = chinese_converter.to_simplified(markdown))
markdown_msg1_update=gr.Markdown.update(value = chinese_converter.to_simplified(markdown_msg1))
markdown_msg2_update=gr.Markdown.update(value = chinese_converter.to_simplified(markdown_msg2))
clear_btn_update = gr.ClearButton.update(value = clear_label[index])
submit_btn_update = gr.Button.update(value = submit_label[index])
return [text_input_update,text_output_update,clear_btn_update,submit_btn_update,markdown_update,
markdown_msg1_update ,markdown_msg2_update]
def clear_text():
text_input_update=gr.Textbox.update(value=None)
text_output_update=gr.Textbox.update(value=None)
return [text_input_update,text_output_update]
# %%
# css = """
# #markdown { background-image: url("file/data/DSC_0105.jpg");
# background-size: cover;
# }
# """
with gr.Blocks() as demo:
index = 0
example_list = [
["小家伙穿黄袍,花丛中把房造。飞到西来飞到东,人人夸他爱劳动。(猜一动物)"],
["一物生来身穿三百多件衣,每天脱一件,年底剩张皮。(猜一物品)"],
["A thousand threads, a million strands. Reaching the water, vanishing all at once. (Hint: natural phenomenon)"],
["无底洞(猜成语)"],
]
radio = gr.Radio(
["简体中文","繁體中文", "English"],show_label=False,value="简体中文"
)
markdown = gr.Markdown(
"""
# Chinese Lantern Riddles Solver with LLM
## 用大语言模型来猜灯谜
""",elem_id="markdown")
with gr.Row():
with gr.Column():
text_input = gr.Textbox(label=text_input_label[index],
value="小家伙穿黄袍,花丛中把房造。飞到西来飞到东,人人夸他爱劳动。(猜一动物)", lines = 2)
with gr.Row():
clear_btn = gr.ClearButton(value=clear_label[index],components=[text_input])
submit_btn = gr.Button(value=submit_label[index], variant = "primary")
text_output = gr.Textbox(label=text_output_label[index])
examples = gr.Examples(
examples=example_list,
inputs=text_input,
outputs=text_output,
fn=helper_text,
cache_examples=True,
)
markdown_msg1 = gr.Markdown(
"""
灯谜是中华文化特色文娱活动,自北宋盛行。每年逢正月十五元宵节,将谜语贴在花灯上,让大家可一起猜谜。
Lantern riddle is a traditional Chinese cultural activity. Being popular since the Song Dynasty (960-1276), it \
is held in the Lantern Festival (15th day of the first lunar month). \
When people are viewing the flower lanterns, they can guess the riddles on the lanterns together.
"""
)
with gr.Column():
markdown_msg2 = gr.Markdown(
"""
![lantern](file/data/DSC_0105.jpg)
---
# 声明 Disclaimer
本应用输出的文本为机器基于模型生成的结果,不代表任何人观点,请谨慎辨别和参考。请在法律允许的范围内使用。
本应用调用了 [ChatYuan-large-v2](https://huggingface.co/ClueAI/ChatYuan-large-v2) 对话语言大模型,\
使用本应用前请务必阅读和同意遵守其[使用授权许可证](https://huggingface.co/ClueAI/ChatYuan-large-v2/blob/main/LICENSE)。
本应用仅供非商业用途。
The outputs of this application are machine-generated with a statistical model. \
The outputs do not reflect any opinions of any human subjects. You must identify the outputs in caution. \
It is your responsbility to decide whether to accept the outputs. You must use the applicaiton in obedience to the Law.
This application utilizes [ChatYuan-large-v2](https://huggingface.co/ClueAI/ChatYuan-large-v2) \
Conversational Large Language Model. Before using this application, you must read and accept to follow \
the [LICENSE](https://huggingface.co/ClueAI/ChatYuan-large-v2/blob/main/LICENSE).
This application is for non-commercial use only.
---
# 感谢 Acknowledgement
本应用调用了 [text2vec-base-chinese](https://huggingface.co/shibing624/text2vec-base-chinese) 生成 text vector embeddings.
该模型是以 [apache-2.0](https://www.apache.org/licenses/LICENSE-2.0) 发行。
This application utilizes [text2vec-base-chinese](https://huggingface.co/shibing624/text2vec-base-chinese) to generate text vector embeddings.
The model is released under [apache-2.0](https://www.apache.org/licenses/LICENSE-2.0)。
""")
submit_btn.click(fn=helper_text, inputs=[text_input,radio], outputs=text_output)
clear_btn.click(fn=clear_text,outputs=[text_input,text_output])
radio.change(fn=change_language,inputs=[radio,text_input,text_output,
markdown, markdown_msg1,markdown_msg2],
outputs=[text_input,text_output,clear_btn,submit_btn,
markdown, markdown_msg1,markdown_msg2])
#demo = gr.Interface(fn=helper_text, inputs=text_input, outputs=text_output,
# flagging_options=["Inappropriate"],allow_flagging="never",
# title="aaa",description="aaa",article="aaa")
demo.queue(api_open=False)
demo.launch(show_api=False)
# %%