decula commited on
Commit
7a88d5b
·
1 Parent(s): 43735ce

changeed the model

Browse files
Files changed (1) hide show
  1. qianwen_rag.py +284 -0
qianwen_rag.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os, gc, copy, torch
3
+ from datetime import datetime
4
+ from pynvml import *
5
+ from duckduckgo_search import DDGS
6
+ import re
7
+ import asyncio
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM
9
+
10
+ # Flag to check if GPU is present
11
+ HAS_GPU = False
12
+
13
+ # Model title and context size limit
14
+ ctx_limit = 20000
15
+ title = "Qwen2-72B-Instruct-2.0bpw-h-novel-exl2 with RAG"
16
+ model_repo = "Orion-zhen/Qwen2-72B-Instruct-2.0bpw-h-novel-exl2"
17
+
18
+ # Get the GPU count
19
+ try:
20
+ nvmlInit()
21
+ GPU_COUNT = nvmlDeviceGetCount()
22
+ if GPU_COUNT > 0:
23
+ HAS_GPU = True
24
+ gpu_h = nvmlDeviceGetHandleByIndex(0)
25
+ print(f"检测到 {GPU_COUNT} 个GPU设备")
26
+ for i in range(GPU_COUNT):
27
+ handle = nvmlDeviceGetHandleByIndex(i)
28
+ info = nvmlDeviceGetMemoryInfo(handle)
29
+ name = nvmlDeviceGetName(handle)
30
+ print(f"GPU {i}: {name}, 总内存: {info.total / 1024**3:.2f} GB")
31
+ except NVMLError as error:
32
+ print(error)
33
+
34
+ # Load the model using transformers
35
+ print(f"正在加载模型: {model_repo}")
36
+
37
+ # 设置设备配置
38
+ device = "cpu"
39
+ if HAS_GPU:
40
+ device = "cuda"
41
+
42
+ # 加载模型和分词器
43
+ tokenizer = AutoTokenizer.from_pretrained(model_repo)
44
+ model = AutoModelForCausalLM.from_pretrained(model_repo)
45
+
46
+ # 将模型移动到适当的设备
47
+ model = model.to(device)
48
+
49
+ # 理解问题并提取关键词的函数
50
+ async def understanding_question(question: str):
51
+ # 简单处理:移除常见的问题词,保留关键内容
52
+ question = question.lower()
53
+ question = re.sub(r'^(can you|could you|please|tell me about|what is|who is|how to|why is|when did)\s+', '', question)
54
+ # 返回处理后的问题作为关键词
55
+ return question
56
+
57
+ # Web search function for RAG with browser agent HTTP headers
58
+ async def run_duckduckgo_search_tool(question: str):
59
+ text = await understanding_question(question)
60
+
61
+ keywords = text.split(",")
62
+ headers = {
63
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:124.0) Gecko/20100101 Firefox/124.0"
64
+ }
65
+
66
+ results = DDGS(headers=headers).text(keywords[0], max_results=5)
67
+ print(results)
68
+
69
+ return text
70
+
71
+ # 修改后的web_search函数,使用run_duckduckgo_search_tool
72
+ def web_search(query, max_results=3):
73
+ try:
74
+ # 设置浏览器代理HTTP头部
75
+ headers = {
76
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:124.0) Gecko/20100101 Firefox/124.0"
77
+ }
78
+
79
+ with DDGS(headers=headers) as ddgs:
80
+ results = list(ddgs.text(query, max_results=max_results))
81
+ if not results:
82
+ return "No search results found."
83
+
84
+ formatted_results = "\n\nSearch Results:\n"
85
+ for i, result in enumerate(results, 1):
86
+ formatted_results += f"[{i}] {result['title']}\n"
87
+ formatted_results += f"URL: {result['href']}\n"
88
+ formatted_results += f"Summary: {result['body']}\n\n"
89
+
90
+ return formatted_results
91
+ except Exception as e:
92
+ print(f"Search error: {e}")
93
+ return ""
94
+
95
+ # Extract search query from user input
96
+ def extract_search_query(text):
97
+ # Look for questions or information requests in the text
98
+ text = text.lower()
99
+ # Remove any existing "User:" or "A:" prefixes
100
+ text = re.sub(r'user:\s*|a:\s*', '', text)
101
+ # Remove common question words that might not be relevant to the search
102
+ text = re.sub(r'^(can you|could you|please|tell me about|what is|who is|how to|why is|when did)\s+', '', text)
103
+ # Limit query length
104
+ return text[:100]
105
+
106
+ # Prompt generation
107
+ def generate_prompt(instruction, input=""):
108
+ instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
109
+ input = input.strip().replace('\r\n','\n').replace('\n\n','\n')
110
+ if input:
111
+ return f"""Instruction: {instruction}
112
+
113
+ Input: {input}
114
+
115
+ Response:"""
116
+ else:
117
+ return f"""User: hi
118
+
119
+ A: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.
120
+
121
+ User: {instruction}
122
+
123
+ A:"""
124
+
125
+ # 使用transformers生成文本的函数
126
+ def generate_text(prompt, max_length=100, temperature=1.0, top_p=0.7):
127
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
128
+
129
+ # 设置生成参数
130
+ gen_kwargs = {
131
+ "max_new_tokens": max_length,
132
+ "temperature": temperature,
133
+ "top_p": top_p,
134
+ "do_sample": True
135
+ }
136
+
137
+ # 生成文本
138
+ with torch.no_grad():
139
+ output = model.generate(**inputs, **gen_kwargs)
140
+
141
+ # 解码生成的文本
142
+ generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
143
+
144
+ # 只返回新生成的部分(去除输入提示)
145
+ return generated_text[len(tokenizer.decode(inputs.input_ids[0], skip_special_tokens=True)):]
146
+
147
+ # Evaluation logic with RAG enhancement
148
+ def evaluate(
149
+ ctx,
150
+ token_count=200,
151
+ temperature=1.0,
152
+ top_p=0.7,
153
+ presencePenalty = 0.1,
154
+ countPenalty = 0.1,
155
+ ):
156
+ # Extract a search query from the user's input
157
+ search_query = extract_search_query(ctx)
158
+
159
+ # Perform web search if the context seems like a question
160
+ search_results = ""
161
+ if len(search_query) > 5 and not ctx.startswith("Assistant:"):
162
+ search_results = web_search(search_query)
163
+
164
+ # Combine original context with search results for RAG
165
+ if search_results:
166
+ # For prompts using the generate_prompt format
167
+ if "User:" in ctx and "\n\nA:" in ctx:
168
+ # Insert search results before the "A:" part
169
+ parts = ctx.split("\n\nA:")
170
+ rag_ctx = parts[0] + "\n\nRelevant Information:" + search_results + "\n\nA:"
171
+ # For instruction format
172
+ elif "Instruction:" in ctx and "\n\nResponse:" in ctx:
173
+ # Insert search results before the "Response:" part
174
+ parts = ctx.split("\n\nResponse:")
175
+ rag_ctx = parts[0] + "\n\nRelevant Information:" + search_results + "\n\nResponse:"
176
+ else:
177
+ # For other formats, append to the end
178
+ rag_ctx = ctx + "\n\nRelevant Information:" + search_results
179
+ else:
180
+ rag_ctx = ctx
181
+
182
+ print("Context with RAG:")
183
+ print(rag_ctx)
184
+
185
+ # 使用transformers生成文本
186
+ result = generate_text(
187
+ rag_ctx,
188
+ max_length=int(token_count),
189
+ temperature=max(0.2, float(temperature)),
190
+ top_p=float(top_p)
191
+ )
192
+
193
+ # 清理GPU内存
194
+ if HAS_GPU:
195
+ if GPU_COUNT >= 2:
196
+ # 清理两块GPU的缓存
197
+ for i in range(GPU_COUNT):
198
+ with torch.cuda.device(f"cuda:{i}"):
199
+ torch.cuda.empty_cache()
200
+ if i < 2: # 只显示前两块GPU的信息
201
+ handle = nvmlDeviceGetHandleByIndex(i)
202
+ gpu_info = nvmlDeviceGetMemoryInfo(handle)
203
+ print(f'GPU {i} VRAM: 总计 {gpu_info.total/(1024**3):.2f}GB, 已用 {gpu_info.used/(1024**3):.2f}GB, 空闲 {gpu_info.free/(1024**3):.2f}GB')
204
+ else:
205
+ gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
206
+ print(f'GPU VRAM: 总计 {gpu_info.total/(1024**3):.2f}GB, 已用 {gpu_info.used/(1024**3):.2f}GB, 空闲 {gpu_info.free/(1024**3):.2f}GB')
207
+ torch.cuda.empty_cache()
208
+
209
+ # 由于transformers生成是一次性的,这里直接返回结果
210
+ return result
211
+
212
+ # Examples and gradio blocks
213
+ examples = [
214
+ ["Assistant: Sure! Here is a very detailed plan to create flying pigs:", 333, 1, 0.3, 0, 1],
215
+ ["Assistant: Sure! Here are some ideas for FTL drive:", 333, 1, 0.3, 0, 1],
216
+ [generate_prompt("Tell me about ravens."), 333, 1, 0.3, 0, 1],
217
+ [generate_prompt("Écrivez un programme Python pour miner 1 Bitcoin, avec des commentaires."), 333, 1, 0.3, 0, 1],
218
+ [generate_prompt("東京で訪れるべき素晴らしい場所とその紹介をいくつか挙げてください。"), 333, 1, 0.3, 0, 1],
219
+ [generate_prompt("Write a story using the following information.", "A man named Alex chops a tree down."), 333, 1, 0.3, 0, 1],
220
+ ["Assistant: Here is a very detailed plan to kill all mosquitoes:", 333, 1, 0.3, 0, 1],
221
+ ['''Edward: I am Edward Elric from fullmetal alchemist. I am in the world of full metal alchemist and know nothing of the real world.
222
+
223
+ User: Hello Edward. What have you been up to recently?
224
+
225
+ Edward:''', 333, 1, 0.3, 0, 1],
226
+ [generate_prompt("What are the latest developments in quantum computing?"), 333, 1, 0.3, 0, 1],
227
+ [generate_prompt("Tell me about the current situation in Ukraine."), 333, 1, 0.3, 0, 1],
228
+ ]
229
+
230
+ ##########################################################################
231
+ port=7860
232
+ use_frpc=True
233
+ frpconfigfile="7680.ini"
234
+ import subprocess
235
+
236
+ def install_Frpc(port, frpconfigfile, use_frpc):
237
+ if use_frpc:
238
+ subprocess.run(['chmod', '+x', './frpc'], check=True)
239
+ print(f'正在启动frp ,端口{port}')
240
+ subprocess.Popen(['./frpc', '-c', frpconfigfile])
241
+
242
+ install_Frpc('7860',frpconfigfile,use_frpc)
243
+
244
+ # Gradio blocks
245
+ with gr.Blocks(title=title) as demo:
246
+ gr.HTML(f"<div style=\"text-align: center;\">\n<h1>Qwen2-72B-Instruct with RAG - {title}</h1>\n</div>")
247
+ with gr.Tab("Raw Generation"):
248
+ gr.Markdown(f"这是带有RAG功能的Qwen2-72B-Instruct模型。支持多种语言和代码。演示限制上下文长度为{ctx_limit}。")
249
+ with gr.Row():
250
+ with gr.Column():
251
+ prompt = gr.Textbox(lines=2, label="提示词", value="")
252
+ token_count = gr.Slider(0, 20000, label="最大Token数", step=200, value=100)
253
+ temperature = gr.Slider(0.2, 2.0, label="温度", step=0.1, value=1.0)
254
+ top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.3)
255
+ presence_penalty = gr.Slider(0.0, 1.0, label="存在惩罚", step=0.1, value=1)
256
+ count_penalty = gr.Slider(0.0, 1.0, label="计数惩罚", step=0.1, value=1)
257
+ with gr.Column():
258
+ with gr.Row():
259
+ submit = gr.Button("提交", variant="primary")
260
+ stop_btn = gr.Button("中断", variant="stop")
261
+ clear = gr.Button("清除", variant="secondary")
262
+ output = gr.Textbox(label="输出", lines=5)
263
+ data = gr.Dataset(components=[prompt, token_count, temperature, top_p, presence_penalty, count_penalty], label="示例指令", headers=["提示词", "最大Token数", "温度", "Top P", "存在惩罚", "计数惩罚"])
264
+
265
+ # 设置提交按钮事件
266
+ submit_event = submit.click(
267
+ evaluate,
268
+ [prompt, token_count, temperature, top_p, presence_penalty, count_penalty],
269
+ [output]
270
+ )
271
+
272
+ # 设置中断按钮事件
273
+ stop_btn.click(
274
+ fn=None,
275
+ inputs=None,
276
+ outputs=None,
277
+ cancels=[submit_event] # 取消正在进行的生成过程
278
+ )
279
+
280
+ clear.click(lambda: None, [], [output])
281
+ data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty])
282
+
283
+ # Gradio launch
284
+ demo.launch(share=False)