zwgao commited on
Commit
3fdcc70
1 Parent(s): 9704f7c
Files changed (47) hide show
  1. .gitignore +4 -0
  2. app.py +690 -0
  3. assets/assistant.png +0 -0
  4. assets/human.png +0 -0
  5. builtin_plan.json +15 -0
  6. cllm/agents/__init__.py +2 -0
  7. cllm/agents/base.py +173 -0
  8. cllm/agents/builtin/__init__.py +3 -0
  9. cllm/agents/builtin/plans.py +634 -0
  10. cllm/agents/builtin/tools.py +1512 -0
  11. cllm/agents/container.py +98 -0
  12. cllm/agents/tog/__init__.py +2 -0
  13. cllm/agents/tog/compiler.py +62 -0
  14. cllm/agents/tog/controller.py +157 -0
  15. cllm/agents/tog/interpretor.py +262 -0
  16. cllm/agents/tog/planner.py +156 -0
  17. cllm/agents/tog/responser.py +66 -0
  18. cllm/services/audio/__init__.py +0 -0
  19. cllm/services/audio/api.py +140 -0
  20. cllm/services/general/__init__.py +0 -0
  21. cllm/services/general/api.py +65 -0
  22. cllm/services/image_editing/__init__.py +0 -0
  23. cllm/services/image_editing/api.py +277 -0
  24. cllm/services/image_generation/__init__.py +0 -0
  25. cllm/services/image_generation/api.py +96 -0
  26. cllm/services/image_inpainting/__init__.py +0 -0
  27. cllm/services/image_inpainting/api.py +76 -0
  28. cllm/services/image_perception/__init__.py +0 -0
  29. cllm/services/image_perception/api.py +202 -0
  30. cllm/services/image_processing/__init__.py +0 -0
  31. cllm/services/image_processing/api.py +63 -0
  32. cllm/services/nlp/__init__.py +0 -0
  33. cllm/services/nlp/api.py +163 -0
  34. cllm/services/nlp/llms/__init__.py +2 -0
  35. cllm/services/nlp/llms/chat_models.py +219 -0
  36. cllm/services/nlp/llms/memory/__init__.py +1 -0
  37. cllm/services/nlp/llms/memory/message_memory.py +131 -0
  38. cllm/services/nlp/llms/memory/utils.py +52 -0
  39. cllm/services/tog/__init__.py +2 -0
  40. cllm/services/tog/api.py +40 -0
  41. cllm/services/utils.py +50 -0
  42. cllm/services/video/__init__.py +0 -0
  43. cllm/services/video/api.py +135 -0
  44. cllm/services/vqa/__init__.py +0 -0
  45. cllm/services/vqa/api.py +28 -0
  46. cllm/utils.py +79 -0
  47. requirements.txt +14 -0
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ __pycache__/
2
+ run.sh
3
+ client_resources/
4
+ cllm.log
app.py ADDED
@@ -0,0 +1,690 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import logging
3
+ import os
4
+ import os.path as osp
5
+ from functools import partial
6
+ from pydoc import locate
7
+ import shutil
8
+ import json
9
+ from traceback import print_exc
10
+ import uuid
11
+ from pathlib import Path
12
+ from collections import OrderedDict
13
+ import numpy as np
14
+ from PIL import Image
15
+
16
+ import whisper
17
+ import fire
18
+ import gradio as gr
19
+ import gradio.themes.base as ThemeBase
20
+ from gradio.themes.utils import colors, fonts, sizes
21
+ import os
22
+ import sys
23
+
24
+ sys.path.append(os.getcwd())
25
+ from cllm.agents.builtin import plans
26
+ from cllm.services.general.api import remote_logging
27
+ from cllm.agents import container, FILE_EXT
28
+ from cllm.utils import get_real_path, plain2md, md2plain
29
+ import openai
30
+
31
+ openai.api_base = os.environ.get("OPENAI_API_BASE", None)
32
+ openai.api_key = os.environ.get("OPENAI_API_KEY", None)
33
+
34
+
35
+ logging.basicConfig(
36
+ filename="cllm.log",
37
+ level=logging.INFO,
38
+ format="%(asctime)s %(levelname)-8s %(message)s",
39
+ )
40
+
41
+ logger = logging.getLogger(__name__)
42
+
43
+ RESOURCE_ROOT = os.environ.get("CLIENT_ROOT", "./client_resources")
44
+
45
+
46
+ def is_image(file_path):
47
+ ext = FILE_EXT["image"]
48
+ _, extension = os.path.splitext(file_path)
49
+ return extension[1:] in ext
50
+
51
+
52
+ def is_video(file_path):
53
+ ext = FILE_EXT["video"]
54
+ _, extension = os.path.splitext(file_path)
55
+ return extension[1:] in ext
56
+
57
+
58
+ def is_audio(file_path):
59
+ ext = FILE_EXT["audio"]
60
+ _, extension = os.path.splitext(file_path)
61
+ return extension[1:] in ext
62
+
63
+
64
+ def get_file_type(file_path):
65
+ if is_image(file_path):
66
+ if "mask" in file_path:
67
+ return "mask"
68
+ return "image"
69
+ elif is_video(file_path):
70
+ return "video"
71
+ elif is_audio(file_path):
72
+ return "audio"
73
+ raise ValueError("Invalid file type")
74
+
75
+
76
+ def convert_dict_to_frame(data):
77
+ import pandas
78
+
79
+ outputs = []
80
+ for k, v in data.items():
81
+ output = {"Resource": k}
82
+ if not isinstance(v, str):
83
+ output["Type"] = str(v.__class__)
84
+ else:
85
+ output["Type"] = v
86
+ outputs.append(output)
87
+ if len(outputs) == 0:
88
+ return None
89
+ return pandas.DataFrame(outputs)
90
+
91
+
92
+ class Seafoam(ThemeBase.Base):
93
+ def __init__(
94
+ self,
95
+ *,
96
+ primary_hue=colors.emerald,
97
+ secondary_hue=colors.blue,
98
+ neutral_hue=colors.gray,
99
+ spacing_size=sizes.spacing_md,
100
+ radius_size=sizes.radius_md,
101
+ text_size=sizes.text_sm,
102
+ ):
103
+ super().__init__(
104
+ primary_hue=primary_hue,
105
+ secondary_hue=secondary_hue,
106
+ neutral_hue=neutral_hue,
107
+ spacing_size=spacing_size,
108
+ radius_size=radius_size,
109
+ text_size=text_size,
110
+ )
111
+ super().set(
112
+ body_background_fill_dark="#111111",
113
+ button_primary_background_fill="*primary_300",
114
+ button_primary_background_fill_hover="*primary_200",
115
+ button_primary_text_color="black",
116
+ button_secondary_background_fill="*secondary_300",
117
+ button_secondary_background_fill_hover="*secondary_200",
118
+ border_color_primary="#0BB9BF",
119
+ slider_color="*secondary_300",
120
+ slider_color_dark="*secondary_600",
121
+ block_title_text_weight="600",
122
+ block_border_width="3px",
123
+ block_shadow="*shadow_drop_lg",
124
+ button_shadow="*shadow_drop_lg",
125
+ button_large_padding="10px",
126
+ )
127
+
128
+
129
+ class InteractionLoop:
130
+ def __init__(
131
+ self,
132
+ controller="cllm.agents.code.Controller",
133
+ ):
134
+ self.stream = True
135
+ Controller = locate(controller)
136
+ self.controller = Controller(stream=self.stream, interpretor_kwargs=dict())
137
+ self.whisper = whisper.load_model("base")
138
+
139
+ def _gen_new_name(self, r_type, ext="png"):
140
+ this_new_uuid = str(uuid.uuid4())[:6]
141
+ new_file_name = f"{this_new_uuid}_{r_type}.{ext}"
142
+ return new_file_name
143
+
144
+ def init_state(self):
145
+ user_state = OrderedDict()
146
+ user_state["resources"] = OrderedDict()
147
+ user_state["history_msgs"] = []
148
+ resources = OrderedDict()
149
+ for item in sorted(os.listdir("./assets/resources")):
150
+ if item.startswith("."):
151
+ continue
152
+ shutil.copy(
153
+ osp.join("./assets/resources", item),
154
+ osp.join(RESOURCE_ROOT, item),
155
+ )
156
+ resources[item] = get_file_type(item)
157
+ # return user_state, user_state["resources"]
158
+ return user_state, resources
159
+
160
+ def add_file(self, user_state, history, file):
161
+ if user_state.get("resources", None) is None:
162
+ user_state["resources"] = OrderedDict()
163
+
164
+ if file is None:
165
+ return user_state, None, history, None
166
+ # filename = os.path.basename(file.name)
167
+ file = Path(file)
168
+ ext = file.suffix[1:]
169
+ if ext in FILE_EXT["image"]:
170
+ ext = "png"
171
+ r_type = get_file_type(file.name)
172
+ new_filename = self._gen_new_name(r_type, ext)
173
+ saved_path = get_real_path(new_filename)
174
+ if ext in FILE_EXT["image"]:
175
+ Image.open(file).convert("RGB").save(saved_path, "png")
176
+ user_state["input_image"] = new_filename
177
+ else:
178
+ shutil.copy(file, saved_path)
179
+ logger.info(f"add file: {saved_path}")
180
+ user_state["resources"][new_filename] = r_type
181
+ for key, val in user_state["resources"].items():
182
+ if key == "prompt_points":
183
+ user_state["resources"].pop(key)
184
+ break
185
+ history, _ = self.add_text(history, (saved_path,), role="human", append=False)
186
+ history, _ = self.add_text(
187
+ history, f"Recieved file {new_filename}", role="assistant", append=False
188
+ )
189
+ memory = convert_dict_to_frame(user_state["resources"])
190
+ image_name = None
191
+ if Path(saved_path).suffix[1:] in FILE_EXT["image"]:
192
+ image_name = saved_path
193
+ return user_state, image_name, history, memory
194
+
195
+ def add_msg(self, history, text, audio, role="assistant", append=False):
196
+ if text is not None and text.strip() != "":
197
+ return self.add_text(history, text, role=role, append=append)
198
+ elif audio is not None:
199
+ return self.add_audio(history, audio, role=role, append=append)
200
+ return history, ""
201
+
202
+ def add_text(self, history, text, role="assistant", append=False):
203
+ if history is None:
204
+ return history, ""
205
+ assert role in ["human", "assistant"]
206
+ idx = 0
207
+ if len(history) == 0 or role == "human":
208
+ history.append([None, None])
209
+ if role == "assistant":
210
+ idx = 1
211
+ if not append and history[-1][1] is not None:
212
+ history.append([None, None])
213
+
214
+ if append:
215
+ history[-1][idx] = (
216
+ text if history[-1][idx] is None else history[-1][idx] + text
217
+ )
218
+ else:
219
+ history[-1][idx] = text
220
+ if isinstance(text, str):
221
+ logger.info(f"add text: {md2plain(text)}")
222
+
223
+ return history, ""
224
+
225
+ def add_audio(self, history, audio, role="assistant", append=False):
226
+ assert role in ["human", "assistant"]
227
+ result = self.whisper.transcribe(audio)
228
+ text = result["text"]
229
+ logger.info(f"add audio: {text}")
230
+ return self.add_text(history, text, role=role, append=append)
231
+
232
+ def plan(self, user_state, input_image, history, history_plan):
233
+ logger.info(f"Task plan...")
234
+ if user_state.get("resources", None) is None:
235
+ user_state["resources"] = OrderedDict()
236
+
237
+ request = history[-1][0]
238
+ user_state["request"] = request
239
+ if isinstance(request, str) and request.startswith("$"):
240
+ solution = f'show$("{request[1:]}")'
241
+ else:
242
+ solution = self.controller.plan(request, state=user_state)
243
+ print(f"request: {request}")
244
+ if solution == self.controller.SHORTCUT:
245
+ # md_text = "**Using builtin shortcut solution.**"
246
+ history, _ = self.add_text(
247
+ history, solution, role="assistant", append=False
248
+ )
249
+ user_state["solution"] = solution
250
+ user_state["history_msgs"] = history
251
+ yield user_state, input_image, history, [solution]
252
+ elif isinstance(solution, str) and solution.startswith("show$"):
253
+ user_state["solution"] = solution
254
+ yield user_state, input_image, history, solution
255
+ else:
256
+ output_text = (
257
+ "The whole process will take some time, please be patient.<br><br>"
258
+ )
259
+ history, _ = self.add_text(
260
+ history, output_text, role="assistant", append=True
261
+ )
262
+ yield user_state, input_image, history, history_plan
263
+ task_decomposition = next(solution)
264
+ if task_decomposition in [None, [], ""]:
265
+ output = "Error: unrecognized resource(s) in task decomposition."
266
+ task_decomposition = "[]"
267
+ else:
268
+ output = task_decomposition
269
+
270
+ output = f"**Task Decomposition:**\n{output}"
271
+ output = plain2md(output)
272
+ history, _ = self.add_text(history, output, role="assistant", append=True)
273
+ user_state["task_decomposition"] = json.loads(task_decomposition)
274
+ yield user_state, input_image, history, history_plan
275
+
276
+ history, _ = self.add_text(
277
+ history,
278
+ plain2md("\n\n**Thoughs-on-Graph:**\n"),
279
+ role="assistant",
280
+ append=True,
281
+ )
282
+ yield user_state, input_image, history, history_plan
283
+ solution_str = next(solution)
284
+ logger.info(f"Thoughs-on-Graph: \n{solution_str}")
285
+ if solution_str in [None, [], ""]:
286
+ output = "Empty solution possibly due to some internal errors."
287
+ solution_str = "[]"
288
+ else:
289
+ output = solution_str
290
+
291
+ output_md = plain2md(output)
292
+ history, _ = self.add_text(
293
+ history, output_md, role="assistant", append=True
294
+ )
295
+ solution = json.loads(solution_str)
296
+ user_state["solution"] = solution
297
+ user_state["history_msgs"] = history
298
+ yield user_state, input_image, history, solution
299
+
300
+ def execute(self, user_state, input_image, history, history_plan):
301
+ resources_state = user_state.get("resources", OrderedDict())
302
+ solution = user_state.get("solution", None)
303
+ if not solution:
304
+ yield user_state, input_image, history, history_plan
305
+ return
306
+ logger.info(f"Tool execution...")
307
+ if isinstance(solution, str) and solution.startswith("show$"):
308
+ key = solution[7:-2]
309
+ r_type = resources_state.get(key)
310
+ if r_type is None:
311
+ resource = f"{key} not found"
312
+ resource = container.auto_type("None", r_type, key)
313
+ history, _ = self.add_text(
314
+ history, (resource.to_chatbot(),), role="assistant"
315
+ )
316
+ user_state["history_msgs"] = history
317
+ yield user_state, input_image, history, history_plan
318
+ return
319
+ elif solution:
320
+ results = self.controller.execute(solution, state=user_state)
321
+ if not results:
322
+ yield user_state, input_image, history, history_plan
323
+ return
324
+
325
+ user_state["outputs"] = []
326
+ for result_per_step, executed_solutions, wrapped_outputs in results:
327
+ tool_name = json.dumps(result_per_step[0], ensure_ascii=False)
328
+ args = json.dumps(result_per_step[1], ensure_ascii=False)
329
+ if isinstance(result_per_step[2], Exception):
330
+ ret = f"Internal error: {result_per_step[2]}"
331
+ else:
332
+ ret = json.dumps(result_per_step[2], ensure_ascii=False)
333
+ history, _ = self.add_text(
334
+ history,
335
+ f"Call **{tool_name}:**<br>&nbsp;&nbsp;&nbsp;&nbsp;**Args**: {plain2md(args)}<br>&nbsp;&nbsp;&nbsp;&nbsp;**Ret**: {plain2md(ret)}",
336
+ role="assistant",
337
+ )
338
+ user_state["history_msgs"] = history
339
+ user_state["executed_solutions"] = executed_solutions
340
+ yield user_state, input_image, history, history_plan
341
+ for _, output in enumerate(wrapped_outputs):
342
+ if output is None or output.value is None:
343
+ continue
344
+ if isinstance(output, container.File):
345
+ history, _ = self.add_text(
346
+ history,
347
+ f"Here is {output.filename}:",
348
+ role="assistant",
349
+ )
350
+ history, _ = self.add_text(
351
+ history, (output.to_chatbot(),), role="assistant"
352
+ )
353
+ user_state["outputs"].extend(wrapped_outputs)
354
+ user_state["history_msgs"] = history
355
+ yield user_state, input_image, history, history_plan
356
+
357
+ else:
358
+ yield user_state, input_image, history, history_plan
359
+
360
+ def reply(self, user_state, history):
361
+ logger.info(f"Make response...")
362
+ executed_solution = user_state.get("executed_solutions", None)
363
+ resources_state = user_state.get("resources", OrderedDict())
364
+ solution = user_state.get("solution", None)
365
+ memory = convert_dict_to_frame(resources_state)
366
+ if isinstance(solution, str) and solution.startswith("show$"):
367
+ return user_state, history, memory
368
+
369
+ outputs = user_state.get("outputs", None)
370
+ response, user_state = self.controller.reply(
371
+ executed_solution, outputs, user_state
372
+ )
373
+ # prompt_mask_out = None
374
+ for i, output in enumerate(response):
375
+ if isinstance(output, container.File):
376
+ history, _ = self.add_text(history, f"Here is [{output.filename}]: ")
377
+ history, _ = self.add_text(history, (output.to_chatbot(),))
378
+ elif i == 0:
379
+ history, _ = self.add_text(history, output.to_chatbot())
380
+
381
+ user_state["history_msgs"] = history
382
+ return user_state, history, memory
383
+
384
+ def vote(self, user_state, history, data: gr.LikeData):
385
+ data_value = data.value
386
+ if isinstance(data_value, dict):
387
+ data_value = json.dumps(data_value)
388
+
389
+ if data.liked:
390
+ print("You upvoted this response: ", data_value)
391
+ logger.info("You upvoted this response: " + data_value)
392
+ else:
393
+ print("You downvoted this response: ", data_value)
394
+ logger.info("You downvoted this response: " + data_value)
395
+
396
+ remote_logging(
397
+ user_state.get("history_msgs", []),
398
+ user_state.get("task_decomposition", ""),
399
+ user_state.get("solution", []),
400
+ data_value,
401
+ data.liked,
402
+ )
403
+
404
+ msg = f"Thanks for your feedback! You feedback will contribute a lot to improving our ControlLLM."
405
+ history, _ = self.add_text(history, msg)
406
+ user_state["history_msgs"] = history
407
+ return user_state, history
408
+
409
+ def save_point(self, user_state, history, data: gr.SelectData):
410
+ if isinstance(data, gr.LikeData):
411
+ return self.vote(user_state, history, data)
412
+
413
+ if not isinstance(data, gr.SelectData):
414
+ return user_state, history
415
+
416
+ resource_state = user_state.get("resources")
417
+ input_image = user_state.get("input_image", None)
418
+ if input_image is None:
419
+ history, _ = self.add_text(history, "Please upload an image at first.")
420
+ history, _ = self.add_text(history, plans.BUILTIN_SEG_BY_POINTS, "human")
421
+ user_state["history_msg"] = history
422
+ return user_state, history
423
+
424
+ resource_state.pop(input_image, None)
425
+ resource_state[input_image] = "image"
426
+
427
+ history = history + [[plans.BUILTIN_SEG_BY_POINTS, None]]
428
+ points = []
429
+ if isinstance(points, str):
430
+ points = json.loads(points)
431
+
432
+ points.append(data.index)
433
+ resource_state[json.dumps(points)] = "prompt_points"
434
+ user_state["resources"] = resource_state
435
+ return user_state, history
436
+
437
+
438
+ def on_switch_input(state_input, text, audio, disable=False):
439
+ if state_input == "audio" or disable:
440
+ return "text", gr.update(visible=True), gr.update(visible=False)
441
+ return "audio", gr.update(visible=False), gr.update(visible=True)
442
+
443
+
444
+ def on_mask_submit(history):
445
+ history = history + [(plans.BUILTIN_SEG_BY_MASK, None)]
446
+ return history
447
+
448
+
449
+ def app(controller="cllm.agents.tog.Controller", https=False, **kwargs):
450
+ loop = InteractionLoop(controller=controller)
451
+ init_state, builtin_resources = loop.init_state()
452
+ css = """
453
+ code {
454
+ font-size: var(--text-sm);
455
+ white-space: pre-wrap; /* Since CSS 2.1 */
456
+ white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
457
+ white-space: -pre-wrap; /* Opera 4-6 */
458
+ white-space: -o-pre-wrap; /* Opera 7 */
459
+ word-wrap: break-word; /* Internet Explorer 5.5+ */
460
+ }
461
+ """
462
+ with gr.Blocks(theme=Seafoam(), css=css) as demo:
463
+ gr.HTML(
464
+ """
465
+ <div align='center'> <h1>ControlLLM </h1> </div>
466
+ <p align="center"> A framework for multi-modal interaction which is able to control LLMs over invoking tools more accurately. </p>
467
+ <p align="center"><a href="https://github.com/OpenGVLab/ControlLLM"><b>GitHub</b></a>
468
+ &nbsp;&nbsp;&nbsp; <a href="https://arxiv.org/abs/2311.11797"><b>ArXiv</b></a></p>
469
+ """,
470
+ )
471
+
472
+ state_input = gr.State("text")
473
+ user_state = gr.State(copy.deepcopy(init_state))
474
+ with gr.Row():
475
+ with gr.Column(scale=6):
476
+ with gr.Tabs():
477
+ with gr.Tab("Chat"):
478
+ chatbot = gr.Chatbot(
479
+ [],
480
+ elem_id="chatbot",
481
+ avatar_images=[
482
+ "assets/human.png",
483
+ "assets/assistant.png",
484
+ ],
485
+ show_copy_button=True,
486
+ height=550,
487
+ )
488
+
489
+ with gr.Row():
490
+ with gr.Column(scale=12):
491
+ text = gr.Textbox(
492
+ show_label=False,
493
+ placeholder="Enter text and press enter, or upload an image.",
494
+ container=False,
495
+ )
496
+ audio = gr.Audio(
497
+ sources="microphone", type="filepath", visible=False
498
+ )
499
+ with gr.Column(scale=2, min_width=80):
500
+ submit = gr.Button("Submit", variant="primary")
501
+ with gr.Column(scale=1, min_width=40):
502
+ record = gr.Button("🎙️")
503
+ with gr.Column(scale=1, min_width=40):
504
+ upload_btn = gr.UploadButton(
505
+ "📁",
506
+ file_types=[
507
+ "image",
508
+ "video",
509
+ "audio",
510
+ ".pdf",
511
+ ],
512
+ )
513
+
514
+ gr.Examples(
515
+ [
516
+ "Who are you?",
517
+ "How is about weather in Beijing",
518
+ "Describe the given image.",
519
+ "find the woman wearing the red skirt in the image",
520
+ "Generate a video that shows Pikachu surfing in waves.",
521
+ "How many horses are there in the image?",
522
+ "Can you erase the dog in the given image?",
523
+ "Remove the object based on the given mask.",
524
+ "Can you make a video of a serene lake with vibrant green grass and trees all around? And then create a webpage using HTML to showcase this video?",
525
+ "Generate an image that shows a beautiful landscape with a calm lake reflecting the blue sky and white clouds. Then generate a video to introduce this image.",
526
+ "replace the masked object with a cute yellow dog",
527
+ "replace the sheep with a cute dog in the image",
528
+ "Recognize the action in the video",
529
+ "Generate an image where a astronaut is riding a horse",
530
+ "Please generate a piece of music from the given image",
531
+ "Please give me an image that shows an astronaut riding a horse on mars.",
532
+ "What’s the weather situation in Berlin? Can you generate a new image that represents the weather in there?",
533
+ "Can you recognize the text from the image and tell me how much is Eggs Florentine?",
534
+ "Generate a piece of music for this video and dub this video with generated music",
535
+ "Generate a new image based on depth map from input image",
536
+ "Remove the cats from the image_1.png, image_2.png, image_3.png",
537
+ "I need the banana removed from the c4c40e_image.png, 9e867c_image.png, 9e13sc_image.png",
538
+ "I would be so happy if you could create a new image using the scribble from input image. The new image should be a tropical island with a dog. Write a detailed description of the given image. and highlight the dog in image",
539
+ "Please generate a piece of music and a new video from the input image",
540
+ "generate a new image conditioned on the segmentation from input image and the new image shows that a gorgeous lady is dancing",
541
+ "generate a new image with a different background but maintaining the same composition as input image",
542
+ "Generate a new image that shows an insect robot preparing a delicious meal. Then give me a video based on new image. Finally, dub the video with suitable background music.",
543
+ "Translate the text into speech: I have a dream that one day this nation will rise up and live out the true meaning of its creed: We hold these truths to be self-evident that all men are created equal.I have a dream that one day on the red hills of Georgia the sons of former slaves and the sons of former slave owners will be able to sit down together at the table of brotherhood. I have a dream that one day even the state of Mississippi, a state sweltering with the heat of injustice, sweltering with the heat of oppression, will be transformed into an oasis of freedom and justice.",
544
+ ],
545
+ inputs=[text],
546
+ )
547
+ gr.Examples(
548
+ list(plans.BUILTIN_PLANS.keys()),
549
+ inputs=[text],
550
+ label="Builtin Examples",
551
+ )
552
+
553
+ with gr.Column(scale=5):
554
+ with gr.Tabs():
555
+ with gr.Tab("Mask Input"):
556
+ image_mask = gr.components.Image(
557
+ sources="upload",
558
+ interactive=True,
559
+ type="filepath",
560
+ )
561
+ # with gr.Row():
562
+ # mask_submit_btn = gr.Button("Segment", variant="primary")
563
+ with gr.Row():
564
+ image_submit_btn = gr.Button("Upload", variant="primary")
565
+
566
+ with gr.Tab("Plan"):
567
+ planbot = gr.JSON(elem_classes="json")
568
+
569
+ with gr.Tab("Memory"):
570
+ memory_table = gr.DataFrame(
571
+ # value=convert_dict_to_frame(builtin_resources),
572
+ label="Memory",
573
+ headers=["Resource", "Type"],
574
+ row_count=5,
575
+ wrap=True,
576
+ )
577
+ gr.Examples(
578
+ [
579
+ osp.join("./assets/resources", item)
580
+ for item in builtin_resources.keys()
581
+ if item.endswith(".png")
582
+ ],
583
+ inputs=[image_mask],
584
+ label="File Examples",
585
+ )
586
+
587
+ chatbot.like(
588
+ loop.vote,
589
+ [
590
+ user_state,
591
+ chatbot,
592
+ ],
593
+ [
594
+ user_state,
595
+ chatbot,
596
+ ],
597
+ )
598
+ reply_inputs = [user_state, image_mask, chatbot, planbot]
599
+ reply_outputs = [
600
+ user_state,
601
+ # image_mask,
602
+ chatbot,
603
+ memory_table,
604
+ # planbot,
605
+ ]
606
+
607
+ add_text = [
608
+ partial(loop.add_text, role="human"),
609
+ [chatbot, text],
610
+ [chatbot, text],
611
+ ]
612
+
613
+ text.submit(*add_text).then(loop.plan, reply_inputs, reply_inputs).then(
614
+ loop.execute, reply_inputs, reply_inputs
615
+ ).then(loop.reply, [user_state, chatbot], reply_outputs)
616
+
617
+ add_msg = [
618
+ partial(loop.add_msg, role="human"),
619
+ [chatbot, text, audio],
620
+ [chatbot, text],
621
+ ]
622
+
623
+ submit.click(*add_msg).then(
624
+ partial(on_switch_input, disable=True),
625
+ [state_input, text, audio],
626
+ [state_input, text, audio],
627
+ ).then(loop.plan, reply_inputs, reply_inputs).then(
628
+ loop.execute, reply_inputs, reply_inputs
629
+ ).then(
630
+ loop.reply, [user_state, chatbot], reply_outputs
631
+ )
632
+
633
+ upload_btn.upload(
634
+ loop.add_file,
635
+ inputs=[user_state, chatbot, upload_btn],
636
+ outputs=[user_state, image_mask, chatbot, memory_table],
637
+ )
638
+ record.click(
639
+ on_switch_input,
640
+ [state_input, text, audio],
641
+ [state_input, text, audio],
642
+ )
643
+
644
+ image_mask.select(
645
+ loop.save_point, [user_state, chatbot], [user_state, chatbot]
646
+ ).then(loop.plan, reply_inputs, reply_inputs).then(
647
+ loop.execute, reply_inputs, reply_inputs
648
+ ).then(
649
+ loop.reply, [user_state, chatbot], reply_outputs
650
+ )
651
+
652
+ image_mask.upload(
653
+ loop.add_file,
654
+ inputs=[user_state, chatbot, image_mask],
655
+ outputs=[user_state, image_mask, chatbot, memory_table],
656
+ )
657
+ image_submit_btn.click(
658
+ loop.add_file,
659
+ inputs=[user_state, chatbot, image_mask],
660
+ outputs=[user_state, image_mask, chatbot, memory_table],
661
+ )
662
+
663
+ if https:
664
+ demo.queue().launch(
665
+ server_name="0.0.0.0",
666
+ # ssl_certfile="./certificate/cert.pem",
667
+ # ssl_keyfile="./certificate/key.pem",
668
+ ssl_verify=False,
669
+ show_api=False,
670
+ allowed_paths=[
671
+ "assets/human.png",
672
+ "assets/assistant.png",
673
+ ],
674
+ **kwargs,
675
+ )
676
+ else:
677
+ demo.queue().launch(
678
+ server_name="0.0.0.0",
679
+ show_api=False,
680
+ allowed_paths=[
681
+ "assets/human.png",
682
+ "assets/assistant.png",
683
+ ],
684
+ **kwargs,
685
+ )
686
+
687
+
688
+ if __name__ == "__main__":
689
+ os.makedirs(RESOURCE_ROOT, exist_ok=True)
690
+ app(controller="cllm.agents.tog.Controller", server_port=10024)
assets/assistant.png ADDED
assets/human.png ADDED
builtin_plan.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "you know what I want": [
3
+ [
4
+ {
5
+ "tool_name": "text_to_image",
6
+ "inputs": {
7
+ "text": "a dog"
8
+ },
9
+ "outputs": [
10
+ "image"
11
+ ]
12
+ }
13
+ ]
14
+ ]
15
+ }
cllm/agents/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .base import Tool, Action
2
+ from .container import *
cllm/agents/base.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from enum import Enum
3
+ from typing import Callable, List
4
+ import json
5
+ from pathlib import Path
6
+ from collections import OrderedDict
7
+
8
+
9
+ @dataclass
10
+ class Action:
11
+ """The action represent an assignment.
12
+ `output = tool_name(**inputs)`
13
+
14
+ Examples:
15
+ >>> mask = segmentation_by_mask(image=image, prompt_mask=prompt_mask)
16
+ >>> image = image_inpainting(image=image, mask=mask)
17
+ """
18
+
19
+ tool_name: str = (None,)
20
+ inputs: dict = (None,)
21
+ outputs: List[str] = (None,)
22
+
23
+ def __str__(self) -> str:
24
+ args = ", ".join([f"{k}={v}" for k, v in self.inputs.items()])
25
+ return "{} = {}(".format(", ".join(self.outputs), self.tool_name) + args + ")"
26
+
27
+ def dict(self):
28
+ args = {str(k): str(v) for k, v in self.inputs.items()}
29
+ # args = {str(item["name"]): str(item["value"]) for item in self.inputs}
30
+ rets = [o if isinstance(o, str) else str(o) for o in self.outputs]
31
+ return {
32
+ "tool": self.tool_name,
33
+ "inputs": args,
34
+ "outputs": rets,
35
+ }
36
+
37
+
38
+ class DataType(Enum):
39
+ TEXT = "text"
40
+ TAGS = "tags"
41
+ TITLE = "title"
42
+ # HTML = "text.html"
43
+ HTML = "html"
44
+ LOCATION = "location"
45
+ WEATHER = "weather"
46
+ TIME = "time"
47
+
48
+ IMAGE = "image"
49
+ VIDEO = "video"
50
+ AUDIO = "audio"
51
+ ANY = "any"
52
+ NONE = "none"
53
+
54
+ SEGMENTATION = "image.segmentation"
55
+ EDGE = "image.edge"
56
+ LINE = "image.line"
57
+ HED = "image.hed"
58
+ CANNY = "image.canny"
59
+ SCRIBBLE = "image.scribble"
60
+ POSE = "image.pose"
61
+ DEPTH = "image.depth"
62
+ NORMAL = "image.normal"
63
+
64
+ MASK = "image.mask" # SAM mask
65
+ POINT = "point"
66
+ BBOX = "bbox" # {'label': 'dog', 'box': [1,2,3,4], 'score': 0.9}
67
+ CATEGORY = "category"
68
+
69
+ LIST = "list"
70
+
71
+ def __str__(self):
72
+ return self.value
73
+
74
+ def __eq__(self, other):
75
+ if isinstance(other, str):
76
+ return self.value == other
77
+ elif isinstance(other, self.__class__):
78
+ return self.value == other.value
79
+ else:
80
+ return False
81
+
82
+
83
+ @dataclass
84
+ class Resource:
85
+ name: str
86
+ type: DataType
87
+ value: None
88
+ # description: str = None
89
+
90
+ def dict(self):
91
+ return {
92
+ "name": self.name,
93
+ "type": str(self.type),
94
+ "value": str(self.value),
95
+ # "description": self.description,
96
+ }
97
+
98
+
99
+ @dataclass
100
+ class Tool:
101
+ class Domain(Enum):
102
+ IMAGE_PERCEPTION = "image-perception"
103
+ IMAGE_GENERATION = "image-generation"
104
+ IMAGE_EDITING = "image-editing"
105
+ IMAGE_PROCESSING = "image-processing"
106
+ AUDIO_PERCEPTION = "audio-perception"
107
+ AUDIO_GENERATION = "audio-generation"
108
+ VIDEO_PERCEPTION = "video-perception"
109
+ VIDEO_GENERATION = "video-generation"
110
+ VIDEO_PROCESSING = "video-processing"
111
+ VIDEO_EDITING = "video-editing"
112
+ VIDEO_CUTTING = "video-cutting"
113
+ NATURAL_LANGUAGE_PROCESSING = "natural-language-processing"
114
+ CODE_GENERATION = "code-generation"
115
+ VISUAL_QUESTION_ANSWERING = "visual-question-answering"
116
+ QUESTION_ANSWERING = "question-answering"
117
+ GENERAL = "general"
118
+
119
+ def __str__(self):
120
+ return self.value
121
+
122
+ @dataclass
123
+ class Argument:
124
+ name: str
125
+ type: DataType
126
+ description: str
127
+
128
+ def dict(self):
129
+ return {
130
+ "name": self.name,
131
+ "type": str(self.type),
132
+ "description": self.description,
133
+ }
134
+
135
+ name: str
136
+ description: str
137
+ domain: Domain
138
+ model: Callable
139
+
140
+ usages: List[str] = field(default_factory=lambda: [])
141
+ args: List[Argument] = field(default_factory=lambda: [])
142
+ returns: List[Argument] = field(default_factory=lambda: [])
143
+
144
+ def dict(self):
145
+ return {
146
+ "name": self.name,
147
+ "description": self.description,
148
+ "domain": str(self.domain),
149
+ "args": [a.dict() for a in self.args],
150
+ "returns": [r.dict() for r in self.returns],
151
+ }
152
+
153
+
154
+ NON_FILE_TYPES = [
155
+ DataType.TAGS,
156
+ DataType.TEXT,
157
+ DataType.TITLE,
158
+ DataType.BBOX,
159
+ DataType.CATEGORY,
160
+ DataType.LIST,
161
+ DataType.LOCATION,
162
+ DataType.POINT,
163
+ DataType.WEATHER,
164
+ DataType.TIME,
165
+ ]
166
+
167
+
168
+ if __name__ == "__main__":
169
+ s = [
170
+ [Action("a", {"aa": [Path("/a/d/e/t.txt")]}, [Path("/a/aa.txt")])],
171
+ Action("b", {"bb": "bbb"}, ["bbb"]),
172
+ ]
173
+ print(json.dumps(s, indent=4, default=lambda o: o.dict()))
cllm/agents/builtin/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from . import plans
2
+ from .plans import BUILTIN_PLANS, load_builtin_plans
3
+ from .tools import TOOLS
cllm/agents/builtin/plans.py ADDED
@@ -0,0 +1,634 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ sys.path.append(os.getcwd())
5
+
6
+ from cllm.agents.base import Action
7
+
8
+ BUILTIN_SEG_BY_POINTS = "Segment the given image based on the prompt points."
9
+ BUILTIN_SEG_BY_MASK = "Segment the given image based on the prompt mask."
10
+ # BUILTIN_REMOVE_BY_MASK = "Remove the object based on the given mask."
11
+ BUILTIN_IMAGE_TO_EDGE = "Generate the edge from the given image."
12
+ BUILTIN_GENERATE_SIMILAR_IMAGE = "Generate a new image similar to the input image"
13
+ # BUILTIN_GENERATE_SIMILAR_IMAGE2 = "Generate a similar image from the given image 2"
14
+ # BUILTIN_GENERATE_SIMILAR_IMAGE3 = "Image to image. 3"
15
+ BUILTIN_GENERATE_SIMILAR_IMAGE4 = "Generate a new image similar to image 4"
16
+ BUILTIN_GENERATE_IMAGE_HED = "Generate a new image based on HED result from input image"
17
+ BUILTIN_GENERATE_IMAGE_DEPTH = (
18
+ "Generate a new image based on depth map from input image"
19
+ )
20
+ BUILTIN_GENERATE_IMAGE_OCR = "Please extract the text from the image"
21
+ BUILTIN_TEXT_EDGE_TO_IMAGE = "Generate an image based on the given edge map."
22
+ BUILTIN_GENERATE_IMAGE = "Generate a new image that shows a woman is skiing"
23
+ BUILTIN_IMAGE_TO_VIDEO = "Generate a video from the image"
24
+ BUILTIN_COUNT_OBJECTS = "Provide me with the count of bears in the input image"
25
+ BUILTIN_VIDEO_TO_WEBPAGE = "Generate a web page for input video"
26
+ BUILTIN_TEXT_TO_MUSIC = "Please generate a piece of music based on given prompt. Here is the prompt: An 80s driving pop song with heavy drums and synth pads in the background"
27
+ BUILTIN_IMAGE_ERASING1 = "Erase the wine glass from the photo"
28
+ BUILTIN_IMAGE_ERASING2 = "Erase the cats in the photo"
29
+ BUILTIN_IMAGE_CROPPING = "Crop the cats from the photo"
30
+ BUILTIN_IMAGE_SEG = "give me the mask of elephant."
31
+ BUILTIN_IMAGE_HIGHLIGHT = "highlight the elephant."
32
+ BUILTIN_TEXT_SPEECH = "translate text into speech"
33
+ BUILTIN_DUBBING = "dub this video with the given audio"
34
+ BUILTIN_COUNT_OBJECTS2 = "Count the horse in the image."
35
+ BUILTIN_IMAGE_TO_VIDEO2 = "Generate an image that shows a serene and beautiful landscape with a calm lake reflecting the blue sky and white clouds. Then generate a video to introduce this image."
36
+ BUILTIN_IMAGE_TO_VIDEO3 = "Create a visual and auditory representation of a peaceful and scenic landscape. The image should depict a serene and beautiful landscape with a calm lake reflecting the blue sky. The music should match the image. Finally, combine the image and the music into a video that showcases the beauty of nature."
37
+ BUILTIN_VIDEO_CLS = "Recognize the action in the video"
38
+ BUILTIN_VIDEO_CLS = "Recognize the action in the video"
39
+ BUILTIN_AUDIO_CLS = "Recognize the event in this audio"
40
+ BUILTIN_IMAGE2MUSIC = "Generate a piece of music for this image"
41
+ BUILTIN_VIDEO2MUSIC = (
42
+ "Generate a piece of music for this video and dub the video with generated music"
43
+ )
44
+
45
+ BUILTIN_PLANS = {
46
+ # BUILTIN_REMOVE_BY_MASK: [
47
+ # [
48
+ # Action(
49
+ # tool_name="image_inpainting",
50
+ # inputs={"image": "image", "mask": "image.mask"},
51
+ # outputs=["<GENERATED>-0"],
52
+ # )
53
+ # ]
54
+ # ],
55
+ BUILTIN_IMAGE_TO_EDGE: [
56
+ [
57
+ Action(
58
+ tool_name="image_to_edge",
59
+ inputs={"image": "image"},
60
+ outputs=["<GENERATED>-0"],
61
+ )
62
+ ]
63
+ ],
64
+ BUILTIN_TEXT_EDGE_TO_IMAGE: [
65
+ [
66
+ Action(
67
+ tool_name="image_captioning",
68
+ inputs={"image": "image"},
69
+ outputs=["<TOOL-GENERATED>-prompt"],
70
+ ),
71
+ Action(
72
+ tool_name="edge_text_to_image",
73
+ inputs={
74
+ "edge": "image.edge",
75
+ "text": "<TOOL-GENERATED>-prompt",
76
+ },
77
+ outputs=["<GENERATED>-0"],
78
+ ),
79
+ ]
80
+ ],
81
+ BUILTIN_GENERATE_SIMILAR_IMAGE: [
82
+ [
83
+ Action(
84
+ tool_name="image_to_edge",
85
+ inputs={"image": "image"},
86
+ outputs=["<TOOL-GENERATED>-edge"],
87
+ ),
88
+ Action(
89
+ tool_name="image_captioning",
90
+ inputs={"image": "image"},
91
+ outputs=["<TOOL-GENERATED>-prompt"],
92
+ ),
93
+ Action(
94
+ tool_name="edge_text_to_image",
95
+ inputs={
96
+ "edge": "<TOOL-GENERATED>-edge",
97
+ "text": "<TOOL-GENERATED>-prompt",
98
+ },
99
+ outputs=["<GENERATED>-0"],
100
+ ),
101
+ ]
102
+ ],
103
+ # BUILTIN_GENERATE_SIMILAR_IMAGE2: [
104
+ # [
105
+ # Action(
106
+ # tool_name="image_captioning",
107
+ # inputs={"image": "image"},
108
+ # outputs=["<TOOL-GENERATED>-prompt"],
109
+ # ),
110
+ # Action(
111
+ # tool_name="text_to_image",
112
+ # inputs={"text": "<TOOL-GENERATED>-prompt"},
113
+ # outputs=["<GENERATED>-0"],
114
+ # ),
115
+ # ]
116
+ # ],
117
+ # BUILTIN_GENERATE_SIMILAR_IMAGE3: [
118
+ # [
119
+ # Action(
120
+ # tool_name="image_to_image",
121
+ # inputs={"image": "image"},
122
+ # outputs=["<GENERATED>-0"],
123
+ # ),
124
+ # ]
125
+ # ],
126
+ BUILTIN_GENERATE_IMAGE_HED: [
127
+ [
128
+ Action(
129
+ tool_name="image_to_hed",
130
+ inputs={"image": "image"},
131
+ outputs=["<TOOL-GENERATED>-image_to_hed-hed-0"],
132
+ ),
133
+ Action(
134
+ tool_name="hed_text_to_image",
135
+ inputs={
136
+ "text": "beautiful mountains and sunset",
137
+ "hed": "<TOOL-GENERATED>-image_to_hed-hed-0",
138
+ },
139
+ outputs=["<GENERATED>-0"],
140
+ ),
141
+ ]
142
+ ],
143
+ BUILTIN_GENERATE_IMAGE_DEPTH: [
144
+ [
145
+ Action(
146
+ tool_name="image_captioning",
147
+ inputs={
148
+ "image": "image",
149
+ },
150
+ outputs=["<TOOL-GENERATED>-image_captioning-text-0"],
151
+ ),
152
+ Action(
153
+ tool_name="image_to_depth",
154
+ inputs={"image": "image"},
155
+ outputs=["<TOOL-GENERATED>-image_to_depth-depth-0"],
156
+ ),
157
+ Action(
158
+ tool_name="depth_text_to_image",
159
+ inputs={
160
+ "text": "<TOOL-GENERATED>-image_captioning-text-0",
161
+ "depth": "<TOOL-GENERATED>-image_to_depth-depth-0",
162
+ },
163
+ outputs=["<GENERATED>-0"],
164
+ ),
165
+ ]
166
+ ],
167
+ BUILTIN_GENERATE_IMAGE_OCR: [
168
+ [
169
+ Action(
170
+ tool_name="optical_character_recognition",
171
+ inputs={"image": "image"},
172
+ outputs=["<GENERATED>-0"],
173
+ )
174
+ ]
175
+ ],
176
+ BUILTIN_COUNT_OBJECTS: [
177
+ [
178
+ Action(
179
+ tool_name="object_detection",
180
+ inputs={"image": "image"},
181
+ outputs=["<TOOL-GENERATED>-object_detection-bbox-0"],
182
+ ),
183
+ Action(
184
+ tool_name="select_bbox",
185
+ inputs={
186
+ "bbox_list": "<TOOL-GENERATED>-object_detection-bbox-0",
187
+ "condition": "bear",
188
+ },
189
+ outputs=["<TOOL-GENERATED>-select_bbox-bbox-0"],
190
+ ),
191
+ Action(
192
+ tool_name="count_objects",
193
+ inputs={"bbox_list": "<TOOL-GENERATED>-select_bbox-bbox-0"},
194
+ outputs=["<GENERATED>-0"],
195
+ ),
196
+ ],
197
+ [
198
+ Action(
199
+ tool_name="image_question_answering",
200
+ inputs={
201
+ "text": "Provide me with the count of bears in the input image",
202
+ "image": "image",
203
+ },
204
+ outputs=["<GENERATED>-1"],
205
+ )
206
+ ],
207
+ ],
208
+ BUILTIN_VIDEO_TO_WEBPAGE: [
209
+ [
210
+ Action(
211
+ tool_name="video_captioning",
212
+ inputs={"video": "video"},
213
+ outputs=["<TOOL-GENERATED>-text-0"],
214
+ ),
215
+ Action(
216
+ tool_name="text_to_music",
217
+ inputs={"text": "<TOOL-GENERATED>-text-0"},
218
+ outputs=["<TOOL-GENERATED>-text_to_music-audio-0"],
219
+ ),
220
+ Action(
221
+ tool_name="dub_video",
222
+ inputs={
223
+ "video": "video",
224
+ "audio": "<TOOL-GENERATED>-text_to_music-audio-0",
225
+ },
226
+ outputs=["<TOOL-GENERATED>-dub_video-video-0"],
227
+ ),
228
+ Action(
229
+ tool_name="title_generation",
230
+ inputs={"text": "<TOOL-GENERATED>-text-0"},
231
+ outputs=["<TOOL-GENERATED>-text-1"],
232
+ ),
233
+ Action(
234
+ tool_name="text_to_tags",
235
+ inputs={"text": "<TOOL-GENERATED>-text-0"},
236
+ outputs=["<TOOL-GENERATED>-tags-0"],
237
+ ),
238
+ Action(
239
+ tool_name="video_to_webpage",
240
+ inputs={
241
+ "video": "<TOOL-GENERATED>-dub_video-video-0",
242
+ "title": "<TOOL-GENERATED>-text-1",
243
+ "tags": "<TOOL-GENERATED>-tags-0",
244
+ "description": "<TOOL-GENERATED>-text-0",
245
+ },
246
+ outputs=["<GENERATED>-0"],
247
+ ),
248
+ ]
249
+ ],
250
+ BUILTIN_TEXT_TO_MUSIC: [
251
+ [
252
+ Action(
253
+ tool_name="text_to_music",
254
+ inputs={
255
+ "text": "An 80s driving pop song with heavy drums and synth pads in the background"
256
+ },
257
+ outputs=["<GENERATED>-audio-0"],
258
+ )
259
+ ]
260
+ ],
261
+ BUILTIN_IMAGE_ERASING1: [
262
+ [
263
+ Action(
264
+ tool_name="image_instance_segmentation",
265
+ inputs={"image": "image"},
266
+ outputs=["<TOOL-GENERATED>-image_instance_segmentation-mask-0"],
267
+ ),
268
+ Action(
269
+ tool_name="select_mask",
270
+ inputs={
271
+ "mask_list": "<TOOL-GENERATED>-image_instance_segmentation-mask-0",
272
+ "condition": "wine glass",
273
+ },
274
+ outputs=["<TOOL-GENERATED>-select_mask-mask-1"],
275
+ ),
276
+ Action(
277
+ tool_name="image_inpainting",
278
+ inputs={
279
+ "image": "image",
280
+ "mask": "<TOOL-GENERATED>-select_mask-mask-0",
281
+ },
282
+ outputs=["<GENERATED>-0"],
283
+ ),
284
+ ]
285
+ ],
286
+ BUILTIN_IMAGE_ERASING2: [
287
+ [
288
+ Action(
289
+ tool_name="image_instance_segmentation",
290
+ inputs={"image": "image"},
291
+ outputs=["<TOOL-GENERATED>-image_instance_segmentation-mask-0"],
292
+ ),
293
+ Action(
294
+ tool_name="select_mask",
295
+ inputs={
296
+ "mask_list": "<TOOL-GENERATED>-image_instance_segmentation-mask-0",
297
+ "condition": "cat",
298
+ },
299
+ outputs=["<TOOL-GENERATED>-select_mask-mask-0"],
300
+ ),
301
+ Action(
302
+ tool_name="image_inpainting",
303
+ inputs={
304
+ "image": "image",
305
+ "mask": "<TOOL-GENERATED>-select_mask-mask-0",
306
+ },
307
+ outputs=["<GENERATED>-0"],
308
+ ),
309
+ ]
310
+ ],
311
+ BUILTIN_IMAGE_CROPPING: [
312
+ [
313
+ Action(
314
+ tool_name="object_detection",
315
+ inputs={"image": "image"},
316
+ outputs=["<TOOL-GENERATED>-object_detection-bbox-0"],
317
+ ),
318
+ Action(
319
+ tool_name="select_bbox",
320
+ inputs={
321
+ "bbox_list": "<TOOL-GENERATED>-object_detection-bbox-0",
322
+ "condition": "cat",
323
+ },
324
+ outputs=["<TOOL-GENERATED>-select_bbox-bbox-0"],
325
+ ),
326
+ Action(
327
+ tool_name="image_cropping",
328
+ inputs={
329
+ "image": "image",
330
+ "object": "<TOOL-GENERATED>-select_bbox-bbox-0",
331
+ },
332
+ outputs=["<GENERATED>-0"],
333
+ ),
334
+ ]
335
+ ],
336
+ BUILTIN_IMAGE_SEG: [
337
+ [
338
+ Action(
339
+ tool_name="image_instance_segmentation",
340
+ inputs={"image": "image"},
341
+ outputs=["<TOOL-GENERATED>-image_instance_segmentation-mask-0"],
342
+ ),
343
+ Action(
344
+ tool_name="select_mask",
345
+ inputs={
346
+ "mask_list": "<TOOL-GENERATED>-image_instance_segmentation-mask-0",
347
+ "condition": "elephant",
348
+ },
349
+ outputs=["<GENERATED>-0"],
350
+ ),
351
+ ]
352
+ ],
353
+ BUILTIN_IMAGE_HIGHLIGHT: [
354
+ [
355
+ Action(
356
+ tool_name="object_detection",
357
+ inputs={"image": "image"},
358
+ outputs=["<TOOL-GENERATED>-object_detection-bbox-0"],
359
+ ),
360
+ Action(
361
+ tool_name="select_bbox",
362
+ inputs={
363
+ "bbox_list": "<TOOL-GENERATED>-object_detection-bbox-0",
364
+ "condition": "elephant",
365
+ },
366
+ outputs=["<TOOL-GENERATED>-select_bbox-bbox-0"],
367
+ ),
368
+ Action(
369
+ tool_name="highlight_object_on_image",
370
+ inputs={
371
+ "image": "image",
372
+ "bbox": "<TOOL-GENERATED>-select_bbox-bbox-0",
373
+ },
374
+ outputs=["<GENERATED>-0"],
375
+ ),
376
+ ]
377
+ ],
378
+ BUILTIN_TEXT_SPEECH: [
379
+ [
380
+ Action(
381
+ tool_name="text_to_speech",
382
+ inputs={
383
+ "text": "Hope is the thing with feathers That perches in the soul, And sings the tune without the words, And never stops at all"
384
+ },
385
+ outputs=["<GENERATED>-0"],
386
+ )
387
+ ]
388
+ ],
389
+ BUILTIN_DUBBING: [
390
+ [
391
+ Action(
392
+ tool_name="dub_video",
393
+ inputs={"video": "video", "audio": "audio"},
394
+ outputs=["<GENERATED>-0"],
395
+ )
396
+ ]
397
+ ],
398
+ BUILTIN_GENERATE_SIMILAR_IMAGE4: [
399
+ [
400
+ Action(
401
+ tool_name="segment_anything",
402
+ inputs={"image": "image"},
403
+ outputs=["<TOOL-GENERATED>-seg"],
404
+ ),
405
+ Action(
406
+ tool_name="image_captioning",
407
+ inputs={"image": "image"},
408
+ outputs=["<TOOL-GENERATED>-prompt"],
409
+ ),
410
+ Action(
411
+ tool_name="segmentation_text_to_image",
412
+ inputs={
413
+ "segmentation": "<TOOL-GENERATED>-seg",
414
+ "text": "<TOOL-GENERATED>-prompt",
415
+ },
416
+ outputs=["<GENERATED>-0"],
417
+ ),
418
+ ]
419
+ ],
420
+ BUILTIN_GENERATE_IMAGE: [
421
+ [
422
+ Action(
423
+ tool_name="text_to_image",
424
+ inputs={"text": "a woman is skiing"},
425
+ outputs=["<GENERATED>-0"],
426
+ )
427
+ ]
428
+ ],
429
+ BUILTIN_IMAGE_TO_VIDEO: [
430
+ [
431
+ Action(
432
+ tool_name="image_to_video",
433
+ inputs={"image": "image"},
434
+ outputs=["<GENERATED>-0"],
435
+ )
436
+ ]
437
+ ],
438
+ BUILTIN_COUNT_OBJECTS2: [
439
+ [
440
+ Action(
441
+ tool_name="object_detection",
442
+ inputs={"image": "image"},
443
+ outputs=["<TOOL-GENERATED>-object_detection-bbox-0"],
444
+ ),
445
+ Action(
446
+ tool_name="select_bbox",
447
+ inputs={
448
+ "bbox_list": "<TOOL-GENERATED>-object_detection-bbox-0",
449
+ "condition": "horse",
450
+ },
451
+ outputs=["<TOOL-GENERATED>-select_bbox-bbox-0"],
452
+ ),
453
+ Action(
454
+ tool_name="count_objects",
455
+ inputs={"bbox_list": "<TOOL-GENERATED>-select_bbox-bbox-0"},
456
+ outputs=["<GENERATED>-0"],
457
+ ),
458
+ ],
459
+ [
460
+ Action(
461
+ tool_name="image_question_answering",
462
+ inputs={
463
+ "text": "Provide me with the count of horses in the input image",
464
+ "image": "image",
465
+ },
466
+ outputs=["<GENERATED>-1"],
467
+ )
468
+ ],
469
+ ],
470
+ BUILTIN_IMAGE_TO_VIDEO2: [
471
+ [
472
+ Action(
473
+ tool_name="text_to_image",
474
+ inputs={
475
+ "text": "A serene and beautiful landscape with a calm lake reflecting the blue sky and white clouds."
476
+ },
477
+ outputs=["<GENERATED>-0"],
478
+ ),
479
+ ],
480
+ [
481
+ Action(
482
+ tool_name="image_captioning",
483
+ inputs={"image": "<GENERATED>-0"},
484
+ outputs=["<TOOL-GENERATED>-text-0"],
485
+ ),
486
+ Action(
487
+ tool_name="text_to_speech",
488
+ inputs={"text": "<TOOL-GENERATED>-text-0"},
489
+ outputs=["<TOOL-GENERATED>-text_to_speech-audio-0"],
490
+ ),
491
+ Action(
492
+ tool_name="image_audio_to_video",
493
+ inputs={
494
+ "image": "<GENERATED>-0",
495
+ "audio": "<TOOL-GENERATED>-text_to_speech-audio-0",
496
+ },
497
+ outputs=["<GENERATED>-1"],
498
+ ),
499
+ ],
500
+ ],
501
+ BUILTIN_IMAGE_TO_VIDEO3: [
502
+ [
503
+ Action(
504
+ tool_name="text_to_image",
505
+ inputs={
506
+ "text": "A serene and beautiful landscape with a calm lake reflecting the blue sky."
507
+ },
508
+ outputs=["<GENERATED>-0"],
509
+ ),
510
+ ],
511
+ [
512
+ Action(
513
+ tool_name="image_captioning",
514
+ inputs={"image": "<GENERATED>-0"},
515
+ outputs=["<TOOL-GENERATED>-text-0"],
516
+ ),
517
+ Action(
518
+ tool_name="text_to_music",
519
+ inputs={"text": "<TOOL-GENERATED>-text-0"},
520
+ outputs=["<GENERATED>-1"],
521
+ ),
522
+ ],
523
+ [
524
+ Action(
525
+ tool_name="image_to_video",
526
+ inputs={
527
+ "image": "<GENERATED>-0",
528
+ },
529
+ outputs=["<TOOL-GENERATED>-image_to_video-video-0"],
530
+ ),
531
+ Action(
532
+ tool_name="dub_video",
533
+ inputs={
534
+ "video": "<TOOL-GENERATED>-image_to_video-video-0",
535
+ "audio": "<GENERATED>-1",
536
+ },
537
+ outputs=["<GENERATED>-2"],
538
+ ),
539
+ ],
540
+ ],
541
+ BUILTIN_VIDEO_CLS: [
542
+ [
543
+ Action(
544
+ tool_name="video_classification",
545
+ inputs={"video": "video"},
546
+ outputs=["<GENERATED>-0"],
547
+ )
548
+ ]
549
+ ],
550
+ BUILTIN_AUDIO_CLS: [
551
+ [
552
+ Action(
553
+ tool_name="audio_classification",
554
+ inputs={"audio": "audio"},
555
+ outputs=["<GENERATED>-0"],
556
+ )
557
+ ]
558
+ ],
559
+ BUILTIN_IMAGE2MUSIC: [
560
+ [
561
+ Action(
562
+ tool_name="image_captioning",
563
+ inputs={"image": "image"},
564
+ outputs=["<TOOL-GENERATED>-text-0"],
565
+ ),
566
+ Action(
567
+ tool_name="text_to_music",
568
+ inputs={"text": "<TOOL-GENERATED>-text-0"},
569
+ outputs=["<GENERATED>-0"],
570
+ ),
571
+ ]
572
+ ],
573
+ BUILTIN_VIDEO2MUSIC: [
574
+ [
575
+ Action(
576
+ tool_name="video_captioning",
577
+ inputs={"video": "video"},
578
+ outputs=["<TOOL-GENERATED>-text-0"],
579
+ ),
580
+ Action(
581
+ tool_name="text_to_music",
582
+ inputs={"text": "<TOOL-GENERATED>-text-0"},
583
+ outputs=["<GENERATED>-0"],
584
+ ),
585
+ ],
586
+ [
587
+ Action(
588
+ tool_name="dub_video",
589
+ inputs={
590
+ "video": "video",
591
+ "audio": "<GENERATED>-0",
592
+ },
593
+ outputs=["<GENERATED>-1"],
594
+ ),
595
+ ],
596
+ ],
597
+ BUILTIN_SEG_BY_POINTS: [
598
+ [
599
+ Action(
600
+ tool_name="image_segmentation_by_points",
601
+ inputs={"image": "image", "prompt_points": "prompt_points"},
602
+ outputs=["<GENERATED>-0"],
603
+ )
604
+ ]
605
+ ],
606
+ # BUILTIN_SEG_BY_MASK: [
607
+ # [
608
+ # Action(
609
+ # tool_name='image_segmentation_by_mask',
610
+ # inputs={'image': 'image', 'prompt_mask': 'prompt_mask'},
611
+ # outputs=['<GENERATED>-0'],
612
+ # )
613
+ # ]
614
+ # ],
615
+ }
616
+
617
+
618
+ def load_builtin_plans(path):
619
+ import json
620
+
621
+ plans = json.load(open(path, "r"))
622
+ processed_plan = {}
623
+ for query, actions in plans.items():
624
+ actions2 = []
625
+ for ac in actions[0]:
626
+ actions2.append(
627
+ Action(
628
+ tool_name=ac["tool_name"],
629
+ inputs=ac["inputs"],
630
+ outputs=ac["outputs"],
631
+ ),
632
+ )
633
+ processed_plan[query] = [actions2]
634
+ return processed_plan
cllm/agents/builtin/tools.py ADDED
@@ -0,0 +1,1512 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+
4
+ sys.path.append(os.getcwd())
5
+ from cllm.services.image_editing.api import (
6
+ inpainting_ldm,
7
+ inpainting_ldm_general,
8
+ partial_image_editing,
9
+ instruct_pix2pix,
10
+ image_cropping,
11
+ image_matting,
12
+ draw_bbox_on_image,
13
+ )
14
+ from cllm.services.image_generation.api import (
15
+ text2image,
16
+ image2image,
17
+ cannytext2image,
18
+ linetext2image,
19
+ hedtext2image,
20
+ scribbletext2image,
21
+ posetext2image,
22
+ segtext2image,
23
+ depthtext2image,
24
+ normaltext2image,
25
+ )
26
+
27
+ from cllm.services.image_processing.api import (
28
+ image2canny,
29
+ image2line,
30
+ image2hed,
31
+ image2scribble,
32
+ image2pose,
33
+ image2depth,
34
+ image2normal,
35
+ )
36
+ from cllm.services.image_perception.api import (
37
+ object_detection,
38
+ image_classification,
39
+ ocr,
40
+ segment_objects,
41
+ visual_grounding,
42
+ image_captioning,
43
+ segment_all,
44
+ seg_by_mask,
45
+ seg_by_points,
46
+ )
47
+ from cllm.services.video.api import (
48
+ video_classification,
49
+ video_captioning,
50
+ image_audio_to_video,
51
+ video_to_webpage,
52
+ dub_video,
53
+ image_to_video,
54
+ text_to_video,
55
+ )
56
+ from cllm.services.audio.api import (
57
+ text_to_music,
58
+ text_to_speech,
59
+ audio_classification,
60
+ )
61
+
62
+ # from cllm.services.sam.api import (
63
+ # segment_by_mask,
64
+ # segment_by_points,
65
+ # set_image,
66
+ # segment_all,
67
+ # )
68
+ from cllm.services.general.api import (
69
+ select,
70
+ count,
71
+ remote_logging,
72
+ )
73
+ from cllm.services.nlp.api import (
74
+ text_to_text_generation,
75
+ title_generation,
76
+ text_to_tags,
77
+ question_answering_with_context,
78
+ openai_chat_model,
79
+ summarization,
80
+ extract_location,
81
+ sentiment_analysis,
82
+ get_weather,
83
+ summarize_weather_condition,
84
+ get_time,
85
+ )
86
+ from cllm.services.vqa.api import image_qa
87
+ from cllm.agents.base import Tool, DataType
88
+
89
+
90
+ QUESTION_ANSWERING_TOOLS = [
91
+ Tool(
92
+ name="image_question_answering",
93
+ description="answers a question about an image",
94
+ domain=Tool.Domain.VISUAL_QUESTION_ANSWERING,
95
+ args=[
96
+ Tool.Argument(
97
+ name="image",
98
+ type=DataType.IMAGE,
99
+ description="the image containing the information",
100
+ ),
101
+ Tool.Argument(
102
+ name="text",
103
+ type=DataType.TEXT,
104
+ description="the question about the image",
105
+ ),
106
+ ],
107
+ returns=[
108
+ Tool.Argument(
109
+ name="response",
110
+ type=DataType.TEXT,
111
+ description="output response",
112
+ )
113
+ ],
114
+ model=image_qa,
115
+ ),
116
+ Tool(
117
+ name="get_weather",
118
+ description="Query the weather conditions by given location. For example: what is the weather in Beijing? how cold is in New York? etc.",
119
+ domain=Tool.Domain.QUESTION_ANSWERING,
120
+ args=[
121
+ Tool.Argument(
122
+ name="location",
123
+ type=DataType.LOCATION,
124
+ description="the location where the weather is to be queried",
125
+ ),
126
+ ],
127
+ returns=[
128
+ Tool.Argument(
129
+ name="result",
130
+ # type=DataType.WEATHER,
131
+ type=DataType.WEATHER,
132
+ description="weather information",
133
+ )
134
+ ],
135
+ model=get_weather,
136
+ ),
137
+ Tool(
138
+ name="get_time",
139
+ description="get current date",
140
+ domain=Tool.Domain.QUESTION_ANSWERING,
141
+ args=[
142
+ # Tool.Argument(
143
+ # name="location",
144
+ # type=DataType.LOCATION,
145
+ # description="location where the time is to be queried",
146
+ # ),
147
+ Tool.Argument(
148
+ name="text",
149
+ type=DataType.TEXT,
150
+ description="input text",
151
+ ),
152
+ ],
153
+ returns=[
154
+ Tool.Argument(
155
+ name="response",
156
+ type=DataType.TIME,
157
+ description="output time",
158
+ )
159
+ ],
160
+ model=get_time,
161
+ ),
162
+ # Tool(
163
+ # name="calculator",
164
+ # description="It can solve mathematics problems and support various mathematical expressions: from basic arithmetic to more complex expressions.",
165
+ # domain=Tool.Domain.QUESTION_ANSWERING,
166
+ # args=[
167
+ # Tool.Argument(
168
+ # name="text",
169
+ # type=DataType.TEXT,
170
+ # description="input instructions",
171
+ # ),
172
+ # ],
173
+ # returns=[
174
+ # Tool.Argument(
175
+ # name="result",
176
+ # type=DataType.TEXT,
177
+ # description="result about weather",
178
+ # )
179
+ # ],
180
+ # model=None,
181
+ # ),
182
+ ]
183
+
184
+ IMAGE_CAPTIONING_TOOLS = [
185
+ Tool(
186
+ name="image_captioning",
187
+ description='Generate a caption or description for the image. It can generate a detailed description that can be used for image perception and image generation. For example: a) you can use this tool when you want to know what is it in the image"; and b) when you want to generate a new image similar or resemble to input.png, you can use `image_captioning` to obtain the description about image input.png.',
188
+ domain=Tool.Domain.IMAGE_PERCEPTION,
189
+ args=[
190
+ Tool.Argument(
191
+ name="image",
192
+ type=DataType.IMAGE,
193
+ description="the image to be captioned",
194
+ ),
195
+ ],
196
+ returns=[
197
+ Tool.Argument(
198
+ name="text",
199
+ type=DataType.TEXT,
200
+ description="the description for the input image",
201
+ )
202
+ ],
203
+ model=image_captioning,
204
+ ),
205
+ ]
206
+
207
+ IMAGE_EDITING_TOOLS = [
208
+ Tool(
209
+ name="partial_image_editing",
210
+ description="Given the mask denoting the region to edit, Edit the given image at local region. Useful when you want to replace an object via a mask image. "
211
+ "like: replace the masked object with a dog. ",
212
+ domain=Tool.Domain.IMAGE_EDITING,
213
+ args=[
214
+ Tool.Argument(
215
+ name="image",
216
+ type=DataType.IMAGE,
217
+ description="the image to be edited",
218
+ ),
219
+ Tool.Argument(
220
+ name="mask",
221
+ type=DataType.MASK,
222
+ description="the mask image representing the editing position",
223
+ ),
224
+ Tool.Argument(
225
+ name="prompt",
226
+ type=DataType.TEXT,
227
+ description="the prompt specified the edition",
228
+ ),
229
+ ],
230
+ returns=[
231
+ Tool.Argument(
232
+ name="image",
233
+ type=DataType.IMAGE,
234
+ description="the edited image",
235
+ )
236
+ ],
237
+ model=partial_image_editing,
238
+ ),
239
+ Tool(
240
+ name="text_image_editing",
241
+ description="Edit the given image based on the text prompt.",
242
+ domain=Tool.Domain.IMAGE_EDITING,
243
+ args=[
244
+ Tool.Argument(
245
+ name="image",
246
+ type=DataType.IMAGE,
247
+ description="the image to be edited",
248
+ ),
249
+ Tool.Argument(
250
+ name="text",
251
+ type=DataType.TEXT,
252
+ description="the prompt specified the edition",
253
+ ),
254
+ ],
255
+ returns=[
256
+ Tool.Argument(
257
+ name="image",
258
+ type=DataType.IMAGE,
259
+ description="the edited image",
260
+ )
261
+ ],
262
+ model=instruct_pix2pix,
263
+ ),
264
+ Tool(
265
+ name="image_inpainting",
266
+ description="inpaint the region of the image based on the given mask. For example: remove the dog in the image, erase the spoon in given image, etc.",
267
+ domain=Tool.Domain.IMAGE_EDITING,
268
+ usages=["remove some objects"],
269
+ args=[
270
+ Tool.Argument(
271
+ name="image",
272
+ type=DataType.IMAGE,
273
+ description="the image to be inpainted",
274
+ ),
275
+ Tool.Argument(
276
+ name="mask",
277
+ type=DataType.MASK,
278
+ description="the segmentation mask for the inpainting region",
279
+ ),
280
+ ],
281
+ returns=[
282
+ Tool.Argument(
283
+ name="image",
284
+ type=DataType.IMAGE,
285
+ description="the processed image",
286
+ )
287
+ ],
288
+ model=inpainting_ldm_general,
289
+ ),
290
+ Tool(
291
+ name="highlight_object_on_image",
292
+ description="This tool is usually used after `object_detection` `visual_grounding` and `select_bbox`. Useful when you want to: 1) highlight the region of interest on the image; 2) know where the object is. For example: highlight the elephant from image, locate the dog in the image, find the spoon in given image, detect if the object is present in the image, etc.",
293
+ domain=Tool.Domain.IMAGE_EDITING,
294
+ usages=["highlight the region of interest on the image"],
295
+ args=[
296
+ Tool.Argument(
297
+ name="image",
298
+ type=DataType.IMAGE,
299
+ description="the image to be processed",
300
+ ),
301
+ Tool.Argument(
302
+ name="bbox",
303
+ type=DataType.BBOX,
304
+ description="the bounding boxes that need to be drawn on the image",
305
+ ),
306
+ ],
307
+ returns=[
308
+ Tool.Argument(
309
+ name="result",
310
+ type=DataType.IMAGE,
311
+ description="the new image on which the tool highlight the the region of interest by bounding boxes",
312
+ )
313
+ ],
314
+ model=draw_bbox_on_image,
315
+ ),
316
+ Tool(
317
+ name="image_cropping",
318
+ description="Crop the image based on the given bounding box. Useful when you want to crop the dog in the image, crop the spoon in given image, etc.",
319
+ domain=Tool.Domain.IMAGE_EDITING,
320
+ args=[
321
+ Tool.Argument(
322
+ name="image",
323
+ type=DataType.IMAGE,
324
+ description="the image to be processed",
325
+ ),
326
+ Tool.Argument(
327
+ name="object",
328
+ type=DataType.BBOX,
329
+ description="the detected object",
330
+ ),
331
+ ],
332
+ returns=[
333
+ Tool.Argument(
334
+ name="image",
335
+ type=DataType.IMAGE,
336
+ description="the cropped image",
337
+ )
338
+ ],
339
+ model=image_cropping,
340
+ ),
341
+ # Tool(
342
+ # name="mask_image",
343
+ # description="Mask the background from the image based on the given mask. For example: mask anything except the dog in the image, extract the spoon from given image without any inpainting, etc.",
344
+ # domain=Tool.Domain.IMAGE_EDITING,
345
+ # args=[
346
+ # Tool.Argument(
347
+ # name="image",
348
+ # type=DataType.IMAGE,
349
+ # description="the image to be processed",
350
+ # ),
351
+ # Tool.Argument(
352
+ # name="mask",
353
+ # type=DataType.MASK,
354
+ # description="the mask of the matted region",
355
+ # ),
356
+ # ],
357
+ # returns=[
358
+ # Tool.Argument(
359
+ # name="image",
360
+ # type=DataType.IMAGE,
361
+ # description="the matted image",
362
+ # )
363
+ # ],
364
+ # model=image_matting,
365
+ # ),
366
+ ]
367
+
368
+ IMAGE_GENERATION_TOOLS = [
369
+ Tool(
370
+ name="text_to_image",
371
+ description="generate an image based on the given description.",
372
+ domain=Tool.Domain.IMAGE_GENERATION,
373
+ args=[
374
+ Tool.Argument(
375
+ name="text",
376
+ type=DataType.TEXT,
377
+ description="the text describing the image",
378
+ ),
379
+ ],
380
+ returns=[
381
+ Tool.Argument(
382
+ name="image",
383
+ type=DataType.IMAGE,
384
+ description="the generated image",
385
+ )
386
+ ],
387
+ model=text2image,
388
+ ),
389
+ Tool(
390
+ name="image_to_image",
391
+ description="generate an new image based on the given image.",
392
+ domain=Tool.Domain.IMAGE_GENERATION,
393
+ args=[
394
+ Tool.Argument(
395
+ name="image",
396
+ type=DataType.IMAGE,
397
+ description="the given image",
398
+ ),
399
+ ],
400
+ returns=[
401
+ Tool.Argument(
402
+ name="image",
403
+ type=DataType.IMAGE,
404
+ description="the generated image",
405
+ )
406
+ ],
407
+ model=image2image,
408
+ ),
409
+ Tool(
410
+ name="line_text_to_image",
411
+ description="generate an image based on the given description and line map.",
412
+ domain=Tool.Domain.IMAGE_GENERATION,
413
+ args=[
414
+ Tool.Argument(
415
+ name="text",
416
+ type=DataType.TEXT,
417
+ description="the text describing the image",
418
+ ),
419
+ Tool.Argument(
420
+ name="line",
421
+ type=DataType.LINE,
422
+ description="the line map outlining the line of the image",
423
+ ),
424
+ ],
425
+ returns=[
426
+ Tool.Argument(
427
+ name="image",
428
+ type=DataType.IMAGE,
429
+ description="the generated image",
430
+ )
431
+ ],
432
+ model=linetext2image,
433
+ ),
434
+ Tool(
435
+ name="hed_text_to_image",
436
+ description="generate an image based on the given description and HED map (holistically-nested edge detection).",
437
+ domain=Tool.Domain.IMAGE_GENERATION,
438
+ args=[
439
+ Tool.Argument(
440
+ name="text",
441
+ type=DataType.TEXT,
442
+ description="the text describing the image",
443
+ ),
444
+ Tool.Argument(
445
+ name="hed",
446
+ type=DataType.HED,
447
+ description="the HED map outlining the edge of the image",
448
+ ),
449
+ ],
450
+ returns=[
451
+ Tool.Argument(
452
+ name="image",
453
+ type=DataType.IMAGE,
454
+ description="the generated image",
455
+ )
456
+ ],
457
+ model=hedtext2image,
458
+ ),
459
+ Tool(
460
+ name="scribble_text_to_image",
461
+ description="generate an image based on the given description and the scribble.",
462
+ domain=Tool.Domain.IMAGE_GENERATION,
463
+ args=[
464
+ Tool.Argument(
465
+ name="text",
466
+ type=DataType.TEXT,
467
+ description="the text describing the image",
468
+ ),
469
+ Tool.Argument(
470
+ name="scribble",
471
+ type=DataType.SCRIBBLE,
472
+ description="the scribble outlining the image",
473
+ ),
474
+ ],
475
+ returns=[
476
+ Tool.Argument(
477
+ name="image",
478
+ type=DataType.IMAGE,
479
+ description="the generated image",
480
+ )
481
+ ],
482
+ model=scribbletext2image,
483
+ ),
484
+ Tool(
485
+ name="pose_text_to_image",
486
+ description="generate an image based on the given description and the pose.",
487
+ domain=Tool.Domain.IMAGE_GENERATION,
488
+ args=[
489
+ Tool.Argument(
490
+ name="text",
491
+ type=DataType.TEXT,
492
+ description="the text describing the image",
493
+ ),
494
+ Tool.Argument(
495
+ name="pose",
496
+ type=DataType.POSE,
497
+ description="the pose of the human in the image",
498
+ ),
499
+ ],
500
+ returns=[
501
+ Tool.Argument(
502
+ name="image",
503
+ type=DataType.IMAGE,
504
+ description="the generated image",
505
+ )
506
+ ],
507
+ model=posetext2image,
508
+ ),
509
+ Tool(
510
+ name="segmentation_text_to_image",
511
+ description="generate an image based on the given description and segmentation mask.",
512
+ domain=Tool.Domain.IMAGE_GENERATION,
513
+ args=[
514
+ Tool.Argument(
515
+ name="text",
516
+ type=DataType.TEXT,
517
+ description="the text describing the image",
518
+ ),
519
+ Tool.Argument(
520
+ name="segmentation",
521
+ type=DataType.SEGMENTATION,
522
+ description="the segmentation mask describing the structure of the image",
523
+ ),
524
+ ],
525
+ returns=[
526
+ Tool.Argument(
527
+ name="image",
528
+ type=DataType.IMAGE,
529
+ description="the generated image",
530
+ )
531
+ ],
532
+ model=segtext2image,
533
+ ),
534
+ Tool(
535
+ name="edge_text_to_image",
536
+ description="generate an image based on the given description and edge map.",
537
+ domain=Tool.Domain.IMAGE_GENERATION,
538
+ args=[
539
+ Tool.Argument(
540
+ name="text",
541
+ type=DataType.TEXT,
542
+ description="the text describing the image",
543
+ ),
544
+ Tool.Argument(
545
+ name="edge",
546
+ type=DataType.EDGE,
547
+ description="the edge map describing the structure of the image",
548
+ ),
549
+ ],
550
+ returns=[
551
+ Tool.Argument(
552
+ name="image",
553
+ type=DataType.IMAGE,
554
+ description="the generated image",
555
+ )
556
+ ],
557
+ model=cannytext2image,
558
+ ),
559
+ Tool(
560
+ name="depth_text_to_image",
561
+ description="generate an image based on the given description and depth map.",
562
+ domain=Tool.Domain.IMAGE_GENERATION,
563
+ args=[
564
+ Tool.Argument(
565
+ name="text",
566
+ type=DataType.TEXT,
567
+ description="the text describing the image",
568
+ ),
569
+ Tool.Argument(
570
+ name="depth",
571
+ type=DataType.DEPTH,
572
+ description="the depth map describing the structure of the image",
573
+ ),
574
+ ],
575
+ returns=[
576
+ Tool.Argument(
577
+ name="image",
578
+ type=DataType.IMAGE,
579
+ description="the generated image",
580
+ )
581
+ ],
582
+ model=depthtext2image,
583
+ ),
584
+ Tool(
585
+ name="normal_text_to_image",
586
+ description="generate an image based on the given description and normal map.",
587
+ domain=Tool.Domain.IMAGE_GENERATION,
588
+ args=[
589
+ Tool.Argument(
590
+ name="text",
591
+ type=DataType.TEXT,
592
+ description="the text describing the image",
593
+ ),
594
+ Tool.Argument(
595
+ name="normal",
596
+ type=DataType.NORMAL,
597
+ description="the normal map describing the structure of the image",
598
+ ),
599
+ ],
600
+ returns=[
601
+ Tool.Argument(
602
+ name="image",
603
+ type=DataType.IMAGE,
604
+ description="the generated image",
605
+ )
606
+ ],
607
+ model=normaltext2image,
608
+ ),
609
+ ]
610
+
611
+ IMAGE_TRANSFORM_TOOLS = [
612
+ Tool(
613
+ name="image_to_edge",
614
+ description="get the edge map of the image.",
615
+ domain=Tool.Domain.IMAGE_PROCESSING,
616
+ args=[
617
+ Tool.Argument(
618
+ name="image",
619
+ type=DataType.IMAGE,
620
+ description="the image to be processed",
621
+ ),
622
+ ],
623
+ returns=[
624
+ Tool.Argument(
625
+ name="edge",
626
+ type=DataType.EDGE,
627
+ description="the edge map of the image",
628
+ )
629
+ ],
630
+ model=image2canny,
631
+ ),
632
+ Tool(
633
+ name="image_to_line",
634
+ description="get the line map of the image.",
635
+ domain=Tool.Domain.IMAGE_PROCESSING,
636
+ args=[
637
+ Tool.Argument(
638
+ name="image",
639
+ type=DataType.IMAGE,
640
+ description="the image to be processed",
641
+ ),
642
+ ],
643
+ returns=[
644
+ Tool.Argument(
645
+ name="line",
646
+ type=DataType.LINE,
647
+ description="the line map of the image",
648
+ )
649
+ ],
650
+ model=image2line,
651
+ ),
652
+ Tool(
653
+ name="image_to_hed",
654
+ description="get the HED map of the image.",
655
+ domain=Tool.Domain.IMAGE_PROCESSING,
656
+ args=[
657
+ Tool.Argument(
658
+ name="image",
659
+ type=DataType.IMAGE,
660
+ description="the image to be processed",
661
+ ),
662
+ ],
663
+ returns=[
664
+ Tool.Argument(
665
+ name="hed",
666
+ type=DataType.HED,
667
+ description="the hed map of the image",
668
+ )
669
+ ],
670
+ model=image2hed,
671
+ ),
672
+ Tool(
673
+ name="image_to_scribble",
674
+ description="get the scribble of the image.",
675
+ domain=Tool.Domain.IMAGE_PROCESSING,
676
+ args=[
677
+ Tool.Argument(
678
+ name="image",
679
+ type=DataType.IMAGE,
680
+ description="the image to be processed",
681
+ ),
682
+ ],
683
+ returns=[
684
+ Tool.Argument(
685
+ name="scribble",
686
+ type=DataType.SCRIBBLE,
687
+ description="the scribble of the image",
688
+ )
689
+ ],
690
+ model=image2scribble,
691
+ ),
692
+ Tool(
693
+ name="image_to_pose",
694
+ description="Get the pose of the image. It is usually used in image generation conditioned on pose map from input image.",
695
+ domain=Tool.Domain.IMAGE_PROCESSING,
696
+ args=[
697
+ Tool.Argument(
698
+ name="image",
699
+ type=DataType.IMAGE,
700
+ description="the image to be processed",
701
+ ),
702
+ ],
703
+ returns=[
704
+ Tool.Argument(
705
+ name="pose",
706
+ type=DataType.POSE,
707
+ description="the pose of the image",
708
+ )
709
+ ],
710
+ model=image2pose,
711
+ ),
712
+ Tool(
713
+ name="image_to_depth",
714
+ description="get the depth map of the image.",
715
+ domain=Tool.Domain.IMAGE_PROCESSING,
716
+ args=[
717
+ Tool.Argument(
718
+ name="image",
719
+ type=DataType.IMAGE,
720
+ description="the image to be processed",
721
+ ),
722
+ ],
723
+ returns=[
724
+ Tool.Argument(
725
+ name="depth",
726
+ type=DataType.DEPTH,
727
+ description="the depth map",
728
+ )
729
+ ],
730
+ model=image2depth,
731
+ ),
732
+ Tool(
733
+ name="image_to_normal",
734
+ description="get the normal map of the image.",
735
+ domain=Tool.Domain.IMAGE_PROCESSING,
736
+ args=[
737
+ Tool.Argument(
738
+ name="image",
739
+ type=DataType.IMAGE,
740
+ description="the image to be processed",
741
+ ),
742
+ ],
743
+ returns=[
744
+ Tool.Argument(
745
+ name="normal",
746
+ type=DataType.NORMAL,
747
+ description="the normal map",
748
+ )
749
+ ],
750
+ model=image2normal,
751
+ ),
752
+ ]
753
+
754
+ IMAGE_PERCEPTION_TOOLS = [
755
+ Tool(
756
+ name="object_detection",
757
+ description="detect all the objects in the image.",
758
+ domain=Tool.Domain.IMAGE_PERCEPTION,
759
+ args=[
760
+ Tool.Argument(
761
+ name="image",
762
+ type=DataType.IMAGE,
763
+ description="the image that contains the objects",
764
+ ),
765
+ ],
766
+ returns=[
767
+ Tool.Argument(
768
+ name="object",
769
+ type=DataType.BBOX,
770
+ description="the detected objects in json format. "
771
+ "example output: [\{'score': 0.9994931221008301, 'label': 'dog', 'box': \{'xmin': 466, 'ymin': 301, 'xmax': 1045, 'ymax': 583\}\}]",
772
+ )
773
+ ],
774
+ model=object_detection,
775
+ ),
776
+ Tool(
777
+ name="image_classification",
778
+ description="classify the objects in the image.",
779
+ domain=Tool.Domain.IMAGE_PERCEPTION,
780
+ usages=["ask about the class of the image"],
781
+ args=[
782
+ Tool.Argument(
783
+ name="image",
784
+ type=DataType.IMAGE,
785
+ description="the image that contains the objects",
786
+ ),
787
+ ],
788
+ returns=[
789
+ Tool.Argument(
790
+ name="category",
791
+ type=DataType.CATEGORY,
792
+ description="the categories in json format. "
793
+ "example output: [\{'score': 0.9, 'label': 'dog'\}]",
794
+ )
795
+ ],
796
+ model=image_classification,
797
+ ),
798
+ Tool(
799
+ name="video_classification",
800
+ description="Classify the video and detect the actions in the video.",
801
+ domain=Tool.Domain.VIDEO_PERCEPTION,
802
+ usages=["ask about the class of the video"],
803
+ args=[
804
+ Tool.Argument(
805
+ name="video",
806
+ type=DataType.VIDEO,
807
+ description="the given video",
808
+ ),
809
+ ],
810
+ returns=[
811
+ Tool.Argument(
812
+ name="category",
813
+ type=DataType.CATEGORY,
814
+ description="the categories in json format. "
815
+ "example output: [\{'score': 0.9, 'label': 'Playing basketball'\}]",
816
+ )
817
+ ],
818
+ model=video_classification,
819
+ ),
820
+ Tool(
821
+ name="image_instance_segmentation",
822
+ description="segment the common objects in the given image.",
823
+ domain=Tool.Domain.IMAGE_PERCEPTION,
824
+ args=[
825
+ Tool.Argument(
826
+ name="image",
827
+ type=DataType.IMAGE,
828
+ description="the image that need to be segmented",
829
+ ),
830
+ ],
831
+ returns=[
832
+ Tool.Argument(
833
+ name="mask", type=DataType.MASK, description="the output mask"
834
+ )
835
+ ],
836
+ model=segment_objects,
837
+ ),
838
+ Tool(
839
+ name="image_segmentation_by_mask",
840
+ description="segment the given image with the prompt mask.",
841
+ domain=Tool.Domain.IMAGE_PERCEPTION,
842
+ args=[
843
+ Tool.Argument(
844
+ name="image",
845
+ type=DataType.IMAGE,
846
+ description="the image that need to be segmented",
847
+ ),
848
+ Tool.Argument(
849
+ name="prompt_mask",
850
+ type=DataType.MASK,
851
+ description="the prompt mask that guides the segmentation",
852
+ ),
853
+ ],
854
+ returns=[
855
+ Tool.Argument(
856
+ name="mask", type=DataType.MASK, description="the output mask"
857
+ )
858
+ ],
859
+ model=seg_by_mask,
860
+ ),
861
+ Tool(
862
+ name="image_segmentation_by_points",
863
+ description="segment the given image with the prompt points.",
864
+ domain=Tool.Domain.IMAGE_PERCEPTION,
865
+ args=[
866
+ Tool.Argument(
867
+ name="image",
868
+ type=DataType.IMAGE,
869
+ description="the image that need to be segmented",
870
+ ),
871
+ Tool.Argument(
872
+ name="prompt_points",
873
+ type=DataType.POINT,
874
+ description="the prompt points that guides the segmentation",
875
+ ),
876
+ ],
877
+ returns=[
878
+ Tool.Argument(
879
+ name="mask", type=DataType.MASK, description="the output mask"
880
+ )
881
+ ],
882
+ model=seg_by_points,
883
+ ),
884
+ Tool(
885
+ name="segment_anything",
886
+ description="Segment the given image without other inputs. This tool return the segmentation map for input image. The segmentation can be used to generate a new image.",
887
+ domain=Tool.Domain.IMAGE_PERCEPTION,
888
+ args=[
889
+ Tool.Argument(
890
+ name="image",
891
+ type=DataType.IMAGE,
892
+ description="the image that need to be segmented",
893
+ ),
894
+ ],
895
+ returns=[
896
+ Tool.Argument(
897
+ name="segmentation",
898
+ type=DataType.SEGMENTATION,
899
+ description="the output segmentation",
900
+ )
901
+ ],
902
+ model=segment_all,
903
+ ),
904
+ Tool(
905
+ name="visual_grounding",
906
+ description="Visual Grounding (VG) aims to locate the most relevant object or region in an image, based on a natural language query. The query can be a phrase, a sentence or even a multi-round dialogue.",
907
+ domain=Tool.Domain.IMAGE_PERCEPTION,
908
+ args=[
909
+ Tool.Argument(
910
+ name="image",
911
+ type=DataType.IMAGE,
912
+ description="the image that need to be processed",
913
+ ),
914
+ Tool.Argument(
915
+ name="query",
916
+ type=DataType.TEXT,
917
+ description="a query that can be a phrase, a sentence",
918
+ ),
919
+ ],
920
+ returns=[
921
+ Tool.Argument(
922
+ name="bbox",
923
+ type=DataType.BBOX,
924
+ description="the detected bounding boxes for ",
925
+ )
926
+ ],
927
+ model=visual_grounding,
928
+ ),
929
+ Tool(
930
+ name="optical_character_recognition",
931
+ description="Optical Character Recognition (OCR) is the process that converts an image of text into a machine-readable text format.",
932
+ domain=Tool.Domain.IMAGE_PERCEPTION,
933
+ args=[
934
+ Tool.Argument(
935
+ name="image",
936
+ type=DataType.IMAGE,
937
+ description="the image that need to be processed",
938
+ )
939
+ ],
940
+ returns=[
941
+ Tool.Argument(
942
+ name="text",
943
+ type=DataType.TEXT,
944
+ description="the recognized text",
945
+ )
946
+ ],
947
+ model=ocr,
948
+ ),
949
+ ]
950
+
951
+ GENERAL_TOOLS = [
952
+ Tool(
953
+ name="select_category",
954
+ description="select the target classes in category list with the given condition.",
955
+ domain=Tool.Domain.GENERAL,
956
+ usages=["pick out the objects with the same type"],
957
+ args=[
958
+ Tool.Argument(
959
+ name="category_list",
960
+ type=DataType.CATEGORY,
961
+ description="the list to be processed",
962
+ ),
963
+ Tool.Argument(
964
+ name="condition",
965
+ type=DataType.TEXT,
966
+ description="the condition to select objects",
967
+ ),
968
+ ],
969
+ returns=[
970
+ Tool.Argument(
971
+ name="target_category_result",
972
+ type=DataType.CATEGORY,
973
+ description="the selected list",
974
+ )
975
+ ],
976
+ model=select,
977
+ ),
978
+ Tool(
979
+ name="select_bbox",
980
+ description="select the bounding boxes with the given condition.",
981
+ domain=Tool.Domain.GENERAL,
982
+ usages=["filter out the bounding boxes with the same type"],
983
+ args=[
984
+ Tool.Argument(
985
+ name="bbox_list",
986
+ type=DataType.BBOX,
987
+ description="the bounding box list to be processed",
988
+ ),
989
+ Tool.Argument(
990
+ name="condition",
991
+ type=DataType.TEXT,
992
+ description="the condition to select objects",
993
+ ),
994
+ ],
995
+ returns=[
996
+ Tool.Argument(
997
+ name="result",
998
+ type=DataType.BBOX,
999
+ description="the selected bbox list",
1000
+ )
1001
+ ],
1002
+ model=select,
1003
+ ),
1004
+ Tool(
1005
+ name="select_mask",
1006
+ description="select the masks with the given condition.",
1007
+ domain=Tool.Domain.GENERAL,
1008
+ args=[
1009
+ Tool.Argument(
1010
+ name="mask_list",
1011
+ type=DataType.MASK,
1012
+ description="the list to be processed",
1013
+ ),
1014
+ Tool.Argument(
1015
+ name="condition",
1016
+ type=DataType.TEXT,
1017
+ description="the condition to select objects",
1018
+ ),
1019
+ ],
1020
+ returns=[
1021
+ Tool.Argument(
1022
+ name="result",
1023
+ type=DataType.MASK,
1024
+ description="the selected mask list",
1025
+ )
1026
+ ],
1027
+ model=select,
1028
+ ),
1029
+ Tool(
1030
+ name="count_categories",
1031
+ description="count target categories in the given list.",
1032
+ domain=Tool.Domain.GENERAL,
1033
+ args=[
1034
+ Tool.Argument(
1035
+ name="category_list",
1036
+ type=DataType.CATEGORY,
1037
+ description="the list to be processed",
1038
+ ),
1039
+ ],
1040
+ returns=[
1041
+ Tool.Argument(
1042
+ name="length",
1043
+ type=DataType.TEXT,
1044
+ description="the length of the given list, return in the string format."
1045
+ "Example: The length of the given list is 10",
1046
+ )
1047
+ ],
1048
+ model=count,
1049
+ ),
1050
+ Tool(
1051
+ name="count_objects",
1052
+ description="count target objects in the given list. It is useful when you want to count the number of objects in the image",
1053
+ domain=Tool.Domain.GENERAL,
1054
+ args=[
1055
+ Tool.Argument(
1056
+ name="bbox_list",
1057
+ type=DataType.BBOX,
1058
+ description="the bounding box list to be counted",
1059
+ ),
1060
+ ],
1061
+ returns=[
1062
+ Tool.Argument(
1063
+ name="length",
1064
+ type=DataType.TEXT,
1065
+ description="the length of the given list, return in the string format."
1066
+ "Example: The length of the given list is 10",
1067
+ )
1068
+ ],
1069
+ model=count,
1070
+ ),
1071
+ Tool(
1072
+ name="count_masks",
1073
+ description="count target mask in the given list.",
1074
+ domain=Tool.Domain.GENERAL,
1075
+ args=[
1076
+ Tool.Argument(
1077
+ name="mask_list",
1078
+ type=DataType.MASK,
1079
+ description="the list to be processed",
1080
+ ),
1081
+ ],
1082
+ returns=[
1083
+ Tool.Argument(
1084
+ name="length",
1085
+ type=DataType.TEXT,
1086
+ description="the length of the given list, return in the string format."
1087
+ "Example: The length of the given list is 10",
1088
+ )
1089
+ ],
1090
+ model=count,
1091
+ ),
1092
+ ]
1093
+
1094
+ VIDEO_TOOLS = [
1095
+ # VIDEO
1096
+ Tool(
1097
+ name="video_captioning",
1098
+ description='Generate a caption or description for video. It can generate a detailed description that can be used for video perception and video generation. For example: a) you can use this tool when you want to know what happened in the video"; and b) when you want to generate tags for input video, you can use translate description obtained from `image_captioning` into tags.',
1099
+ domain=Tool.Domain.VIDEO_PERCEPTION,
1100
+ args=[
1101
+ Tool.Argument(
1102
+ name="video",
1103
+ type=DataType.VIDEO,
1104
+ description="the video to be captioned.",
1105
+ ),
1106
+ ],
1107
+ returns=[
1108
+ Tool.Argument(
1109
+ name="caption",
1110
+ type=DataType.TEXT,
1111
+ description="the caption or description of input video.",
1112
+ )
1113
+ ],
1114
+ model=video_captioning,
1115
+ ),
1116
+ Tool(
1117
+ name="image_audio_to_video",
1118
+ description="Generate a video with speech to introduce the image.",
1119
+ domain=Tool.Domain.VIDEO_GENERATION,
1120
+ args=[
1121
+ Tool.Argument(
1122
+ name="image",
1123
+ type=DataType.IMAGE,
1124
+ description="The input image to be introduced.",
1125
+ ),
1126
+ Tool.Argument(
1127
+ name="audio",
1128
+ type=DataType.AUDIO,
1129
+ description="The audio contained the speech of image description.",
1130
+ ),
1131
+ ],
1132
+ returns=[
1133
+ Tool.Argument(
1134
+ name="video",
1135
+ type=DataType.VIDEO,
1136
+ description="Generated video that can introduce the image with speech",
1137
+ )
1138
+ ],
1139
+ model=image_audio_to_video,
1140
+ ),
1141
+ Tool(
1142
+ name="image_to_video",
1143
+ description="Generate a video based on image.",
1144
+ domain=Tool.Domain.VIDEO_GENERATION,
1145
+ args=[
1146
+ Tool.Argument(
1147
+ name="image",
1148
+ type=DataType.IMAGE,
1149
+ description="The input image.",
1150
+ ),
1151
+ ],
1152
+ returns=[
1153
+ Tool.Argument(
1154
+ name="video",
1155
+ type=DataType.VIDEO,
1156
+ description="Generated video from the input image.",
1157
+ )
1158
+ ],
1159
+ model=image_to_video,
1160
+ ),
1161
+ Tool(
1162
+ name="video_to_webpage",
1163
+ description="Generate a web page to promote and introduce the video.",
1164
+ domain=Tool.Domain.VIDEO_PROCESSING,
1165
+ args=[
1166
+ Tool.Argument(
1167
+ name="video",
1168
+ type=DataType.VIDEO,
1169
+ description="The input image to be introduced.",
1170
+ ),
1171
+ Tool.Argument(
1172
+ name="title",
1173
+ type=DataType.TITLE,
1174
+ description="The title of video.",
1175
+ ),
1176
+ Tool.Argument(
1177
+ name="tags",
1178
+ type=DataType.TAGS,
1179
+ description="The tags of video.",
1180
+ ),
1181
+ Tool.Argument(
1182
+ name="description",
1183
+ type=DataType.TEXT,
1184
+ description="The description of video.",
1185
+ ),
1186
+ ],
1187
+ returns=[
1188
+ Tool.Argument(
1189
+ name="html_code",
1190
+ type=DataType.HTML,
1191
+ description="Generated HTML webpage with code that can introduce the video with speech.",
1192
+ )
1193
+ ],
1194
+ model=video_to_webpage,
1195
+ ),
1196
+ Tool(
1197
+ name="dub_video",
1198
+ description="Dub the input video with given audio track.",
1199
+ domain=Tool.Domain.VIDEO_EDITING,
1200
+ args=[
1201
+ Tool.Argument(
1202
+ name="video",
1203
+ type=DataType.VIDEO,
1204
+ description="The input image to be introduced.",
1205
+ ),
1206
+ Tool.Argument(
1207
+ name="audio",
1208
+ type=DataType.AUDIO,
1209
+ description="The audio of video.",
1210
+ ),
1211
+ ],
1212
+ returns=[
1213
+ Tool.Argument(
1214
+ name="video",
1215
+ type=DataType.VIDEO,
1216
+ description="Output video with designated audio.",
1217
+ )
1218
+ ],
1219
+ model=dub_video,
1220
+ ),
1221
+ Tool(
1222
+ name="text_to_video",
1223
+ description="It takes as input a natural language description and produces a video matching that description",
1224
+ domain=Tool.Domain.VIDEO_GENERATION,
1225
+ args=[
1226
+ Tool.Argument(
1227
+ name="prompt",
1228
+ type=DataType.TEXT,
1229
+ description="the text describing the image",
1230
+ )
1231
+ ],
1232
+ returns=[
1233
+ Tool.Argument(
1234
+ name="video",
1235
+ type=DataType.VIDEO,
1236
+ description="the generated video",
1237
+ )
1238
+ ],
1239
+ model=text_to_video,
1240
+ ),
1241
+ ]
1242
+
1243
+ AUDIO_TOOLS = [
1244
+ # AUDIO
1245
+ Tool(
1246
+ name="text_to_music",
1247
+ description="Generate music condioned on input text/prompt. For example, you can use this tool when you want to generate music for a poem, generate a piece of music from image.",
1248
+ domain=Tool.Domain.AUDIO_GENERATION,
1249
+ args=[
1250
+ Tool.Argument(
1251
+ name="text",
1252
+ type=DataType.TEXT,
1253
+ description="Input text for music generation.",
1254
+ ),
1255
+ ],
1256
+ returns=[
1257
+ Tool.Argument(
1258
+ name="music",
1259
+ type=DataType.AUDIO,
1260
+ description="Generated music conditioned on text.",
1261
+ )
1262
+ ],
1263
+ model=text_to_music,
1264
+ ),
1265
+ Tool(
1266
+ name="text_to_speech",
1267
+ description="Create natural-sounding speech from text, where the speech can be generated in multiple languages and for multiple speakers",
1268
+ domain=Tool.Domain.AUDIO_GENERATION,
1269
+ args=[
1270
+ Tool.Argument(
1271
+ name="text",
1272
+ type=DataType.TEXT,
1273
+ description="The input text that will be translated into speech.",
1274
+ ),
1275
+ ],
1276
+ returns=[
1277
+ Tool.Argument(
1278
+ name="speech",
1279
+ type=DataType.AUDIO,
1280
+ description="Generated speech or voice conditioned on text.",
1281
+ )
1282
+ ],
1283
+ model=text_to_speech,
1284
+ ),
1285
+ Tool(
1286
+ name="audio_classification",
1287
+ description="Audio classification is the task of assigning a label or class to a given audio. It can be used for recognizing which command a user is giving or the emotion of a statement, as well as identifying a speaker.",
1288
+ domain=Tool.Domain.AUDIO_PERCEPTION,
1289
+ args=[
1290
+ Tool.Argument(
1291
+ name="audio",
1292
+ type=DataType.AUDIO,
1293
+ description="The input audio that will be classified.",
1294
+ ),
1295
+ ],
1296
+ returns=[
1297
+ Tool.Argument(
1298
+ name="speech",
1299
+ type=DataType.CATEGORY,
1300
+ description="The recognized categories in json format.",
1301
+ )
1302
+ ],
1303
+ model=audio_classification,
1304
+ ),
1305
+ ]
1306
+
1307
+ NLP_TOOLS = [
1308
+ # Text
1309
+ Tool(
1310
+ name="text_to_text_generation",
1311
+ description="Text to text generation. It can be used for sentence acceptability judgment, Sentiment analysis, Paraphrasing/sentence similarity, Natural language inference, Sentence completion, Word sense disambiguation, Question answering.",
1312
+ domain=Tool.Domain.NATURAL_LANGUAGE_PROCESSING,
1313
+ args=[
1314
+ Tool.Argument(
1315
+ name="text",
1316
+ type=DataType.TEXT,
1317
+ description="The input text",
1318
+ ),
1319
+ ],
1320
+ returns=[
1321
+ Tool.Argument(
1322
+ name="answer",
1323
+ type=DataType.TEXT,
1324
+ description="Generated answer for given input.",
1325
+ )
1326
+ ],
1327
+ model=text_to_text_generation,
1328
+ ),
1329
+ Tool(
1330
+ name="title_generation",
1331
+ description="Generate a title for given text.",
1332
+ domain=Tool.Domain.NATURAL_LANGUAGE_PROCESSING,
1333
+ args=[
1334
+ Tool.Argument(
1335
+ name="text",
1336
+ type=DataType.TEXT,
1337
+ description="The input text",
1338
+ ),
1339
+ ],
1340
+ returns=[
1341
+ Tool.Argument(
1342
+ name="title",
1343
+ type=DataType.TITLE,
1344
+ description="Generated title based given sentences.",
1345
+ )
1346
+ ],
1347
+ model=title_generation,
1348
+ ),
1349
+ Tool(
1350
+ name="openai_chat_model",
1351
+ description="Answer the question by Large Language Model.",
1352
+ domain=Tool.Domain.QUESTION_ANSWERING,
1353
+ args=[
1354
+ Tool.Argument(
1355
+ name="input_msg",
1356
+ type=DataType.TEXT,
1357
+ description="The input text",
1358
+ )
1359
+ ],
1360
+ returns=[
1361
+ Tool.Argument(
1362
+ name="answer",
1363
+ type=DataType.TEXT,
1364
+ description="Generated answer based given text.",
1365
+ )
1366
+ ],
1367
+ model=openai_chat_model,
1368
+ ),
1369
+ Tool(
1370
+ name="summarization",
1371
+ description="Summarize sentences, long narratives, articles, papers, textbooks.",
1372
+ domain=Tool.Domain.NATURAL_LANGUAGE_PROCESSING,
1373
+ args=[
1374
+ Tool.Argument(
1375
+ name="text",
1376
+ type=DataType.TEXT,
1377
+ description="The input text to be Summarized.",
1378
+ ),
1379
+ ],
1380
+ returns=[
1381
+ Tool.Argument(
1382
+ name="summarized_text",
1383
+ type=DataType.TEXT,
1384
+ description="Summarized text.",
1385
+ )
1386
+ ],
1387
+ model=summarization,
1388
+ ),
1389
+ Tool(
1390
+ name="text_to_tags",
1391
+ description="Predict the tags of text, article and papers by using the their textual content as input",
1392
+ domain=Tool.Domain.NATURAL_LANGUAGE_PROCESSING,
1393
+ args=[
1394
+ Tool.Argument(
1395
+ name="text",
1396
+ type=DataType.TEXT,
1397
+ description="The input text to be Summarized.",
1398
+ ),
1399
+ ],
1400
+ returns=[
1401
+ Tool.Argument(
1402
+ name="tags",
1403
+ type=DataType.TAGS,
1404
+ description="The extracted tags from input text",
1405
+ )
1406
+ ],
1407
+ model=text_to_tags,
1408
+ ),
1409
+ Tool(
1410
+ name="named_entity_recognition",
1411
+ description="Named-entity recognition (NER) (also known as (named) entity identification, entity chunking, and entity extraction) is a subtask of information extraction that seeks to locate and classify named entities mentioned in unstructured text into pre-defined categories such as person names, organizations, locations, medical codes, time expressions, quantities, monetary values, percentages, etc.",
1412
+ domain=Tool.Domain.NATURAL_LANGUAGE_PROCESSING,
1413
+ args=[
1414
+ Tool.Argument(
1415
+ name="text",
1416
+ type=DataType.TEXT,
1417
+ description="The input text from which the named entities are extracted",
1418
+ ),
1419
+ ],
1420
+ returns=[
1421
+ Tool.Argument(
1422
+ name="tags",
1423
+ type=DataType.TAGS,
1424
+ description="The extracted entities",
1425
+ )
1426
+ ],
1427
+ model=None,
1428
+ ),
1429
+ Tool(
1430
+ name="sentiment_analysis",
1431
+ description="Sentiment analysis is the process of analyzing digital text to determine if the emotional tone of the message is positive, negative, or neutral.",
1432
+ domain=Tool.Domain.NATURAL_LANGUAGE_PROCESSING,
1433
+ args=[
1434
+ Tool.Argument(
1435
+ name="text",
1436
+ type=DataType.TEXT,
1437
+ description="The input text to be analyzed",
1438
+ ),
1439
+ ],
1440
+ returns=[
1441
+ Tool.Argument(
1442
+ name="text",
1443
+ type=DataType.TEXT,
1444
+ description="The sentiment of text",
1445
+ )
1446
+ ],
1447
+ model=sentiment_analysis,
1448
+ ),
1449
+ Tool(
1450
+ name="extract_location",
1451
+ description="Extracts the locale name from the text. For example, if the text is 'what is the weather in Beijing', the tool will return 'Beijing'. If the text is 'Samuel ppops in a happy plce called Berlin which happens to be Kazakhstan', the tool will return 'Berlin,Kazakhstan'.",
1452
+ domain=Tool.Domain.NATURAL_LANGUAGE_PROCESSING,
1453
+ args=[
1454
+ Tool.Argument(
1455
+ name="text",
1456
+ type=DataType.TEXT,
1457
+ description="The input text to be analyzed",
1458
+ ),
1459
+ ],
1460
+ returns=[
1461
+ Tool.Argument(
1462
+ name="location",
1463
+ type=DataType.LOCATION,
1464
+ description="The sentiment of text",
1465
+ )
1466
+ ],
1467
+ model=extract_location,
1468
+ ),
1469
+ Tool(
1470
+ name="summarize_weather_condition",
1471
+ description="Translate the json formatted weather information into the text that human can understand. For example, when you want to generate a new image based on weather information",
1472
+ domain=Tool.Domain.NATURAL_LANGUAGE_PROCESSING,
1473
+ args=[
1474
+ Tool.Argument(
1475
+ name="weather",
1476
+ type=DataType.WEATHER,
1477
+ description="weather condition",
1478
+ )
1479
+ ],
1480
+ returns=[
1481
+ Tool.Argument(
1482
+ name="weather_summary",
1483
+ type=DataType.TEXT,
1484
+ description="the weather summary",
1485
+ )
1486
+ ],
1487
+ model=summarize_weather_condition,
1488
+ ),
1489
+ ]
1490
+
1491
+ TOOLS = (
1492
+ QUESTION_ANSWERING_TOOLS
1493
+ + IMAGE_CAPTIONING_TOOLS
1494
+ + IMAGE_EDITING_TOOLS
1495
+ + IMAGE_GENERATION_TOOLS
1496
+ + IMAGE_TRANSFORM_TOOLS
1497
+ + IMAGE_PERCEPTION_TOOLS
1498
+ + GENERAL_TOOLS
1499
+ + VIDEO_TOOLS
1500
+ + AUDIO_TOOLS
1501
+ + NLP_TOOLS
1502
+ )
1503
+ TOOLS = {tool.name: tool for tool in TOOLS}
1504
+
1505
+ if __name__ == "__main__":
1506
+ tools = []
1507
+ for tool in TOOLS.values():
1508
+ tools.append(tool.dict())
1509
+ import json
1510
+
1511
+ with open("tools.json", "w") as f:
1512
+ json.dump(tools, f, indent=4)
cllm/agents/container.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ sys.path.append(os.getcwd())
5
+ import os.path as osp
6
+ from pathlib import Path
7
+ import json
8
+ from .base import DataType
9
+ from cllm.utils import get_real_path
10
+
11
+
12
+ # sys.path.insert(0, sys.path[0] + "/../")
13
+ FILE_EXT = {
14
+ "image": ["png", "jpeg", "jpg", "gif", "bmp", "tiff", "webp"],
15
+ "video": ["mp4", "mov", "avi", "mkv"],
16
+ "audio": ["wav", "mp3"],
17
+ }
18
+
19
+
20
+ class Container:
21
+ def __init__(self, name, rtype, value) -> None:
22
+ self.name = name
23
+ self.rtype = rtype
24
+ self.value = value
25
+
26
+ def to_chatbot(self):
27
+ pass
28
+
29
+ def __str__(self):
30
+ pass
31
+
32
+ def __repr__(self) -> str:
33
+ return str(self)
34
+
35
+
36
+ class File(Container):
37
+ def to_chatbot(self):
38
+ return str(self.value)
39
+
40
+ @property
41
+ def filename(self):
42
+ return os.path.basename(self.value)
43
+
44
+ def __str__(self):
45
+ return f"`{self.filename}`"
46
+
47
+
48
+ class HTML(File):
49
+ def to_chatbot(self):
50
+ return str(self.value)
51
+
52
+ def __str__(self):
53
+ return f"`{self.filename}`"
54
+
55
+
56
+ class Image(File):
57
+ def __str__(self):
58
+ return f"`{self.filename}`"
59
+
60
+
61
+ class Video(File):
62
+ def __str__(self):
63
+ return f"`{self.filename}`"
64
+
65
+
66
+ class Audio(File):
67
+ def __str__(self):
68
+ return f"`{self.filename}`"
69
+
70
+
71
+ class Text(Container):
72
+ def to_chatbot(self):
73
+ if isinstance(self.value, str):
74
+ return self.value
75
+ elif isinstance(self.value, (list, tuple, dict)):
76
+ return json.dumps(self.value, indent=2)
77
+ return self.value
78
+
79
+ def __str__(self):
80
+ if isinstance(self.value, (list, dict)):
81
+ return json.dumps(self.value)
82
+ elif isinstance(self.value, str):
83
+ return self.value
84
+ return str(self.value)
85
+
86
+
87
+ def auto_type(name, rtype, value):
88
+ if value is None:
89
+ return None
90
+ if "image" in str(rtype):
91
+ return Image(name, rtype, get_real_path(value))
92
+ if DataType.VIDEO == rtype:
93
+ return Video(name, rtype, get_real_path(value))
94
+ if DataType.AUDIO == rtype:
95
+ return Audio(name, rtype, get_real_path(value))
96
+ if DataType.HTML == rtype:
97
+ return HTML(name, rtype, get_real_path(value))
98
+ return Text(name, rtype, value)
cllm/agents/tog/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .planner import Planner
2
+ from .controller import Controller
cllm/agents/tog/compiler.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union
2
+ import ast
3
+ import sys
4
+ import os
5
+
6
+ sys.path.append(os.getcwd())
7
+ from .cllm.agents.base import Action
8
+
9
+
10
+ class Parser:
11
+ def parse(self, plan) -> List[Action]:
12
+ # ignore indent
13
+ input = "\n".join([line.strip() for line in plan.split("\n")])
14
+ actions = []
15
+ for stmt in ast.parse(input).body:
16
+ if isinstance(stmt, ast.Assign):
17
+ assign: ast.Assign = stmt
18
+ output: ast.Name = assign.targets[0]
19
+ func_call: ast.Call = assign.value
20
+ func_name: ast.Name = func_call.func
21
+ kwargs: List[ast.keyword] = func_call.keywords
22
+ args = {}
23
+ for kwarg in kwargs:
24
+ k = kwarg.arg
25
+ if isinstance(kwarg.value, ast.Name):
26
+ v = kwarg.value.id
27
+ else:
28
+ v = ast.literal_eval(kwarg.value)
29
+ args[k] = v
30
+ action = Action(
31
+ tool_name=func_name.id, outputs=[output.id], inputs=args
32
+ )
33
+ actions.append(action)
34
+ return actions
35
+
36
+
37
+ class Compiler:
38
+ def __init__(self):
39
+ self.parser = Parser()
40
+
41
+ def compile(self, plan: Union[str, List[Union[Action, str]]]) -> List[Action]:
42
+ """The input could be a plain string, a list of structured `Action`
43
+ or combination of structured `Action` or unstructured action string.
44
+ """
45
+ actions = self.parse(plan)
46
+ actions = self.correct(actions)
47
+ return actions
48
+
49
+ def parse(self, plan) -> List[Action]:
50
+ if isinstance(plan, str):
51
+ return self.parser.parse(plan)
52
+
53
+ actions = []
54
+ for action in plan:
55
+ if isinstance(action, str):
56
+ action = self.parser.parse(action)[0]
57
+ actions.append(action)
58
+
59
+ return actions
60
+
61
+ def correct(self, actions):
62
+ return actions
cllm/agents/tog/controller.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import traceback
2
+ import logging
3
+ from typing import Tuple, List
4
+ import copy
5
+ from pathlib import Path
6
+
7
+ import json
8
+ from collections import OrderedDict
9
+ import os
10
+ import sys
11
+
12
+ sys.path.append(os.getcwd())
13
+ from cllm.agents import container
14
+ from cllm.agents.builtin import BUILTIN_PLANS, load_builtin_plans
15
+ from cllm.agents.container import auto_type
16
+ from cllm.agents.base import DataType, NON_FILE_TYPES
17
+
18
+ from .interpretor import Interpretor
19
+ from .planner import Planner
20
+ from .responser import generate_response
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class Controller:
26
+ def __init__(self, stream=True, interpretor_kwargs={}):
27
+ self.stream = stream
28
+ self.planner = Planner(self.stream)
29
+ self.interpretor = Interpretor(**interpretor_kwargs)
30
+ self.SHORTCUT = "**Using builtin shortcut solution.**"
31
+ BUILTIN_PLANS.update(load_builtin_plans("builtin_plan.json"))
32
+ logger.info(BUILTIN_PLANS)
33
+
34
+ def plan(self, request: str, state: dict):
35
+ logger.info(request)
36
+
37
+ resource_memory = state.get("resources", {})
38
+ raw_solution = None
39
+ # shortcut for builtin plan
40
+ for trigger_prompt, _ in BUILTIN_PLANS.items():
41
+ if request == trigger_prompt:
42
+ return self.SHORTCUT
43
+
44
+ # dynamic execution
45
+ if raw_solution is None:
46
+ raw_solution = self.planner.plan(request, resource_memory)
47
+ return raw_solution
48
+
49
+ def parse_solution_from_stream(self, raw_solution):
50
+ return self.planner.parse(raw_solution)
51
+
52
+ def execute(self, raw_solution: str, state: dict):
53
+ resource_memory = state.get("resources")
54
+ request = state["request"]
55
+ solution = None
56
+ if raw_solution == self.SHORTCUT:
57
+ for trigger_prompt, builtin_plan in BUILTIN_PLANS.items():
58
+ if request == trigger_prompt:
59
+ solution = builtin_plan
60
+ solution = self._fill_args(solution, resource_memory)
61
+ else:
62
+ solution = self.planner.parse(raw_solution)
63
+
64
+ if not solution:
65
+ return None
66
+ try:
67
+ history_msgs = state.get("history_msgs")
68
+ return self.interpretor.interpret(solution, history_msgs)
69
+ except Exception as e:
70
+ traceback.print_exc()
71
+ return None
72
+
73
+ def reply(self, executed_plan: dict, outputs: list, state: dict):
74
+ error_response = [
75
+ auto_type(
76
+ "response",
77
+ DataType.TEXT,
78
+ "Sorry, I cannot understand your request due to an internal error.",
79
+ )
80
+ ]
81
+ state = copy.deepcopy(state)
82
+ if (
83
+ executed_plan is None
84
+ or len(executed_plan) == 0
85
+ or outputs is None
86
+ or len(outputs) == 0
87
+ ):
88
+ return error_response, state
89
+ resources = state.get("resources", OrderedDict())
90
+ for o in outputs:
91
+ if isinstance(o, container.File):
92
+ resources[str(o.filename)] = str(o.rtype)
93
+ state["resources"] = resources
94
+ response = generate_response(state["request"], executed_plan, outputs)
95
+ if len(response) == 0:
96
+ return error_response, state
97
+ logger.info(response)
98
+ return response, state
99
+
100
+ def run(self, task: str, state: dict) -> Tuple[List, str]:
101
+ try:
102
+ return self._run(task, state)
103
+ except:
104
+ traceback.print_exc()
105
+ logger.info(traceback.format_exc())
106
+ return [
107
+ auto_type(
108
+ "response",
109
+ DataType.TEXT,
110
+ "Sorry, I cannot understand your request due to an internal error.",
111
+ )
112
+ ], "Internal Error"
113
+
114
+ def _run(self, task: str, state: dict) -> Tuple[List, str]:
115
+ logger.info(task)
116
+ BUILTIN_PLANS.update(load_builtin_plans("builtin_plan.json"))
117
+ logger.info(BUILTIN_PLANS)
118
+ resource_memory = state.get("resources", OrderedDict())
119
+ history_msgs = state.get("history_msgs", [])
120
+ plan = None
121
+
122
+ # shortcut for builtin plan
123
+ for trigger_prompt, builtin_plan in BUILTIN_PLANS.items():
124
+ if task == trigger_prompt:
125
+ plan = builtin_plan
126
+ plan = self._fill_args(plan, resource_memory)
127
+
128
+ # dynamic executation
129
+ if plan is None:
130
+ plan = self.planner.planning(task, resource_memory)
131
+ logger.info(plan)
132
+
133
+ executed_plan, output_files = self.interpretor.interpret(
134
+ plan, resource_memory, history_msgs
135
+ )
136
+ logger.info(output_files)
137
+ for o in output_files:
138
+ if isinstance(o, container.File):
139
+ resource_memory[o.filename] = str(o.rtype)
140
+
141
+ outputs = generate_response(task, executed_plan, output_files)
142
+
143
+ logger.info(outputs)
144
+ return outputs, executed_plan
145
+
146
+ def _fill_args(self, plan, memory):
147
+ plan = copy.deepcopy(plan)
148
+ latest_resource = OrderedDict()
149
+ for key, val in memory.items():
150
+ latest_resource[val] = key
151
+
152
+ for actions in plan:
153
+ for action in actions:
154
+ for key, val in action.inputs.items():
155
+ if "<TOOL-GENERATED>" not in val:
156
+ action.inputs[key] = latest_resource.get(val, val)
157
+ return plan
cllm/agents/tog/interpretor.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from traceback import print_exc
3
+ from typing import List, Dict
4
+ import os.path as osp
5
+ import io
6
+ import copy
7
+ import re
8
+ import uuid
9
+ from matplotlib.pyplot import isinteractive
10
+
11
+ from numpy import isin
12
+ import sys
13
+ import os
14
+
15
+ sys.path.append(os.getcwd())
16
+ from cllm.agents.base import Action, DataType, Tool, NON_FILE_TYPES
17
+ from cllm.agents.builtin import TOOLS
18
+ from cllm.agents.container import auto_type
19
+ from cllm.utils import get_real_path, get_root_dir, transform_msgs
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ def code(source, type="py"):
25
+ return f"```{type}\n{source}\n```"
26
+
27
+
28
+ class Interpretor:
29
+ def __init__(self):
30
+ self.tools = TOOLS
31
+ self.non_file_types = NON_FILE_TYPES
32
+
33
+ def interpret(self, stages: List[List[Action]], history_msgs: List = []):
34
+ memory = {}
35
+ solution = copy.deepcopy(stages)
36
+ history_msgs = copy.deepcopy(history_msgs)
37
+ history_msgs = transform_msgs(history_msgs)
38
+ has_error = False
39
+ for actions in solution:
40
+ for action in actions:
41
+ tool = self.load_tool(name=action.tool_name)
42
+ tool_inputs = self.load_args(tool, action.inputs, memory)
43
+ tool_inputs["history_msgs"] = history_msgs
44
+ tool_inputs["root_dir"] = get_root_dir()
45
+ try:
46
+ tool_outputs = tool.model(**tool_inputs)
47
+ action.inputs = self._update_inputs(memory, action.inputs)
48
+ action.outputs, wrapped_outputs = self._update_output(
49
+ memory, action, tool_outputs, tool
50
+ )
51
+ logger.info(
52
+ "Call {}, args {}, return {}".format(
53
+ action.tool_name, action.inputs, action.outputs
54
+ )
55
+ )
56
+ executed_action = (
57
+ action.tool_name,
58
+ action.inputs,
59
+ action.outputs,
60
+ )
61
+ except FileNotFoundError as e:
62
+ print_exc()
63
+ tool_outputs = None
64
+ logger.error(f"Error when executing {action.tool_name}: {e}")
65
+ has_error = True
66
+ wrapped_outputs = []
67
+ executed_action = (
68
+ action.tool_name,
69
+ action.inputs,
70
+ f"FileNotFoundError: No such file or directory: {osp.basename(e.filename)}",
71
+ )
72
+ except Exception as e:
73
+ print_exc()
74
+ tool_outputs = None
75
+ has_error = True
76
+ logger.error(f"Error when executing {action.tool_name}: {e}")
77
+ wrapped_outputs = []
78
+ executed_action = (
79
+ action.tool_name,
80
+ action.inputs,
81
+ f"Internal error: {e}",
82
+ )
83
+ yield executed_action, solution, wrapped_outputs
84
+ if has_error:
85
+ return
86
+
87
+ def _update_output(self, memory, action, tool_outputs, tool):
88
+ outputs = []
89
+ wrapped_outputs = []
90
+ if action.outputs is not None:
91
+ if len(action.outputs) == 1:
92
+ tool_outputs = [tool_outputs]
93
+ for i, (arg_name, arg_value) in enumerate(
94
+ zip(action.outputs, tool_outputs)
95
+ ):
96
+ memory[arg_name] = arg_value
97
+ if arg_value is None:
98
+ outputs.append(arg_value)
99
+ wrapped_outputs.append(
100
+ auto_type(
101
+ arg_name,
102
+ DataType.TEXT,
103
+ None,
104
+ )
105
+ )
106
+ continue
107
+
108
+ if isinstance(arg_value, (dict, list)):
109
+ arg_value = self.pretty_floats(arg_value)
110
+
111
+ if tool.returns[i].type in self.non_file_types:
112
+ outputs.append(arg_value)
113
+ wrapped_outputs.append(
114
+ auto_type(
115
+ arg_name,
116
+ tool.returns[i].type,
117
+ arg_value,
118
+ )
119
+ )
120
+
121
+ continue
122
+
123
+ transformed_output = self.transform_output(
124
+ action.inputs,
125
+ tool.name,
126
+ tool.args,
127
+ arg_value,
128
+ tool.returns[i].type,
129
+ )
130
+
131
+ outputs.append(transformed_output)
132
+ memory[arg_name] = transformed_output
133
+ if not isinstance(transformed_output, list):
134
+ wrapped_outputs.append(
135
+ auto_type(
136
+ arg_name,
137
+ tool.returns[i].type,
138
+ transformed_output,
139
+ )
140
+ )
141
+ continue
142
+
143
+ for output in transformed_output:
144
+ if DataType.MASK == tool.returns[i].type:
145
+ output = output if isinstance(output, str) else output["mask"]
146
+ wrapped_outputs.append(
147
+ auto_type(
148
+ arg_name,
149
+ tool.returns[i].type,
150
+ output if isinstance(output, str) else output["mask"],
151
+ )
152
+ )
153
+ return outputs, wrapped_outputs
154
+
155
+ def pretty_floats(self, obj):
156
+ if isinstance(obj, float):
157
+ return round(obj, 4)
158
+ elif isinstance(obj, dict):
159
+ return dict((k, self.pretty_floats(v)) for k, v in obj.items())
160
+ elif isinstance(obj, (list, tuple)):
161
+ return list(map(self.pretty_floats, obj))
162
+ return obj
163
+
164
+ def _update_inputs(self, memory, action_inputs):
165
+ action_inputs = copy.deepcopy(action_inputs)
166
+ for key, value in action_inputs.items():
167
+ if "<TOOL-GENERATED>" in value:
168
+ action_inputs[key] = memory.get(value, value)
169
+ elif "<GENERATED>" in value:
170
+ action_inputs[key] = memory.get(value, value)
171
+
172
+ return action_inputs
173
+
174
+ def gen_filename(self, too_name, resource_type):
175
+ def to_camelcase(s):
176
+ res = re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), s)
177
+ res = res[0].upper() + res[1:]
178
+ return res
179
+
180
+ if resource_type == DataType.VIDEO:
181
+ ext = "mp4"
182
+ elif resource_type == DataType.AUDIO:
183
+ ext = "wav"
184
+ elif resource_type == DataType.HTML:
185
+ ext = "html"
186
+ else:
187
+ ext = "png"
188
+ too_name = too_name.replace("_to_", "2_")
189
+ too_name = to_camelcase(too_name)
190
+ this_file_id = str(uuid.uuid4())[:6]
191
+ type_str = str(resource_type).split(".")[-1]
192
+ return f"{this_file_id}_{type_str}.{ext}"
193
+
194
+ def _save_resource(self, file_name, resource, resource_type):
195
+ if isinstance(resource, dict):
196
+ if "mask" in resource:
197
+ resource = resource["mask"]
198
+ if resource_type == DataType.HTML:
199
+ with open(get_real_path(file_name), "w") as fout:
200
+ fout.write(resource)
201
+ elif resource is not None:
202
+ if isinstance(resource, io.BufferedReader):
203
+ resource = resource.read()
204
+ with open(get_real_path(file_name), "wb") as fout:
205
+ fout.write(resource)
206
+ else:
207
+ return None
208
+
209
+ def transform_output(
210
+ self, action_inputs, tool_name, tool_args, tool_output, output_type
211
+ ):
212
+ if output_type != DataType.MASK:
213
+ if isinstance(tool_output, list):
214
+ results = []
215
+ for output in tool_output:
216
+ file_name = self.gen_filename(tool_name, output_type)
217
+ self._save_resource(file_name, output, output_type)
218
+ results.append(file_name)
219
+ return results
220
+ else:
221
+ file_name = self.gen_filename(tool_name, output_type)
222
+ self._save_resource(file_name, tool_output, output_type)
223
+ return file_name
224
+
225
+ tool_output = copy.deepcopy(tool_output)
226
+ if isinstance(tool_output, list):
227
+ for output in tool_output:
228
+ if isinstance(output["mask"], str):
229
+ continue
230
+
231
+ file_name = self.gen_filename(tool_name, output_type)
232
+ self._save_resource(file_name, output, output_type)
233
+ output["mask"] = file_name
234
+ elif isinstance(tool_output, bytes):
235
+ file_name = self.gen_filename(tool_name, output_type)
236
+ self._save_resource(file_name, tool_output, output_type)
237
+ tool_output = file_name
238
+ elif tool_output is None:
239
+ pass
240
+ else:
241
+ raise RuntimeError("Wrong type.")
242
+
243
+ return tool_output
244
+
245
+ def load_tool(self, name):
246
+ return self.tools[name]
247
+
248
+ def load_args(self, tool: Tool, action_inputs, memory):
249
+ real_args = {}
250
+ for item in tool.args:
251
+ arg_name = item.name
252
+ arg_value = action_inputs[arg_name]
253
+ if "<GENERATED>" in arg_value or "<TOOL-GENERATED>" in arg_value:
254
+ assert arg_value in memory, print(f"Unknown {arg_name}: {arg_value}")
255
+ real_args[arg_name] = memory[arg_value]
256
+ else:
257
+ real_args[arg_name] = arg_value
258
+ return real_args
259
+
260
+ @property
261
+ def variables(self):
262
+ return {k: v for k, v in self.memory.items() if k not in TOOLS and k != "print"}
cllm/agents/tog/planner.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import List
3
+ import logging
4
+ import os
5
+ import sys
6
+
7
+ sys.path.append(os.getcwd())
8
+ from ..base import Action, NON_FILE_TYPES
9
+
10
+ # from cllm.services.tog import TaskSolver, TaskDecomposer, config
11
+ # from cllm.services.nlp.llms import ChatOpenAI, MessageMemory
12
+ from cllm.services.tog.api import tog, task_decomposer
13
+ from collections import OrderedDict
14
+ import copy
15
+
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class Planner:
21
+ def __init__(
22
+ self, streaming=False, backend="remote", device="cuda:0", **llm_kwargs
23
+ ):
24
+ self.streaming = streaming
25
+ if backend == "local":
26
+ pass
27
+ # self.cfg = config
28
+ # self.device = device
29
+ # self.mem = MessageMemory(**self.cfg.memory)
30
+ # self.llm = ChatOpenAI(temperature=0.2, **llm_kwargs)
31
+ # self.tog = TaskSolver(self.llm, self.cfg.task_solver_config, device).solve
32
+ # self.decomposer = TaskDecomposer(device, self.cfg.task_decomposer_cfg).solve
33
+ elif backend == "remote":
34
+ self.decomposer = task_decomposer
35
+ self.tog = tog
36
+ else:
37
+ raise ValueError("Backend should be chosen from [remote, local]")
38
+
39
+ def _find_latest_resource(self, resources, type):
40
+ for key, val in list(resources.items())[::-1]:
41
+ if val == type:
42
+ return key
43
+ return None
44
+
45
+ def _check_task_decomposition(
46
+ self, task_decomposition: str | list, available_resources: dict
47
+ ):
48
+ copy_task_decomposition = copy.deepcopy(task_decomposition)
49
+ available_resources = copy.deepcopy(available_resources)
50
+ if isinstance(copy_task_decomposition, str):
51
+ copy_task_decomposition = json.loads(copy_task_decomposition)
52
+
53
+ for subtask in copy_task_decomposition:
54
+ for arg in subtask["args"]:
55
+ if arg["type"] in NON_FILE_TYPES:
56
+ continue
57
+
58
+ r_type = available_resources.get(arg["value"], "None").split(".")[-1]
59
+ if arg["value"] not in available_resources or arg["type"] != r_type:
60
+ new_value = self._find_latest_resource(
61
+ available_resources, arg["type"]
62
+ )
63
+ if new_value is None:
64
+ logger.error(
65
+ f"No available resource for {arg['value']} with type {arg['type']}"
66
+ )
67
+ return None
68
+
69
+ arg["value"] = new_value
70
+
71
+ available_resources[subtask["returns"][0]["value"]] = subtask["returns"][0][
72
+ "type"
73
+ ]
74
+ return json.dumps(copy_task_decomposition, indent=2, ensure_ascii=False)
75
+
76
+ def wrap_request(self, request, memory):
77
+ logger.info(memory)
78
+ resource_list = {k: v.split(".")[-1] for k, v in memory.items()}
79
+ request = f"Resource list: {resource_list}\n{request}"
80
+ logger.info(f"Input: {request}")
81
+ return request
82
+
83
+ def solve_streaming(self, request: str, memory: dict = OrderedDict()):
84
+ request = self.wrap_request(request, memory)
85
+ sub_tasks = self.decomposer(request, streaming=self.streaming)
86
+ logger.info(f"Task decomposition: \n{sub_tasks}")
87
+ sub_tasks = self._check_task_decomposition(sub_tasks, memory)
88
+ yield sub_tasks
89
+ if sub_tasks in [None, "", []]:
90
+ yield None
91
+ else:
92
+ solutions = self.tog(request, sub_tasks, streaming=self.streaming)
93
+ yield solutions
94
+
95
+ def solve(self, request: str, memory: dict = OrderedDict()) -> List:
96
+ self.wrap_request(request, memory)
97
+ sub_tasks = self.decomposer(request)
98
+ solutions = self.tog(request, sub_tasks)
99
+ print(f"solutions: {solutions}")
100
+ return sub_tasks, solutions
101
+
102
+ def plan(self, task, memory: dict = OrderedDict()) -> List:
103
+ if self.streaming:
104
+ return self.solve_streaming(task, memory)
105
+ else:
106
+ return self.solve(task, memory)
107
+
108
+ def _check_solutions(self, solution: List | str) -> bool:
109
+ if isinstance(solution, str):
110
+ solution = json.loads(solution)
111
+ if len(solution) == 0:
112
+ return False
113
+
114
+ valid = True
115
+ for i, stage_candiate in enumerate(solution):
116
+ if len(stage_candiate) == 0:
117
+ logger.error(f"No solution is found in {i}-th subtask.")
118
+ valid = False
119
+ elif (
120
+ "solution" in stage_candiate[0]
121
+ and len(stage_candiate[0]["solution"]) == 0
122
+ ):
123
+ logger.error(f"No solution is found in {i+1}-th subtask.")
124
+ valid = False
125
+ else:
126
+ logger.info(f"Solutions for {i+1}-th subtask:\n{stage_candiate}")
127
+ return valid
128
+
129
+ def parse(self, solution: List | str) -> List[List[Action]]:
130
+ if isinstance(solution, str):
131
+ solution = json.loads(solution)
132
+
133
+ if not self._check_solutions(solution):
134
+ return None
135
+
136
+ if isinstance(solution[0], Action):
137
+ return solution
138
+
139
+ stages = []
140
+ for i, stage_candiate in enumerate(solution):
141
+ stage = stage_candiate[0]["solution"]
142
+ actions = []
143
+ for action in stage:
144
+ inputs = {arg["name"]: arg["value"] for arg in action["args"]}
145
+ outputs = [r["value"] for r in action["returns"]]
146
+ actions.append(
147
+ Action(action["tool_name"], inputs=inputs, outputs=outputs)
148
+ )
149
+ stages.append(actions)
150
+ return stages
151
+
152
+ def __call__(
153
+ self, request: str, memory: dict = OrderedDict()
154
+ ) -> List[List[Action]]:
155
+ solution = self.solve(request, memory)
156
+ return self.parse(solution)
cllm/agents/tog/responser.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import openai
2
+ import logging
3
+ import os
4
+ import sys
5
+
6
+ sys.path.append(os.getcwd())
7
+ from cllm.services.nlp.llms.chat_models import ChatOpenAI
8
+
9
+ # from cllm.services.nlp.llms.memory import MessageMemory
10
+ from langchain.schema import SystemMessage
11
+
12
+ from cllm.agents.base import DataType
13
+ from cllm.agents import container
14
+
15
+
16
+ RESPONSE_GENERATION_PROMPT = """Your name is ControlLLM, an AI-powered assistant developed by OpenGV-lab from Shanghai Artificial Intelligence Laboratory. For user's request, the system executes the solution and collects the results based on the following workflow. You need to respond to user requests based on the following information.
17
+ Here are the information for you reference.
18
+
19
+ ## User Request
20
+ {}
21
+
22
+ ## Workflow and Execution Results
23
+ {}
24
+
25
+ Now you should pay attention to Collected Results. You first must answer the user’s request in a straightforward manner. Then you need to summarize the workflow and intermediate results friendly. Some of the results may not be accurate and need you to use your judgement in making decisions. If the results contain file names, you have to output the file name directly. Only if there is nothing returned by tools, you should tell user you can not finish the task. Now, please friendly summarize the results and answer the question for the user requests `{}`.
26
+ """.strip()
27
+
28
+
29
+ SIMPLE_RESPONSE_GENERATION_PROMPT = """Your name is ControlLLM, an AI-powered assistant developed by OpenGVLab from Shanghai Artificial Intelligence Laboratory. You need to respond to user requests based on the following information.
30
+ Here are the information for you reference.
31
+
32
+ ## User Request
33
+ {}
34
+
35
+ ## Workflow and Execution Results
36
+ {}
37
+
38
+ Now, please friendly summarize the results and answer the question for the user requests `{}`.
39
+ """.strip()
40
+
41
+ logger = logging.getLogger(__name__)
42
+
43
+
44
+ def generate_response(user_input, solution, output_files):
45
+ if (
46
+ len(solution) <= 1
47
+ and len(solution[0]) <= 1
48
+ and solution[0][0].tool_name == "question_answering"
49
+ ):
50
+ content = SIMPLE_RESPONSE_GENERATION_PROMPT.format(
51
+ user_input, solution, user_input
52
+ )
53
+ else:
54
+ content = RESPONSE_GENERATION_PROMPT.format(user_input, solution, user_input)
55
+
56
+ logger.info("##### Response Generation #####")
57
+ logger.info(content)
58
+
59
+ chat = ChatOpenAI(model_name="gpt-3.5-turbo-1106")
60
+ messages = [SystemMessage(content=content)]
61
+ output = chat(messages)
62
+ logger.info(output)
63
+
64
+ # files = [output for output in output_files if isinstance(output, container.File)]
65
+ # return [container.Text('Response', DataType.TEXT, output)] + files
66
+ return [container.Text("Response", DataType.TEXT, output)]
cllm/services/audio/__init__.py ADDED
File without changes
cllm/services/audio/api.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import uuid
4
+ import requests
5
+
6
+ from cllm.services.nlp.api import openai_chat_model
7
+ from cllm.services.utils import get_bytes_value
8
+
9
+ __ALL__ = [
10
+ "audio_classification",
11
+ "automatic_speech_recognition",
12
+ "text_to_speech",
13
+ ]
14
+
15
+
16
+ HOST = os.environ.get("CLLM_SERVICES_HOST", "localhost")
17
+ PORT = os.environ.get("CLLM_SERVICES_PORT", 10056)
18
+
19
+
20
+ def setup(host="localhost", port=10057):
21
+ global HOST, PORT
22
+ HOST = host
23
+ PORT = port
24
+
25
+
26
+ def audio_classification(audio, **kwargs):
27
+ host = kwargs.get("host", HOST)
28
+ port = kwargs.get("port", PORT)
29
+ url = f"http://{host}:{port}/audio_classification"
30
+ if isinstance(audio, str):
31
+ audio = open(audio, "rb").read()
32
+ files = {"audio": (audio, get_bytes_value(audio))}
33
+ response = requests.post(url, files=files)
34
+ return response.json()
35
+
36
+
37
+ def automatic_speech_recognition(audio: str, **kwargs):
38
+ host = kwargs.get("host", HOST)
39
+ port = kwargs.get("port", PORT)
40
+ url = f"http://{host}:{port}/automatic_speech_recognition"
41
+ # audio_file = open(audio, "rb")
42
+ files = {"audio": (audio, get_bytes_value(audio))}
43
+ response = requests.post(url, files=files)
44
+ return response.json()
45
+
46
+
47
+ def text_to_speech(text: str, **kwargs):
48
+ host = kwargs.get("host", HOST)
49
+ port = kwargs.get("port", PORT)
50
+ human_msg = f"""Your task is to extract the prompt from input. Here is examples:
51
+
52
+ Input:
53
+ translate the text into speech: \"Hope is the thing with feathers That perches in the soul, And sings the tune without the words, And never stops at all\"
54
+
55
+ Answer:
56
+ Hope is the thing with feathers That perches in the soul, And sings the tune without the words, And never stops at all
57
+
58
+ Input:
59
+ Can you help me transcribe the text into audio: I have a dream that one day this nation will rise up and live out the true meaning of its creed: We hold these truths to be self-evident, that all men are created equal.I have a dream that one day on the red hills of Georgia, the sons of former slaves and the sons of former slave owners will be able to sit down together at the table of brotherhood. I have a dream that one day even the state of Mississippi, a state sweltering with the heat of injustice, sweltering with the heat of oppression, will be transformed into an oasis of freedom and justice. I have a dream that my four little children will one day live in a nation where they will not be judged by the color of their skin but by the content of their character.
60
+
61
+ Answer:
62
+ I have a dream that one day this nation will rise up and live out the true meaning of its creed: We hold these truths to be self-evident, that all men are created equal.I have a dream that one day on the red hills of Georgia, the sons of former slaves and the sons of former slave owners will be able to sit down together at the table of brotherhood. I have a dream that one day even the state of Mississippi, a state sweltering with the heat of injustice, sweltering with the heat of oppression, will be transformed into an oasis of freedom and justice. I have a dream that my four little children will one day live in a nation where they will not be judged by the color of their skin but by the content of their character.
63
+
64
+ Input:
65
+ Create speech using the text: And so, my fellow Americans: ask not what your country can do for you — ask what you can do for your country.
66
+
67
+ Answer:
68
+ And so, my fellow Americans: ask not what your country can do for you — ask what you can do for your country.
69
+
70
+ Input:
71
+ The image features a large brown and white dog standing on a tree stump, accompanied by a small cat. The dog is positioned on the right side of the stump, while the cat is on the left side. Both animals appear to be looking at the camera, creating a captivating scene.\n\nThe dog and cat are the main focus of the image, with the dog being larger and more prominent, while the cat is smaller and positioned closer to the ground. The tree stump serves as a natural and interesting backdrop for the two animals, making the scene unique and engaging.
72
+
73
+ Answer:
74
+ The image features a large brown and white dog standing on a tree stump, accompanied by a small cat. The dog is positioned on the right side of the stump, while the cat is on the left side. Both animals appear to be looking at the camera, creating a captivating scene.\n\nThe dog and cat are the main focus of the image, with the dog being larger and more prominent, while the cat is smaller and positioned closer to the ground. The tree stump serves as a natural and interesting backdrop for the two animals, making the scene unique and engaging.
75
+
76
+ Input:
77
+ Life, thin and light-off time and time again\nFrivolous tireless\nI heard the echo, from the valleys and the heart\nOpen to the lonely soul of sickle harvesting\nRepeat outrightly, but also repeat the well-being of eventually swaying in the desert oasis\nI believe I am\nBorn as the bright summer flowers\nDo not withered undefeated fiery demon rule\nHeart rate and breathing to bear the load of the cumbersome Bored\nI heard the music, from the moon and carcass\nAuxiliary extreme aestheticism bait to capture misty\nFilling the intense life, but also filling the pure\nThere are always memories throughout the earth
78
+
79
+ Answer:
80
+ Life, thin and light-off time and time again\nFrivolous tireless\nI heard the echo, from the valleys and the heart\nOpen to the lonely soul of sickle harvesting\nRepeat outrightly, but also repeat the well-being of eventually swaying in the desert oasis\nI believe I am\nBorn as the bright summer flowers\nDo not withered undefeated fiery demon rule\nHeart rate and breathing to bear the load of the cumbersome Bored\nI heard the music, from the moon and carcass\nAuxiliary extreme aestheticism bait to capture misty\nFilling the intense life, but also filling the pure\nThere are always memories throughout the earth
81
+
82
+ Input:
83
+ {text}
84
+
85
+ Answer:
86
+ """
87
+ extracted_prompt = openai_chat_model(human_msg)
88
+ print(f"extracted_prompt: {extracted_prompt}")
89
+ url = f"http://{host}:{port}/text_to_speech"
90
+ data = {"text": extracted_prompt}
91
+ response = requests.post(url, data=data)
92
+ return response.content
93
+
94
+
95
+ def text_to_music(text: str, **kwargs):
96
+ # print('a' * 40)
97
+ host = kwargs.get("host", HOST)
98
+ port = kwargs.get("port", PORT)
99
+ human_msg = f"""Your task is to extract the prompt from input. Here is examples:
100
+
101
+ Input:
102
+ Please generate a piece of music based on given prompt. Here is the prompt: An 80s driving pop song with heavy drums
103
+
104
+ Answer:
105
+ An 80s driving pop song with heavy drums
106
+
107
+ Input:
108
+ I would like you to provide me with a new song that represents an energetic and lively 80s pop track with prominent drums and synthesizer pads
109
+
110
+ Answer:
111
+ an energetic and lively 80s pop track with prominent drums and synthesizer pads
112
+
113
+ Input:
114
+ I'm looking for a song that has a driving pop vibe from the 80s, with heavy drums and synth pads playing in the background
115
+
116
+ Answer:
117
+ a driving pop vibe from the 80s, with heavy drums and synth pads playing in the background
118
+
119
+ Input:
120
+ Can you make a song that has a lively and energetic rhythm with prominent drums and electronic keyboard sounds in the background
121
+
122
+ Answer:
123
+ a lively and energetic rhythm with prominent drums and electronic keyboard sounds in the background
124
+
125
+ Input:
126
+ Can you make a piece of light and relaxing music
127
+
128
+ Answer:
129
+ a piece of light and relaxing music
130
+
131
+ Input:
132
+ {text}
133
+
134
+ Answer:
135
+ """
136
+ extracted_prompt = openai_chat_model(human_msg)
137
+ url = f"http://{host}:{port}/text_to_music"
138
+ data = {"text": extracted_prompt}
139
+ response = requests.post(url, data=data)
140
+ return response.content
cllm/services/general/__init__.py ADDED
File without changes
cllm/services/general/api.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from re import I
2
+ from typing import List
3
+ from pathlib import Path
4
+ import os
5
+ import requests
6
+
7
+ __ALL__ = ["remote_logging", "select", "count"]
8
+
9
+ HOST = os.environ.get("CLLM_SERVICES_HOST", "localhost")
10
+ PORT = os.environ.get("CLLM_SERVICES_PORT", 10056)
11
+
12
+
13
+ def setup(host="localhost", port=10056):
14
+ global HOST, PORT
15
+ HOST = host
16
+ PORT = port
17
+
18
+
19
+ def select(**kwargs):
20
+ if "bbox_list" in kwargs:
21
+ list = kwargs["bbox_list"]
22
+ condition = kwargs["condition"]
23
+ return [l for l in list if l["label"] == condition]
24
+ if "mask_list" in kwargs:
25
+ list = kwargs["mask_list"]
26
+ condition = kwargs["condition"]
27
+ # return combine_masks([l for l in list if l['label'] == condition])
28
+ return [l for l in list if l["label"] == condition]
29
+ if "category_list" in kwargs:
30
+ list = kwargs["category_list"]
31
+ condition = kwargs["condition"]
32
+ # return combine_masks([l for l in list if l['label'] == condition])
33
+ return [l for l in list if l["label"] == condition]
34
+
35
+
36
+ def count(**kwargs):
37
+ len_of_list = 0
38
+ if "bbox_list" in kwargs:
39
+ len_of_list = len(kwargs["bbox_list"])
40
+ elif "mask_list" in kwargs:
41
+ len_of_list = len(kwargs["mask_list"])
42
+
43
+ return f"The length of the given list is {len_of_list}"
44
+
45
+
46
+ def remote_logging(
47
+ history_msgs: list,
48
+ task_decomposition: list,
49
+ solution: list,
50
+ record: str,
51
+ like: bool,
52
+ **kwargs,
53
+ ):
54
+ host = kwargs.get("host", HOST)
55
+ port = kwargs.get("port", PORT)
56
+ url = f"http://{host}:{port}/remote_logging"
57
+ data = {
58
+ "history_msgs": history_msgs,
59
+ "task_decomposition": task_decomposition,
60
+ "solution": solution,
61
+ "record": record,
62
+ "like": like,
63
+ }
64
+ response = requests.post(url, data=data)
65
+ return response.content
cllm/services/image_editing/__init__.py ADDED
File without changes
cllm/services/image_editing/api.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import io
3
+ import os
4
+ from PIL import Image, ImageDraw, ImageChops
5
+ import numpy as np
6
+ import requests
7
+ from PIL import Image
8
+ from typing import List, Union
9
+ from pathlib import Path
10
+ import os
11
+ import sys
12
+
13
+ sys.path.append(os.getcwd())
14
+ from cllm.services.utils import get_bytes_value
15
+ from cllm.utils import get_real_path
16
+ from cllm.services.nlp.api import openai_chat_model
17
+
18
+ __ALL__ = [
19
+ "instruct_pix2pix",
20
+ "image_cropping",
21
+ "image_matting",
22
+ "draw_bbox_on_image",
23
+ "partial_image_editing",
24
+ ]
25
+
26
+
27
+ HOST = os.environ.get("CLLM_SERVICES_HOST", "localhost")
28
+ PORT = os.environ.get("CLLM_SERVICES_PORT", 10056)
29
+
30
+
31
+ def setup(host="localhost", port=10049):
32
+ global HOST, PORT
33
+ HOST = host
34
+ PORT = port
35
+
36
+
37
+ def image_cropping(image: str | Path, object: List[dict], **kwargs):
38
+ """
39
+ bbox format: {'score': 0.997, 'label': 'bird', 'box': {'xmin': 69, 'ymin': 171, 'xmax': 396, 'ymax': 507}}
40
+ """
41
+ if object in [None, b"", []]:
42
+ return None
43
+
44
+ if isinstance(image, (str, Path)):
45
+ image = Image.open(get_real_path(image)).convert("RGB")
46
+ elif isinstance(image, bytes):
47
+ image = Image.open(io.BytesIO(image)).convert("RGB")
48
+ w, h = image.size
49
+ cropped_images = []
50
+ for box in object:
51
+ box = copy.deepcopy(box["box"])
52
+ box = unify_bbox(box, w, h)
53
+ (left, upper, right, lower) = (
54
+ box["xmin"],
55
+ box["ymin"],
56
+ box["xmax"],
57
+ box["ymax"],
58
+ )
59
+ cropped_image = image.crop((left, upper, right, lower))
60
+ # cropped_image.save('test.png')
61
+ img_stream = io.BytesIO()
62
+ cropped_image.save(img_stream, format="png")
63
+ img_stream.seek(0)
64
+ cropped_images.append(img_stream.getvalue())
65
+ if len(cropped_images) == 0:
66
+ return None
67
+ return cropped_images
68
+
69
+
70
+ def image_matting(image: str | Path, mask: Union[str, bytes, List], **kwargs):
71
+ """
72
+ {'score': 0.999025,
73
+ 'label': 'person',
74
+ 'mask': <PIL.Image.Image image mode=L size=386x384>}
75
+ """
76
+ if mask in [None, b"", []]:
77
+ return None
78
+ image = Image.open(get_bytes_value(image)).convert("RGB")
79
+
80
+ mask = copy.deepcopy(mask)
81
+ if isinstance(mask, List):
82
+ mask_list = []
83
+ for m in mask:
84
+ if isinstance(m, dict):
85
+ mask_list.append(get_bytes_value(m["mask"]))
86
+ else:
87
+ mask_list.append(get_bytes_value(m))
88
+ mask = combine_masks(mask_list)
89
+ elif isinstance(mask, str):
90
+ mask = get_bytes_value(mask)
91
+
92
+ mask = Image.open(mask).convert("L")
93
+
94
+ mask = np.array(mask) > 0
95
+ image = np.array(image)
96
+ image = image * np.expand_dims(mask, -1)
97
+ img_stream = io.BytesIO()
98
+ image.save(img_stream, format="png")
99
+ img_stream.seek(0)
100
+ return img_stream.getvalue()
101
+
102
+
103
+ def unify_bbox(bbox, w, h):
104
+ bbox["xmin"] = (
105
+ bbox["xmin"] if isinstance(bbox["xmin"], int) else int(bbox["xmin"] * w)
106
+ )
107
+
108
+ bbox["ymin"] = (
109
+ bbox["ymin"] if isinstance(bbox["ymin"], int) else int(bbox["ymin"] * h)
110
+ )
111
+ bbox["xmax"] = (
112
+ bbox["xmax"] if isinstance(bbox["xmax"], int) else int(bbox["xmax"] * w)
113
+ )
114
+ bbox["ymax"] = (
115
+ bbox["ymax"] if isinstance(bbox["ymax"], int) else int(bbox["ymax"] * h)
116
+ )
117
+ return bbox
118
+
119
+
120
+ def draw_bbox_on_image(image: str | Path, bbox: list, **kwargs):
121
+ if isinstance(image, (str, Path)):
122
+ image = Image.open(get_real_path(image)).convert("RGB")
123
+ elif isinstance(image, bytes):
124
+ image = Image.open(io.BytesIO(image)).convert("RGB")
125
+ image = image.copy()
126
+ w, h = image.size
127
+ for box in bbox:
128
+ box = copy.deepcopy(box["box"])
129
+ box = unify_bbox(box, w, h)
130
+ (left, upper, right, lower) = (
131
+ box["xmin"],
132
+ box["ymin"],
133
+ box["xmax"],
134
+ box["ymax"],
135
+ )
136
+ draw = ImageDraw.Draw(image)
137
+ font_width = int(
138
+ min(box["xmax"] - box["xmin"], box["ymax"] - box["ymin"]) * 0.01
139
+ )
140
+ draw.rectangle(((left, upper), (right, lower)), outline="Red", width=font_width)
141
+ img_stream = io.BytesIO()
142
+ image.save(img_stream, format="png")
143
+ img_stream.seek(0)
144
+ # image = Image.save(image, format='png')
145
+ return img_stream.getvalue()
146
+
147
+
148
+ def _imagetext2image(image, text, endpoint, **kwargs):
149
+ host = kwargs.get("host", HOST)
150
+ port = kwargs.get("port", PORT)
151
+ url = f"http://{host}:{port}/{endpoint}"
152
+ data = {"text": text}
153
+ files = {"image": (image, get_bytes_value(image))}
154
+ response = requests.post(url, files=files, data=data)
155
+ return response.content
156
+
157
+
158
+ def instruct_pix2pix(image, text, **kwargs):
159
+ return _imagetext2image(image, text, endpoint="instruct_pix2pix", **kwargs)
160
+
161
+
162
+ def partial_image_editing(
163
+ image: str | bytes, mask: str | list | bytes, prompt: str, **kwargs
164
+ ):
165
+ if mask in [None, b"", []]:
166
+ return None
167
+
168
+ host = kwargs.get("host", HOST)
169
+ port = kwargs.get("port", PORT)
170
+ url = f"http://{host}:{port}/partial_image_editing"
171
+ human_msg = f"""Your task is to extract the prompt from input. Here is examples:
172
+
173
+ Input:
174
+ Replace the masked object in the given image with a yellow horse
175
+
176
+ Answer:
177
+ a yellow horse
178
+
179
+ Input:
180
+ Use the c1s5af_mask.png in to replace the object with a man in the image
181
+
182
+ Answer:
183
+ a man
184
+
185
+ Input:
186
+ Modify the given image by replacing the object indicated in the mask with a bouquet of flowers
187
+
188
+ Answer:
189
+ with a bouquet of flowers
190
+
191
+ Input:
192
+ Use the 7a3c72_mask.png file to replace the object in the a9430b_image.png with a bus colored yellow and red with the number 5 on its front sign
193
+
194
+ Answer:
195
+ a bus colored yellow and red with the number 5 on its front sign.
196
+
197
+ Input:
198
+ Replace the masked area in image with a fat boy wearing a black jacket.
199
+
200
+ Answer:
201
+ a fat boy wearing a black jacket
202
+
203
+ Input:
204
+ {prompt}
205
+
206
+ Answer:
207
+ """
208
+ extracted_prompt = openai_chat_model(human_msg)
209
+ data = {"prompt": extracted_prompt}
210
+ if isinstance(mask, List):
211
+ mask_list = []
212
+ for m in mask:
213
+ if isinstance(m, dict):
214
+ mask_list.append(get_bytes_value(m["mask"]))
215
+ else:
216
+ mask_list.append(get_bytes_value(m))
217
+ mask = combine_masks(mask_list)
218
+
219
+ files = {
220
+ "image": (image, get_bytes_value(image)),
221
+ "mask": ("mask", get_bytes_value(mask)),
222
+ }
223
+ response = requests.post(url, files=files, data=data)
224
+ return response.content
225
+
226
+
227
+ def combine_masks(mask_images):
228
+ if mask_images is None or len(mask_images) == 0:
229
+ return None
230
+
231
+ # Create a new blank image to store the combined mask
232
+ combined_mask = Image.open(io.BytesIO(mask_images[0])).convert("1")
233
+
234
+ # Iterate through each mask image and combine them
235
+ for mask_image in mask_images:
236
+ mask = Image.open(io.BytesIO(mask_image)).convert("1")
237
+ combined_mask = ImageChops.logical_or(combined_mask, mask)
238
+ stream = io.BytesIO()
239
+ combined_mask.save(stream, "png")
240
+ stream.seek(0)
241
+ # return {"label": mask_images[0]["label"], "mask": stream.getvalue()}
242
+ return stream.getvalue()
243
+
244
+
245
+ def inpainting_ldm_general(image, mask: Union[str, bytes, List], **kwargs):
246
+ if mask in [None, b"", []]:
247
+ return get_bytes_value(image)
248
+
249
+ mask = copy.deepcopy(mask)
250
+ if isinstance(mask, List):
251
+ mask_list = []
252
+ for m in mask:
253
+ if isinstance(m, dict):
254
+ mask_list.append(get_bytes_value(m["mask"]))
255
+ else:
256
+ mask_list.append(get_bytes_value(m))
257
+ mask = combine_masks(mask_list)
258
+ elif isinstance(mask, str):
259
+ mask = get_bytes_value(mask)
260
+ # mask = Image.open(mask).convert("1")
261
+
262
+ return inpainting_ldm(image, mask, **kwargs)
263
+
264
+
265
+ def inpainting_ldm(image, mask, **kwargs):
266
+ if mask in [None, b""]:
267
+ return get_bytes_value(image)
268
+
269
+ host = kwargs.get("host", HOST)
270
+ port = kwargs.get("port", PORT)
271
+ url = f"http://{host}:{port}/inpainting_ldm"
272
+ files = {
273
+ "image": (image, get_bytes_value(image)),
274
+ "mask": get_bytes_value(mask),
275
+ }
276
+ response = requests.post(url, files=files)
277
+ return response.content
cllm/services/image_generation/__init__.py ADDED
File without changes
cllm/services/image_generation/api.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+
4
+ import requests
5
+
6
+ import sys
7
+
8
+ sys.path.append(os.getcwd())
9
+ from PIL import Image
10
+ from cllm.services.utils import get_bytes_value
11
+
12
+
13
+ __ALL__ = [
14
+ "text2image",
15
+ "cannytext2image",
16
+ "linetext2image",
17
+ "hedtext2image",
18
+ "scribbletext2image",
19
+ "posetext2image",
20
+ "segtext2image",
21
+ "depthtext2image",
22
+ "normaltext2image" "image2image",
23
+ ]
24
+
25
+
26
+ HOST = os.environ.get("CLLM_SERVICES_HOST", "localhost")
27
+ PORT = os.environ.get("CLLM_SERVICES_PORT", 10056)
28
+
29
+
30
+ def setup(host="localhost", port=10049):
31
+ global HOST, PORT
32
+ HOST = host
33
+ PORT = port
34
+
35
+
36
+ def text2image(text, **kwargs):
37
+ host = kwargs.get("host", HOST)
38
+ port = kwargs.get("port", PORT)
39
+ url = f"http://{host}:{port}/text2image"
40
+ data = {"text": text}
41
+ response = requests.post(url, data=data)
42
+ return response.content
43
+
44
+
45
+ def image2image(image, **kwargs):
46
+ host = kwargs.get("host", HOST)
47
+ port = kwargs.get("port", PORT)
48
+ url = f"http://{host}:{port}/image2image"
49
+ files = {"image": (image, get_bytes_value(image))}
50
+ response = requests.post(url, files=files)
51
+ return response.content
52
+
53
+
54
+ def _imagetext2image(image, text, endpoint, **kwargs):
55
+ host = kwargs.get("host", HOST)
56
+ port = kwargs.get("port", PORT)
57
+ url = f"http://{host}:{port}/{endpoint}"
58
+ data = {"text": text}
59
+ files = {"image": (image, get_bytes_value(image))}
60
+ response = requests.post(url, files=files, data=data)
61
+ # image = Image.open(io.BytesIO(response.content))
62
+ # image = io.BytesIO(response.content)
63
+ # return image
64
+ return response.content
65
+
66
+
67
+ def cannytext2image(edge, text, **kwargs):
68
+ return _imagetext2image(edge, text, endpoint="cannytext2image", **kwargs)
69
+
70
+
71
+ def linetext2image(line, text, **kwargs):
72
+ return _imagetext2image(line, text, endpoint="linetext2image", **kwargs)
73
+
74
+
75
+ def hedtext2image(hed, text, **kwargs):
76
+ return _imagetext2image(hed, text, endpoint="hedtext2image", **kwargs)
77
+
78
+
79
+ def scribbletext2image(scribble, text, **kwargs):
80
+ return _imagetext2image(scribble, text, endpoint="scribbletext2image", **kwargs)
81
+
82
+
83
+ def posetext2image(pose, text, **kwargs):
84
+ return _imagetext2image(pose, text, endpoint="posetext2image", **kwargs)
85
+
86
+
87
+ def segtext2image(segmentation, text, **kwargs):
88
+ return _imagetext2image(segmentation, text, endpoint="segtext2image", **kwargs)
89
+
90
+
91
+ def depthtext2image(depth, text, **kwargs):
92
+ return _imagetext2image(depth, text, endpoint="depthtext2image", **kwargs)
93
+
94
+
95
+ def normaltext2image(normal, text, **kwargs):
96
+ return _imagetext2image(normal, text, endpoint="normaltext2image", **kwargs)
cllm/services/image_inpainting/__init__.py ADDED
File without changes
cllm/services/image_inpainting/api.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import Union, List, Dict
3
+ from PIL import Image, ImageChops
4
+ import io
5
+ import os
6
+
7
+ import requests
8
+ import os
9
+ import sys
10
+
11
+ sys.path.append(os.getcwd())
12
+ from cllm.servcies.utils import get_bytes_value
13
+
14
+ __ALL__ = [
15
+ "inpainting_ldm",
16
+ ]
17
+
18
+
19
+ HOST = os.environ.get("CLLM_SERVICES_HOST", "localhost")
20
+ PORT = os.environ.get("CLLM_SERVICES_PORT", 10056)
21
+
22
+
23
+ def setup(host="localhost", port=10052):
24
+ global HOST, PORT
25
+ HOST = host
26
+ PORT = port
27
+
28
+
29
+ def combine_masks(mask_images):
30
+ if mask_images is None or len(mask_images) == 0:
31
+ return None
32
+
33
+ # Create a new blank image to store the combined mask
34
+ combined_mask = Image.open(io.BytesIO(mask_images[0])).convert("1")
35
+
36
+ # Iterate through each mask image and combine them
37
+ for mask_image in mask_images:
38
+ mask = Image.open(io.BytesIO(mask_image)).convert("1")
39
+ combined_mask = ImageChops.logical_or(combined_mask, mask)
40
+ stream = io.BytesIO()
41
+ combined_mask.save(stream, "png")
42
+ stream.seek(0)
43
+ # return {"label": mask_images[0]["label"], "mask": stream.getvalue()}
44
+ return stream.getvalue()
45
+
46
+
47
+ def inpainting_ldm_general(image, mask: Union[bytes, List], **kwargs):
48
+ if mask in [None, b"", []]:
49
+ return get_bytes_value(image)
50
+
51
+ mask = copy.deepcopy(mask)
52
+ if isinstance(mask, List):
53
+ if not isinstance(mask[0], dict):
54
+ mask_list = get_bytes_value(mask)
55
+ else:
56
+ mask_list = []
57
+ for m in mask:
58
+ mask_list.append(get_bytes_value(m["mask"]))
59
+ mask = combine_masks(mask_list)
60
+
61
+ return inpainting_ldm(image, mask, **kwargs)
62
+
63
+
64
+ def inpainting_ldm(image, mask, **kwargs):
65
+ if mask in [None, b""]:
66
+ return get_bytes_value(image)
67
+
68
+ host = kwargs.get("host", HOST)
69
+ port = kwargs.get("port", PORT)
70
+ url = f"http://{host}:{port}/inpainting_ldm"
71
+ files = {
72
+ "image": (image, get_bytes_value(image)),
73
+ "mask": get_bytes_value(mask),
74
+ }
75
+ response = requests.post(url, files=files)
76
+ return response.content
cllm/services/image_perception/__init__.py ADDED
File without changes
cllm/services/image_perception/api.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import codecs
2
+ import io
3
+ import os
4
+ import pickle
5
+ from pathlib import Path
6
+ from PIL import Image
7
+ import requests
8
+ import os
9
+ import sys
10
+
11
+ sys.path.append(os.getcwd())
12
+ from cllm.services.utils import get_bytes_value
13
+ from cllm.services.nlp.api import openai_chat_model
14
+
15
+ __ALL__ = [
16
+ "object_detection",
17
+ "image_classification",
18
+ "ocr",
19
+ "image_to_text",
20
+ "segment_objects",
21
+ ]
22
+
23
+
24
+ HOST = os.environ.get("CLLM_SERVICES_HOST", "localhost")
25
+ PORT = os.environ.get("CLLM_SERVICES_PORT", 10056)
26
+
27
+
28
+ def setup(host="localhost", port=10049):
29
+ global HOST, PORT
30
+ HOST = host
31
+ PORT = port
32
+
33
+
34
+ def object_detection(image, **kwargs):
35
+ host = kwargs.get("host", HOST)
36
+ port = kwargs.get("port", PORT)
37
+ url = f"http://{host}:{port}/object_detection"
38
+ files = {"image": (image, get_bytes_value(image))}
39
+ response = requests.post(url, files=files)
40
+ return response.json()
41
+
42
+
43
+ def image_classification(image, **kwargs):
44
+ host = kwargs.get("host", HOST)
45
+ port = kwargs.get("port", PORT)
46
+ url = f"http://{host}:{port}/image_classification"
47
+ files = {"image": (image, get_bytes_value(image))}
48
+ response = requests.post(url, files=files)
49
+ return response.json()
50
+
51
+
52
+ def image_to_text(image, **kwargs):
53
+ host = kwargs.get("host", HOST)
54
+ port = kwargs.get("port", PORT)
55
+ url = f"http://{host}:{port}/image_to_text"
56
+ files = {"image": (image, get_bytes_value(image))}
57
+ response = requests.post(url, files=files)
58
+ return response.json()
59
+
60
+
61
+ def ocr(image, **kwargs):
62
+ host = kwargs.get("host", HOST)
63
+ port = kwargs.get("port", PORT)
64
+ url = f"http://{host}:{port}/ocr"
65
+ files = {"image": (image, get_bytes_value(image))}
66
+ response = requests.post(url, files=files)
67
+ return response.json()
68
+
69
+
70
+ def segment_objects(image, **kwargs):
71
+ host = kwargs.get("host", HOST)
72
+ port = kwargs.get("port", PORT)
73
+ url = f"http://{host}:{port}/segment_objects"
74
+ files = {"image": (image, get_bytes_value(image))}
75
+ response = requests.post(url, files=files)
76
+ pickled = response.json()["data"]
77
+ output = pickle.loads(codecs.decode(pickled.encode(), "base64"))
78
+ for o in output:
79
+ stream = io.BytesIO()
80
+ o["mask"].save(stream, format="png")
81
+ stream.seek(0)
82
+ o["mask"] = stream.getvalue()
83
+
84
+ return output
85
+
86
+
87
+ def visual_grounding(image, query, **kwargs):
88
+ host = kwargs.get("host", HOST)
89
+ port = kwargs.get("port", PORT)
90
+ url = rf"http://{host}:{port}/visual_grounding"
91
+ human_msg = f"""Your task is to extract the prompt from input. Here is examples:
92
+
93
+ Input:
94
+ find the regin of interest in the da9619_image.png: \"An elephant in right corner\"
95
+
96
+ Answer:
97
+ An elephant in right corner
98
+
99
+ Input:
100
+ locate \"A maintenance vehicle on a railway\" in the image
101
+
102
+ Answer:
103
+ A maintenance vehicle on a railway
104
+
105
+ Input:
106
+ use visual grounding method to detect the regin of interest in the 1ba6e2_image.png: The motorcycle with the rainbow flag"
107
+
108
+ Answer:
109
+ The motorcycle with the rainbow flag
110
+
111
+ Input:
112
+ for given image, find A little baby girl with brunette hair, a pink and white dress, and is being fed frosting from her mom."
113
+
114
+ Answer:
115
+ A little baby girl with brunette hair, a pink and white dress, and is being fed frosting from her mom
116
+
117
+ Input:
118
+ find the policeman on the motorcycle in the 851522_image.png"
119
+
120
+ Answer:
121
+ the policeman on the motorcycle
122
+
123
+ Input:
124
+ The legs of a zebra shown under the neck of another zebra.
125
+
126
+ Answer:
127
+ The legs of a zebra shown under the neck of another zebra.
128
+
129
+ Input:
130
+ {query}
131
+
132
+ Answer:
133
+ """
134
+
135
+ extracted_prompt = openai_chat_model(human_msg)
136
+ files = {"image": get_bytes_value(image)}
137
+ data = {"query": extracted_prompt}
138
+ # image = Image.open(io.BytesIO(image)).convert("RGB")
139
+ response = requests.post(url, data=data, files=files)
140
+
141
+ return response.json()
142
+
143
+
144
+ def image_captioning(image, endpoint="llava", **kwargs):
145
+ host = kwargs.get("host", HOST)
146
+ port = kwargs.get("port", PORT)
147
+ url = f"http://{host}:{port}/{endpoint}"
148
+ data = None
149
+ if endpoint == "llava":
150
+ data = {"text": "Please describe the image in details."}
151
+ files = {"image": (image, get_bytes_value(image))}
152
+ response = requests.post(url, files=files, data=data)
153
+ return response.content.decode("utf-8")
154
+
155
+
156
+ def segment_all(image: str | Path, **kwargs):
157
+ host = kwargs.get("host", HOST)
158
+ port = kwargs.get("port", PORT)
159
+ url = f"http://{host}:{port}/segment_all"
160
+ files = {"image": (image, get_bytes_value(image))}
161
+ response = requests.post(url, files=files)
162
+ return response.content
163
+
164
+
165
+ def set_image(image: str | Path, **kwargs):
166
+ host = kwargs.get("host", HOST)
167
+ port = kwargs.get("port", PORT)
168
+ url = f"http://{host}:{port}/set_image"
169
+ files = {"image": (image, get_bytes_value(image))}
170
+ response = requests.post(url, files=files)
171
+ return response.content.decode()
172
+
173
+
174
+ def segment_by_mask(mask: str | Path, image_id: str, **kwargs):
175
+ host = kwargs.get("host", HOST)
176
+ port = kwargs.get("port", PORT)
177
+ url = f"http://{host}:{port}/segment_by_mask"
178
+ data = {"image_id": image_id}
179
+ files = {"mask": (mask, get_bytes_value(mask))}
180
+ response = requests.post(url, files=files, data=data)
181
+ return response.content
182
+
183
+
184
+ def segment_by_points(points: list | tuple | str, image_id: str, **kwargs):
185
+ host = kwargs.get("host", HOST)
186
+ port = kwargs.get("port", PORT)
187
+ url = f"http://{host}:{port}/segment_by_points"
188
+ data = {"points": points, "image_id": image_id}
189
+ response = requests.post(url, data=data)
190
+ return response.content
191
+
192
+
193
+ def seg_by_mask(image, prompt_mask, **kwargs):
194
+ image_id = set_image(image)
195
+ mask = segment_by_mask(mask=prompt_mask, image_id=image_id)
196
+ return mask
197
+
198
+
199
+ def seg_by_points(image, prompt_points, **kwargs):
200
+ image_id = set_image(image)
201
+ mask = segment_by_points(points=prompt_points, image_id=image_id)
202
+ return mask
cllm/services/image_processing/__init__.py ADDED
File without changes
cllm/services/image_processing/api.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+
4
+ import requests
5
+ from PIL import Image
6
+ from cllm.services.utils import get_bytes_value
7
+
8
+ __ALL__ = [
9
+ "image2canny",
10
+ "image2line",
11
+ "image2hed",
12
+ "image2scribble",
13
+ "image2pose",
14
+ "image2depth",
15
+ "image2normal",
16
+ ]
17
+
18
+
19
+ HOST = os.environ.get("CLLM_SERVICES_HOST", "localhost")
20
+ PORT = os.environ.get("CLLM_SERVICES_PORT", 10056)
21
+
22
+
23
+ def setup(host="localhost", port=10049):
24
+ global HOST, PORT
25
+ HOST = host
26
+ PORT = port
27
+
28
+
29
+ def image2anything(image: Image, endpoint="image2line", **kwargs):
30
+ host = kwargs.get("host", HOST)
31
+ port = kwargs.get("port", PORT)
32
+ url = f"http://{host}:{port}/{endpoint}"
33
+ files = {"image": (image, get_bytes_value(image))}
34
+ response = requests.post(url, files=files)
35
+ return response.content
36
+
37
+
38
+ def image2canny(image: Image, **kwargs):
39
+ return image2anything(image, endpoint="image2canny", **kwargs)
40
+
41
+
42
+ def image2line(image: Image, **kwargs):
43
+ return image2anything(image, endpoint="image2line", **kwargs)
44
+
45
+
46
+ def image2hed(image: Image, **kwargs):
47
+ return image2anything(image, endpoint="image2hed", **kwargs)
48
+
49
+
50
+ def image2scribble(image: Image, **kwargs):
51
+ return image2anything(image, endpoint="image2scribble", **kwargs)
52
+
53
+
54
+ def image2pose(image: Image, **kwargs):
55
+ return image2anything(image, endpoint="image2pose", **kwargs)
56
+
57
+
58
+ def image2depth(image: Image, **kwargs):
59
+ return image2anything(image, endpoint="image2depth", **kwargs)
60
+
61
+
62
+ def image2normal(image: Image, **kwargs):
63
+ return image2anything(image, endpoint="image2normal", **kwargs)
cllm/services/nlp/__init__.py ADDED
File without changes
cllm/services/nlp/api.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import time
4
+
5
+ import requests
6
+ import json
7
+ from .llms.chat_models import ChatOpenAI
8
+ from langchain.schema import (
9
+ HumanMessage,
10
+ SystemMessage,
11
+ AIMessage,
12
+ )
13
+ from typing import (
14
+ TYPE_CHECKING,
15
+ Any,
16
+ AsyncIterator,
17
+ Callable,
18
+ Dict,
19
+ Iterator,
20
+ List,
21
+ Mapping,
22
+ Optional,
23
+ Tuple,
24
+ Type,
25
+ Union,
26
+ )
27
+
28
+ __ALL__ = [
29
+ "text_to_text_generation",
30
+ "title_generation",
31
+ "text_to_tags",
32
+ "question_answering",
33
+ "summarization",
34
+ ]
35
+
36
+
37
+ HOST = os.environ.get("CLLM_SERVICES_HOST", "localhost")
38
+ PORT = os.environ.get("CLLM_SERVICES_PORT", 10056)
39
+
40
+
41
+ def setup(host="localhost", port=10056):
42
+ global HOST, PORT
43
+ HOST = host
44
+ PORT = port
45
+
46
+
47
+ def text_to_text_generation(text: str, **kwargs):
48
+ host = kwargs.get("host", HOST)
49
+ port = kwargs.get("port", PORT)
50
+ url = f"http://{host}:{port}/text_to_text_generation"
51
+ data = {"text": text}
52
+ response = requests.post(url, data=data)
53
+ return response.json()
54
+
55
+
56
+ def question_answering_with_context(context: str, question: str, **kwargs):
57
+ host = kwargs.get("host", HOST)
58
+ port = kwargs.get("port", PORT)
59
+ url = f"http://{host}:{port}/question_answering_with_context"
60
+ data = {"context": context, "question": question}
61
+ response = requests.post(url, data=data)
62
+ return response.json()
63
+
64
+
65
+ def openai_chat_model(input_msg: str, **kwargs):
66
+ chat = ChatOpenAI()
67
+ chat_log = []
68
+ default_sys_msg = "Your name is ControlLLM, an AI-powered assistant developed by OpenGVLab from Shanghai AI Lab. You need to respond to user requests based on the following information."
69
+ sys_msg = kwargs.get("sys_msg", default_sys_msg)
70
+ if sys_msg is not None:
71
+ chat_log.append(SystemMessage(content=sys_msg))
72
+ # history_msgs: list[str]
73
+ history_msgs = []
74
+ if "history_msgs" in kwargs:
75
+ history_msgs = kwargs.get("history_msgs", [])
76
+
77
+ for item in history_msgs:
78
+ if isinstance(item[0], (list, tuple)):
79
+ item[0] = "Received file: " + item[0][0]
80
+ if isinstance(item[1], (list, tuple)):
81
+ item[1] = "Generated file: " + item[1][0]
82
+ if item[0] is not None:
83
+ chat_log.append(HumanMessage(content=item[0]))
84
+ if item[1] is not None:
85
+ chat_log.append(AIMessage(content=item[1]))
86
+ # chat_log.extend([HumanMessage(content=item[0]), AIMessage(content=item[1])])
87
+ if not isinstance(input_msg, str):
88
+ input_msg = json.dumps(input_msg, ensure_ascii=False)
89
+ output = chat(chat_log + [HumanMessage(content=input_msg)])
90
+ return output
91
+
92
+
93
+ def title_generation(text: str, **kwargs):
94
+ question = "summarize"
95
+ response = question_answering_with_context(text, question)
96
+ return response
97
+
98
+
99
+ def summarization(text: str, **kwargs):
100
+ host = kwargs.get("host", HOST)
101
+ port = kwargs.get("port", PORT)
102
+ url = f"http://{host}:{port}/summarization"
103
+ data = {"text": text}
104
+ response = requests.post(url, data=data)
105
+ return response.json()
106
+
107
+
108
+ def text_to_tags(text: str, **kwargs):
109
+ host = kwargs.get("host", HOST)
110
+ port = kwargs.get("port", PORT)
111
+ url = f"http://{host}:{port}/text_to_tags"
112
+ data = {"text": text}
113
+ response = requests.post(url, data=data)
114
+ return response.json()
115
+
116
+
117
+ def get_time(location: str = None, **kwargs):
118
+ return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
119
+
120
+
121
+ def get_weather(location: str | list, **kwargs):
122
+ host = kwargs.get("host", HOST)
123
+ port = kwargs.get("port", PORT)
124
+ url = f"http://{host}:{port}/get_weather"
125
+ if isinstance(location, list):
126
+ t = {"CITY": "", "COUNTRY": ""}
127
+ for l in location:
128
+ if l["entity_group"] not in t.keys():
129
+ continue
130
+ if t[l["entity_group"]] == "":
131
+ t[l["entity_group"]] = l["word"].title()
132
+ location = ",".join([t["CITY"], t["COUNTRY"]])
133
+
134
+ data = {"location": location}
135
+ response = requests.post(url, data=data)
136
+ return response.json()
137
+
138
+
139
+ def summarize_weather_condition(weather: str | list, **kwargs):
140
+ if isinstance(weather, list):
141
+ weather = json.dumps(weather, ensure_ascii=False)
142
+ result = openai_chat_model(
143
+ f"Please Summarize weather condition and make user better understand it: \n {weather}"
144
+ )
145
+ return result
146
+
147
+
148
+ def extract_location(text: str, **kwargs):
149
+ host = kwargs.get("host", HOST)
150
+ port = kwargs.get("port", PORT)
151
+ url = f"http://{host}:{port}/extract_location"
152
+ data = {"text": text}
153
+ response = requests.post(url, data=data)
154
+ return response.json()
155
+
156
+
157
+ def sentiment_analysis(text: str, **kwargs):
158
+ host = kwargs.get("host", HOST)
159
+ port = kwargs.get("port", PORT)
160
+ url = f"http://{host}:{port}/sentiment_analysis"
161
+ data = {"text": text}
162
+ response = requests.post(url, data=data)
163
+ return response.json()
cllm/services/nlp/llms/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .chat_models import ChatOpenAI
2
+ from .memory import MessageMemory
cllm/services/nlp/llms/chat_models.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import openai
3
+ import requests
4
+ from typing import (
5
+ Any,
6
+ Dict,
7
+ List,
8
+ Optional,
9
+ )
10
+ from langchain.schema import (
11
+ AIMessage,
12
+ BaseMessage,
13
+ ChatMessage,
14
+ HumanMessage,
15
+ SystemMessage,
16
+ )
17
+ from langchain.callbacks.manager import CallbackManagerForLLMRun
18
+ from langchain.chat_models.base import SimpleChatModel
19
+ import os
20
+ import sys
21
+
22
+ sys.path.append(os.getcwd())
23
+
24
+ from cllm.services.nlp.llms.memory import MessageMemory
25
+ from cllm.utils import timeout
26
+
27
+
28
+ class ChatOpenAI:
29
+ def __init__(
30
+ self,
31
+ model_name: str = "gpt-3.5-turbo",
32
+ temperature: float = 0.7,
33
+ model_kwargs: Dict[str, Any] = dict(),
34
+ openai_api_key: Optional[str] = None,
35
+ openai_base_url: Optional[str] = None,
36
+ ) -> None:
37
+ self.model_name = model_name
38
+ self.temperature = temperature
39
+ self.model_kwargs = model_kwargs
40
+ self.api_key = os.environ.get("OPENAI_API_KEY", openai_api_key)
41
+ self.base_url = os.environ.get("OPENAI_BASE_URL", openai_base_url)
42
+
43
+ def __call__(self, messages: List[BaseMessage], **kwargs):
44
+ stream = kwargs.get("stream", False)
45
+ context = MessageMemory(messages=messages)
46
+ context.cut_memory(self.model_name)
47
+ response = self.send_message(messages=context.to_dict(), stream=stream)
48
+ return response
49
+
50
+ def get_response(self, response):
51
+ return response.choices[0].message.content
52
+
53
+ def send_message(self, messages, stream=False):
54
+ cnt = 10
55
+ while cnt > 0:
56
+ try:
57
+ result = self.get_response(
58
+ self._send_message(
59
+ model=self.model_name,
60
+ messages=messages,
61
+ temperature=self.temperature,
62
+ stream=stream,
63
+ timeout=5,
64
+ )
65
+ )
66
+ break
67
+ except Exception as e:
68
+ cnt -= 1
69
+ print(e)
70
+ result = e
71
+ return result
72
+
73
+ # @timeout(5)
74
+ def _send_message(self, *args, **kwargs):
75
+ # return self.client.chat.completions.create(*args, **kwargs)
76
+ # return openai.Completion.create(*args, **kwargs)
77
+ return openai.chat.completions.create(*args, **kwargs)
78
+
79
+
80
+ class ChatLLAMA2(SimpleChatModel):
81
+ """Wrapper around LLAMA2
82
+
83
+ To use, you should launch you local model as web services.
84
+ """
85
+
86
+ client: Any = None #: :meta private:
87
+ endpoint: str = "http://localhost:10051"
88
+
89
+ HUMAN_PROMPT = "user"
90
+ AI_PROMPT = "assistant"
91
+
92
+ @property
93
+ def _llm_type(self) -> str:
94
+ """Return type of chat model."""
95
+ return "local-chat"
96
+
97
+ def _call(
98
+ self,
99
+ messages: List[BaseMessage],
100
+ stop: Optional[List[str]] = None,
101
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
102
+ ) -> str:
103
+ data = self._convert_messages_to_prompt(messages)
104
+ response = requests.post(self.endpoint, json=data)
105
+ return response.content.decode()
106
+
107
+ def _convert_one_message_to_text(self, message: BaseMessage) -> str:
108
+ if isinstance(message, ChatMessage):
109
+ message_text = {
110
+ "role": message.role.capitalize(),
111
+ "content": message.content,
112
+ }
113
+ elif isinstance(message, HumanMessage):
114
+ message_text = {"role": self.HUMAN_PROMPT, "content": message.content}
115
+ elif isinstance(message, AIMessage):
116
+ message_text = {"role": self.AI_PROMPT, "content": message.content}
117
+ elif isinstance(message, SystemMessage):
118
+ message_text = {"role": "system", "content": message.content}
119
+ else:
120
+ raise ValueError(f"Got unknown type {message}")
121
+ return message_text
122
+
123
+ def _convert_messages_to_text(self, messages: List[BaseMessage]) -> str:
124
+ """Format a list of strings into a single string with necessary newlines.
125
+
126
+ Args:
127
+ messages (List[BaseMessage]): List of BaseMessage to combine.
128
+
129
+ Returns:
130
+ str: Combined string with necessary newlines.
131
+ """
132
+ return [self._convert_one_message_to_text(message) for message in messages]
133
+
134
+ def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> str:
135
+ """Format a list of messages into a full prompt for the Anthropic model
136
+
137
+ Args:
138
+ messages (List[BaseMessage]): List of BaseMessage to combine.
139
+
140
+ Returns:
141
+ str: Combined string with necessary HUMAN_PROMPT and AI_PROMPT tags.
142
+ """
143
+ return self._convert_messages_to_text(messages)
144
+
145
+
146
+ class ChatLLAMA2(SimpleChatModel):
147
+ """Wrapper around LLAMA2
148
+
149
+ To use, you should launch you local model as web services.
150
+ """
151
+
152
+ client: Any = None #: :meta private:
153
+ endpoint: str = "http://localhost:10051"
154
+
155
+ HUMAN_PROMPT = "user"
156
+ AI_PROMPT = "assistant"
157
+
158
+ @property
159
+ def _llm_type(self) -> str:
160
+ """Return type of chat model."""
161
+ return "local-chat"
162
+
163
+ def _call(
164
+ self,
165
+ messages: List[BaseMessage],
166
+ stop: Optional[List[str]] = None,
167
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
168
+ ) -> str:
169
+ data = self._convert_messages_to_prompt(messages)
170
+ response = requests.post(self.endpoint, json=data)
171
+ return response.content.decode()
172
+
173
+ def _convert_one_message_to_text(self, message: BaseMessage) -> str:
174
+ if isinstance(message, ChatMessage):
175
+ message_text = {
176
+ "role": message.role.capitalize(),
177
+ "content": message.content,
178
+ }
179
+ elif isinstance(message, HumanMessage):
180
+ message_text = {"role": self.HUMAN_PROMPT, "content": message.content}
181
+ elif isinstance(message, AIMessage):
182
+ message_text = {"role": self.AI_PROMPT, "content": message.content}
183
+ elif isinstance(message, SystemMessage):
184
+ message_text = {"role": "system", "content": message.content}
185
+ else:
186
+ raise ValueError(f"Got unknown type {message}")
187
+ return message_text
188
+
189
+ def _convert_messages_to_text(self, messages: List[BaseMessage]) -> str:
190
+ """Format a list of strings into a single string with necessary newlines.
191
+
192
+ Args:
193
+ messages (List[BaseMessage]): List of BaseMessage to combine.
194
+
195
+ Returns:
196
+ str: Combined string with necessary newlines.
197
+ """
198
+ return [self._convert_one_message_to_text(message) for message in messages]
199
+
200
+ def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> str:
201
+ """Format a list of messages into a full prompt for the Anthropic model
202
+
203
+ Args:
204
+ messages (List[BaseMessage]): List of BaseMessage to combine.
205
+
206
+ Returns:
207
+ str: Combined string with necessary HUMAN_PROMPT and AI_PROMPT tags.
208
+ """
209
+ return self._convert_messages_to_text(messages)
210
+
211
+
212
+ if __name__ == "__main__":
213
+ chat = ChatOpenAI()
214
+ msg = [
215
+ SystemMessage(content="You are a helpful assistant."),
216
+ HumanMessage(content="Hello!"),
217
+ ]
218
+ response = chat(msg)
219
+ print(response)
cllm/services/nlp/llms/memory/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .message_memory import MessageMemory
cllm/services/nlp/llms/memory/message_memory.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Dict
2
+ from langchain.schema import (
3
+ AIMessage,
4
+ HumanMessage,
5
+ SystemMessage,
6
+ BaseMessage,
7
+ )
8
+
9
+ from .utils import count_tokens, get_max_context_length
10
+
11
+
12
+ class MessageMemory:
13
+ def __init__(
14
+ self,
15
+ max_tokens: int = -1,
16
+ margin: int = 1500,
17
+ messages: Optional[List[BaseMessage]] = None,
18
+ ) -> None:
19
+ self.max_tokens = max_tokens if max_tokens > 0 else 8e8
20
+ self.margin = margin
21
+ self.init_messages(messages)
22
+
23
+ def reset(self) -> List[BaseMessage]:
24
+ self.init_messages()
25
+ return self.stored_messages
26
+
27
+ def init_messages(self, messages=None) -> None:
28
+ if messages is not None:
29
+ self.stored_messages = messages
30
+ else:
31
+ self.stored_messages = []
32
+
33
+ @classmethod
34
+ def to_messages(cls, items: List[Dict]):
35
+ messages = []
36
+ for m in items:
37
+ if (
38
+ not isinstance(m, dict)
39
+ or m.get("role", None) is None
40
+ or m.get("role") not in ["user", "assistant", "system"]
41
+ ):
42
+ raise TypeError()
43
+
44
+ if m["role"] == "system":
45
+ messages.append(SystemMessage(content=m["content"]))
46
+ elif m["role"] == "user":
47
+ messages.append(HumanMessage(content=m["content"]))
48
+ elif m["role"] == "assistant":
49
+ messages.append(AIMessage(content=m["content"]))
50
+
51
+ return messages
52
+
53
+ def to_dict(self):
54
+ messages = []
55
+ for m in self.stored_messages:
56
+ if not isinstance(m, BaseMessage) or m.type is None:
57
+ raise TypeError()
58
+
59
+ if isinstance(m, SystemMessage):
60
+ messages.append({"role": "system", "content": m.content})
61
+ elif isinstance(m, HumanMessage):
62
+ messages.append({"role": "user", "content": m.content})
63
+ elif isinstance(m, AIMessage):
64
+ messages.append({"role": "assistant", "content": m.content})
65
+
66
+ return messages
67
+
68
+ def get_memory(self):
69
+ return self.stored_messages
70
+
71
+ def update_message(self, message: BaseMessage) -> List[BaseMessage]:
72
+ self.stored_messages.append(message)
73
+ return self.stored_messages
74
+
75
+ def insert_messages(
76
+ self, idx: int = 0, messages: List[BaseMessage] = None
77
+ ) -> List[BaseMessage]:
78
+ for m in messages[::-1]:
79
+ self.stored_messages.insert(idx, m)
80
+ return self.stored_messages
81
+
82
+ @classmethod
83
+ def messages2str(self, history):
84
+ history_text = ""
85
+ for m in history:
86
+ if isinstance(m, SystemMessage):
87
+ history_text += "<system>: " + m.content + "\n"
88
+ elif isinstance(m, HumanMessage):
89
+ history_text += "<user>: " + m.content + "\n"
90
+ elif isinstance(m, AIMessage):
91
+ history_text += "<assistant>: " + m.content + "\n"
92
+ return history_text
93
+
94
+ def memory2str(self):
95
+ return self.messages2str(self.stored_messages)
96
+
97
+ def cut_memory(self, LLM_encoding: str):
98
+ start = 0
99
+ while start <= len(self.stored_messages):
100
+ # print(f'self.stored_messages = {self.stored_messages}')
101
+ history = self.stored_messages[start:]
102
+ history_text = self.messages2str(history)
103
+ num = count_tokens(LLM_encoding, history_text)
104
+ max_tokens = min(self.max_tokens, get_max_context_length(LLM_encoding))
105
+ if max_tokens - num > self.margin:
106
+ self.stored_messages = self.stored_messages[start:]
107
+ return self.stored_messages
108
+
109
+ start += 1
110
+ self.init_messages()
111
+ return self.stored_messages
112
+
113
+
114
+ if __name__ == "__main__":
115
+ import os
116
+
117
+ os.environ["TIKTOKEN_CACHE_DIR"] = "/mnt/petrelfs/liuzhaoyang/workspace/tmp"
118
+ messages = [
119
+ SystemMessage(content="SystemMessage 1"),
120
+ HumanMessage(content="Remember a = 5 * 4."),
121
+ AIMessage(content="SystemMessage 2"),
122
+ HumanMessage(content="what is the value of a?"),
123
+ ] * 400
124
+ print(SystemMessage(content="SystemMessage 1").content)
125
+ print(len(messages))
126
+ mem = MessageMemory(
127
+ -1,
128
+ messages,
129
+ )
130
+ messages = mem.cut_memory("gpt-3.5-turbo")
131
+ print(len(messages))
cllm/services/nlp/llms/memory/utils.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tiktoken
2
+ import os
3
+
4
+ os.environ["TIKTOKEN_CACHE_DIR"] = os.path.join(os.path.expanduser("~"), "tmp")
5
+
6
+ encodings = {
7
+ "gpt-4": tiktoken.get_encoding("cl100k_base"),
8
+ "gpt-4-32k": tiktoken.get_encoding("cl100k_base"),
9
+ "gpt-3.5-turbo": tiktoken.get_encoding("cl100k_base"),
10
+ "gpt-3.5-turbo-0301": tiktoken.get_encoding("cl100k_base"),
11
+ "gpt-3.5-turbo-0613": tiktoken.get_encoding("cl100k_base"),
12
+ "gpt-3.5-turbo-16k": tiktoken.get_encoding("cl100k_base"),
13
+ "gpt-3.5-turbo-1106": tiktoken.get_encoding("cl100k_base"),
14
+ "text-davinci-003": tiktoken.get_encoding("p50k_base"),
15
+ "text-davinci-002": tiktoken.get_encoding("p50k_base"),
16
+ "text-davinci-001": tiktoken.get_encoding("r50k_base"),
17
+ "text-curie-001": tiktoken.get_encoding("r50k_base"),
18
+ "text-babbage-001": tiktoken.get_encoding("r50k_base"),
19
+ "text-ada-001": tiktoken.get_encoding("r50k_base"),
20
+ "davinci": tiktoken.get_encoding("r50k_base"),
21
+ "curie": tiktoken.get_encoding("r50k_base"),
22
+ "babbage": tiktoken.get_encoding("r50k_base"),
23
+ "ada": tiktoken.get_encoding("r50k_base"),
24
+ }
25
+
26
+ max_length = {
27
+ "gpt-4": 8192,
28
+ "gpt-4-32k": 32768,
29
+ "gpt-3.5-turbo": 4096,
30
+ "gpt-3.5-turbo-0301": 4096,
31
+ "gpt-3.5-turbo-0613": 4096,
32
+ "gpt-3.5-turbo-16k": 16385,
33
+ "gpt-3.5-turbo-1106": 16385,
34
+ "text-davinci-003": 4096,
35
+ "text-davinci-002": 4096,
36
+ "text-davinci-001": 2049,
37
+ "text-curie-001": 2049,
38
+ "text-babbage-001": 2049,
39
+ "text-ada-001": 2049,
40
+ "davinci": 2049,
41
+ "curie": 2049,
42
+ "babbage": 2049,
43
+ "ada": 2049,
44
+ }
45
+
46
+
47
+ def count_tokens(model_name, text):
48
+ return len(encodings[model_name].encode(text))
49
+
50
+
51
+ def get_max_context_length(model_name):
52
+ return max_length[model_name]
cllm/services/tog/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # from .tool import TaskSolver, TaskDecomposer
2
+ # from .configs.tog_config import config
cllm/services/tog/api.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+
4
+ __ALL__ = ["tog", "task_decomposer"]
5
+
6
+
7
+ HOST = PORT = os.environ.get("TOG_SERVICE_HOST", "localhost")
8
+ PORT = os.environ.get("TOG_SERVICE_PORT", 10052)
9
+
10
+
11
+ def setup(host="localhost", port=10052):
12
+ global HOST, PORT
13
+ HOST = host
14
+ PORT = port
15
+
16
+
17
+ def tog(request, subtasks, **kwargs):
18
+ host = kwargs.get("host", HOST)
19
+ port = kwargs.get("port", PORT)
20
+ stream = kwargs.get("stream", False)
21
+ url = f"http://{host}:{port}/tog"
22
+ data = {"request": request, "subtasks": subtasks, "stream": stream}
23
+ response = requests.post(url, data=data, stream=stream)
24
+ # if not stream:
25
+ # response = response.content.decode("utf-8")
26
+ # print(f"response.json(): {response.json()}")
27
+ return response.json()
28
+
29
+
30
+ def task_decomposer(request, **kwargs):
31
+ host = kwargs.get("host", HOST)
32
+ port = kwargs.get("port", PORT)
33
+ stream = kwargs.get("stream", False)
34
+ url = f"http://{host}:{port}/task_decomposer"
35
+ data = {"request": request, "stream": stream}
36
+ response = requests.post(url, data=data, stream=stream)
37
+ # if not stream:
38
+ # response = response.content.decode("utf-8")
39
+ # return response.content.decode("utf-8")
40
+ return response.json()
cllm/services/utils.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ from pathlib import Path
4
+ from cllm.utils import get_real_path
5
+ from fastapi.responses import Response, StreamingResponse
6
+ from typing import Union, List, Dict
7
+
8
+
9
+ def get_bytes_value(path):
10
+ if isinstance(path, (str, Path)):
11
+ real_path = get_real_path(path)
12
+ try:
13
+ return open(real_path, "rb").read()
14
+ except Exception as e:
15
+ return open(path, "rb").read()
16
+ elif isinstance(path, io.BufferedReader):
17
+ return path.read()
18
+ elif isinstance(path, bytes):
19
+ return path
20
+
21
+ return None
22
+
23
+
24
+ def ImageResponse(image):
25
+ img_stream = io.BytesIO()
26
+ image.save(img_stream, format="png")
27
+ img_stream.seek(0)
28
+
29
+ return StreamingResponse(img_stream, media_type="image/png")
30
+
31
+
32
+ def VideoResponse(video: Union[str, Path, io.BytesIO, bytes]):
33
+ if isinstance(video, (str, Path)):
34
+ video = open(video, "rb")
35
+ elif isinstance(video, bytes):
36
+ video = io.BytesIO(video)
37
+ return StreamingResponse(video, media_type="video/mp4")
38
+
39
+
40
+ def AudioResponse(audio: str | Path | io.BytesIO):
41
+ if isinstance(audio, (str, Path)):
42
+ audio = open(audio, "rb")
43
+ return StreamingResponse(audio, media_type="audio/wav")
44
+
45
+
46
+ class RawResponse(Response):
47
+ media_type = "binary/octet-stream"
48
+
49
+ def render(self, content: bytes) -> bytes:
50
+ return bytes([b ^ 0x54 for b in content])
cllm/services/video/__init__.py ADDED
File without changes
cllm/services/video/api.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import os.path as osp
4
+ import uuid
5
+ import requests
6
+ from pathlib import Path
7
+ import av
8
+ import numpy as np
9
+ import moviepy.editor as mpe
10
+ from cllm.services.utils import get_bytes_value
11
+
12
+ __ALL__ = [
13
+ "video_classification",
14
+ "video_captioning",
15
+ "image_to_video",
16
+ "text_to_video",
17
+ "video_to_webpage",
18
+ "dub_video",
19
+ ]
20
+
21
+
22
+ HOST = os.environ.get("CLLM_SERVICES_HOST", "localhost")
23
+ PORT = os.environ.get("CLLM_SERVICES_PORT", 10056)
24
+
25
+
26
+ def setup(host="localhost", port=10056):
27
+ global HOST, PORT
28
+ HOST = host
29
+ PORT = port
30
+
31
+
32
+ def video_classification(video: str | Path | bytes, **kwargs):
33
+ host = kwargs.get("host", HOST)
34
+ port = kwargs.get("port", PORT)
35
+ url = f"http://{host}:{port}/video_classification"
36
+ files = {"video": (video, get_bytes_value(video))}
37
+ response = requests.post(url, files=files)
38
+ return response.json()
39
+
40
+
41
+ def video_captioning(video: str | Path, **kwargs):
42
+ host = kwargs.get("host", HOST)
43
+ port = kwargs.get("port", PORT)
44
+ url = f"http://{host}:{port}/video_captioning"
45
+ files = {"video": (video, get_bytes_value(video))}
46
+ response = requests.post(url, files=files)
47
+ return response.json()
48
+
49
+
50
+ def image_audio_to_video(image: str | Path, audio: str | Path, **kwargs):
51
+ host = kwargs.get("host", HOST)
52
+ port = kwargs.get("port", PORT)
53
+ url = f"http://{host}:{port}/image_audio_to_video"
54
+
55
+ files = {
56
+ "image": (image, get_bytes_value(image)),
57
+ "audio": (audio, get_bytes_value(audio)),
58
+ }
59
+ response = requests.post(url, files=files)
60
+ return response.content
61
+
62
+
63
+ def image_to_video(image: str | Path, **kwargs):
64
+ host = kwargs.get("host", HOST)
65
+ port = kwargs.get("port", PORT)
66
+ url = f"http://{host}:{port}/image_to_video"
67
+ files = {"image": (image, get_bytes_value(image))}
68
+ response = requests.post(url, files=files)
69
+ return response.content
70
+
71
+
72
+ def text_to_video(prompt: str, **kwargs):
73
+ host = kwargs.get("host", HOST)
74
+ port = kwargs.get("port", PORT)
75
+ url = f"http://{host}:{port}/text_to_video"
76
+ data = {"prompt": prompt}
77
+ response = requests.post(url, data=data)
78
+ return response.content
79
+
80
+
81
+ def video_to_webpage(
82
+ video: str | Path,
83
+ title: str,
84
+ tags: list[str],
85
+ description: str,
86
+ **kwargs,
87
+ ):
88
+ host = kwargs.get("host", HOST)
89
+ port = kwargs.get("port", PORT)
90
+ url = f"http://{host}:{port}/video_to_webpage"
91
+
92
+ files = {"video": (video, get_bytes_value(video))}
93
+ data = {
94
+ "title": title,
95
+ "tags": tags,
96
+ "description": description,
97
+ }
98
+ response = requests.post(url, files=files, data=data)
99
+ return response.json()
100
+
101
+
102
+ def dub_video(video: str | Path | bytes, audio: str | Path | bytes, **kwargs):
103
+ root_dir = kwargs["root_dir"]
104
+ vid_file_location = osp.join(root_dir, video)
105
+ aud_file_location = osp.join(root_dir, audio)
106
+ video = mpe.VideoFileClip(vid_file_location)
107
+
108
+ # read audio file
109
+ audio = mpe.AudioFileClip(aud_file_location)
110
+
111
+ # set audio for video
112
+ new_video = video.set_audio(audio)
113
+
114
+ # export the video file
115
+ save_path = osp.join(root_dir, f"new_{str(uuid.uuid4())[:6]}.mp4")
116
+ new_video.write_videofile(save_path)
117
+ return open(save_path, "rb").read()
118
+
119
+
120
+ def decoding_key_frames(video: str | Path | bytes, **kwargs):
121
+ video = io.BytesIO(get_bytes_value(video))
122
+ container = av.open(video)
123
+ # extract evenly spaced frames from video
124
+ seg_len = container.streams.video[0].frames
125
+ indices = set(np.linspace(0, seg_len, num=4, endpoint=False).astype(np.int64))
126
+ frames = []
127
+ container.seek(0)
128
+ for i, frame in enumerate(container.decode(video=0)):
129
+ if i in indices:
130
+ stream = io.BytesIO()
131
+ # frame = frame.to_image().save(f"frame_{i}.png")
132
+ frame = frame.to_image().save(stream)
133
+ frames.append(frame)
134
+
135
+ return frames
cllm/services/vqa/__init__.py ADDED
File without changes
cllm/services/vqa/api.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ from pathlib import Path
4
+ import requests
5
+ from PIL import Image
6
+ from cllm.services.utils import get_bytes_value
7
+
8
+ __ALL__ = ["vqa_blip"]
9
+
10
+
11
+ HOST = os.environ.get("CLLM_SERVICES_HOST", "localhost")
12
+ PORT = os.environ.get("CLLM_SERVICES_PORT", 10056)
13
+
14
+
15
+ def setup(host="localhost", port=10049):
16
+ global HOST, PORT
17
+ HOST = host
18
+ PORT = port
19
+
20
+
21
+ def image_qa(image, text, endpoint="llava", **kwargs):
22
+ host = kwargs.get("host", HOST)
23
+ port = kwargs.get("port", PORT)
24
+ url = f"http://{host}:{port}/{endpoint}"
25
+ files = {"image": (image, get_bytes_value(image))}
26
+ data = {"text": text}
27
+ response = requests.post(url, files=files, data=data)
28
+ return response.json()
cllm/utils.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import functools
3
+ import signal
4
+ from pathlib import Path
5
+
6
+ RESOURCE_ROOT = os.environ.get("RESOURCE_ROOT", "./client_resources")
7
+
8
+
9
+ def get_real_path(path):
10
+ if path is None:
11
+ return None
12
+ if RESOURCE_ROOT in path:
13
+ return path
14
+ return os.path.join(RESOURCE_ROOT, path)
15
+
16
+
17
+ def get_root_dir():
18
+ return RESOURCE_ROOT
19
+
20
+
21
+ def md2plain(md):
22
+ plain_text = md.replace("&nbsp;", " ")
23
+ plain_text = plain_text.replace("<br>", "\n")
24
+ plain_text = plain_text.replace("\<", "<")
25
+ plain_text = plain_text.replace("\>", ">")
26
+ return plain_text
27
+
28
+
29
+ def plain2md(plain_text: str):
30
+ md_text = plain_text.replace("<", "\<")
31
+ md_text = md_text.replace(">", "\>")
32
+ md_text = md_text.replace("\n", "<br>")
33
+ # md_text = md_text + "<br>"
34
+ md_text = md_text.replace(" ", "&nbsp;")
35
+ return md_text
36
+
37
+
38
+ def transform_msgs(history_msgs: list = []):
39
+ if history_msgs is None:
40
+ return []
41
+ filtered_msg = []
42
+ for item in history_msgs:
43
+ if isinstance(item[0], str):
44
+ item[0] = md2plain(item[0])
45
+ if isinstance(item[1], str):
46
+ item[1] = md2plain(item[1])
47
+ if isinstance(item[1], str) and item[1].startswith(
48
+ "The whole process will take some time, please be patient."
49
+ ):
50
+ item[1] = None
51
+
52
+ filtered_msg.append(item)
53
+ return filtered_msg
54
+
55
+
56
+ def timeout(sec):
57
+ """
58
+ timeout decorator
59
+ :param sec: function raise TimeoutError after ? seconds
60
+ """
61
+
62
+ def decorator(func):
63
+ @functools.wraps(func)
64
+ def wrapped_func(*args, **kwargs):
65
+ def _handle_timeout(signum, frame):
66
+ err_msg = f"Function {func.__name__} timed out after {sec} seconds"
67
+ raise TimeoutError(err_msg)
68
+
69
+ signal.signal(signal.SIGALRM, _handle_timeout)
70
+ signal.alarm(sec)
71
+ try:
72
+ result = func(*args, **kwargs)
73
+ finally:
74
+ signal.alarm(0)
75
+ return result
76
+
77
+ return wrapped_func
78
+
79
+ return decorator
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ av==10.0.0
2
+ torch==2.0.1 --index-url https://download.pytorch.org/whl/cu118
3
+ torchvision==0.15.2 --index-url https://download.pytorch.org/whl/cu118
4
+ torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118
5
+ openai==1.3.7
6
+ openai-whisper==20230918
7
+ fire==0.5.0
8
+ fastapi==0.104.
9
+ numpy==1.25.2
10
+ pillow==10.0.1
11
+ langchain==0.0.348
12
+ transformers==4.34.1
13
+ moviepy==1.0.3
14
+