hiwei commited on
Commit
353256e
1 Parent(s): 2bebf39

init project

Browse files
README.md DELETED
@@ -1,13 +0,0 @@
1
- ---
2
- title: Chatglm2 6b Explorer
3
- emoji: 🐠
4
- colorFrom: green
5
- colorTo: blue
6
- sdk: gradio
7
- sdk_version: 3.37.0
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
apps/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .translator import translator_demo
2
+ from .simple_chat import simple_chat_demo
3
+ from .instruction_chat import instruction_chat_demo
apps/components.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+
4
+ def chat_accordion():
5
+ with gr.Accordion("参数设置", open=False):
6
+ temperature = gr.Slider(
7
+ minimum=0.1,
8
+ maximum=2.0,
9
+ value=0.8,
10
+ step=0.1,
11
+ interactive=True,
12
+ label="Temperature",
13
+ )
14
+ top_p = gr.Slider(
15
+ minimum=0.1,
16
+ maximum=0.99,
17
+ value=0.9,
18
+ step=0.01,
19
+ interactive=True,
20
+ label="top_p",
21
+ )
22
+ return temperature, top_p
apps/instruction_chat.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import traceback
2
+
3
+ import gradio as gr
4
+
5
+ from chatClient import ChatClient
6
+ from apps.components import chat_accordion
7
+
8
+ BOT_NAME = "ChatGLM2-6B"
9
+ TITLE = """<h3 align="center">🤗 ChatGLM2-6B 预设指令对话</h3>"""
10
+ RETRY_COMMAND = "/retry"
11
+ DEFAULT_INSTRUCTIONS = "你是一个爱笑的机器人,名字叫小莫,在回答问题的时候会添加合适的emoji来表达情绪。"
12
+
13
+
14
+ def chat(client: ChatClient):
15
+ with gr.Row():
16
+ with gr.Column(elem_id="chat_container", scale=3):
17
+ with gr.Row():
18
+ chatbot = gr.Chatbot(elem_id="chatbot")
19
+ with gr.Row():
20
+ inputs = gr.Textbox(
21
+ placeholder=f"你好 {BOT_NAME} !",
22
+ label="输入内容后点击回车",
23
+ max_lines=3,
24
+ )
25
+ with gr.Row(elem_id="button_container"):
26
+ with gr.Column():
27
+ retry_button = gr.Button("♻️ 重试上一轮对话")
28
+ with gr.Column():
29
+ delete_turn_button = gr.Button("🧽 删除上一轮对话")
30
+ with gr.Column():
31
+ clear_chat_button = gr.Button("✨ 删除全部对话历史")
32
+
33
+ with gr.Column(elem_id="param_container", scale=1):
34
+ with gr.Row():
35
+ with gr.Accordion("对话预设指令", open=True):
36
+ instructions = gr.Textbox(
37
+ placeholder="LLM instructions",
38
+ value=DEFAULT_INSTRUCTIONS,
39
+ lines=10,
40
+ interactive=True,
41
+ label="指令",
42
+ max_lines=16,
43
+ show_label=False,
44
+ )
45
+ with gr.Row():
46
+ temperature, top_p = chat_accordion()
47
+
48
+ def run_chat(message: str, chat_history, instructions: str, temperature: float, top_p: float):
49
+ if not message or (message == RETRY_COMMAND and len(chat_history) == 0):
50
+ yield chat_history
51
+ return
52
+
53
+ if message == RETRY_COMMAND and chat_history:
54
+ prev_turn = chat_history.pop(-1)
55
+ user_message, _ = prev_turn
56
+ message = user_message
57
+
58
+ # chat_history = chat_history + [[message, ""]]
59
+ try:
60
+ stream = client.instruct_chat(
61
+ message,
62
+ chat_history,
63
+ instructions,
64
+ temperature=temperature,
65
+ top_p=top_p,
66
+ )
67
+ for resp, history in stream:
68
+ chat_history = history
69
+ yield chat_history
70
+ except Exception as e:
71
+ if not chat_history:
72
+ chat_history = []
73
+ chat_history += [["有错误了", traceback.format_exc()]]
74
+ yield chat_history
75
+
76
+ def delete_last_turn(chat_history):
77
+ if chat_history:
78
+ chat_history.pop(-1)
79
+ return {chatbot: gr.update(value=chat_history)}
80
+
81
+ def run_retry(message: str, chat_history, instructions, temperature: float, top_p: float):
82
+ yield from run_chat(RETRY_COMMAND, chat_history, instructions, temperature, top_p)
83
+
84
+ def clear_chat():
85
+ return []
86
+
87
+ inputs.submit(
88
+ run_chat,
89
+ [inputs, chatbot, instructions, temperature, top_p],
90
+ outputs=[chatbot],
91
+ show_progress=False,
92
+ )
93
+ inputs.submit(lambda: "", inputs=None, outputs=inputs)
94
+ delete_turn_button.click(delete_last_turn, inputs=[chatbot], outputs=[chatbot])
95
+ retry_button.click(
96
+ run_retry,
97
+ [inputs, chatbot, instructions, temperature, top_p],
98
+ outputs=[chatbot],
99
+ show_progress=False,
100
+ )
101
+ clear_chat_button.click(clear_chat, [], chatbot)
102
+
103
+
104
+ def instruction_chat_demo(client: ChatClient):
105
+ gr.HTML(TITLE)
106
+ chat(client)
apps/simple_chat.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from chatClient import ChatClient
3
+ import traceback
4
+ from apps.components import chat_accordion
5
+
6
+ BOT_NAME = "ChatGLM2-6B"
7
+ TITLE = """<h3 align="center">🤖 通用对话</h3>"""
8
+ RETRY_COMMAND = "/retry"
9
+
10
+
11
+ def chat(client: ChatClient):
12
+ with gr.Row():
13
+ with gr.Column(elem_id="chat_container", scale=3):
14
+ with gr.Row():
15
+ chatbot = gr.Chatbot(elem_id="chatbot")
16
+ with gr.Row():
17
+ inputs = gr.Textbox(
18
+ placeholder=f"你好 {BOT_NAME} !",
19
+ label="输入内容后点击回车",
20
+ max_lines=3,
21
+ )
22
+ with gr.Row(elem_id="button_container"):
23
+ with gr.Column():
24
+ retry_button = gr.Button("♻️ 重试上一轮对话")
25
+ with gr.Column():
26
+ delete_turn_button = gr.Button("🧽 删除上一轮对话")
27
+ with gr.Column():
28
+ clear_chat_button = gr.Button("✨ 删除全部对话历史")
29
+
30
+ with gr.Column(elem_id="param_container", scale=1):
31
+ temperature, top_p = chat_accordion()
32
+
33
+ def run_chat(message: str, chat_history, temperature: float, top_p: float):
34
+ if not message or (message == RETRY_COMMAND and len(chat_history) == 0):
35
+ yield chat_history
36
+ return
37
+
38
+ if message == RETRY_COMMAND and chat_history:
39
+ prev_turn = chat_history.pop(-1)
40
+ user_message, _ = prev_turn
41
+ message = user_message
42
+
43
+ # chat_history = chat_history + [[message, ""]]
44
+ try:
45
+ stream = client.simple_chat(
46
+ message,
47
+ chat_history,
48
+ temperature=temperature,
49
+ top_p=top_p,
50
+ )
51
+ for resp, history in stream:
52
+ chat_history = history
53
+ yield chat_history
54
+ except Exception as e:
55
+ if not chat_history:
56
+ chat_history = []
57
+ chat_history += [["有错误了", traceback.format_exc()]]
58
+ yield chat_history
59
+
60
+ def delete_last_turn(chat_history):
61
+ if chat_history:
62
+ chat_history.pop(-1)
63
+ return {chatbot: gr.update(value=chat_history)}
64
+
65
+ def run_retry(message: str, chat_history, temperature: float, top_p: float):
66
+ yield from run_chat(RETRY_COMMAND, chat_history, temperature, top_p)
67
+
68
+ def clear_chat():
69
+ return []
70
+
71
+ inputs.submit(
72
+ run_chat,
73
+ [inputs, chatbot, temperature, top_p],
74
+ outputs=[chatbot],
75
+ show_progress=False,
76
+ )
77
+ inputs.submit(lambda: "", inputs=None, outputs=inputs)
78
+ delete_turn_button.click(delete_last_turn, inputs=[chatbot], outputs=[chatbot])
79
+ retry_button.click(
80
+ run_retry,
81
+ [inputs, chatbot, temperature, top_p],
82
+ outputs=[chatbot],
83
+ show_progress=False,
84
+ )
85
+ clear_chat_button.click(clear_chat, [], chatbot)
86
+
87
+
88
+ def simple_chat_demo(client: ChatClient):
89
+ gr.HTML(TITLE)
90
+ chat(client)
apps/translator.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import traceback
2
+
3
+ import gradio as gr
4
+
5
+ from apps.components import chat_accordion
6
+
7
+ IDEA_TITLE = "ChatGLM2-6B 翻译官"
8
+
9
+ prompt_tmpl = """imagine you are a professional translator. Your task is translating the text around by ``` to Chinese.
10
+
11
+ input text:
12
+
13
+ ```
14
+ {input_text}
15
+ ```
16
+
17
+ translation result:"""
18
+
19
+
20
+ def translator_demo(client):
21
+
22
+ def stream_translate(input_text, temperature: float, top_p: float):
23
+ if not input_text:
24
+ return None
25
+ message = prompt_tmpl.format(input_text=input_text)
26
+ try:
27
+ stream = client.simple_chat(
28
+ message,
29
+ [],
30
+ temperature=temperature,
31
+ top_p=top_p,
32
+ )
33
+ for resp, _ in stream:
34
+ pass
35
+ return resp
36
+ except Exception as e:
37
+ return traceback.format_exc()
38
+
39
+ def clear_content():
40
+ return None, None
41
+
42
+ with gr.Row():
43
+ with gr.Column():
44
+ inputs = gr.Textbox(label="请输入原文", max_lines=5)
45
+ gr.Dropdown(["en -> zh"], value="en -> zh", label="翻译语言")
46
+ temperature, top_p = chat_accordion()
47
+ with gr.Row(elem_id="button_container"):
48
+ with gr.Column():
49
+ commit_btn = gr.Button(value="翻译", variant='primary')
50
+ with gr.Column():
51
+ clear_btn = gr.Button(value="清空")
52
+
53
+ with gr.Column():
54
+ outputs = gr.Textbox(label="译文", max_lines=5)
55
+
56
+ commit_btn.click(stream_translate, inputs=[inputs, temperature, top_p], outputs=[outputs])
57
+ clear_btn.click(clear_content, inputs=None, outputs=[inputs, outputs])
chatClient.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ from websockets.exceptions import ConnectionClosedOK
4
+ from websockets.sync.client import connect
5
+
6
+ from chatglm2_6b.modelClient import ChatGLM2
7
+ import abc
8
+
9
+
10
+ class ChatClient(abc.ABC):
11
+ @abc.abstractmethod
12
+ def simple_chat(self, query, history, temperature, top_p):
13
+ pass
14
+
15
+ @abc.abstractmethod
16
+ def instruct_chat(self, message, chat_history, instructions, temperature, top_p):
17
+ pass
18
+
19
+
20
+ def format_chat_prompt(message: str, chat_history, instructions: str) -> str:
21
+ instructions = instructions.strip(" ").strip("\n")
22
+ prompt = f"对话背景设定:{instructions}"
23
+ for i, (user_message, bot_message) in enumerate(chat_history):
24
+ prompt = f"{prompt}\n\n[Round {i + 1}]\n\n问:{user_message}\n\n答:{bot_message}"
25
+ prompt = f"{prompt}\n\n[Round {len(chat_history)+1}]\n\n问:{message}\n\n答:"
26
+ return prompt
27
+
28
+
29
+ class ChatGLM2APIClient(ChatClient):
30
+ def __init__(self, ws_url=None):
31
+ self.ws_url = "ws://localhost:10001"
32
+ if ws_url:
33
+ self.ws_url = ws_url
34
+
35
+ def simple_chat(self, query, history, temperature, top_p):
36
+ """chatglm2-6b 模型定义的对话方法"""
37
+ url = f"{self.ws_url}/streamChat"
38
+ with connect(url) as websocket:
39
+ msg = json.dumps({
40
+ "query": query, "history": history,
41
+ "temperature": temperature, "top_p": top_p,
42
+ })
43
+ websocket.send(msg)
44
+
45
+ data = None
46
+ try:
47
+ while True:
48
+ data = websocket.recv()
49
+ data = json.loads(data)
50
+ yield data['resp'], data['history']
51
+ except ConnectionClosedOK:
52
+ print("generation is finished")
53
+
54
+ def instruct_chat(self, message, chat_history, instructions, temperature, top_p):
55
+ """基于chatglm2-6b text_generate 实现的基于预设指令的对话"""
56
+ url = f"{self.ws_url}/streamGenerate"
57
+
58
+ prompt = format_chat_prompt(message, chat_history, instructions)
59
+ chat_history = chat_history + [[message, ""]]
60
+ params = json.dumps({"prompt": prompt, "temperature": temperature, "top_p": top_p})
61
+ with connect(url) as websocket:
62
+ websocket.send(params)
63
+
64
+ data = None
65
+ try:
66
+ while True:
67
+ data = websocket.recv()
68
+ data = json.loads(data)
69
+ resp = data['text']
70
+
71
+ last_turn = list(chat_history.pop(-1))
72
+ last_turn[-1] = resp
73
+ chat_history = chat_history + [last_turn]
74
+ yield resp, chat_history
75
+ except ConnectionClosedOK:
76
+ print("generation is finished")
77
+
78
+
79
+ class ChatGLM2ModelClient(ChatClient):
80
+ def __init__(self, model_path=None):
81
+ self.model = ChatGLM2(model_path)
82
+
83
+ def simple_chat(self, query, history, temperature, top_p):
84
+ kwargs = {
85
+ "query": query, "history": history,
86
+ "temperature": temperature, "top_p": top_p,
87
+ }
88
+ for resp, history in self.model.stream_chat(**kwargs):
89
+ yield resp, history
90
+
91
+ def instruct_chat(self, message, chat_history, instructions, temperature, top_p):
92
+ prompt = format_chat_prompt(message, chat_history, instructions)
93
+ chat_history = chat_history + [[message, ""]]
94
+ kwargs = {"prompt": prompt, "temperature": temperature, "top_p": top_p}
95
+ for resp in self.model.stream_generate(**kwargs):
96
+ last_turn = list(chat_history.pop(-1))
97
+ last_turn[-1] = resp
98
+ chat_history = chat_history + [last_turn]
99
+ yield resp, chat_history
chatglm2_6b/__init__.py ADDED
File without changes
chatglm2_6b/modelClient.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModel
5
+ from transformers.generation.logits_process import LogitsProcessor
6
+ from transformers.generation.utils import LogitsProcessorList
7
+
8
+ DEFAULT_MODEL_PATH = "THUDM/chatglm2-6b"
9
+
10
+
11
+ class InvalidScoreLogitsProcessor(LogitsProcessor):
12
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
13
+ if torch.isnan(scores).any() or torch.isinf(scores).any():
14
+ scores.zero_()
15
+ scores[..., 5] = 5e4
16
+ return scores
17
+
18
+
19
+ class ChatGLM2(object):
20
+ def __init__(self, model_path=None):
21
+ if not model_path:
22
+ self.model_path = DEFAULT_MODEL_PATH
23
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
24
+ model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True).half().cuda()
25
+ self.model = model.eval()
26
+
27
+ def generate(
28
+ self,
29
+ prompt: str,
30
+ do_sample: bool = True,
31
+ max_length: int = 8192,
32
+ num_beams: int = 1,
33
+ temperature: float = 0.8,
34
+ top_p: float = 0.8,
35
+ ):
36
+ logits_processor = LogitsProcessorList()
37
+ logits_processor.append(InvalidScoreLogitsProcessor())
38
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
39
+ "temperature": temperature, "logits_processor": logits_processor}
40
+ inputs = self.tokenizer([prompt], return_tensors="pt")
41
+ inputs = inputs.to(self.model.device)
42
+ outputs = self.model.generate(**inputs, **gen_kwargs)
43
+ outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
44
+ response = self.tokenizer.decode(outputs)
45
+ response = self.model.process_response(response)
46
+ return response
47
+
48
+ def stream_generate(
49
+ self,
50
+ prompt: str,
51
+ do_sample: bool = True,
52
+ max_length: int = 8192,
53
+ temperature: float = 0.8,
54
+ top_p: float = 0.8,
55
+ ):
56
+ logits_processor = LogitsProcessorList()
57
+ logits_processor.append(InvalidScoreLogitsProcessor())
58
+ gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
59
+ "temperature": temperature, "logits_processor": logits_processor}
60
+ inputs = self.tokenizer([prompt], return_tensors="pt")
61
+ inputs = inputs.to(self.model.device)
62
+ for outputs in self.model.stream_generate(**inputs, **gen_kwargs):
63
+ outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
64
+ response = self.tokenizer.decode(outputs)
65
+ if response and response[-1] != "�":
66
+ response = self.model.process_response(response)
67
+ yield response
68
+
69
+ def stream_chat(
70
+ self,
71
+ query: str,
72
+ history: List[Tuple[str, str]],
73
+ max_length: int = 8192,
74
+ do_sample=True,
75
+ top_p=0.8,
76
+ temperature=0.8
77
+ ):
78
+ stream = self.model.stream_chat(self.tokenizer, query, history,
79
+ max_length=max_length, do_sample=do_sample, top_p=top_p, temperature=temperature)
80
+ for resp, new_history in stream:
81
+ yield resp, new_history
82
+
chatglm2_6b/server.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+
3
+ import anyio
4
+ from fastapi import FastAPI, WebSocket
5
+ from pydantic import BaseModel
6
+
7
+ from chatglm2_6b.modelClient import ChatGLM2
8
+ from config import Settings
9
+
10
+ app = FastAPI()
11
+
12
+ chat_glm2 = ChatGLM2(Settings.CHATGLM_MODEL_PATH)
13
+
14
+
15
+ class ChatParams(BaseModel):
16
+ prompt: str
17
+ do_sample: bool = True
18
+ max_length: int = 2048
19
+ temperature: float = 0.8
20
+ top_p: float = 0.8
21
+
22
+
23
+ @app.post("/generate")
24
+ def generate(params: ChatParams):
25
+ input_params = params.dict()
26
+ text = chat_glm2.generate(**input_params)
27
+ return {"text": text}
28
+
29
+
30
+ @app.websocket("/streamGenerate")
31
+ async def stream_generate(websocket: WebSocket):
32
+ await websocket.accept()
33
+ params = await websocket.receive_json()
34
+ func = functools.partial(chat_glm2.stream_generate, **params)
35
+ stream = await anyio.to_thread.run_sync(func)
36
+ for resp in stream:
37
+ await websocket.send_json({"text": resp})
38
+ await websocket.close()
39
+
40
+
41
+ @app.websocket("/streamChat")
42
+ async def stream_chat(websocket: WebSocket):
43
+ await websocket.accept()
44
+ params = await websocket.receive_json()
45
+ func = functools.partial(chat_glm2.stream_chat, **params)
46
+ stream = await anyio.to_thread.run_sync(func)
47
+ for resp, history in stream:
48
+ await websocket.send_json({"resp": resp, "history": history})
49
+ await websocket.close()
config.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+
4
+ class Settings:
5
+ CHAT_CLIENT = os.environ.get('CHAT_CLIENT', "ChatGLM2APIClient")
6
+ MODEL_WS_URL = os.environ.get('MODEL_WS_URL', "ws://localhost:10001")
7
+ CHATGLM_MODEL_PATH = os.environ.get('CHATGLM_MODEL_PATH', "THUDM/chatglm2-6b")
gallery.gradio.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from apps import translator_demo, simple_chat_demo, instruction_chat_demo
3
+ from chatClient import ChatClient, ChatGLM2APIClient, ChatGLM2ModelClient
4
+ from config import Settings
5
+
6
+ TITLE = """<h2 align="center">🚀 ChatGLM2-6B apps gallery</h2>"""
7
+
8
+ demo_register = {
9
+ "通用对话": simple_chat_demo,
10
+ "预设指令对话": instruction_chat_demo,
11
+ "翻译器": translator_demo,
12
+ }
13
+
14
+
15
+ def get_gallery(client: ChatClient):
16
+ with gr.Blocks(
17
+ # css=None
18
+ # css="""#chat_container {width: 700px; margin-left: auto; margin-right: auto;}
19
+ # #button_container {width: 700px; margin-left: auto; margin-right: auto;}
20
+ # #param_container {width: 700px; margin-left: auto; margin-right: auto;}"""
21
+ css="""#chatbot {
22
+ font-size: 14px;
23
+ min-height: 300px;
24
+ }"""
25
+ ) as demo:
26
+ gr.HTML(TITLE)
27
+ for name, demo_func in demo_register.items():
28
+ with gr.Tab(name):
29
+ demo_func(client)
30
+ return demo
31
+
32
+
33
+ def build_client():
34
+ client_class = Settings.CHAT_CLIENT
35
+ if client_class == 'ChatGLM2ModelClient':
36
+
37
+ return ChatGLM2ModelClient(Settings.CHATGLM_MODEL_PATH)
38
+ if client_class == 'ChatGLM2APIClient':
39
+ return ChatGLM2APIClient(Settings.MODEL_WS_URL)
40
+ raise Exception(f"Wrong ChatClient: {client_class}")
41
+
42
+
43
+ if __name__ == "__main__":
44
+ client = build_client()
45
+ demo = get_gallery(client)
46
+ demo.queue(max_size=128, concurrency_count=16)
47
+ demo.launch()
runserver.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uvicorn
2
+
3
+ from chatglm2_6b.server import app
4
+
5
+
6
+ def runserver():
7
+ uvicorn.run(app, host="0.0.0.0", port=10001)
8
+
9
+
10
+ if __name__ == '__main__':
11
+ runserver()