File size: 11,726 Bytes
936a3fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
# %%
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings

import gradio as gr
import hanzidentifier
import re

import chinese_converter

# %%
#Load the LLM model and pipeline directly
llm_model_name="Qwen/Qwen1.5-0.5B-Chat"

#pipe = pipeline("text2text-generation", model=model)
model = AutoModelForCausalLM.from_pretrained(
    llm_model_name
)

tokenizer = AutoTokenizer.from_pretrained(llm_model_name)

# %%
# %%
# loading the vector encoder
vec_model_name = "shibing624/text2vec-base-chinese"

encode_kwargs = {'normalize_embeddings': False}
model_kwargs = {'device': 'cpu'}

huggingface_embeddings= HuggingFaceEmbeddings(
    model_name=vec_model_name,
    model_kwargs=model_kwargs,
    encode_kwargs = encode_kwargs
)


# %%
persist_directory = 'chroma/'
vectordb = Chroma(embedding_function=huggingface_embeddings,persist_directory=persist_directory)
print(vectordb._collection.count())

# %%
text_input_label=["谜面","謎面","Riddle"]
text_output_label=["谜底","謎底","Answer"]

clear_label = ["清除","清除","Clear"]
submit_label = ["提交","提交","Submit"]

# %%
# helper functions for prompt processing for this LLM

def preprocess(text):
  text = text.replace("\n", "\\n").replace("\t", "\\t")
  return text

def postprocess(text):
  return text.replace("\\n", "\n").replace("\\t", "\t").replace('%20','  ')


# get answer from LLM with prompt input
def answer(input_text,context=""):
    prompt = f"{input_text}\n提示:\n{context}\n谜底是什么?请解释。"
    prompt = prompt.strip()

    print(prompt)
    #text = preprocess(text)
    #out_text = pipe(text)
    messages = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": prompt}
    ]
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    model_inputs = tokenizer([text], return_tensors="pt").to(device="cpu")

    generated_ids = model.generate(
        model_inputs.input_ids,
        max_new_tokens=512,
        do_sample=False,
        top_p=0.0
    )
    generated_ids = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]

    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

    #return out_text[0]["generated_text"]
    return response
    #return postprocess(out_text[0]["generated_text"])

# helper function for RAG
def helper_rag(text):
    docs_out = vectordb.similarity_search_with_relevance_scores(text,k=1)
    #docs_out = vectordb.max_marginal_relevance_search(text,k=5,fetch_k = 20, lambda_mult = 0.5)
    context = ""
    for doc in docs_out:
        if doc[1] > 0.7:
            context += doc[0].page_content + "\n"

    return context

# helper function for prompt
def helper_text(text_input,radio=None):
    chinese_type = "simplified"

    if hanzidentifier.is_traditional(text_input):
        text_input = chinese_converter.to_simplified(text_input)
        chinese_type = "traditional"

    text_input = re.sub(r'hint',"猜",text_input,flags=re.I)

    if not any(c in text_input for c in ["猜", "打"]):
        warning = "请给一个提示,提示格式,例子:猜一水果,打一字。"
        if chinese_type == "traditional" or radio == "繁體中文":
            warning = chinese_converter.to_traditional(warning)
        return warning
                                         
    text=f"""猜谜语:\n谜面:{text_input}
    """

    context = helper_rag(text)

    output = answer(text,context=context)

    print(output)

    if chinese_type == "traditional":
        output = chinese_converter.to_traditional(output)

    #output = re.split(r'\s+',output)
        
    return output
    
    #return output[0]



# Gradio function for configure the language of UI
def change_language(radio,text_input,text_output,markdown, 
                    markdown_msg1, markdown_msg2):
    if radio == "简体中文":
        index = 0
        text_input_update=gr.Textbox.update(value = chinese_converter.to_simplified(text_input), label = text_input_label[index])
        text_output_update=gr.Textbox.update(value = chinese_converter.to_simplified(text_output),label = text_output_label[index])
        markdown_update=gr.Markdown.update(value = chinese_converter.to_simplified(markdown))
        markdown_msg1_update=gr.Markdown.update(value = chinese_converter.to_simplified(markdown_msg1))
        markdown_msg2_update=gr.Markdown.update(value = chinese_converter.to_simplified(markdown_msg2))
    elif radio == "繁體中文":
        index = 1
        text_input_update=gr.Textbox.update(value = chinese_converter.to_traditional(text_input),label = text_input_label[index])
        text_output_update=gr.Textbox.update(value = chinese_converter.to_traditional(text_output),label = text_output_label[index])
        markdown_update=gr.Markdown.update(value = chinese_converter.to_traditional(markdown))
        markdown_msg1_update=gr.Markdown.update(value = chinese_converter.to_traditional(markdown_msg1))
        markdown_msg2_update=gr.Markdown.update(value = chinese_converter.to_traditional(markdown_msg2))
    elif radio == "English":
        index = 2
        text_input_update=gr.Textbox.update(label = text_input_label[index])
        text_output_update=gr.Textbox.update(label = text_output_label[index])
        markdown_update=gr.Markdown.update(value = markdown)
        markdown_msg1_update=gr.Markdown.update(value = markdown_msg1)
        markdown_msg2_update=gr.Markdown.update(value = markdown_msg2)

    else:
        index = 0
        text_input_update=gr.Textbox.update(label = text_input_label[index])
        text_output_update=gr.Textbox.update(label = text_output_label[index])
        markdown_update=gr.Markdown.update(value = chinese_converter.to_simplified(markdown))
        markdown_msg1_update=gr.Markdown.update(value = chinese_converter.to_simplified(markdown_msg1))
        markdown_msg2_update=gr.Markdown.update(value = chinese_converter.to_simplified(markdown_msg2))

    clear_btn_update = gr.ClearButton.update(value = clear_label[index])
    submit_btn_update = gr.Button.update(value = submit_label[index])

    return [text_input_update,text_output_update,clear_btn_update,submit_btn_update,markdown_update, 
            markdown_msg1_update ,markdown_msg2_update]


def clear_text():
    text_input_update=gr.Textbox.update(value=None)
    text_output_update=gr.Textbox.update(value=None)

    return [text_input_update,text_output_update]
    

# %%
# css = """
# #markdown { background-image: url("file/data/DSC_0105.jpg");
#             background-size: cover;
#           }
# """

with gr.Blocks() as demo:
    index = 0
    example_list = [
        ["小家伙穿黄袍,花丛中把房造。飞到西来飞到东,人人夸他爱劳动。(猜一动物)"],
        ["一物生来身穿三百多件衣,每天脱一件,年底剩张皮。(猜一物品)"],
        ["A thousand threads, a million strands. Reaching the water, vanishing all at once. (Hint: natural phenomenon)"],
        ["无底洞(猜成语)"], 
    ]
    radio = gr.Radio(
        ["简体中文","繁體中文", "English"],show_label=False,value="简体中文"
    )
    markdown = gr.Markdown(
            """
            # Chinese Lantern Riddles Solver with LLM
            ## 用大语言模型来猜灯谜 
            """,elem_id="markdown")
    with gr.Row():
        with gr.Column():
            text_input = gr.Textbox(label=text_input_label[index], 
                         value="小家伙穿黄袍,花丛中把房造。飞到西来飞到东,人人夸他爱劳动。(猜一动物)", lines = 2)
            with gr.Row():
                clear_btn = gr.ClearButton(value=clear_label[index],components=[text_input])
                submit_btn = gr.Button(value=submit_label[index], variant = "primary")

            text_output = gr.Textbox(label=text_output_label[index])


            examples = gr.Examples(
                       examples=example_list,
                       inputs=text_input,
                       outputs=text_output,
                       fn=helper_text,
                       cache_examples=True,
            )
            markdown_msg1 = gr.Markdown(
                """
                灯谜是中华文化特色文娱活动,自北宋盛行。每年逢正月十五元宵节,将谜语贴在花灯上,让大家可一起猜谜。

                Lantern riddle is a traditional Chinese cultural activity. Being popular since the Song Dynasty (960-1276), it \
                is held in the Lantern Festival (15th day of the first lunar month). \
                When people are viewing the flower lanterns, they can guess the riddles on the lanterns together. 


                """
            )

        with gr.Column():
            markdown_msg2 = gr.Markdown(
            """
            ![lantern](file/data/DSC_0105.jpg)

            ---
            # 声明 Disclaimer

            本应用输出的文本为机器基于模型生成的结果,不代表任何人观点,请谨慎辨别和参考。请在法律允许的范围内使用。

            本应用调用了 [ChatYuan-large-v2](https://huggingface.co/ClueAI/ChatYuan-large-v2) 对话语言大模型,\
            使用本应用前请务必阅读和同意遵守其[使用授权许可证](https://huggingface.co/ClueAI/ChatYuan-large-v2/blob/main/LICENSE)。

            本应用仅供非商业用途。

            The outputs of this application are machine-generated with a statistical model. \
            The outputs do not reflect any opinions of any human subjects. You must identify the outputs in caution. \
            It is your responsbility to decide whether to accept the outputs. You must use the applicaiton in obedience to the Law.

            This application utilizes [ChatYuan-large-v2](https://huggingface.co/ClueAI/ChatYuan-large-v2) \
            Conversational Large Language Model. Before using this application, you must read and accept to follow \
            the [LICENSE](https://huggingface.co/ClueAI/ChatYuan-large-v2/blob/main/LICENSE).

            This application is for non-commercial use only.

            ---

            # 感谢 Acknowledgement

            本应用调用了 [text2vec-base-chinese](https://huggingface.co/shibing624/text2vec-base-chinese) 生成 text vector embeddings.
            该模型是以 [apache-2.0](https://www.apache.org/licenses/LICENSE-2.0) 发行。

            This application utilizes [text2vec-base-chinese](https://huggingface.co/shibing624/text2vec-base-chinese) to generate text vector embeddings.
            The model is released under [apache-2.0](https://www.apache.org/licenses/LICENSE-2.0)。
            """)

            




    submit_btn.click(fn=helper_text, inputs=[text_input,radio], outputs=text_output)

    clear_btn.click(fn=clear_text,outputs=[text_input,text_output])
    radio.change(fn=change_language,inputs=[radio,text_input,text_output,
                                            markdown, markdown_msg1,markdown_msg2],
                 outputs=[text_input,text_output,clear_btn,submit_btn, 
                          markdown, markdown_msg1,markdown_msg2])

    #demo = gr.Interface(fn=helper_text, inputs=text_input, outputs=text_output, 
    #                      flagging_options=["Inappropriate"],allow_flagging="never",
    #                      title="aaa",description="aaa",article="aaa")
demo.queue(api_open=False)
demo.launch(show_api=False)  
 

# %%