AngoHF commited on
Commit
f293d7e
1 Parent(s): a9b3003

first commit

Browse files
Files changed (1) hide show
  1. app.py +114 -0
app.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+
4
+ from transformers import (
5
+ AutoModelForCausalLM,
6
+ AutoTokenizer
7
+ )
8
+ from peft import PeftModel
9
+ import torch
10
+
11
+ model_path = "Qwen/Qwen1.5-1.8B-Chat"
12
+ lora_path = "AngoHF/EssayGPT" #+ "/checkpoint-100"
13
+
14
+ if torch.cuda.is_available():
15
+ device = "cuda:0"
16
+ else:
17
+ device = "cpu"
18
+
19
+ tokenizer = AutoTokenizer.from_pretrained(
20
+ model_path,
21
+ )
22
+ config_kwargs = {"device_map": device}
23
+
24
+ model = AutoModelForCausalLM.from_pretrained(
25
+ model_path,
26
+ torch_dtype=torch.float16,
27
+ **config_kwargs
28
+ )
29
+
30
+ model = PeftModel.from_pretrained(model, lora_path)
31
+ model = model.merge_and_unload()
32
+ model.eval()
33
+
34
+ # model.config.use_cache = True
35
+ # model.to("cpu")
36
+ # model.save_pretrained("/data/ango/EssayGPT")
37
+
38
+ # tokenizer.save_pretrained("/data/ango/EssayGPT")
39
+
40
+
41
+ MAX_MATERIALS = 4
42
+
43
+
44
+ def call(related_materials, materials, question):
45
+ query_texts = [f"材料{i + 1}\n{material}" for i, material in enumerate(materials) if i in related_materials]
46
+ query_texts.append(f"问题:{question}")
47
+ query = "\n".join(query_texts)
48
+ messages = [
49
+ {"role": "system", "content": "请你根据以下提供的材料来回答问题"},
50
+ {"role": "user", "content": query}
51
+ ]
52
+ text = tokenizer.apply_chat_template(
53
+ messages,
54
+ tokenize=False,
55
+ add_generation_prompt=True
56
+ )
57
+ model_inputs = tokenizer([text], return_tensors="pt").to(device)
58
+ print(len(model_inputs.input_ids[0]))
59
+ generated_ids = model.generate(
60
+ model_inputs.input_ids,
61
+ max_length=8096
62
+ )
63
+ generated_ids = [
64
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
65
+ ]
66
+
67
+ response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
68
+ return response
69
+
70
+
71
+ def create_ui():
72
+ with gr.Blocks() as app:
73
+ gr.Markdown("""<center><font size=8>EssayGPT-申论大模型</center>""")
74
+ gr.Markdown(
75
+ """<center><font size=4>1.把材料填入对应位置 2.输入问题和要求 3.选择解答问题需要的相关材料 4.点击"提问!"</center>""")
76
+ with gr.Row():
77
+ with gr.Column():
78
+ materials = []
79
+
80
+ for i in range(MAX_MATERIALS):
81
+ with gr.Tab(f"材料{i + 1}"):
82
+ materials.append(gr.Textbox(label="材料内容"))
83
+ with gr.Column():
84
+ related_materials = gr.Dropdown(
85
+ choices=list(range(1, MAX_MATERIALS + 1)), multiselect=True,
86
+ label="问题所需相关材料")
87
+ question = gr.Textbox(label="问题")
88
+ submit = gr.Button("提问!")
89
+ answer = gr.Textbox(label="回答")
90
+ build_ui({"materials": materials, "related_materials": related_materials, "question": question,
91
+ "submit": submit, "answer": answer})
92
+ return app
93
+
94
+
95
+ def build_ui(components):
96
+ def func(related_materials, question, *materials):
97
+ if not related_materials:
98
+ return "请选择问题所需相关材料"
99
+ related_materials = [i - 1 for i in related_materials]
100
+ return call(related_materials, materials, question)
101
+
102
+ components["submit"].click(func,
103
+ [components["related_materials"], components["question"], *components["materials"]],
104
+ components["answer"])
105
+
106
+
107
+ def run():
108
+ app = create_ui()
109
+ app.queue()
110
+ app.launch(share=True)
111
+
112
+
113
+ if __name__ == '__main__':
114
+ run()