JiangYH commited on
Commit
87818fb
1 Parent(s): c34f5bb

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. ChatWorld/ChatWorld.py +6 -8
  2. ChatWorld/NaiveDB.py +5 -2
  3. ChatWorld/models.py +8 -5
  4. app.py +31 -5
ChatWorld/ChatWorld.py CHANGED
@@ -1,7 +1,7 @@
1
  from jinja2 import Template
2
  import torch
3
 
4
- from .models import GLM
5
 
6
  from .NaiveDB import NaiveDB
7
  from .utils import *
@@ -19,7 +19,7 @@ class ChatWorld:
19
 
20
  self.history = []
21
 
22
- self.client = None
23
  self.model = GLM()
24
  self.db = NaiveDB()
25
  self.prompt = Template(('Please be aware that your codename in this conversation is "{{model_role_name}}"'
@@ -81,17 +81,15 @@ class ChatWorld:
81
  return {"role": "system", "content": self.prompt.render(model_role_name=self.model_role_name, model_role_nickname=self.model_role_nickname, role_name=role_name, role_nickname=role_nick_name, RAG=rag)}
82
 
83
  def chat(self, text: str, user_role_name: str, user_role_nick_name: str = None, use_local_model=False):
 
 
84
  message = [self.getSystemPrompt(text,
85
- user_role_name, user_role_nick_name)] + self.history
86
- print(message)
87
  if use_local_model:
88
  response = self.model.get_response(message)
89
  else:
90
- response = self.client.chat(
91
- user_role_name, text, user_role_nick_name)
92
 
93
- self.history.append(
94
- {"role": "user", "content": f"{user_role_name}:「{text}」"})
95
  self.history.append(
96
  {"role": "assistant", "content": f"{self.model_role_name}:「{response}」"})
97
  return response
 
1
  from jinja2 import Template
2
  import torch
3
 
4
+ from .models import GLM, GLM_api
5
 
6
  from .NaiveDB import NaiveDB
7
  from .utils import *
 
19
 
20
  self.history = []
21
 
22
+ self.client = GLM_api()
23
  self.model = GLM()
24
  self.db = NaiveDB()
25
  self.prompt = Template(('Please be aware that your codename in this conversation is "{{model_role_name}}"'
 
81
  return {"role": "system", "content": self.prompt.render(model_role_name=self.model_role_name, model_role_nickname=self.model_role_nickname, role_name=role_name, role_nickname=role_nick_name, RAG=rag)}
82
 
83
  def chat(self, text: str, user_role_name: str, user_role_nick_name: str = None, use_local_model=False):
84
+ self.history.append(
85
+ {"role": "user", "content": f"{user_role_name}:「{text}」"})
86
  message = [self.getSystemPrompt(text,
87
+ user_role_name, user_role_nick_name), {"role": "user", "content": f"{user_role_name}:「{text}」"}]
 
88
  if use_local_model:
89
  response = self.model.get_response(message)
90
  else:
91
+ response = self.client.chat(message)
 
92
 
 
 
93
  self.history.append(
94
  {"role": "assistant", "content": f"{self.model_role_name}:「{response}」"})
95
  return response
ChatWorld/NaiveDB.py CHANGED
@@ -81,7 +81,10 @@ class NaiveDB:
81
  similarities.sort(key=lambda x: x[0], reverse=True)
82
  self.last_search_ids = [x[1] for x in similarities[:n_results]]
83
 
84
-
 
 
85
 
86
- top_stories = [self.stories[_id] for _id in self.last_search_ids]
 
87
  return top_stories
 
81
  similarities.sort(key=lambda x: x[0], reverse=True)
82
  self.last_search_ids = [x[1] for x in similarities[:n_results]]
83
 
84
+ stories_length = len(self.stories)
85
+ search_id_range = [(max(0, i-3), min(i+4, stories_length))
86
+ for i in self.last_search_ids]
87
 
88
+ top_stories = ["\n".join(self.stories[start:end+1])
89
+ for start, end in search_id_range]
90
  return top_stories
ChatWorld/models.py CHANGED
@@ -40,7 +40,7 @@ class GLM():
40
 
41
  self.client = client.eval()
42
 
43
- def message2query(messages) -> str:
44
  # [{'role': 'user', 'content': '老师: 同学请自我介绍一下'}]
45
  # <|system|>
46
  # You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.
@@ -53,7 +53,9 @@ class GLM():
53
  return "".join([template.substitute(message) for message in messages])
54
 
55
  def get_response(self, message):
56
- response, history = self.client.chat(self.tokenizer, message)
 
 
57
  return response
58
 
59
 
@@ -62,7 +64,8 @@ class GLM_api:
62
  self.client = ZhipuAI(api_key=os.environ["ZHIPU_API_KEY"])
63
  self.model = model_name
64
 
65
- def getResponse(self, message):
 
66
  response = self.client.chat.completions.create(
67
- model=self.model, prompt=message)
68
- return response.choices[0].message
 
40
 
41
  self.client = client.eval()
42
 
43
+ def message2query(self, messages) -> str:
44
  # [{'role': 'user', 'content': '老师: 同学请自我介绍一下'}]
45
  # <|system|>
46
  # You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.
 
53
  return "".join([template.substitute(message) for message in messages])
54
 
55
  def get_response(self, message):
56
+ response, history = self.client.chat(
57
+ self.tokenizer, self.message2query(message))
58
+ print(self.message2query(message))
59
  return response
60
 
61
 
 
64
  self.client = ZhipuAI(api_key=os.environ["ZHIPU_API_KEY"])
65
  self.model = model_name
66
 
67
+ def chat(self, message):
68
+ print(message)
69
  response = self.client.chat.completions.create(
70
+ model=self.model, messages=message)
71
+ return response.choices[0].message.content
app.py CHANGED
@@ -11,6 +11,8 @@ logging.basicConfig(level=logging.INFO, filename="demo.log", filemode="w",
11
 
12
  chatWorld = ChatWorld()
13
 
 
 
14
 
15
  def getContent(input_file):
16
  # 读取文件内容
@@ -31,33 +33,57 @@ def getContent(input_file):
31
  role_name_list = [i for i in role_name_set if i != ""]
32
  logging.info(f"role_name_list: {role_name_list}")
33
 
 
 
 
34
  return gr.Radio(choices=role_name_list, interactive=True, value=role_name_list[0]), gr.Radio(choices=role_name_list, interactive=True, value=role_name_list[-1])
35
 
36
 
37
  def submit_message(message, history, model_role_name, role_name, model_role_nickname, role_nickname):
 
38
  chatWorld.setRoleName(model_role_name, model_role_nickname)
39
  response = chatWorld.chat(message,
40
  role_name, role_nickname, use_local_model=True)
41
  return response
42
 
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  with gr.Blocks() as demo:
45
 
46
  upload_c = gr.File(label="上传文档文件")
47
 
48
  with gr.Row():
49
- model_role_name = gr.Radio([], label="模型角色名")
50
  model_role_nickname = gr.Textbox(label="模型角色昵称")
51
 
52
  with gr.Row():
53
- role_name = gr.Radio([], label="角色名")
54
  role_nickname = gr.Textbox(label="角色昵称")
55
 
56
  upload_c.upload(fn=getContent, inputs=upload_c,
57
  outputs=[model_role_name, role_name])
58
 
59
- chatBox = gr.ChatInterface(
60
- submit_message, chatbot=gr.Chatbot(height=400, render=False), additional_inputs=[model_role_name, role_name, model_role_nickname, role_nickname])
 
 
 
 
61
 
62
 
63
- demo.launch(share=True, debug=True, server_name="0.0.0.0")
 
11
 
12
  chatWorld = ChatWorld()
13
 
14
+ role_name_list_global = None
15
+
16
 
17
  def getContent(input_file):
18
  # 读取文件内容
 
33
  role_name_list = [i for i in role_name_set if i != ""]
34
  logging.info(f"role_name_list: {role_name_list}")
35
 
36
+ global role_name_list_global
37
+ role_name_list_global = role_name_list
38
+
39
  return gr.Radio(choices=role_name_list, interactive=True, value=role_name_list[0]), gr.Radio(choices=role_name_list, interactive=True, value=role_name_list[-1])
40
 
41
 
42
  def submit_message(message, history, model_role_name, role_name, model_role_nickname, role_nickname):
43
+ print(f"history: {history}")
44
  chatWorld.setRoleName(model_role_name, model_role_nickname)
45
  response = chatWorld.chat(message,
46
  role_name, role_nickname, use_local_model=True)
47
  return response
48
 
49
 
50
+ def submit_message_api(message, history, model_role_name, role_name, model_role_nickname, role_nickname):
51
+ print(f"history: {history}")
52
+ chatWorld.setRoleName(model_role_name, model_role_nickname)
53
+ response = chatWorld.chat(message,
54
+ role_name, role_nickname, use_local_model=False)
55
+ return response
56
+
57
+
58
+ def get_role_list():
59
+ global role_name_list_global
60
+ if role_name_list_global:
61
+ return role_name_list_global
62
+ else:
63
+ return []
64
+
65
+
66
  with gr.Blocks() as demo:
67
 
68
  upload_c = gr.File(label="上传文档文件")
69
 
70
  with gr.Row():
71
+ model_role_name = gr.Radio(get_role_list(), label="模型角色名")
72
  model_role_nickname = gr.Textbox(label="模型角色昵称")
73
 
74
  with gr.Row():
75
+ role_name = gr.Radio(get_role_list(), label="角色名")
76
  role_nickname = gr.Textbox(label="角色昵称")
77
 
78
  upload_c.upload(fn=getContent, inputs=upload_c,
79
  outputs=[model_role_name, role_name])
80
 
81
+ with gr.Row():
82
+ chatBox_local = gr.ChatInterface(
83
+ submit_message, chatbot=gr.Chatbot(height=400, label="本地模型", render=False), additional_inputs=[model_role_name, role_name, model_role_nickname, role_nickname])
84
+
85
+ chatBox_api = gr.ChatInterface(
86
+ submit_message_api, chatbot=gr.Chatbot(height=400, label="API模型", render=False), additional_inputs=[model_role_name, role_name, model_role_nickname, role_nickname])
87
 
88
 
89
+ demo.launch(server_name="0.0.0.0")