zyliu commited on
Commit
f289b70
·
1 Parent(s): 8fc777e

update gradio demo

Browse files
.streamlit/config.toml DELETED
@@ -1,7 +0,0 @@
1
- [server]
2
- enableStaticServing = false
3
- enableXsrfProtection = false
4
- enableCORS = false
5
-
6
- [browser] # This ip and port will show in command prompt
7
- enableCORS = false
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -3,8 +3,8 @@ title: InternVL
3
  emoji: ⚡
4
  colorFrom: yellow
5
  colorTo: gray
6
- sdk: streamlit
7
- sdk_version: 1.28.2
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
3
  emoji: ⚡
4
  colorFrom: yellow
5
  colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 4.36.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
app.py CHANGED
@@ -1,60 +1,116 @@
1
- import streamlit as st
2
-
3
- st.set_page_config(layout="wide")
4
-
5
- hide_streamlit_style = """
6
- <style>
7
- /* Hide the Streamlit header and menu */
8
- header {visibility: hidden;}
9
- </style>
10
- """
11
-
12
- st.markdown(hide_streamlit_style, unsafe_allow_html=True)
13
-
14
- st.markdown(
15
- """
16
- <style>
17
- html, body, .fullScreenFrame, .fullScreenFrame iframe {
18
- margin: 0;
19
- padding: 0;
20
- height: 100%;
21
- width: 100%;
22
- border: none;
23
- display: block;
24
- overflow: hidden;
25
- }
26
-
27
- .fullScreenFrame {
28
- position: fixed;
29
- top: 0;
30
- left: 0;
31
- right: 0;
32
- bottom: 0;
33
- z-index: 9999;
34
- }
35
-
36
- .main .block-container {
37
- padding: 0;
38
- margin: 0;
39
- height: 100vh;
40
- }
41
-
42
- /* Hide Streamlit header and footer */
43
- header, footer {
44
- display: none;
45
- }
46
- </style>
47
- """,
48
- unsafe_allow_html=True,
49
- )
50
-
51
- # Embed the external Streamlit webpage
52
- st.markdown(
53
- """
54
- <div class="fullScreenFrame">
55
- <iframe src="https://internvl.opengvlab.com/"></iframe>
56
- </div>
57
- """,
58
- unsafe_allow_html=True,
59
- )
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fire
2
+ import subprocess
3
+ import os
4
+ import time
5
+ import signal
6
+ import subprocess
7
+ import atexit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+
10
+ def kill_processes_by_cmd_substring(cmd_substring):
11
+ # execute `ps -ef` and obtain its output
12
+ result = subprocess.run(["ps", "-ef"], stdout=subprocess.PIPE, text=True)
13
+ lines = result.stdout.splitlines()
14
+
15
+ # visit each line
16
+ for line in lines:
17
+ if cmd_substring in line:
18
+ # extract PID
19
+ parts = line.split()
20
+ pid = int(parts[1])
21
+ print(f"Killing process with PID: {pid}, CMD: {line}")
22
+ os.kill(pid, signal.SIGTERM)
23
+
24
+
25
+ def main(
26
+ python_path="python",
27
+ run_controller=True,
28
+ run_worker=True,
29
+ run_gradio=True,
30
+ controller_port=10086,
31
+ gradio_port=10087,
32
+ worker_names=[
33
+ "OpenGVLab/InternVL2-8B",
34
+ ],
35
+ run_sd_worker=False,
36
+ **kwargs,
37
+ ):
38
+ host = "http://0.0.0.0"
39
+ controller_process = None
40
+ if run_controller:
41
+ # python controller.py --host 0.0.0.0 --port 10086
42
+ cmd_args = [
43
+ f"{python_path}",
44
+ "controller.py",
45
+ "--host",
46
+ "0.0.0.0",
47
+ "--port",
48
+ f"{controller_port}",
49
+ ]
50
+ kill_processes_by_cmd_substring(" ".join(cmd_args))
51
+ print("Launching controller: ", " ".join(cmd_args))
52
+ controller_process = subprocess.Popen(cmd_args)
53
+ atexit.register(controller_process.terminate)
54
+
55
+ worker_processes = []
56
+ if run_worker:
57
+ worker_port = 10088
58
+ for worker_name in worker_names:
59
+ cmd_args = [
60
+ f"{python_path}",
61
+ "model_worker.py",
62
+ "--port",
63
+ f"{worker_port}",
64
+ "--controller-url",
65
+ f"{host}:{controller_port}",
66
+ "--model-path",
67
+ f"{worker_name}",
68
+ "--load-8bit",
69
+ ]
70
+ kill_processes_by_cmd_substring(" ".join(cmd_args))
71
+ print("Launching worker: ", " ".join(cmd_args))
72
+ worker_process = subprocess.Popen(cmd_args)
73
+ worker_processes.append(worker_process)
74
+ atexit.register(worker_process.terminate)
75
+ worker_port += 1
76
+
77
+ time.sleep(10)
78
+ gradio_process = None
79
+ if run_gradio:
80
+ # python gradio_web_server.py --port 10088 --controller-url http://0.0.0.0:10086
81
+ cmd_args = [
82
+ f"{python_path}",
83
+ "gradio_web_server.py",
84
+ "--port",
85
+ f"{gradio_port}",
86
+ "--controller-url",
87
+ f"{host}:{controller_port}",
88
+ "--model-list-mode",
89
+ "reload",
90
+ ]
91
+ kill_processes_by_cmd_substring(" ".join(cmd_args))
92
+ print("Launching gradio: ", " ".join(cmd_args))
93
+ gradio_process = subprocess.Popen(cmd_args)
94
+ atexit.register(gradio_process.terminate)
95
+
96
+ sd_worker_process = None
97
+ if run_sd_worker:
98
+ # python model_worker.py --port 10088 --controller-address http://
99
+ cmd_args = [f"{python_path}", "sd_worker.py"]
100
+ kill_processes_by_cmd_substring(" ".join(cmd_args))
101
+ print("Launching sd_worker: ", " ".join(cmd_args))
102
+ sd_worker_process = subprocess.Popen(cmd_args)
103
+ atexit.register(sd_worker_process.terminate)
104
+
105
+ for worker_process in worker_processes:
106
+ worker_process.wait()
107
+ if controller_process:
108
+ controller_process.wait()
109
+ if gradio_process:
110
+ gradio_process.wait()
111
+ if sd_worker_process:
112
+ sd_worker_process.wait()
113
+
114
+
115
+ if __name__ == "__main__":
116
+ fire.Fire(main)
{static → assets}/SimHei.ttf RENAMED
File without changes
assets/assistant.png ADDED
assets/human.png ADDED
controller.py CHANGED
@@ -5,9 +5,9 @@ It sends worker addresses to clients.
5
  import argparse
6
  import dataclasses
7
  import json
 
8
  import threading
9
  import time
10
- import re
11
  from enum import Enum, auto
12
  from typing import List
13
 
@@ -113,6 +113,8 @@ class Controller:
113
  model_names.update(w_info.model_names)
114
 
115
  def extract_key(s):
 
 
116
  match = re.match(r'InternVL2-(\d+)B', s)
117
  if match:
118
  return int(match.group(1))
 
5
  import argparse
6
  import dataclasses
7
  import json
8
+ import re
9
  import threading
10
  import time
 
11
  from enum import Enum, auto
12
  from typing import List
13
 
 
113
  model_names.update(w_info.model_names)
114
 
115
  def extract_key(s):
116
+ if 'Pro' in s:
117
+ return 999
118
  match = re.match(r'InternVL2-(\d+)B', s)
119
  if match:
120
  return int(match.group(1))
conversation.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import dataclasses
3
+ import base64
4
+ import copy
5
+ import hashlib
6
+ import datetime
7
+ from io import BytesIO
8
+ from PIL import Image
9
+ from typing import Any, List, Dict, Union
10
+ from dataclasses import field
11
+
12
+ from utils import LOGDIR
13
+
14
+
15
+ def pil2base64(img: Image.Image) -> str:
16
+ buffered = BytesIO()
17
+ img.save(buffered, format="PNG")
18
+ return base64.b64encode(buffered.getvalue()).decode()
19
+
20
+
21
+ def resize_img(img: Image.Image, max_len: int, min_len: int) -> Image.Image:
22
+ max_hw, min_hw = max(img.size), min(img.size)
23
+ aspect_ratio = max_hw / min_hw
24
+ # max_len, min_len = 800, 400
25
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
26
+ longest_edge = int(shortest_edge * aspect_ratio)
27
+ W, H = img.size
28
+ if H > W:
29
+ H, W = longest_edge, shortest_edge
30
+ else:
31
+ H, W = shortest_edge, longest_edge
32
+ return img.resize((W, H))
33
+
34
+
35
+ @dataclasses.dataclass
36
+ class Conversation:
37
+ """A class that keeps all conversation history."""
38
+
39
+ SYSTEM = "system"
40
+ USER = "user"
41
+ ASSISTANT = "assistant"
42
+
43
+ roles: List[str] = field(
44
+ default_factory=lambda: [
45
+ Conversation.SYSTEM,
46
+ Conversation.USER,
47
+ Conversation.ASSISTANT,
48
+ ]
49
+ )
50
+ mandatory_system_message = "我是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。"
51
+ system_message: str = "请尽可能详细地回答用户的问题。"
52
+ messages: List[Dict[str, Any]] = field(default_factory=lambda: [])
53
+ max_image_limit: int = 4
54
+ skip_next: bool = False
55
+ streaming_placeholder: str = "▌"
56
+
57
+ def get_system_message(self):
58
+ return self.mandatory_system_message + "\n\n" + self.system_message
59
+
60
+ def set_system_message(self, system_message: str):
61
+ self.system_message = system_message
62
+ return self
63
+
64
+ def get_prompt(self, inlude_image=False):
65
+ send_messages = [{"role": "system", "content": self.get_system_message()}]
66
+ # send_messages = []
67
+ for message in self.messages:
68
+ if message["role"] == self.USER:
69
+ user_message = {
70
+ "role": self.USER,
71
+ "content": message["content"],
72
+ }
73
+ if inlude_image and "image" in message:
74
+ user_message["image"] = []
75
+ for image in message["image"]:
76
+ user_message["image"].append(pil2base64(image))
77
+ send_messages.append(user_message)
78
+ elif message["role"] == self.ASSISTANT:
79
+ send_messages.append(
80
+ {"role": self.ASSISTANT, "content": message["content"]}
81
+ )
82
+ elif message["role"] == self.SYSTEM:
83
+ send_messages.append(
84
+ {
85
+ "role": self.SYSTEM,
86
+ "content": message["content"],
87
+ }
88
+ )
89
+ else:
90
+ raise ValueError(f"Invalid role: {message['role']}")
91
+ return send_messages
92
+
93
+ def append_message(
94
+ self,
95
+ role,
96
+ content,
97
+ image_list=None,
98
+ ):
99
+ self.messages.append(
100
+ {
101
+ "role": role,
102
+ "content": content,
103
+ "image": [] if image_list is None else image_list,
104
+ # "filenames": save_filenames,
105
+ }
106
+ )
107
+
108
+ def get_images(
109
+ self,
110
+ return_copy=False,
111
+ return_base64=False,
112
+ source: Union[str, None] = None,
113
+ ):
114
+ assert source in [self.USER, self.ASSISTANT, None], f"Invalid source: {soure}"
115
+ images = []
116
+ for i, msg in enumerate(self.messages):
117
+ if source and msg["role"] != source:
118
+ continue
119
+
120
+ for image in msg.get("image", []):
121
+ # org_image = [i.copy() for i in image]
122
+ if return_copy:
123
+ image = image.copy()
124
+
125
+ if return_base64:
126
+ image = pil2base64(image)
127
+
128
+ images.append(image)
129
+
130
+ return images
131
+
132
+ def to_gradio_chatbot(self):
133
+ ret = []
134
+ for i, msg in enumerate(self.messages):
135
+ if msg["role"] == self.SYSTEM:
136
+ continue
137
+
138
+ alt_str = (
139
+ "user upload image" if msg["role"] == self.USER else "output image"
140
+ )
141
+ image = msg.get("image", [])
142
+ if not isinstance(image, list):
143
+ images = [image]
144
+ else:
145
+ images = image
146
+
147
+ img_str_list = []
148
+ for i in range(len(images)):
149
+ image = resize_img(
150
+ images[i],
151
+ 400,
152
+ 800,
153
+ )
154
+ img_b64_str = pil2base64(image)
155
+ W, H = image.size
156
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="{alt_str}" style="width: {W}px; max-width:none; max-height:none"></img>'
157
+ img_str = (
158
+ f'<img src="data:image/png;base64,{img_b64_str}" alt="{alt_str}" />'
159
+ )
160
+ img_str_list.append(img_str)
161
+
162
+ if msg["role"] == self.USER:
163
+ msg_str = " ".join(img_str_list) + msg["content"]
164
+ ret.append([msg_str, None])
165
+ else:
166
+ msg_str = msg["content"] + " ".join(img_str_list)
167
+ ret[-1][-1] = msg_str
168
+ return ret
169
+
170
+ def update_message(self, role, content, image=None, idx=-1):
171
+ assert len(self.messages) > 0, "No message in the conversation."
172
+
173
+ idx = (idx + len(self.messages)) % len(self.messages)
174
+
175
+ assert (
176
+ self.messages[idx]["role"] == role
177
+ ), f"Role mismatch: {role} vs {self.messages[idx]['role']}"
178
+
179
+ self.messages[idx]["content"] = content
180
+ if image is not None:
181
+ if image not in self.messages[idx]["image"]:
182
+ self.messages[idx]["image"] = []
183
+ if not isinstance(image, list):
184
+ image = [image]
185
+ self.messages[idx]["image"].extend(image)
186
+
187
+ def return_last_message(self):
188
+ return self.messages[-1]["content"]
189
+
190
+ def end_of_current_turn(self):
191
+ assert len(self.messages) > 0, "No message in the conversation."
192
+ assert (
193
+ self.messages[-1]["role"] == self.ASSISTANT
194
+ ), f"It should end with the message from assistant instead of {self.messages[-1]['role']}."
195
+
196
+ if self.messages[-1]["content"][-1] != self.streaming_placeholder:
197
+ return
198
+
199
+ self.update_message(self.ASSISTANT, self.messages[-1]["content"][:-1], None)
200
+
201
+ def copy(self):
202
+ return Conversation(
203
+ mandatory_system_message=self.mandatory_system_message,
204
+ system_message=self.system_message,
205
+ roles=copy.deepcopy(self.roles),
206
+ messages=copy.deepcopy(self.messages),
207
+ )
208
+
209
+ def dict(self):
210
+ """
211
+ all_images = state.get_images()
212
+ all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
213
+ t = datetime.datetime.now()
214
+ for image, hash in zip(all_images, all_image_hash):
215
+ filename = os.path.join(
216
+ LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg"
217
+ )
218
+ if not os.path.isfile(filename):
219
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
220
+ image.save(filename)
221
+ """
222
+ messages = []
223
+ for message in self.messages:
224
+ images = []
225
+ for image in message.get("image", []):
226
+ filename = self.save_image(image)
227
+ images.append(filename)
228
+
229
+ messages.append(
230
+ {
231
+ "role": message["role"],
232
+ "content": message["content"],
233
+ "image": images,
234
+ }
235
+ )
236
+ if len(images) == 0:
237
+ messages[-1].pop("image")
238
+
239
+ return {
240
+ "mandatory_system_message": self.mandatory_system_message,
241
+ "system_message": self.system_message,
242
+ "roles": self.roles,
243
+ "messages": messages,
244
+ }
245
+
246
+ def save_image(self, image: Image.Image) -> str:
247
+ t = datetime.datetime.now()
248
+ image_hash = hashlib.md5(image.tobytes()).hexdigest()
249
+ filename = os.path.join(
250
+ LOGDIR,
251
+ "serve_images",
252
+ f"{t.year}-{t.month:02d}-{t.day:02d}",
253
+ f"{image_hash}.jpg",
254
+ )
255
+ if not os.path.isfile(filename):
256
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
257
+ image.save(filename)
258
+
259
+ return filename
gallery/child_1.jpg ADDED
gallery/child_2.jpg ADDED
gallery/child_3.jpg ADDED
gradio_web_server.py ADDED
@@ -0,0 +1,824 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from ast import parse
3
+ import datetime
4
+ import json
5
+ import os
6
+ import time
7
+ import hashlib
8
+ import re
9
+
10
+ import gradio as gr
11
+ import requests
12
+ import random
13
+ from filelock import FileLock
14
+ from io import BytesIO
15
+ from PIL import Image, ImageDraw, ImageFont
16
+
17
+ from constants import LOGDIR
18
+ from utils import (
19
+ build_logger,
20
+ server_error_msg,
21
+ violates_moderation,
22
+ moderation_msg,
23
+ load_image_from_base64,
24
+ get_log_filename,
25
+ )
26
+ from conversation import Conversation
27
+
28
+ logger = build_logger("gradio_web_server", "gradio_web_server.log")
29
+
30
+ headers = {"User-Agent": "InternVL-Chat Client"}
31
+
32
+ no_change_btn = gr.Button()
33
+ enable_btn = gr.Button(interactive=True)
34
+ disable_btn = gr.Button(interactive=False)
35
+
36
+
37
+ def write2file(path, content):
38
+ lock = FileLock(f"{path}.lock")
39
+ with lock:
40
+ with open(path, "a") as fout:
41
+ fout.write(content)
42
+
43
+
44
+ def sort_models(models):
45
+ def custom_sort_key(model_name):
46
+ # InternVL-Chat-V1-5 should be the first item
47
+ if model_name == "InternVL-Chat-V1-5":
48
+ return (1, model_name) # 1 indicates highest precedence
49
+ elif model_name.startswith("InternVL-Chat-V1-5-"):
50
+ return (1, model_name) # 1 indicates highest precedence
51
+ else:
52
+ return (0, model_name) # 0 indicates normal order
53
+
54
+ models.sort(key=custom_sort_key, reverse=True)
55
+ try: # We have five InternVL-Chat-V1-5 models, randomly choose one to be the first
56
+ first_three = models[:4]
57
+ random.shuffle(first_three)
58
+ models[:4] = first_three
59
+ except:
60
+ pass
61
+ return models
62
+
63
+
64
+ def get_model_list():
65
+ ret = requests.post(args.controller_url + "/refresh_all_workers")
66
+ assert ret.status_code == 200
67
+ ret = requests.post(args.controller_url + "/list_models")
68
+ models = ret.json()["models"]
69
+ models = sort_models(models)
70
+
71
+ logger.info(f"Models: {models}")
72
+ return models
73
+
74
+
75
+ get_window_url_params = """
76
+ function() {
77
+ const params = new URLSearchParams(window.location.search);
78
+ url_params = Object.fromEntries(params);
79
+ console.log(url_params);
80
+ return url_params;
81
+ }
82
+ """
83
+
84
+
85
+ def init_state(state=None):
86
+ if state is not None:
87
+ del state
88
+ return Conversation()
89
+
90
+
91
+ def find_bounding_boxes(state, response):
92
+ pattern = re.compile(r"<ref>\s*(.*?)\s*</ref>\s*<box>\s*(\[\[.*?\]\])\s*</box>")
93
+ matches = pattern.findall(response)
94
+ results = []
95
+ for match in matches:
96
+ results.append((match[0], eval(match[1])))
97
+ returned_image = None
98
+ latest_image = state.get_images(source=state.USER)[-1]
99
+ returned_image = latest_image.copy()
100
+ width, height = returned_image.size
101
+ draw = ImageDraw.Draw(returned_image)
102
+ for result in results:
103
+ line_width = max(1, int(min(width, height) / 200))
104
+ random_color = (
105
+ random.randint(0, 128),
106
+ random.randint(0, 128),
107
+ random.randint(0, 128),
108
+ )
109
+ category_name, coordinates = result
110
+ coordinates = [
111
+ (
112
+ float(x[0]) / 1000,
113
+ float(x[1]) / 1000,
114
+ float(x[2]) / 1000,
115
+ float(x[3]) / 1000,
116
+ )
117
+ for x in coordinates
118
+ ]
119
+ coordinates = [
120
+ (
121
+ int(x[0] * width),
122
+ int(x[1] * height),
123
+ int(x[2] * width),
124
+ int(x[3] * height),
125
+ )
126
+ for x in coordinates
127
+ ]
128
+ for box in coordinates:
129
+ draw.rectangle(box, outline=random_color, width=line_width)
130
+ font = ImageFont.truetype("assets/SimHei.ttf", int(20 * line_width / 2))
131
+ text_size = font.getbbox(category_name)
132
+ text_width, text_height = (
133
+ text_size[2] - text_size[0],
134
+ text_size[3] - text_size[1],
135
+ )
136
+ text_position = (box[0], max(0, box[1] - text_height))
137
+ draw.rectangle(
138
+ [
139
+ text_position,
140
+ (text_position[0] + text_width, text_position[1] + text_height),
141
+ ],
142
+ fill=random_color,
143
+ )
144
+ draw.text(text_position, category_name, fill="white", font=font)
145
+ return returned_image if len(matches) > 0 else None
146
+
147
+
148
+ def query_image_generation(response, sd_worker_url, timeout=15):
149
+ if not sd_worker_url:
150
+ return None
151
+ sd_worker_url = f"{sd_worker_url}/generate_image/"
152
+ pattern = r"```drawing-instruction\n(.*?)\n```"
153
+ match = re.search(pattern, response, re.DOTALL)
154
+ if match:
155
+ payload = {"caption": match.group(1)}
156
+ print("drawing-instruction:", payload)
157
+ response = requests.post(sd_worker_url, json=payload, timeout=timeout)
158
+ response.raise_for_status() # 检查HTTP请求是否成功
159
+ image = Image.open(BytesIO(response.content))
160
+ return image
161
+ else:
162
+ return None
163
+
164
+
165
+ def load_demo(url_params, request: gr.Request):
166
+ logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
167
+
168
+ dropdown_update = gr.Dropdown(visible=True)
169
+ if "model" in url_params:
170
+ model = url_params["model"]
171
+ if model in models:
172
+ dropdown_update = gr.Dropdown(value=model, visible=True)
173
+
174
+ state = init_state()
175
+ return state, dropdown_update
176
+
177
+
178
+ def load_demo_refresh_model_list(request: gr.Request):
179
+ logger.info(f"load_demo. ip: {request.client.host}")
180
+ models = get_model_list()
181
+ state = init_state()
182
+ dropdown_update = gr.Dropdown(
183
+ choices=models, value=models[0] if len(models) > 0 else ""
184
+ )
185
+ return state, dropdown_update
186
+
187
+
188
+ def vote_last_response(state, liked, model_selector, request: gr.Request):
189
+ conv_data = {
190
+ "tstamp": round(time.time(), 4),
191
+ "like": liked,
192
+ "model": model_selector,
193
+ "state": state.dict(),
194
+ "ip": request.client.host,
195
+ }
196
+ write2file(get_log_filename(), json.dumps(conv_data) + "\n")
197
+
198
+
199
+ def upvote_last_response(state, model_selector, request: gr.Request):
200
+ logger.info(f"upvote. ip: {request.client.host}")
201
+ vote_last_response(state, True, model_selector, request)
202
+ textbox = gr.MultimodalTextbox(value=None, interactive=True)
203
+ return (textbox,) + (disable_btn,) * 3
204
+
205
+
206
+ def downvote_last_response(state, model_selector, request: gr.Request):
207
+ logger.info(f"downvote. ip: {request.client.host}")
208
+ vote_last_response(state, False, model_selector, request)
209
+ textbox = gr.MultimodalTextbox(value=None, interactive=True)
210
+ return (textbox,) + (disable_btn,) * 3
211
+
212
+
213
+ def vote_selected_response(
214
+ state, model_selector, request: gr.Request, data: gr.LikeData
215
+ ):
216
+ logger.info(
217
+ f"Vote: {data.liked}, index: {data.index}, value: {data.value} , ip: {request.client.host}"
218
+ )
219
+ conv_data = {
220
+ "tstamp": round(time.time(), 4),
221
+ "like": data.liked,
222
+ "index": data.index,
223
+ "model": model_selector,
224
+ "state": state.dict(),
225
+ "ip": request.client.host,
226
+ }
227
+ write2file(get_log_filename(), json.dumps(conv_data) + "\n")
228
+ return
229
+
230
+
231
+ def flag_last_response(state, model_selector, request: gr.Request):
232
+ logger.info(f"flag. ip: {request.client.host}")
233
+ vote_last_response(state, "flag", model_selector, request)
234
+ textbox = gr.MultimodalTextbox(value=None, interactive=True)
235
+ return (textbox,) + (disable_btn,) * 3
236
+
237
+
238
+ def regenerate(state, image_process_mode, request: gr.Request):
239
+ logger.info(f"regenerate. ip: {request.client.host}")
240
+ # state.messages[-1][-1] = None
241
+ state.update_message(Conversation.ASSISTANT, None, -1)
242
+ prev_human_msg = state.messages[-2]
243
+ if type(prev_human_msg[1]) in (tuple, list):
244
+ prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
245
+ state.skip_next = False
246
+ textbox = gr.MultimodalTextbox(value=None, interactive=True)
247
+ return (state, state.to_gradio_chatbot(), textbox) + (disable_btn,) * 5
248
+
249
+
250
+ def clear_history(request: gr.Request):
251
+ logger.info(f"clear_history. ip: {request.client.host}")
252
+ state = init_state()
253
+ textbox = gr.MultimodalTextbox(value=None, interactive=True)
254
+ return (state, state.to_gradio_chatbot(), textbox) + (disable_btn,) * 5
255
+
256
+
257
+ def change_system_prompt(state, system_prompt, request: gr.Request):
258
+ logger.info(f"Change system prompt. ip: {request.client.host}")
259
+ state.set_system_message(system_prompt)
260
+ return state
261
+
262
+
263
+ def add_text(state, message, system_prompt, request: gr.Request):
264
+ images = message.get("files", [])
265
+ text = message.get("text", "").strip()
266
+ logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
267
+ # import pdb; pdb.set_trace()
268
+ textbox = gr.MultimodalTextbox(value=None, interactive=False)
269
+ if len(text) <= 0 and len(images) == 0:
270
+ state.skip_next = True
271
+ return (state, state.to_gradio_chatbot(), textbox) + (no_change_btn,) * 5
272
+ if args.moderate:
273
+ flagged = violates_moderation(text)
274
+ if flagged:
275
+ state.skip_next = True
276
+ textbox = gr.MultimodalTextbox(
277
+ value={"text": moderation_msg}, interactive=True
278
+ )
279
+ return (state, state.to_gradio_chatbot(), textbox) + (no_change_btn,) * 5
280
+ images = [Image.open(path).convert("RGB") for path in images]
281
+
282
+ if len(images) > 0 and len(state.get_images(source=state.USER)) > 0:
283
+ state = init_state(state)
284
+ state.set_system_message(system_prompt)
285
+ state.append_message(Conversation.USER, text, images)
286
+ state.skip_next = False
287
+ return (state, state.to_gradio_chatbot(), textbox) + (disable_btn,) * 5
288
+
289
+
290
+ def http_bot(
291
+ state,
292
+ model_selector,
293
+ temperature,
294
+ top_p,
295
+ repetition_penalty,
296
+ max_new_tokens,
297
+ max_input_tiles,
298
+ # bbox_threshold,
299
+ # mask_threshold,
300
+ request: gr.Request,
301
+ ):
302
+ logger.info(f"http_bot. ip: {request.client.host}")
303
+ start_tstamp = time.time()
304
+ model_name = model_selector
305
+ if hasattr(state, "skip_next") and state.skip_next:
306
+ # This generate call is skipped due to invalid inputs
307
+ yield (
308
+ state,
309
+ state.to_gradio_chatbot(),
310
+ gr.MultimodalTextbox(interactive=False),
311
+ ) + (no_change_btn,) * 5
312
+ return
313
+
314
+ # Query worker address
315
+ controller_url = args.controller_url
316
+ ret = requests.post(
317
+ controller_url + "/get_worker_address", json={"model": model_name}
318
+ )
319
+ worker_addr = ret.json()["address"]
320
+ logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
321
+
322
+ # No available worker
323
+ if worker_addr == "":
324
+ # state.messages[-1][-1] = server_error_msg
325
+ state.update_message(Conversation.ASSISTANT, server_error_msg)
326
+ yield (
327
+ state,
328
+ state.to_gradio_chatbot(),
329
+ gr.MultimodalTextbox(interactive=False),
330
+ disable_btn,
331
+ disable_btn,
332
+ disable_btn,
333
+ enable_btn,
334
+ enable_btn,
335
+ )
336
+ return
337
+
338
+ all_images = state.get_images(source=state.USER)
339
+ all_image_paths = [state.save_image(image) for image in all_images]
340
+
341
+ # Make requests
342
+ pload = {
343
+ "model": model_name,
344
+ "prompt": state.get_prompt(),
345
+ "temperature": float(temperature),
346
+ "top_p": float(top_p),
347
+ "max_new_tokens": max_new_tokens,
348
+ "max_input_tiles": max_input_tiles,
349
+ # "bbox_threshold": bbox_threshold,
350
+ # "mask_threshold": mask_threshold,
351
+ "repetition_penalty": repetition_penalty,
352
+ "images": f"List of {len(all_images)} images: {all_image_paths}",
353
+ }
354
+ logger.info(f"==== request ====\n{pload}")
355
+ pload.pop("images")
356
+ pload["prompt"] = state.get_prompt(inlude_image=True)
357
+ state.append_message(Conversation.ASSISTANT, state.streaming_placeholder)
358
+ yield (
359
+ state,
360
+ state.to_gradio_chatbot(),
361
+ gr.MultimodalTextbox(interactive=False),
362
+ ) + (disable_btn,) * 5
363
+
364
+ try:
365
+ # Stream output
366
+ response = requests.post(
367
+ worker_addr + "/worker_generate_stream",
368
+ headers=headers,
369
+ json=pload,
370
+ stream=True,
371
+ timeout=20,
372
+ )
373
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
374
+ if chunk:
375
+ data = json.loads(chunk.decode())
376
+ if data["error_code"] == 0:
377
+ if "text" in data:
378
+ output = data["text"].strip()
379
+ output += state.streaming_placeholder
380
+
381
+ image = None
382
+ if "image" in data:
383
+ image = load_image_from_base64(data["image"])
384
+ _ = state.save_image(image)
385
+
386
+ state.update_message(Conversation.ASSISTANT, output, image)
387
+ yield (
388
+ state,
389
+ state.to_gradio_chatbot(),
390
+ gr.MultimodalTextbox(interactive=False),
391
+ ) + (disable_btn,) * 5
392
+ else:
393
+ output = (
394
+ f"**{data['text']}**" + f" (error_code: {data['error_code']})"
395
+ )
396
+
397
+ state.update_message(Conversation.ASSISTANT, output, None)
398
+ yield (
399
+ state,
400
+ state.to_gradio_chatbot(),
401
+ gr.MultimodalTextbox(interactive=True),
402
+ ) + (
403
+ disable_btn,
404
+ disable_btn,
405
+ disable_btn,
406
+ enable_btn,
407
+ enable_btn,
408
+ )
409
+ return
410
+ except requests.exceptions.RequestException as e:
411
+ state.update_message(Conversation.ASSISTANT, server_error_msg, None)
412
+ yield (
413
+ state,
414
+ state.to_gradio_chatbot(),
415
+ gr.MultimodalTextbox(interactive=True),
416
+ ) + (
417
+ disable_btn,
418
+ disable_btn,
419
+ disable_btn,
420
+ enable_btn,
421
+ enable_btn,
422
+ )
423
+ return
424
+
425
+ ai_response = state.return_last_message()
426
+ if "<ref>" in ai_response:
427
+ returned_image = find_bounding_boxes(state, ai_response)
428
+ returned_image = [returned_image] if returned_image else []
429
+ state.update_message(Conversation.ASSISTANT, ai_response, returned_image)
430
+ if "```drawing-instruction" in ai_response:
431
+ returned_image = query_image_generation(
432
+ ai_response, sd_worker_url=sd_worker_url
433
+ )
434
+ returned_image = [returned_image] if returned_image else []
435
+ state.update_message(Conversation.ASSISTANT, ai_response, returned_image)
436
+
437
+ state.end_of_current_turn()
438
+
439
+ yield (
440
+ state,
441
+ state.to_gradio_chatbot(),
442
+ gr.MultimodalTextbox(interactive=True),
443
+ ) + (enable_btn,) * 5
444
+
445
+ finish_tstamp = time.time()
446
+ logger.info(f"{output}")
447
+ data = {
448
+ "tstamp": round(finish_tstamp, 4),
449
+ "like": None,
450
+ "model": model_name,
451
+ "start": round(start_tstamp, 4),
452
+ "finish": round(start_tstamp, 4),
453
+ "state": state.dict(),
454
+ "images": all_image_paths,
455
+ "ip": request.client.host,
456
+ }
457
+ write2file(get_log_filename(), json.dumps(data) + "\n")
458
+
459
+
460
+ title_html = """
461
+ <h2> <span class="gradient-text" id="text">InternVL2</span><span class="plain-text">: Better than the Best—Expanding Performance Boundaries of Open-Source Multimodal Models with the Progressive Scaling Strategy</span></h2>
462
+ <a href="https://internvl.github.io/blog/2024-07-02-InternVL-2.0/">[📜 InternVL2 Blog]</a>
463
+ <a href="https://huggingface.co/spaces/OpenGVLab/InternVL">[🤗 HF Demo]</a>
464
+ <a href="https://github.com/OpenGVLab/InternVL?tab=readme-ov-file#quick-start-with-huggingface">[🚀 Quick Start]</a>
465
+ <a href="https://github.com/OpenGVLab/InternVL/blob/main/document/How_to_use_InternVL_API.md">[🌐 API]</a>
466
+ """
467
+
468
+ tos_markdown = """
469
+ ### Terms of use
470
+ By using this service, users are required to agree to the following terms:
471
+ The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
472
+ Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
473
+ For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
474
+ """
475
+
476
+
477
+ learn_more_markdown = """
478
+ ### License
479
+ The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
480
+
481
+ ### Acknowledgement
482
+ This demo is modified from LLaVA's demo. Thanks for their awesome work!
483
+ """
484
+ # .gradio-container {margin: 5px 10px 0 10px !important};
485
+ block_css = """
486
+ .gradio-container {margin: 0.1% 1% 0 1% !important; max-width: 98% !important;};
487
+ #buttons button {
488
+ min-width: min(120px,100%);
489
+ }
490
+
491
+ .gradient-text {
492
+ font-size: 28px;
493
+ width: auto;
494
+ font-weight: bold;
495
+ background: linear-gradient(45deg, red, orange, yellow, green, blue, indigo, violet);
496
+ background-clip: text;
497
+ -webkit-background-clip: text;
498
+ color: transparent;
499
+ }
500
+
501
+ .plain-text {
502
+ font-size: 22px;
503
+ width: auto;
504
+ font-weight: bold;
505
+ }
506
+ """
507
+
508
+ js = """
509
+ function createWaveAnimation() {
510
+ const text = document.getElementById('text');
511
+ var i = 0;
512
+ setInterval(function() {
513
+ const colors = [
514
+ 'red, orange, yellow, green, blue, indigo, violet, purple',
515
+ 'orange, yellow, green, blue, indigo, violet, purple, red',
516
+ 'yellow, green, blue, indigo, violet, purple, red, orange',
517
+ 'green, blue, indigo, violet, purple, red, orange, yellow',
518
+ 'blue, indigo, violet, purple, red, orange, yellow, green',
519
+ 'indigo, violet, purple, red, orange, yellow, green, blue',
520
+ 'violet, purple, red, orange, yellow, green, blue, indigo',
521
+ 'purple, red, orange, yellow, green, blue, indigo, violet',
522
+ ];
523
+ const angle = 45;
524
+ const colorIndex = i % colors.length;
525
+ text.style.background = `linear-gradient(${angle}deg, ${colors[colorIndex]})`;
526
+ text.style.webkitBackgroundClip = 'text';
527
+ text.style.backgroundClip = 'text';
528
+ text.style.color = 'transparent';
529
+ text.style.fontSize = '28px';
530
+ text.style.width = 'auto';
531
+ text.textContent = 'InternVL2';
532
+ text.style.fontWeight = 'bold';
533
+ i += 1;
534
+ }, 200);
535
+ const params = new URLSearchParams(window.location.search);
536
+ url_params = Object.fromEntries(params);
537
+ console.log(url_params);
538
+ return url_params;
539
+ }
540
+
541
+ """
542
+
543
+
544
+ def build_demo(embed_mode):
545
+ textbox = gr.MultimodalTextbox(
546
+ interactive=True,
547
+ file_types=["image", "video"],
548
+ placeholder="Enter message or upload file...",
549
+ show_label=False,
550
+ )
551
+
552
+ with gr.Blocks(
553
+ title="InternVL-Chat",
554
+ theme=gr.themes.Default(),
555
+ css=block_css,
556
+ ) as demo:
557
+ state = gr.State()
558
+
559
+ if not embed_mode:
560
+ # gr.Markdown(title_markdown)
561
+ gr.HTML(title_html)
562
+
563
+ with gr.Row():
564
+ with gr.Column(scale=2):
565
+
566
+ with gr.Row(elem_id="model_selector_row"):
567
+ model_selector = gr.Dropdown(
568
+ choices=models,
569
+ value=models[0] if len(models) > 0 else "",
570
+ # value="InternVL-Chat-V1-5",
571
+ interactive=True,
572
+ show_label=False,
573
+ container=False,
574
+ )
575
+
576
+ with gr.Accordion("System Prompt", open=False) as system_prompt_row:
577
+ system_prompt = gr.Textbox(
578
+ value="请尽可能详细地回答用户的问题。",
579
+ label="System Prompt",
580
+ interactive=True,
581
+ )
582
+ with gr.Accordion("Parameters", open=False) as parameter_row:
583
+ temperature = gr.Slider(
584
+ minimum=0.0,
585
+ maximum=1.0,
586
+ value=0.2,
587
+ step=0.1,
588
+ interactive=True,
589
+ label="Temperature",
590
+ )
591
+ top_p = gr.Slider(
592
+ minimum=0.0,
593
+ maximum=1.0,
594
+ value=0.7,
595
+ step=0.1,
596
+ interactive=True,
597
+ label="Top P",
598
+ )
599
+ repetition_penalty = gr.Slider(
600
+ minimum=1.0,
601
+ maximum=1.5,
602
+ value=1.1,
603
+ step=0.02,
604
+ interactive=True,
605
+ label="Repetition penalty",
606
+ )
607
+ max_output_tokens = gr.Slider(
608
+ minimum=0,
609
+ maximum=4096,
610
+ value=1024,
611
+ step=64,
612
+ interactive=True,
613
+ label="Max output tokens",
614
+ )
615
+ max_input_tiles = gr.Slider(
616
+ minimum=1,
617
+ maximum=32,
618
+ value=12,
619
+ step=1,
620
+ interactive=True,
621
+ label="Max input tiles (control the image size)",
622
+ )
623
+ examples = gr.Examples(
624
+ examples=[
625
+ [
626
+ {
627
+ "files": [
628
+ "gallery/prod_9.jpg",
629
+ ],
630
+ "text": "What's at the far end of the image?",
631
+ }
632
+ ],
633
+ [
634
+ {
635
+ "files": [
636
+ "gallery/astro_on_unicorn.png",
637
+ ],
638
+ "text": "What does this image mean?",
639
+ }
640
+ ],
641
+ [
642
+ {
643
+ "files": [
644
+ "gallery/prod_12.png",
645
+ ],
646
+ "text": "What are the consequences of the easy decisions shown in this image?",
647
+ }
648
+ ],
649
+ [
650
+ {
651
+ "files": [
652
+ "gallery/child_1.jpg",
653
+ "gallery/child_2.jpg",
654
+ f"gallery/child_3.jpg",
655
+ ],
656
+ "text": "这三帧图片讲述了一件什么事情?",
657
+ }
658
+ ],
659
+ ],
660
+ inputs=[textbox],
661
+ )
662
+
663
+ with gr.Column(scale=8):
664
+ chatbot = gr.Chatbot(
665
+ elem_id="chatbot",
666
+ label="InternVL2",
667
+ height=580,
668
+ show_copy_button=True,
669
+ show_share_button=True,
670
+ avatar_images=[
671
+ "assets/human.png",
672
+ "assets/assistant.png",
673
+ ],
674
+ bubble_full_width=False,
675
+ )
676
+ with gr.Row():
677
+ with gr.Column(scale=8):
678
+ textbox.render()
679
+ with gr.Column(scale=1, min_width=50):
680
+ submit_btn = gr.Button(value="Send", variant="primary")
681
+ with gr.Row(elem_id="buttons") as button_row:
682
+ upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
683
+ downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
684
+ flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
685
+ # stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
686
+ regenerate_btn = gr.Button(
687
+ value="🔄 Regenerate", interactive=False
688
+ )
689
+ clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
690
+
691
+ if not embed_mode:
692
+ gr.Markdown(tos_markdown)
693
+ gr.Markdown(learn_more_markdown)
694
+ url_params = gr.JSON(visible=False)
695
+
696
+ # Register listeners
697
+ btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
698
+ upvote_btn.click(
699
+ upvote_last_response,
700
+ [state, model_selector],
701
+ [textbox, upvote_btn, downvote_btn, flag_btn],
702
+ )
703
+ downvote_btn.click(
704
+ downvote_last_response,
705
+ [state, model_selector],
706
+ [textbox, upvote_btn, downvote_btn, flag_btn],
707
+ )
708
+ chatbot.like(
709
+ vote_selected_response,
710
+ [state, model_selector],
711
+ [],
712
+ )
713
+ flag_btn.click(
714
+ flag_last_response,
715
+ [state, model_selector],
716
+ [textbox, upvote_btn, downvote_btn, flag_btn],
717
+ )
718
+ regenerate_btn.click(
719
+ regenerate,
720
+ [state, system_prompt],
721
+ [state, chatbot, textbox] + btn_list,
722
+ ).then(
723
+ http_bot,
724
+ [
725
+ state,
726
+ model_selector,
727
+ temperature,
728
+ top_p,
729
+ repetition_penalty,
730
+ max_output_tokens,
731
+ max_input_tiles,
732
+ # bbox_threshold,
733
+ # mask_threshold,
734
+ ],
735
+ [state, chatbot, textbox] + btn_list,
736
+ )
737
+ clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
738
+
739
+ textbox.submit(
740
+ add_text,
741
+ [state, textbox, system_prompt],
742
+ [state, chatbot, textbox] + btn_list,
743
+ ).then(
744
+ http_bot,
745
+ [
746
+ state,
747
+ model_selector,
748
+ temperature,
749
+ top_p,
750
+ repetition_penalty,
751
+ max_output_tokens,
752
+ max_input_tiles,
753
+ # bbox_threshold,
754
+ # mask_threshold,
755
+ ],
756
+ [state, chatbot, textbox] + btn_list,
757
+ )
758
+ submit_btn.click(
759
+ add_text,
760
+ [state, textbox, system_prompt],
761
+ [state, chatbot, textbox] + btn_list,
762
+ ).then(
763
+ http_bot,
764
+ [
765
+ state,
766
+ model_selector,
767
+ temperature,
768
+ top_p,
769
+ repetition_penalty,
770
+ max_output_tokens,
771
+ max_input_tiles,
772
+ # bbox_threshold,
773
+ # mask_threshold,
774
+ ],
775
+ [state, chatbot, textbox] + btn_list,
776
+ )
777
+
778
+ if args.model_list_mode == "once":
779
+ demo.load(
780
+ load_demo,
781
+ [url_params],
782
+ [state, model_selector],
783
+ js=js,
784
+ )
785
+ elif args.model_list_mode == "reload":
786
+ demo.load(
787
+ load_demo_refresh_model_list,
788
+ None,
789
+ [state, model_selector],
790
+ js=js,
791
+ )
792
+ else:
793
+ raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
794
+
795
+ return demo
796
+
797
+
798
+ if __name__ == "__main__":
799
+ parser = argparse.ArgumentParser()
800
+ parser.add_argument("--host", type=str, default="0.0.0.0")
801
+ parser.add_argument("--port", type=int, default=11000)
802
+ parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
803
+ parser.add_argument("--concurrency-count", type=int, default=10)
804
+ parser.add_argument(
805
+ "--model-list-mode", type=str, default="once", choices=["once", "reload"]
806
+ )
807
+ parser.add_argument("--sd-worker-url", type=str, default=None)
808
+ parser.add_argument("--share", action="store_true")
809
+ parser.add_argument("--moderate", action="store_true")
810
+ parser.add_argument("--embed", action="store_true")
811
+ args = parser.parse_args()
812
+ logger.info(f"args: {args}")
813
+
814
+ models = get_model_list()
815
+
816
+ sd_worker_url = args.sd_worker_url
817
+ logger.info(args)
818
+ demo = build_demo(args.embed)
819
+ demo.queue(api_open=False).launch(
820
+ server_name=args.host,
821
+ server_port=args.port,
822
+ share=args.share,
823
+ max_threads=args.concurrency_count,
824
+ )
library.py DELETED
@@ -1,95 +0,0 @@
1
- # --------------------------------------------------------
2
- # InternVL
3
- # Copyright (c) 2024 OpenGVLab
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Modified from https://github.com/hreikin/streamlit-uploads-library/blob/main/streamlit_uploads_library/library.py
6
- # --------------------------------------------------------
7
-
8
- import logging
9
- from math import ceil
10
-
11
- import streamlit as st
12
-
13
- logger = logging.getLogger(__name__)
14
-
15
-
16
- class Library():
17
- """Create a simple library out of streamlit widgets.
18
-
19
- Using the library is simple, import `streamlit_uploads_library` and then instantiate the class with the
20
- required `directory` variable. Other options can be configured by passing in different variables
21
- when instantiating the class.
22
-
23
- Example Usage:
24
- python
25
- import streamlit as st
26
- from library import Library
27
-
28
- st.set_page_config(page_title="Streamlit Uploads Library", layout="wide")
29
- default_library = Library(images=pil_images)
30
- """
31
-
32
- def __init__(self, images, image_alignment='end', number_of_columns=5):
33
- self.images = images
34
- self.image_alignment = image_alignment
35
- self.number_of_columns = number_of_columns
36
- self.root_container = self.create(images=self.images,
37
- image_alignment=self.image_alignment,
38
- number_of_columns=self.number_of_columns)
39
-
40
- def create(_self, images, image_alignment, number_of_columns):
41
- """Creates a simple library or gallery with columns.
42
-
43
- Creates a library or gallery using columns out of streamlit widgets.
44
- """
45
- root_container = st.container()
46
- with root_container:
47
- # To be able to display the images, details and buttons all in one row and aligned
48
- # correctly so that images of different sizes don't affect the alignment of the details
49
- # and buttons we need do some minor maths and keep track of multiple index values.
50
- # First we instantiate some defaults.
51
- col_idx = 0
52
- filename_idx = 0
53
- max_idx = number_of_columns - 1
54
- # Get the file list and filename list, work out the total number of files from the
55
- # length of the file list.
56
- library_files = images
57
- num_of_files = len(library_files)
58
- # Work out the number of rows required by dividing the number of files by the number of
59
- # columns and rounding up using `math.ceil`.
60
- num_of_rows_req = ceil(num_of_files / number_of_columns)
61
- # Create the required number of rows (st.container).
62
- library_rows = list()
63
- library_rows_idx = 0
64
- for i in range(num_of_rows_req):
65
- library_rows.append(st.container())
66
- # For each library row we need to create separate rows (st.container) for images,
67
- # and rows (st.expander) for details and buttons to keep them in the correct columns.
68
- for idx in range(num_of_rows_req):
69
- with library_rows[library_rows_idx]:
70
- imgs_columns = list(st.columns(number_of_columns))
71
- # Since we are keeping track of the column and filename indexes we can use
72
- # those to slice the `library_files` list at the correct points for each row
73
- # and then increase or reset the indexes as required.
74
- for img in library_files[filename_idx:(filename_idx + number_of_columns)]:
75
- with imgs_columns[col_idx]:
76
- st.image(img, use_column_width='auto')
77
- st.write(
78
- f"""<style>
79
- [data-testid="stHorizontalBlock"] {{
80
- align-items: {image_alignment};
81
- }}
82
- </style>
83
- """,
84
- unsafe_allow_html=True
85
- )
86
- # Keeps track of the current column, if we reach the `max_idx` we reset it
87
- # to 0 and increase the row index. This combined with the slicing should
88
- # ensure all images, details and buttons are in the correct columns.
89
- if col_idx < max_idx:
90
- col_idx += 1
91
- else:
92
- col_idx = 0
93
- library_rows_idx += 1
94
- filename_idx += 1
95
- return root_container
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mm_utils.py DELETED
@@ -1,102 +0,0 @@
1
- import base64
2
- from io import BytesIO
3
-
4
- import torch
5
- from PIL import Image
6
- from transformers import StoppingCriteria
7
-
8
- from .constants import IMAGE_TOKEN_INDEX
9
-
10
-
11
- def load_image_from_base64(image):
12
- return Image.open(BytesIO(base64.b64decode(image)))
13
-
14
-
15
- def expand2square(pil_img, background_color):
16
- width, height = pil_img.size
17
- if width == height:
18
- return pil_img
19
- elif width > height:
20
- result = Image.new(pil_img.mode, (width, width), background_color)
21
- result.paste(pil_img, (0, (width - height) // 2))
22
- return result
23
- else:
24
- result = Image.new(pil_img.mode, (height, height), background_color)
25
- result.paste(pil_img, ((height - width) // 2, 0))
26
- return result
27
-
28
-
29
- def process_images(images, image_processor, model_cfg):
30
- image_aspect_ratio = getattr(model_cfg, 'image_aspect_ratio', None)
31
- new_images = []
32
- if image_aspect_ratio == 'pad':
33
- for image in images:
34
- image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
35
- image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
36
- new_images.append(image)
37
- else:
38
- return image_processor(images, return_tensors='pt')['pixel_values']
39
- if all(x.shape == new_images[0].shape for x in new_images):
40
- new_images = torch.stack(new_images, dim=0)
41
- return new_images
42
-
43
-
44
- def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX,
45
- num_image_tokens=None, return_tensors=None):
46
- prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
47
-
48
- def insert_separator(X, sep):
49
- return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
50
-
51
- input_ids = []
52
- offset = 0
53
- if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
54
- offset = 1
55
- input_ids.append(prompt_chunks[0][0])
56
-
57
- for x in insert_separator(prompt_chunks, [image_token_index] * (offset + num_image_tokens)):
58
- input_ids.extend(x[offset:])
59
-
60
- if return_tensors is not None:
61
- if return_tensors == 'pt':
62
- return torch.tensor(input_ids, dtype=torch.long)
63
- raise ValueError(f'Unsupported tensor type: {return_tensors}')
64
- return input_ids
65
-
66
-
67
- def get_model_name_from_path(model_path):
68
- model_path = model_path.strip('/')
69
- model_paths = model_path.split('/')
70
- if model_paths[-1].startswith('checkpoint-'):
71
- return model_paths[-2] + '_' + model_paths[-1]
72
- else:
73
- return model_paths[-1]
74
-
75
-
76
- class KeywordsStoppingCriteria(StoppingCriteria):
77
- def __init__(self, keywords, tokenizer, input_ids):
78
- self.keywords = keywords
79
- self.keyword_ids = []
80
- self.max_keyword_len = 0
81
- for keyword in keywords:
82
- cur_keyword_ids = tokenizer(keyword).input_ids
83
- if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
84
- cur_keyword_ids = cur_keyword_ids[1:]
85
- if len(cur_keyword_ids) > self.max_keyword_len:
86
- self.max_keyword_len = len(cur_keyword_ids)
87
- self.keyword_ids.append(torch.tensor(cur_keyword_ids))
88
- self.tokenizer = tokenizer
89
- self.start_len = input_ids.shape[1]
90
-
91
- def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
92
- assert output_ids.shape[0] == 1, 'Only support batch size 1 (yet)' # TODO
93
- offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
94
- self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
95
- for keyword_id in self.keyword_ids:
96
- if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
97
- return True
98
- outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
99
- for keyword in self.keywords:
100
- if keyword in outputs:
101
- return True
102
- return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model_worker.py CHANGED
@@ -9,14 +9,15 @@ A model worker executes the model.
9
  """
10
  import argparse
11
  import asyncio
12
- import base64
13
  import json
14
- import os
15
  import threading
16
  import time
17
  import uuid
 
18
  from functools import partial
19
- from io import BytesIO
20
  from threading import Thread
21
 
22
  import requests
@@ -28,33 +29,36 @@ from fastapi import BackgroundTasks, FastAPI, Request
28
  from fastapi.responses import StreamingResponse
29
  from PIL import Image
30
  from torchvision.transforms.functional import InterpolationMode
31
- from transformers import (AutoModelForCausalLM, AutoTokenizer,
32
- TextIteratorStreamer)
33
- from utils import build_logger, pretty_print_semaphore, server_error_msg
 
 
 
 
 
34
 
35
  worker_id = str(uuid.uuid4())[:6]
36
- logger = build_logger('model_worker', f'model_worker_{worker_id}.log')
37
  global_counter = 0
38
  model_semaphore = None
39
 
40
 
41
- def load_image_from_base64(image):
42
- return Image.open(BytesIO(base64.b64decode(image)))
43
-
44
-
45
  def build_transform(input_size):
46
  MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
47
- transform = T.Compose([
48
- T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
49
- T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
50
- T.ToTensor(),
51
- T.Normalize(mean=MEAN, std=STD)
52
- ])
 
 
53
  return transform
54
 
55
 
56
  def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
57
- best_ratio_diff = float('inf')
58
  best_ratio = (1, 1)
59
  area = width * height
60
  for ratio in target_ratios:
@@ -69,19 +73,26 @@ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_
69
  return best_ratio
70
 
71
 
72
- def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False):
 
 
73
  orig_width, orig_height = image.size
74
  aspect_ratio = orig_width / orig_height
75
 
76
  # calculate the existing image aspect ratio
77
  target_ratios = set(
78
- (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
79
- i * j <= max_num and i * j >= min_num)
 
 
 
 
80
  target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
81
 
82
  # find the closest aspect ratio to the target
83
  target_aspect_ratio = find_closest_aspect_ratio(
84
- aspect_ratio, target_ratios, orig_width, orig_height, image_size)
 
85
 
86
  # calculate the target width and height
87
  target_width = image_size * target_aspect_ratio[0]
@@ -96,7 +107,7 @@ def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnai
96
  (i % (target_width // image_size)) * image_size,
97
  (i // (target_width // image_size)) * image_size,
98
  ((i % (target_width // image_size)) + 1) * image_size,
99
- ((i // (target_width // image_size)) + 1) * image_size
100
  )
101
  # split the image
102
  split_img = resized_img.crop(box)
@@ -114,78 +125,163 @@ def heart_beat_worker(controller):
114
  controller.send_heart_beat()
115
 
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  class ModelWorker:
118
- def __init__(self, controller_addr, worker_addr, worker_id, model_path, model_name,
119
- load_8bit, device, context_len=8192):
 
 
 
 
 
 
 
 
 
120
  self.controller_addr = controller_addr
121
  self.worker_addr = worker_addr
122
  self.worker_id = worker_id
123
- if model_path.endswith('/'):
124
  model_path = model_path[:-1]
125
  if model_name is None:
126
- model_paths = model_path.split('/')
127
- if model_paths[-1].startswith('checkpoint-'):
128
- self.model_name = model_paths[-2] + '_' + model_paths[-1]
129
  else:
130
  self.model_name = model_paths[-1]
131
  else:
132
  self.model_name = model_name
133
 
134
- logger.info(f'Loading the model {self.model_name} on worker {worker_id} ...')
135
 
136
- self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False)
137
- if device == 'auto':
138
- os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
139
- # This can make distributed deployment work properly
140
- self.model = AutoModelForCausalLM.from_pretrained(
 
 
 
 
 
 
 
 
 
141
  model_path,
142
  load_in_8bit=load_8bit,
143
- torch_dtype=torch.float16,
144
- device_map='auto',
145
- trust_remote_code=True).eval()
 
146
  else:
147
- self.model = AutoModelForCausalLM.from_pretrained(
148
  model_path,
149
  load_in_8bit=load_8bit,
150
- torch_dtype=torch.float16,
151
- trust_remote_code=True).eval()
152
- if not load_8bit and not device == 'auto':
 
153
  self.model = self.model.cuda()
 
 
 
154
  self.image_size = self.model.config.force_image_size
155
  self.context_len = context_len
156
  self.register_to_controller()
157
  self.heart_beat_thread = threading.Thread(
158
- target=heart_beat_worker, args=(self,))
 
159
  self.heart_beat_thread.start()
160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  def register_to_controller(self):
162
- logger.info('Register to controller')
163
 
164
- url = self.controller_addr + '/register_worker'
165
  data = {
166
- 'worker_name': self.worker_addr,
167
- 'check_heart_beat': True,
168
- 'worker_status': self.get_status()
169
  }
170
  r = requests.post(url, json=data)
171
  assert r.status_code == 200
172
 
173
  def send_heart_beat(self):
174
- logger.info(f'Send heart beat. Models: {[self.model_name]}. '
175
- f'Semaphore: {pretty_print_semaphore(model_semaphore)}. '
176
- f'global_counter: {global_counter}')
 
 
177
 
178
- url = self.controller_addr + '/receive_heart_beat'
179
 
180
  while True:
181
  try:
182
- ret = requests.post(url, json={
183
- 'worker_name': self.worker_addr,
184
- 'queue_length': self.get_queue_length()}, timeout=5)
185
- exist = ret.json()['exist']
 
 
 
 
 
186
  break
187
  except requests.exceptions.RequestException as e:
188
- logger.error(f'heart beat error: {e}')
189
  time.sleep(5)
190
 
191
  if not exist:
@@ -195,80 +291,115 @@ class ModelWorker:
195
  if model_semaphore is None:
196
  return 0
197
  else:
198
- return args.limit_model_concurrency - model_semaphore._value + (len(
199
- model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
 
 
 
 
 
 
 
200
 
201
  def get_status(self):
202
  return {
203
- 'model_names': [self.model_name],
204
- 'speed': 1,
205
- 'queue_length': self.get_queue_length(),
206
  }
207
 
 
208
  @torch.inference_mode()
209
  def generate_stream(self, params):
210
- system_message = params['prompt'][0]['content']
211
- send_messages = params['prompt'][1:]
212
- max_input_tiles = params['max_input_tiles']
213
- temperature = params['temperature']
214
- top_p = params['top_p']
215
- max_new_tokens = params['max_new_tokens']
216
- repetition_penalty = params['repetition_penalty']
217
  do_sample = True if temperature > 0.0 else False
218
 
219
- global_image_cnt = 1
220
  history, pil_images, max_input_tile_list = [], [], []
221
  for message in send_messages:
222
- if message['role'] == 'user':
223
- prefix = ''
224
- if 'image' in message:
225
  max_input_tile_temp = []
226
- for image_str in message['image']:
227
  pil_images.append(load_image_from_base64(image_str))
228
- prefix += f'Image-{global_image_cnt}: <image>\n\n'
229
  global_image_cnt += 1
230
- max_input_tile_temp.append(max(1, max_input_tiles // len(message['image'])))
 
 
231
  if len(max_input_tile_temp) > 0:
232
  max_input_tile_list.append(max_input_tile_temp)
233
- content = prefix + message['content']
234
- history.append([content, ])
 
 
 
 
235
  else:
236
- history[-1].append(message['content'])
237
  question, history = history[-1][0], history[:-1]
238
 
 
 
 
 
 
 
 
239
  # Create a new list to store processed sublists
240
  flattened_list = []
241
  # Iterate through all but the last sublist in max_input_tile_list and process them
242
  for sublist in max_input_tile_list[:-1]:
243
- processed_sublist = [1] * len(sublist) # Change each element in the sublist to 1
244
- flattened_list.extend(processed_sublist) # Flatten the processed sublist and add to the new list
 
 
 
 
245
  # If max_input_tile_list is not empty, add the last sublist to the new list
246
  if max_input_tile_list:
247
  flattened_list.extend(max_input_tile_list[-1])
248
  max_input_tile_list = flattened_list
249
- assert len(max_input_tile_list) == len(pil_images), 'The number of max_input_tile_list and pil_images should be the same.'
250
- logger.info(f'max_input_tile_list: {max_input_tile_list}')
 
251
 
252
  old_system_message = self.model.system_message
253
  self.model.system_message = system_message
254
  image_tiles = []
255
  transform = build_transform(input_size=self.image_size)
256
  if len(pil_images) > 0:
257
- for current_max_input_tiles, pil_image in zip(max_input_tile_list, pil_images):
 
 
258
  if self.model.config.dynamic_image_size:
259
  tiles = dynamic_preprocess(
260
- pil_image, image_size=self.image_size, max_num=current_max_input_tiles,
261
- use_thumbnail=self.model.config.use_thumbnail)
 
 
 
262
  else:
263
  tiles = [pil_image]
264
  image_tiles += tiles
265
  pixel_values = [transform(item) for item in image_tiles]
266
- pixel_values = torch.stack(pixel_values).to(self.model.device, dtype=torch.float16)
267
- logger.info(f'Split images to {pixel_values.shape}')
 
 
268
  else:
269
  pixel_values = None
270
 
271
- streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=False, timeout=10)
 
 
272
  generation_config = dict(
273
  num_beams=1,
274
  max_new_tokens=max_new_tokens,
@@ -279,53 +410,61 @@ class ModelWorker:
279
  top_p=top_p,
280
  streamer=streamer,
281
  )
282
- logger.info(history)
283
- logger.info(f'Generation config: {generation_config}')
284
- try:
285
- thread = Thread(target=self.model.chat, kwargs=dict(
 
286
  tokenizer=self.tokenizer,
287
  pixel_values=pixel_values,
288
  question=question,
289
  history=history,
290
  return_history=False,
291
  generation_config=generation_config,
292
- ))
293
- thread.start()
294
-
295
- generated_text = ''
296
- for new_text in streamer:
297
- generated_text += new_text
298
- yield json.dumps({'text': generated_text.replace(self.model.conv_template.sep, ''),
299
- 'error_code': 0}).encode() + b'\0'
300
- self.model.system_message = old_system_message
301
- except:
302
- torch.cuda.empty_cache()
 
 
 
 
303
 
304
  def generate_stream_gate(self, params):
305
  try:
306
  for x in self.generate_stream(params):
307
  yield x
308
  except ValueError as e:
309
- print('Caught ValueError:', e)
 
310
  ret = {
311
- 'text': server_error_msg,
312
- 'error_code': 1,
313
  }
314
- yield json.dumps(ret).encode() + b'\0'
315
  except torch.cuda.CudaError as e:
316
- print('Caught torch.cuda.CudaError:', e)
 
317
  ret = {
318
- 'text': server_error_msg,
319
- 'error_code': 1,
320
  }
321
- yield json.dumps(ret).encode() + b'\0'
322
  except Exception as e:
323
- print('Caught Unknown Error', e)
 
324
  ret = {
325
- 'text': server_error_msg,
326
- 'error_code': 1,
327
  }
328
- yield json.dumps(ret).encode() + b'\0'
329
 
330
 
331
  app = FastAPI()
@@ -337,7 +476,7 @@ def release_model_semaphore(fn=None):
337
  fn()
338
 
339
 
340
- @app.post('/worker_generate_stream')
341
  async def generate_stream(request: Request):
342
  global model_semaphore, global_counter
343
  global_counter += 1
@@ -349,35 +488,39 @@ async def generate_stream(request: Request):
349
  worker.send_heart_beat()
350
  generator = worker.generate_stream_gate(params)
351
  background_tasks = BackgroundTasks()
352
- background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
 
 
353
  return StreamingResponse(generator, background=background_tasks)
354
 
355
 
356
- @app.post('/worker_get_status')
357
  async def get_status(request: Request):
358
  return worker.get_status()
359
 
360
 
361
- if __name__ == '__main__':
362
  parser = argparse.ArgumentParser()
363
- parser.add_argument('--host', type=str, default='0.0.0.0')
364
- parser.add_argument('--port', type=int, default=21002)
365
- parser.add_argument('--worker-address', type=str, default='http://localhost:21002')
366
- parser.add_argument('--controller-address', type=str, default='http://localhost:21001')
367
- parser.add_argument('--model-path', type=str, default='facebook/opt-350m')
368
- parser.add_argument('--model-name', type=str)
369
- parser.add_argument('--device', type=str, default='cuda')
370
- parser.add_argument('--limit-model-concurrency', type=int, default=5)
371
- parser.add_argument('--stream-interval', type=int, default=1)
372
- parser.add_argument('--load-8bit', action='store_true')
373
  args = parser.parse_args()
374
- logger.info(f'args: {args}')
375
-
376
- worker = ModelWorker(args.controller_address,
377
- args.worker_address,
378
- worker_id,
379
- args.model_path,
380
- args.model_name,
381
- args.load_8bit,
382
- args.device)
383
- uvicorn.run(app, host=args.host, port=args.port, log_level='info')
 
 
 
9
  """
10
  import argparse
11
  import asyncio
12
+
13
  import json
14
+ import math
15
  import threading
16
  import time
17
  import uuid
18
+ import traceback
19
  from functools import partial
20
+
21
  from threading import Thread
22
 
23
  import requests
 
29
  from fastapi.responses import StreamingResponse
30
  from PIL import Image
31
  from torchvision.transforms.functional import InterpolationMode
32
+ from transformers import AutoModel, AutoTokenizer, TextIteratorStreamer
33
+ from utils import (
34
+ build_logger,
35
+ pretty_print_semaphore,
36
+ server_error_msg,
37
+ load_image_from_base64,
38
+ )
39
+ import spaces
40
 
41
  worker_id = str(uuid.uuid4())[:6]
42
+ logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
43
  global_counter = 0
44
  model_semaphore = None
45
 
46
 
 
 
 
 
47
  def build_transform(input_size):
48
  MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
49
+ transform = T.Compose(
50
+ [
51
+ T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
52
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
53
+ T.ToTensor(),
54
+ T.Normalize(mean=MEAN, std=STD),
55
+ ]
56
+ )
57
  return transform
58
 
59
 
60
  def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
61
+ best_ratio_diff = float("inf")
62
  best_ratio = (1, 1)
63
  area = width * height
64
  for ratio in target_ratios:
 
73
  return best_ratio
74
 
75
 
76
+ def dynamic_preprocess(
77
+ image, min_num=1, max_num=6, image_size=448, use_thumbnail=False
78
+ ):
79
  orig_width, orig_height = image.size
80
  aspect_ratio = orig_width / orig_height
81
 
82
  # calculate the existing image aspect ratio
83
  target_ratios = set(
84
+ (i, j)
85
+ for n in range(min_num, max_num + 1)
86
+ for i in range(1, n + 1)
87
+ for j in range(1, n + 1)
88
+ if i * j <= max_num and i * j >= min_num
89
+ )
90
  target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
91
 
92
  # find the closest aspect ratio to the target
93
  target_aspect_ratio = find_closest_aspect_ratio(
94
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size
95
+ )
96
 
97
  # calculate the target width and height
98
  target_width = image_size * target_aspect_ratio[0]
 
107
  (i % (target_width // image_size)) * image_size,
108
  (i // (target_width // image_size)) * image_size,
109
  ((i % (target_width // image_size)) + 1) * image_size,
110
+ ((i // (target_width // image_size)) + 1) * image_size,
111
  )
112
  # split the image
113
  split_img = resized_img.crop(box)
 
125
  controller.send_heart_beat()
126
 
127
 
128
+ def split_model(model_name):
129
+ device_map = {}
130
+ world_size = torch.cuda.device_count()
131
+ num_layers = {
132
+ "InternVL2-8B": 32,
133
+ "InternVL2-26B": 48,
134
+ "InternVL2-40B": 60,
135
+ "InternVL2-Llama3-76B": 80,
136
+ "InternVL2-78B": 80,
137
+ "InternVL2-Pro": 80,
138
+ }[model_name]
139
+ # Since the first GPU will be used for ViT, treat it as half a GPU.
140
+ num_layers_per_gpu = math.ceil(num_layers / (world_size - 0.5))
141
+ num_layers_per_gpu = [num_layers_per_gpu] * world_size
142
+ num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * 0.5)
143
+ layer_cnt = 0
144
+ for i, num_layer in enumerate(num_layers_per_gpu):
145
+ for j in range(num_layer):
146
+ device_map[f"language_model.model.layers.{layer_cnt}"] = i
147
+ layer_cnt += 1
148
+ device_map["vision_model"] = 0
149
+ device_map["mlp1"] = 0
150
+ device_map["language_model.model.tok_embeddings"] = 0
151
+ device_map["language_model.model.embed_tokens"] = 0
152
+ device_map["language_model.output"] = 0
153
+ device_map["language_model.model.norm"] = 0
154
+ device_map["language_model.lm_head"] = 0
155
+ device_map[f"language_model.model.layers.{num_layers - 1}"] = 0
156
+
157
+ return device_map
158
+
159
+
160
  class ModelWorker:
161
+ def __init__(
162
+ self,
163
+ controller_addr,
164
+ worker_addr,
165
+ worker_id,
166
+ model_path,
167
+ model_name,
168
+ load_8bit,
169
+ device,
170
+ context_len=8192,
171
+ ):
172
  self.controller_addr = controller_addr
173
  self.worker_addr = worker_addr
174
  self.worker_id = worker_id
175
+ if model_path.endswith("/"):
176
  model_path = model_path[:-1]
177
  if model_name is None:
178
+ model_paths = model_path.split("/")
179
+ if model_paths[-1].startswith("checkpoint-"):
180
+ self.model_name = model_paths[-2] + "_" + model_paths[-1]
181
  else:
182
  self.model_name = model_paths[-1]
183
  else:
184
  self.model_name = model_name
185
 
186
+ logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
187
 
188
+ tokenizer = AutoTokenizer.from_pretrained(
189
+ model_path, trust_remote_code=True, use_fast=False
190
+ )
191
+ tokens_to_keep = ["<box>", "</box>", "<ref>", "</ref>"]
192
+ tokenizer.additional_special_tokens = [
193
+ item
194
+ for item in tokenizer.additional_special_tokens
195
+ if item not in tokens_to_keep
196
+ ]
197
+ self.tokenizer = tokenizer
198
+
199
+ if device == "auto":
200
+ device_map = split_model(self.model_name)
201
+ self.model = AutoModel.from_pretrained(
202
  model_path,
203
  load_in_8bit=load_8bit,
204
+ torch_dtype=torch.bfloat16,
205
+ device_map=device_map,
206
+ trust_remote_code=True,
207
+ ).eval()
208
  else:
209
+ self.model = AutoModel.from_pretrained(
210
  model_path,
211
  load_in_8bit=load_8bit,
212
+ torch_dtype=torch.bfloat16,
213
+ trust_remote_code=True,
214
+ ).eval()
215
+ if not load_8bit and not device == "auto":
216
  self.model = self.model.cuda()
217
+ self.load_8bit = load_8bit
218
+ self.device = device
219
+ self.model_path = model_path
220
  self.image_size = self.model.config.force_image_size
221
  self.context_len = context_len
222
  self.register_to_controller()
223
  self.heart_beat_thread = threading.Thread(
224
+ target=heart_beat_worker, args=(self,)
225
+ )
226
  self.heart_beat_thread.start()
227
 
228
+ def reload_model(self):
229
+ del self.model
230
+ torch.cuda.empty_cache()
231
+ if self.device == "auto":
232
+ device_map = split_model(self.model_name)
233
+ self.model = AutoModel.from_pretrained(
234
+ self.model_path,
235
+ load_in_8bit=self.load_8bit,
236
+ torch_dtype=torch.bfloat16,
237
+ device_map=device_map,
238
+ trust_remote_code=True,
239
+ ).eval()
240
+ else:
241
+ self.model = AutoModel.from_pretrained(
242
+ self.model_path,
243
+ load_in_8bit=self.load_8bit,
244
+ torch_dtype=torch.bfloat16,
245
+ trust_remote_code=True,
246
+ ).eval()
247
+ if not self.load_8bit and not self.device == "auto":
248
+ self.model = self.model.cuda()
249
+
250
  def register_to_controller(self):
251
+ logger.info("Register to controller")
252
 
253
+ url = self.controller_addr + "/register_worker"
254
  data = {
255
+ "worker_name": self.worker_addr,
256
+ "check_heart_beat": True,
257
+ "worker_status": self.get_status(),
258
  }
259
  r = requests.post(url, json=data)
260
  assert r.status_code == 200
261
 
262
  def send_heart_beat(self):
263
+ logger.info(
264
+ f"Send heart beat. Models: {[self.model_name]}. "
265
+ f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
266
+ f"global_counter: {global_counter}"
267
+ )
268
 
269
+ url = self.controller_addr + "/receive_heart_beat"
270
 
271
  while True:
272
  try:
273
+ ret = requests.post(
274
+ url,
275
+ json={
276
+ "worker_name": self.worker_addr,
277
+ "queue_length": self.get_queue_length(),
278
+ },
279
+ timeout=5,
280
+ )
281
+ exist = ret.json()["exist"]
282
  break
283
  except requests.exceptions.RequestException as e:
284
+ logger.error(f"heart beat error: {e}")
285
  time.sleep(5)
286
 
287
  if not exist:
 
291
  if model_semaphore is None:
292
  return 0
293
  else:
294
+ return (
295
+ args.limit_model_concurrency
296
+ - model_semaphore._value
297
+ + (
298
+ len(model_semaphore._waiters)
299
+ if model_semaphore._waiters is not None
300
+ else 0
301
+ )
302
+ )
303
 
304
  def get_status(self):
305
  return {
306
+ "model_names": [self.model_name],
307
+ "speed": 1,
308
+ "queue_length": self.get_queue_length(),
309
  }
310
 
311
+ @spaces.GPU
312
  @torch.inference_mode()
313
  def generate_stream(self, params):
314
+ system_message = params["prompt"][0]["content"]
315
+ send_messages = params["prompt"][1:]
316
+ max_input_tiles = params["max_input_tiles"]
317
+ temperature = params["temperature"]
318
+ top_p = params["top_p"]
319
+ max_new_tokens = params["max_new_tokens"]
320
+ repetition_penalty = params["repetition_penalty"]
321
  do_sample = True if temperature > 0.0 else False
322
 
323
+ global_image_cnt = 0
324
  history, pil_images, max_input_tile_list = [], [], []
325
  for message in send_messages:
326
+ if message["role"] == "user":
327
+ prefix = ""
328
+ if "image" in message:
329
  max_input_tile_temp = []
330
+ for image_str in message["image"]:
331
  pil_images.append(load_image_from_base64(image_str))
332
+ prefix += f"Image-{global_image_cnt + 1}: <image>\n\n"
333
  global_image_cnt += 1
334
+ max_input_tile_temp.append(
335
+ max(1, max_input_tiles // len(message["image"]))
336
+ )
337
  if len(max_input_tile_temp) > 0:
338
  max_input_tile_list.append(max_input_tile_temp)
339
+ content = prefix + message["content"]
340
+ history.append(
341
+ [
342
+ content,
343
+ ]
344
+ )
345
  else:
346
+ history[-1].append(message["content"])
347
  question, history = history[-1][0], history[:-1]
348
 
349
+ if global_image_cnt == 1:
350
+ question = question.replace("Image-1: <image>\n\n", "<image>\n")
351
+ history = [
352
+ [item[0].replace("Image-1: <image>\n\n", "<image>\n"), item[1]]
353
+ for item in history
354
+ ]
355
+
356
  # Create a new list to store processed sublists
357
  flattened_list = []
358
  # Iterate through all but the last sublist in max_input_tile_list and process them
359
  for sublist in max_input_tile_list[:-1]:
360
+ processed_sublist = [1] * len(
361
+ sublist
362
+ ) # Change each element in the sublist to 1
363
+ flattened_list.extend(
364
+ processed_sublist
365
+ ) # Flatten the processed sublist and add to the new list
366
  # If max_input_tile_list is not empty, add the last sublist to the new list
367
  if max_input_tile_list:
368
  flattened_list.extend(max_input_tile_list[-1])
369
  max_input_tile_list = flattened_list
370
+ assert len(max_input_tile_list) == len(
371
+ pil_images
372
+ ), "The number of max_input_tile_list and pil_images should be the same."
373
 
374
  old_system_message = self.model.system_message
375
  self.model.system_message = system_message
376
  image_tiles = []
377
  transform = build_transform(input_size=self.image_size)
378
  if len(pil_images) > 0:
379
+ for current_max_input_tiles, pil_image in zip(
380
+ max_input_tile_list, pil_images
381
+ ):
382
  if self.model.config.dynamic_image_size:
383
  tiles = dynamic_preprocess(
384
+ pil_image,
385
+ image_size=self.image_size,
386
+ max_num=current_max_input_tiles,
387
+ use_thumbnail=self.model.config.use_thumbnail,
388
+ )
389
  else:
390
  tiles = [pil_image]
391
  image_tiles += tiles
392
  pixel_values = [transform(item) for item in image_tiles]
393
+ pixel_values = torch.stack(pixel_values).to(
394
+ self.model.device, dtype=torch.bfloat16
395
+ )
396
+ logger.info(f"Split images to {pixel_values.shape}")
397
  else:
398
  pixel_values = None
399
 
400
+ streamer = TextIteratorStreamer(
401
+ self.tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=10
402
+ )
403
  generation_config = dict(
404
  num_beams=1,
405
  max_new_tokens=max_new_tokens,
 
410
  top_p=top_p,
411
  streamer=streamer,
412
  )
413
+ logger.info(f"Generation config: {generation_config}")
414
+
415
+ thread = Thread(
416
+ target=self.model.chat,
417
+ kwargs=dict(
418
  tokenizer=self.tokenizer,
419
  pixel_values=pixel_values,
420
  question=question,
421
  history=history,
422
  return_history=False,
423
  generation_config=generation_config,
424
+ ),
425
+ )
426
+ thread.start()
427
+
428
+ generated_text = ""
429
+ for new_text in streamer:
430
+ generated_text += new_text
431
+ if generated_text.endswith(self.model.conv_template.sep):
432
+ generated_text = generated_text[: -len(self.model.conv_template.sep)]
433
+ yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
434
+ logger.info(
435
+ f"max_input_tile_list: {max_input_tile_list}, history: {history}, "
436
+ f"question: {question}, answer: {generated_text}"
437
+ )
438
+ self.model.system_message = old_system_message
439
 
440
  def generate_stream_gate(self, params):
441
  try:
442
  for x in self.generate_stream(params):
443
  yield x
444
  except ValueError as e:
445
+ print("Caught ValueError:", e)
446
+ traceback.print_exc()
447
  ret = {
448
+ "text": server_error_msg,
449
+ "error_code": 1,
450
  }
451
+ yield json.dumps(ret).encode() + b"\0"
452
  except torch.cuda.CudaError as e:
453
+ traceback.print_exc()
454
+ print("Caught torch.cuda.CudaError:", e)
455
  ret = {
456
+ "text": server_error_msg,
457
+ "error_code": 1,
458
  }
459
+ yield json.dumps(ret).encode() + b"\0"
460
  except Exception as e:
461
+ traceback.print_exc()
462
+ print("Caught Unknown Error", e)
463
  ret = {
464
+ "text": server_error_msg,
465
+ "error_code": 1,
466
  }
467
+ yield json.dumps(ret).encode() + b"\0"
468
 
469
 
470
  app = FastAPI()
 
476
  fn()
477
 
478
 
479
+ @app.post("/worker_generate_stream")
480
  async def generate_stream(request: Request):
481
  global model_semaphore, global_counter
482
  global_counter += 1
 
488
  worker.send_heart_beat()
489
  generator = worker.generate_stream_gate(params)
490
  background_tasks = BackgroundTasks()
491
+ background_tasks.add_task(
492
+ partial(release_model_semaphore, fn=worker.send_heart_beat)
493
+ )
494
  return StreamingResponse(generator, background=background_tasks)
495
 
496
 
497
+ @app.post("/worker_get_status")
498
  async def get_status(request: Request):
499
  return worker.get_status()
500
 
501
 
502
+ if __name__ == "__main__":
503
  parser = argparse.ArgumentParser()
504
+ parser.add_argument("--host", type=str, default="0.0.0.0")
505
+ parser.add_argument("--port", type=int, default=21002)
506
+ parser.add_argument("--worker-url", type=str, default="http://localhost")
507
+ parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
508
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
509
+ parser.add_argument("--model-name", type=str)
510
+ parser.add_argument("--device", type=str, default="cuda")
511
+ parser.add_argument("--limit-model-concurrency", type=int, default=5)
512
+ parser.add_argument("--stream-interval", type=int, default=1)
513
+ parser.add_argument("--load-8bit", action="store_true")
514
  args = parser.parse_args()
515
+ logger.info(f"args: {args}")
516
+
517
+ worker = ModelWorker(
518
+ args.controller_url,
519
+ args.worker_url + f":{args.port}",
520
+ worker_id,
521
+ args.model_path,
522
+ args.model_name,
523
+ args.load_8bit,
524
+ args.device,
525
+ )
526
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
requirements.txt CHANGED
@@ -1,4 +1,14 @@
1
- opencv-python
2
- streamlit_image_select
3
- streamlit==1.36.0
4
- flask
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers==0.29.2
2
+ fastapi==0.111.1
3
+ filelock==3.15.4
4
+ fire==0.6.0
5
+ gradio==4.38.1
6
+ numpy==2.0.1
7
+ Pillow==10.4.0
8
+ pydantic==2.8.2
9
+ Requests==2.32.3
10
+ spaces==0.28.3
11
+ torch==2.0.1
12
+ torchvision==0.15.2
13
+ transformers==4.37.2
14
+ uvicorn==0.30.3
utils.py CHANGED
@@ -1,13 +1,22 @@
 
1
  import logging
2
  import logging.handlers
3
  import os
4
  import sys
5
-
 
 
 
6
  import requests
7
  from constants import LOGDIR
 
8
 
9
- server_error_msg = '**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**'
10
- moderation_msg = 'YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN.'
 
 
 
 
11
 
12
  handler = None
13
 
@@ -16,8 +25,8 @@ def build_logger(logger_name, logger_filename):
16
  global handler
17
 
18
  formatter = logging.Formatter(
19
- fmt='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
20
- datefmt='%Y-%m-%d %H:%M:%S',
21
  )
22
 
23
  # Set the format of root handlers
@@ -26,12 +35,12 @@ def build_logger(logger_name, logger_filename):
26
  logging.getLogger().handlers[0].setFormatter(formatter)
27
 
28
  # Redirect stdout and stderr to loggers
29
- stdout_logger = logging.getLogger('stdout')
30
  stdout_logger.setLevel(logging.INFO)
31
  sl = StreamToLogger(stdout_logger, logging.INFO)
32
  sys.stdout = sl
33
 
34
- stderr_logger = logging.getLogger('stderr')
35
  stderr_logger.setLevel(logging.ERROR)
36
  sl = StreamToLogger(stderr_logger, logging.ERROR)
37
  sys.stderr = sl
@@ -45,7 +54,8 @@ def build_logger(logger_name, logger_filename):
45
  os.makedirs(LOGDIR, exist_ok=True)
46
  filename = os.path.join(LOGDIR, logger_filename)
47
  handler = logging.handlers.TimedRotatingFileHandler(
48
- filename, when='D', utc=True)
 
49
  handler.setFormatter(formatter)
50
 
51
  for name, item in logging.root.manager.loggerDict.items():
@@ -59,33 +69,34 @@ class StreamToLogger(object):
59
  """
60
  Fake file-like stream object that redirects writes to a logger instance.
61
  """
 
62
  def __init__(self, logger, log_level=logging.INFO):
63
  self.terminal = sys.stdout
64
  self.logger = logger
65
  self.log_level = log_level
66
- self.linebuf = ''
67
 
68
  def __getattr__(self, attr):
69
  return getattr(self.terminal, attr)
70
 
71
  def write(self, buf):
72
  temp_linebuf = self.linebuf + buf
73
- self.linebuf = ''
74
  for line in temp_linebuf.splitlines(True):
75
  # From the io.TextIOWrapper docs:
76
  # On output, if newline is None, any '\n' characters written
77
  # are translated to the system default line separator.
78
  # By default sys.stdout.write() expects '\n' newlines and then
79
  # translates them so this is still cross platform.
80
- if line[-1] == '\n':
81
  self.logger.log(self.log_level, line.rstrip())
82
  else:
83
  self.linebuf += line
84
 
85
  def flush(self):
86
- if self.linebuf != '':
87
  self.logger.log(self.log_level, self.linebuf.rstrip())
88
- self.linebuf = ''
89
 
90
 
91
  def disable_torch_init():
@@ -93,23 +104,26 @@ def disable_torch_init():
93
  Disable the redundant torch default initialization to accelerate model creation.
94
  """
95
  import torch
96
- setattr(torch.nn.Linear, 'reset_parameters', lambda self: None)
97
- setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None)
 
98
 
99
 
100
  def violates_moderation(text):
101
  """
102
  Check whether the text violates OpenAI moderation API.
103
  """
104
- url = 'https://api.openai.com/v1/moderations'
105
- headers = {'Content-Type': 'application/json',
106
- 'Authorization': 'Bearer ' + os.environ['OPENAI_API_KEY']}
107
- text = text.replace('\n', '')
108
- data = '{' + '"input": ' + f'"{text}"' + '}'
109
- data = data.encode('utf-8')
 
 
110
  try:
111
  ret = requests.post(url, headers=headers, data=data, timeout=5)
112
- flagged = ret.json()['results'][0]['flagged']
113
  except requests.exceptions.RequestException as e:
114
  flagged = False
115
  except KeyError as e:
@@ -120,5 +134,30 @@ def violates_moderation(text):
120
 
121
  def pretty_print_semaphore(semaphore):
122
  if semaphore is None:
123
- return 'None'
124
- return f'Semaphore(value={semaphore._value}, locked={semaphore.locked()})'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ast import Dict
2
  import logging
3
  import logging.handlers
4
  import os
5
  import sys
6
+ import base64
7
+ from PIL import Image
8
+ from io import BytesIO
9
+ import json
10
  import requests
11
  from constants import LOGDIR
12
+ import datetime
13
 
14
+ server_error_msg = (
15
+ "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
16
+ )
17
+ moderation_msg = (
18
+ "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
19
+ )
20
 
21
  handler = None
22
 
 
25
  global handler
26
 
27
  formatter = logging.Formatter(
28
+ fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
29
+ datefmt="%Y-%m-%d %H:%M:%S",
30
  )
31
 
32
  # Set the format of root handlers
 
35
  logging.getLogger().handlers[0].setFormatter(formatter)
36
 
37
  # Redirect stdout and stderr to loggers
38
+ stdout_logger = logging.getLogger("stdout")
39
  stdout_logger.setLevel(logging.INFO)
40
  sl = StreamToLogger(stdout_logger, logging.INFO)
41
  sys.stdout = sl
42
 
43
+ stderr_logger = logging.getLogger("stderr")
44
  stderr_logger.setLevel(logging.ERROR)
45
  sl = StreamToLogger(stderr_logger, logging.ERROR)
46
  sys.stderr = sl
 
54
  os.makedirs(LOGDIR, exist_ok=True)
55
  filename = os.path.join(LOGDIR, logger_filename)
56
  handler = logging.handlers.TimedRotatingFileHandler(
57
+ filename, when="D", utc=True
58
+ )
59
  handler.setFormatter(formatter)
60
 
61
  for name, item in logging.root.manager.loggerDict.items():
 
69
  """
70
  Fake file-like stream object that redirects writes to a logger instance.
71
  """
72
+
73
  def __init__(self, logger, log_level=logging.INFO):
74
  self.terminal = sys.stdout
75
  self.logger = logger
76
  self.log_level = log_level
77
+ self.linebuf = ""
78
 
79
  def __getattr__(self, attr):
80
  return getattr(self.terminal, attr)
81
 
82
  def write(self, buf):
83
  temp_linebuf = self.linebuf + buf
84
+ self.linebuf = ""
85
  for line in temp_linebuf.splitlines(True):
86
  # From the io.TextIOWrapper docs:
87
  # On output, if newline is None, any '\n' characters written
88
  # are translated to the system default line separator.
89
  # By default sys.stdout.write() expects '\n' newlines and then
90
  # translates them so this is still cross platform.
91
+ if line[-1] == "\n":
92
  self.logger.log(self.log_level, line.rstrip())
93
  else:
94
  self.linebuf += line
95
 
96
  def flush(self):
97
+ if self.linebuf != "":
98
  self.logger.log(self.log_level, self.linebuf.rstrip())
99
+ self.linebuf = ""
100
 
101
 
102
  def disable_torch_init():
 
104
  Disable the redundant torch default initialization to accelerate model creation.
105
  """
106
  import torch
107
+
108
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
109
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
110
 
111
 
112
  def violates_moderation(text):
113
  """
114
  Check whether the text violates OpenAI moderation API.
115
  """
116
+ url = "https://api.openai.com/v1/moderations"
117
+ headers = {
118
+ "Content-Type": "application/json",
119
+ "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"],
120
+ }
121
+ text = text.replace("\n", "")
122
+ data = "{" + '"input": ' + f'"{text}"' + "}"
123
+ data = data.encode("utf-8")
124
  try:
125
  ret = requests.post(url, headers=headers, data=data, timeout=5)
126
+ flagged = ret.json()["results"][0]["flagged"]
127
  except requests.exceptions.RequestException as e:
128
  flagged = False
129
  except KeyError as e:
 
134
 
135
  def pretty_print_semaphore(semaphore):
136
  if semaphore is None:
137
+ return "None"
138
+ return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
139
+
140
+
141
+ def load_image_from_base64(image):
142
+ return Image.open(BytesIO(base64.b64decode(image)))
143
+
144
+
145
+ def get_log_filename():
146
+ t = datetime.datetime.now()
147
+ name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
148
+ return name
149
+
150
+
151
+ def data_wrapper(data):
152
+ if isinstance(data, bytes):
153
+ return data
154
+ elif isinstance(data, Image.Image):
155
+ buffered = BytesIO()
156
+ data.save(buffered, format="PNG")
157
+ return buffered.getvalue()
158
+ elif isinstance(data, str):
159
+ return data.encode()
160
+ elif isinstance(data, Dict):
161
+ return json.dumps(data).encode()
162
+ else:
163
+ raise ValueError(f"Unsupported data type: {type(data)}")