Spaces:
Runtime error
Runtime error
camille-vanhoffelen
commited on
Commit
•
b3d3593
1
Parent(s):
1ce354a
First working gradio app for langchain-HuggingGPT
Browse files- .gitignore +181 -0
- LICENSE +21 -0
- app.py +222 -0
- hugginggpt/__init__.py +4 -0
- hugginggpt/exceptions.py +55 -0
- hugginggpt/history.py +29 -0
- hugginggpt/huggingface_api.py +13 -0
- hugginggpt/llm_factory.py +82 -0
- hugginggpt/log.py +10 -0
- hugginggpt/model_inference.py +410 -0
- hugginggpt/model_scraper.py +90 -0
- hugginggpt/model_selection.py +97 -0
- hugginggpt/resources.py +104 -0
- hugginggpt/response_generation.py +43 -0
- hugginggpt/task_parsing.py +149 -0
- hugginggpt/task_planning.py +61 -0
- logging-config.toml +26 -0
- logs/.gitkeep +0 -0
- main.py +138 -0
- output/.gitkeep +0 -0
- output/audios/.gitkeep +0 -0
- output/images/.gitkeep +0 -0
- output/videos/.gitkeep +0 -0
- pdm.lock +0 -0
- pyproject.toml +50 -0
- requirements.txt +0 -0
- resources/banner.txt +4 -0
- resources/huggingface-models-metadata.jsonl +0 -0
- resources/prompt-templates/model-selection-prompt.json +9 -0
- resources/prompt-templates/openai-model-inference-prompt.json +9 -0
- resources/prompt-templates/response-generation-prompt.json +8 -0
- resources/prompt-templates/task-planning-example-prompt.json +8 -0
- resources/prompt-templates/task-planning-examples.json +42 -0
- resources/prompt-templates/task-planning-few-shot-prompt.json +11 -0
.gitignore
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/#use-with-ide
|
110 |
+
.pdm.toml
|
111 |
+
|
112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
113 |
+
__pypackages__/
|
114 |
+
|
115 |
+
# Celery stuff
|
116 |
+
celerybeat-schedule
|
117 |
+
celerybeat.pid
|
118 |
+
|
119 |
+
# SageMath parsed files
|
120 |
+
*.sage.py
|
121 |
+
|
122 |
+
# Environments
|
123 |
+
.env
|
124 |
+
.venv
|
125 |
+
env/
|
126 |
+
venv/
|
127 |
+
ENV/
|
128 |
+
env.bak/
|
129 |
+
venv.bak/
|
130 |
+
|
131 |
+
# Spyder project settings
|
132 |
+
.spyderproject
|
133 |
+
.spyproject
|
134 |
+
|
135 |
+
# Rope project settings
|
136 |
+
.ropeproject
|
137 |
+
|
138 |
+
# mkdocs documentation
|
139 |
+
/site
|
140 |
+
|
141 |
+
# mypy
|
142 |
+
.mypy_cache/
|
143 |
+
.dmypy.json
|
144 |
+
dmypy.json
|
145 |
+
|
146 |
+
# Pyre type checker
|
147 |
+
.pyre/
|
148 |
+
|
149 |
+
# pytype static type analyzer
|
150 |
+
.pytype/
|
151 |
+
|
152 |
+
# Cython debug symbols
|
153 |
+
cython_debug/
|
154 |
+
|
155 |
+
# PyCharm
|
156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
157 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
158 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
+
.idea/
|
161 |
+
|
162 |
+
# Logs
|
163 |
+
*.log
|
164 |
+
*.log*
|
165 |
+
|
166 |
+
# Outputs
|
167 |
+
output/images/*
|
168 |
+
!output/images/.gitkeep
|
169 |
+
output/videos/*
|
170 |
+
!output/videos/.gitkeep
|
171 |
+
output/audios/*
|
172 |
+
!output/audios/.gitkeep
|
173 |
+
|
174 |
+
# PDM
|
175 |
+
.pdm-python
|
176 |
+
|
177 |
+
# macos
|
178 |
+
*.DS_Store
|
179 |
+
|
180 |
+
# Examples
|
181 |
+
!.env.example
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 Camille Van Hoffelen
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
app.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
|
5 |
+
import gradio as gr
|
6 |
+
from dotenv import load_dotenv
|
7 |
+
|
8 |
+
from hugginggpt.history import ConversationHistory
|
9 |
+
from hugginggpt.llm_factory import create_llms
|
10 |
+
from hugginggpt.log import setup_logging
|
11 |
+
from hugginggpt.resources import (
|
12 |
+
GENERATED_RESOURCES_DIR,
|
13 |
+
get_resource_url,
|
14 |
+
init_resource_dirs,
|
15 |
+
load_audio,
|
16 |
+
load_image,
|
17 |
+
save_audio,
|
18 |
+
save_image,
|
19 |
+
)
|
20 |
+
from main import compute
|
21 |
+
|
22 |
+
load_dotenv()
|
23 |
+
setup_logging()
|
24 |
+
logger = logging.getLogger(__name__)
|
25 |
+
init_resource_dirs()
|
26 |
+
|
27 |
+
OPENAI_KEY = os.environ.get("OPENAI_API_KEY")
|
28 |
+
HUGGINGFACE_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
|
29 |
+
|
30 |
+
|
31 |
+
class Client:
|
32 |
+
def __init__(self) -> None:
|
33 |
+
self.llms = None
|
34 |
+
self.llm_history = ConversationHistory()
|
35 |
+
self.last_user_input = ""
|
36 |
+
|
37 |
+
@property
|
38 |
+
def is_init(self) -> bool:
|
39 |
+
return (
|
40 |
+
os.environ.get("OPENAI_API_KEY")
|
41 |
+
and os.environ.get("OPENAI_API_KEY").startswith("sk-")
|
42 |
+
and os.environ.get("HUGGINGFACEHUB_API_TOKEN")
|
43 |
+
and os.environ.get("HUGGINGFACEHUB_API_TOKEN").startswith("hf_")
|
44 |
+
)
|
45 |
+
|
46 |
+
def add_text(self, user_input, messages):
|
47 |
+
if not self.is_init:
|
48 |
+
return (
|
49 |
+
"Please set your OpenAI API key and Hugging Face token first!!!",
|
50 |
+
messages,
|
51 |
+
)
|
52 |
+
if not self.llms:
|
53 |
+
self.llms = create_llms()
|
54 |
+
|
55 |
+
messages = display_message(
|
56 |
+
role="user", message=user_input, messages=messages, save_media=True
|
57 |
+
)
|
58 |
+
self.last_user_input = user_input
|
59 |
+
return "", messages
|
60 |
+
|
61 |
+
def bot(self, messages):
|
62 |
+
if not self.is_init:
|
63 |
+
return {}, messages
|
64 |
+
user_input = self.last_user_input
|
65 |
+
response, task_summaries = compute(
|
66 |
+
user_input=user_input,
|
67 |
+
history=self.llm_history,
|
68 |
+
llms=self.llms,
|
69 |
+
)
|
70 |
+
messages = display_message(
|
71 |
+
role="assistant", message=response, messages=messages, save_media=False
|
72 |
+
)
|
73 |
+
self.llm_history.add(role="user", content=user_input)
|
74 |
+
self.llm_history.add(role="assistant", content="")
|
75 |
+
return task_summaries, messages
|
76 |
+
|
77 |
+
|
78 |
+
css = ".json {height: 527px; overflow: scroll;} .json-holder {height: 527px; overflow: scroll;}"
|
79 |
+
with gr.Blocks(css=css) as demo:
|
80 |
+
gr.Markdown("<h1><center>langchain-HuggingGPT</center></h1>")
|
81 |
+
gr.Markdown(
|
82 |
+
"<p align='center'><img src='https://i.ibb.co/qNH3Jym/logo.png' height='25' width='95'></p>"
|
83 |
+
)
|
84 |
+
gr.Markdown(
|
85 |
+
"<p align='center' style='font-size: 20px;'>A lightweight implementation of <a href='https://arxiv.org/abs/2303.17580'>HuggingGPT</a> with <a href='https://docs.langchain.com/docs/'>langchain</a>. No local inference, only models available on the <a href='https://huggingface.co/inference-api'>Hugging Face Inference API</a> are used.</p>"
|
86 |
+
)
|
87 |
+
gr.HTML(
|
88 |
+
"""<center><a href="https://huggingface.co/spaces/camillevanhoffelen/langchain-HuggingGPT?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>Duplicate the Space and run securely with your OpenAI API Key and Hugging Face Token</center>"""
|
89 |
+
)
|
90 |
+
if not OPENAI_KEY:
|
91 |
+
with gr.Row().style():
|
92 |
+
with gr.Column(scale=0.85):
|
93 |
+
openai_api_key = gr.Textbox(
|
94 |
+
show_label=False,
|
95 |
+
placeholder="Set your OpenAI API key here and press Enter",
|
96 |
+
lines=1,
|
97 |
+
type="password",
|
98 |
+
).style(container=False)
|
99 |
+
with gr.Column(scale=0.15, min_width=0):
|
100 |
+
btn1 = gr.Button("Submit").style(full_height=True)
|
101 |
+
|
102 |
+
if not HUGGINGFACE_TOKEN:
|
103 |
+
with gr.Row().style():
|
104 |
+
with gr.Column(scale=0.85):
|
105 |
+
hugging_face_token = gr.Textbox(
|
106 |
+
show_label=False,
|
107 |
+
placeholder="Set your Hugging Face Token here and press Enter",
|
108 |
+
lines=1,
|
109 |
+
type="password",
|
110 |
+
).style(container=False)
|
111 |
+
with gr.Column(scale=0.15, min_width=0):
|
112 |
+
btn3 = gr.Button("Submit").style(full_height=True)
|
113 |
+
|
114 |
+
with gr.Row().style():
|
115 |
+
with gr.Column(scale=0.6):
|
116 |
+
chatbot = gr.Chatbot([], elem_id="chatbot").style(height=500)
|
117 |
+
with gr.Column(scale=0.4):
|
118 |
+
results = gr.JSON(elem_classes="json")
|
119 |
+
|
120 |
+
with gr.Row().style():
|
121 |
+
with gr.Column(scale=0.85):
|
122 |
+
txt = gr.Textbox(
|
123 |
+
show_label=False,
|
124 |
+
placeholder="Enter text and press enter. The url must contain the media type. e.g, https://example.com/example.jpg",
|
125 |
+
lines=1,
|
126 |
+
).style(container=False)
|
127 |
+
with gr.Column(scale=0.15, min_width=0):
|
128 |
+
btn2 = gr.Button("Send").style(full_height=True)
|
129 |
+
|
130 |
+
def set_key(openai_api_key):
|
131 |
+
os.environ["OPENAI_API_KEY"] = openai_api_key
|
132 |
+
return openai_api_key
|
133 |
+
|
134 |
+
def set_token(hugging_face_token):
|
135 |
+
os.environ["HUGGINGFACEHUB_API_TOKEN"] = hugging_face_token
|
136 |
+
return hugging_face_token
|
137 |
+
|
138 |
+
def add_text(state, user_input, messages):
|
139 |
+
return state["client"].add_text(user_input, messages)
|
140 |
+
|
141 |
+
def bot(state, messages):
|
142 |
+
return state["client"].bot(messages)
|
143 |
+
|
144 |
+
if not OPENAI_KEY or not HUGGINGFACE_TOKEN:
|
145 |
+
openai_api_key.submit(set_key, [openai_api_key], [openai_api_key])
|
146 |
+
btn1.click(set_key, [openai_api_key], [openai_api_key])
|
147 |
+
hugging_face_token.submit(set_token, [hugging_face_token], [hugging_face_token])
|
148 |
+
btn3.click(set_token, [hugging_face_token], [hugging_face_token])
|
149 |
+
|
150 |
+
state = gr.State(value={"client": Client()})
|
151 |
+
|
152 |
+
txt.submit(add_text, [state, txt, chatbot], [txt, chatbot]).then(
|
153 |
+
bot, [state, chatbot], [results, chatbot]
|
154 |
+
)
|
155 |
+
btn2.click(add_text, [state, txt, chatbot], [txt, chatbot]).then(
|
156 |
+
bot, [state, chatbot], [results, chatbot]
|
157 |
+
)
|
158 |
+
|
159 |
+
gr.Examples(
|
160 |
+
examples=[
|
161 |
+
"Draw me a sheep",
|
162 |
+
"Write a poem about sheep, then read it to me",
|
163 |
+
"Transcribe the audio file found at /audios/499e.flac. Then tell me how similar the transcription is to the following sentence: Sheep are nice.",
|
164 |
+
"Show me a joke and an image of sheep",
|
165 |
+
],
|
166 |
+
inputs=txt,
|
167 |
+
)
|
168 |
+
|
169 |
+
|
170 |
+
def display_message(role: str, message: str, messages: list, save_media: bool):
|
171 |
+
# Text
|
172 |
+
messages.append(format_message(role=role, message=message))
|
173 |
+
|
174 |
+
# Media
|
175 |
+
image_urls, audio_urls = extract_medias(message)
|
176 |
+
for image_url in image_urls:
|
177 |
+
image_url = get_resource_url(image_url)
|
178 |
+
if save_media:
|
179 |
+
image = load_image(image_url)
|
180 |
+
image_url = save_image(image)
|
181 |
+
image_url = GENERATED_RESOURCES_DIR + image_url
|
182 |
+
messages.append(format_message(role=role, message=(image_url,)))
|
183 |
+
|
184 |
+
for audio_url in audio_urls:
|
185 |
+
audio_url = get_resource_url(audio_url)
|
186 |
+
if save_media:
|
187 |
+
audio = load_audio(audio_url)
|
188 |
+
audio_url = save_audio(audio)
|
189 |
+
audio_url = GENERATED_RESOURCES_DIR + audio_url
|
190 |
+
messages.append(format_message(role=role, message=(audio_url,)))
|
191 |
+
|
192 |
+
return messages
|
193 |
+
|
194 |
+
|
195 |
+
def format_message(role, message):
|
196 |
+
if role == "user":
|
197 |
+
return message, None
|
198 |
+
if role == "assistant":
|
199 |
+
return None, message
|
200 |
+
else:
|
201 |
+
raise ValueError("role must be either user or assistant")
|
202 |
+
|
203 |
+
|
204 |
+
def extract_medias(message: str):
|
205 |
+
image_pattern = re.compile(
|
206 |
+
r"(http(s?):|\/)?([\.\/_\w:-])*?\.(jpg|jpeg|tiff|gif|png)"
|
207 |
+
)
|
208 |
+
image_urls = []
|
209 |
+
for match in image_pattern.finditer(message):
|
210 |
+
if match.group(0) not in image_urls:
|
211 |
+
image_urls.append(match.group(0))
|
212 |
+
|
213 |
+
audio_pattern = re.compile(r"(http(s?):|\/)?([\.\/_\w:-])*?\.(flac|wav)")
|
214 |
+
audio_urls = []
|
215 |
+
for match in audio_pattern.finditer(message):
|
216 |
+
if match.group(0) not in audio_urls:
|
217 |
+
audio_urls.append(match.group(0))
|
218 |
+
|
219 |
+
return image_urls, audio_urls
|
220 |
+
|
221 |
+
|
222 |
+
demo.launch()
|
hugginggpt/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .model_inference import infer
|
2 |
+
from .model_selection import select_model
|
3 |
+
from .response_generation import generate_response
|
4 |
+
from .task_planning import plan_tasks
|
hugginggpt/exceptions.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
|
3 |
+
|
4 |
+
def wrap_exceptions(exception_cls, message=None):
|
5 |
+
"""Wrap exceptions raised by a function with a custom exception class."""
|
6 |
+
def decorated(f):
|
7 |
+
@functools.wraps(f)
|
8 |
+
def wrapped(*args, **kwargs):
|
9 |
+
try:
|
10 |
+
return f(*args, **kwargs)
|
11 |
+
except Exception as e:
|
12 |
+
raise exception_cls(message) from e
|
13 |
+
|
14 |
+
return wrapped
|
15 |
+
|
16 |
+
return decorated
|
17 |
+
|
18 |
+
|
19 |
+
def async_wrap_exceptions(exception_cls, message=None):
|
20 |
+
"""Wrap exceptions raised by an async function with a custom exception class."""
|
21 |
+
def decorated(f):
|
22 |
+
@functools.wraps(f)
|
23 |
+
async def wrapped(*args, **kwargs):
|
24 |
+
try:
|
25 |
+
return await f(*args, **kwargs)
|
26 |
+
except Exception as e:
|
27 |
+
raise exception_cls(message) from e
|
28 |
+
|
29 |
+
return wrapped
|
30 |
+
|
31 |
+
return decorated
|
32 |
+
|
33 |
+
|
34 |
+
class TaskPlanningException(Exception):
|
35 |
+
pass
|
36 |
+
|
37 |
+
|
38 |
+
class TaskParsingException(Exception):
|
39 |
+
pass
|
40 |
+
|
41 |
+
|
42 |
+
class ModelScrapingException(Exception):
|
43 |
+
pass
|
44 |
+
|
45 |
+
|
46 |
+
class ModelSelectionException(Exception):
|
47 |
+
pass
|
48 |
+
|
49 |
+
|
50 |
+
class ModelInferenceException(Exception):
|
51 |
+
pass
|
52 |
+
|
53 |
+
|
54 |
+
class ResponseGenerationException(Exception):
|
55 |
+
pass
|
hugginggpt/history.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
|
4 |
+
class ConversationHistory:
|
5 |
+
"""Stores previous user and assistant messages. Used as additional context for task planning."""
|
6 |
+
def __init__(self):
|
7 |
+
self.history = []
|
8 |
+
|
9 |
+
def add(self, role: str, content: str):
|
10 |
+
self.history.append({"role": role, "content": content})
|
11 |
+
|
12 |
+
def __str__(self):
|
13 |
+
return json.dumps(self.history)
|
14 |
+
|
15 |
+
def __repr__(self):
|
16 |
+
return str(self)
|
17 |
+
|
18 |
+
def __len__(self):
|
19 |
+
return len(self.history)
|
20 |
+
|
21 |
+
def __getitem__(self, item):
|
22 |
+
return self.history[item]
|
23 |
+
|
24 |
+
def __setitem__(self, key, value):
|
25 |
+
self.history[key] = value
|
26 |
+
|
27 |
+
def __delitem__(self, key):
|
28 |
+
del self.history[key]
|
29 |
+
|
hugginggpt/huggingface_api.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from dotenv import load_dotenv
|
4 |
+
|
5 |
+
load_dotenv()
|
6 |
+
|
7 |
+
HUGGINGFACE_INFERENCE_API_URL = "https://api-inference.huggingface.co/models/"
|
8 |
+
HUGGINGFACE_INFERENCE_API_STATUS_URL = f"https://api-inference.huggingface.co/status/"
|
9 |
+
|
10 |
+
|
11 |
+
def get_hf_headers():
|
12 |
+
HUGGINGFACEHUB_API_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
|
13 |
+
return {"Authorization": f"Bearer {HUGGINGFACEHUB_API_TOKEN}"}
|
hugginggpt/llm_factory.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from collections import namedtuple
|
3 |
+
|
4 |
+
import tiktoken
|
5 |
+
from langchain import OpenAI
|
6 |
+
|
7 |
+
LLM_NAME = "text-davinci-003"
|
8 |
+
# Encoding for text-davinci-003
|
9 |
+
ENCODING_NAME = "p50k_base"
|
10 |
+
ENCODING = tiktoken.get_encoding(ENCODING_NAME)
|
11 |
+
# Max input tokens for text-davinci-003
|
12 |
+
LLM_MAX_TOKENS = 4096
|
13 |
+
|
14 |
+
# As specified in huggingGPT paper
|
15 |
+
TASK_PLANNING_LOGIT_BIAS = 0.1
|
16 |
+
MODEL_SELECTION_LOGIT_BIAS = 5
|
17 |
+
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
LLMs = namedtuple(
|
21 |
+
"LLMs",
|
22 |
+
[
|
23 |
+
"task_planning_llm",
|
24 |
+
"model_selection_llm",
|
25 |
+
"model_inference_llm",
|
26 |
+
"response_generation_llm",
|
27 |
+
"output_fixing_llm",
|
28 |
+
],
|
29 |
+
)
|
30 |
+
|
31 |
+
|
32 |
+
def create_llms() -> LLMs:
|
33 |
+
"""Create various LLM agents according to the huggingGPT paper's specifications."""
|
34 |
+
logger.info(f"Creating {LLM_NAME} LLMs")
|
35 |
+
|
36 |
+
task_parsing_highlight_ids = get_token_ids_for_task_parsing()
|
37 |
+
choose_model_highlight_ids = get_token_ids_for_choose_model()
|
38 |
+
|
39 |
+
task_planning_llm = OpenAI(
|
40 |
+
model_name=LLM_NAME,
|
41 |
+
temperature=0,
|
42 |
+
logit_bias={
|
43 |
+
token_id: TASK_PLANNING_LOGIT_BIAS
|
44 |
+
for token_id in task_parsing_highlight_ids
|
45 |
+
},
|
46 |
+
)
|
47 |
+
model_selection_llm = OpenAI(
|
48 |
+
model_name=LLM_NAME,
|
49 |
+
temperature=0,
|
50 |
+
logit_bias={
|
51 |
+
token_id: MODEL_SELECTION_LOGIT_BIAS
|
52 |
+
for token_id in choose_model_highlight_ids
|
53 |
+
},
|
54 |
+
)
|
55 |
+
model_inference_llm = OpenAI(model_name=LLM_NAME, temperature=0)
|
56 |
+
response_generation_llm = OpenAI(model_name=LLM_NAME, temperature=0)
|
57 |
+
output_fixing_llm = OpenAI(model_name=LLM_NAME, temperature=0)
|
58 |
+
return LLMs(
|
59 |
+
task_planning_llm=task_planning_llm,
|
60 |
+
model_selection_llm=model_selection_llm,
|
61 |
+
model_inference_llm=model_inference_llm,
|
62 |
+
response_generation_llm=response_generation_llm,
|
63 |
+
output_fixing_llm=output_fixing_llm,
|
64 |
+
)
|
65 |
+
|
66 |
+
|
67 |
+
def get_token_ids_for_task_parsing() -> list[int]:
|
68 |
+
text = """{"task": "text-classification", "token-classification", "text2text-generation", "summarization", "translation", "question-answering", "conversational", "text-generation", "sentence-similarity", "tabular-classification", "object-detection", "image-classification", "image-to-image", "image-to-text", "text-to-image", "visual-question-answering", "document-question-answering", "image-segmentation", "text-to-speech", "automatic-speech-recognition", "audio-to-audio", "audio-classification", "args", "text", "path", "dep", "id", "<GENERATED>-"}"""
|
69 |
+
res = ENCODING.encode(text)
|
70 |
+
res = list(set(res))
|
71 |
+
return res
|
72 |
+
|
73 |
+
|
74 |
+
def get_token_ids_for_choose_model() -> list[int]:
|
75 |
+
text = """{"id": "reason"}"""
|
76 |
+
res = ENCODING.encode(text)
|
77 |
+
res = list(set(res))
|
78 |
+
return res
|
79 |
+
|
80 |
+
|
81 |
+
def count_tokens(text: str) -> int:
|
82 |
+
return len(ENCODING.encode(text))
|
hugginggpt/log.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging.config
|
2 |
+
import tomllib
|
3 |
+
|
4 |
+
LOGGING_CONFIG_FILE = "logging-config.toml"
|
5 |
+
|
6 |
+
|
7 |
+
def setup_logging():
|
8 |
+
with open("logging-config.toml", "rb") as f:
|
9 |
+
config = tomllib.load(f)
|
10 |
+
logging.config.dictConfig(config)
|
hugginggpt/model_inference.py
ADDED
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import json
|
3 |
+
import logging
|
4 |
+
import random
|
5 |
+
from io import BytesIO
|
6 |
+
from typing import Any
|
7 |
+
|
8 |
+
import requests
|
9 |
+
from PIL import Image, ImageDraw
|
10 |
+
from langchain import LLMChain
|
11 |
+
from langchain.llms.base import BaseLLM
|
12 |
+
from langchain.prompts import load_prompt
|
13 |
+
from pydantic import BaseModel, Json
|
14 |
+
|
15 |
+
from hugginggpt.exceptions import ModelInferenceException, wrap_exceptions
|
16 |
+
from hugginggpt.huggingface_api import (HUGGINGFACE_INFERENCE_API_URL, get_hf_headers)
|
17 |
+
from hugginggpt.model_selection import Model
|
18 |
+
from hugginggpt.resources import (
|
19 |
+
audio_from_bytes,
|
20 |
+
encode_audio,
|
21 |
+
encode_image,
|
22 |
+
get_prompt_resource,
|
23 |
+
get_resource_url,
|
24 |
+
image_from_bytes,
|
25 |
+
load_image,
|
26 |
+
save_audio,
|
27 |
+
save_image,
|
28 |
+
)
|
29 |
+
from hugginggpt.task_parsing import Task
|
30 |
+
|
31 |
+
logger = logging.getLogger(__name__)
|
32 |
+
|
33 |
+
|
34 |
+
@wrap_exceptions(ModelInferenceException, "Error during model inference")
|
35 |
+
def infer(task: Task, model_id: str, llm: BaseLLM, session: requests.Session):
|
36 |
+
"""Execute a task either with LLM or huggingface inference API."""
|
37 |
+
if model_id == "openai":
|
38 |
+
return infer_openai(task=task, llm=llm)
|
39 |
+
else:
|
40 |
+
return infer_huggingface(task=task, model_id=model_id, session=session)
|
41 |
+
|
42 |
+
|
43 |
+
def infer_openai(task: Task, llm: BaseLLM):
|
44 |
+
logger.info("Starting OpenAI inference")
|
45 |
+
prompt_template = load_prompt(
|
46 |
+
get_prompt_resource("openai-model-inference-prompt.json")
|
47 |
+
)
|
48 |
+
llm_chain = LLMChain(prompt=prompt_template, llm=llm)
|
49 |
+
# Need to replace double quotes with single quotes for correct response generation
|
50 |
+
output = llm_chain.predict(
|
51 |
+
task=task.json(), task_name=task.task, args=task.args, stop=["<im_end>"]
|
52 |
+
)
|
53 |
+
result = {"generated text": output}
|
54 |
+
logger.debug(f"Inference result: {result}")
|
55 |
+
return result
|
56 |
+
|
57 |
+
|
58 |
+
def infer_huggingface(task: Task, model_id: str, session: requests.Session):
|
59 |
+
logger.info("Starting huggingface inference")
|
60 |
+
url = HUGGINGFACE_INFERENCE_API_URL + model_id
|
61 |
+
huggingface_task = create_huggingface_task(task=task)
|
62 |
+
data = huggingface_task.inference_inputs
|
63 |
+
headers = get_hf_headers()
|
64 |
+
response = session.post(url, headers=headers, data=data)
|
65 |
+
response.raise_for_status()
|
66 |
+
result = huggingface_task.parse_response(response)
|
67 |
+
logger.debug(f"Inference result: {result}")
|
68 |
+
return result
|
69 |
+
|
70 |
+
|
71 |
+
# NLP Tasks
|
72 |
+
|
73 |
+
|
74 |
+
# deepset/roberta-base-squad2 was removed from huggingface_models-metadata.jsonl because it is currently broken
|
75 |
+
# Example added to task-planning-examples.json compared to original paper
|
76 |
+
class QuestionAnswering:
|
77 |
+
def __init__(self, task: Task):
|
78 |
+
self.task = task
|
79 |
+
|
80 |
+
@property
|
81 |
+
def inference_inputs(self):
|
82 |
+
data = {
|
83 |
+
"inputs": {
|
84 |
+
"question": self.task.args["question"],
|
85 |
+
"context": self.task.args["context"]
|
86 |
+
if "context" in self.task.args
|
87 |
+
else "",
|
88 |
+
}
|
89 |
+
}
|
90 |
+
return json.dumps(data)
|
91 |
+
|
92 |
+
def parse_response(self, response):
|
93 |
+
return response.json()
|
94 |
+
|
95 |
+
|
96 |
+
# Example added to task-planning-examples.json compared to original paper
|
97 |
+
class SentenceSimilarity:
|
98 |
+
def __init__(self, task: Task):
|
99 |
+
self.task = task
|
100 |
+
|
101 |
+
@property
|
102 |
+
def inference_inputs(self):
|
103 |
+
data = {
|
104 |
+
"inputs": {
|
105 |
+
"source_sentence": self.task.args["text1"],
|
106 |
+
"sentences": [self.task.args["text2"]],
|
107 |
+
}
|
108 |
+
}
|
109 |
+
# Using string to bypass requests' form encoding
|
110 |
+
return json.dumps(data)
|
111 |
+
|
112 |
+
def parse_response(self, response):
|
113 |
+
return response.json()
|
114 |
+
|
115 |
+
|
116 |
+
# Example added to task-planning-examples.json compared to original paper
|
117 |
+
class TextClassification:
|
118 |
+
def __init__(self, task: Task):
|
119 |
+
self.task = task
|
120 |
+
|
121 |
+
@property
|
122 |
+
def inference_inputs(self):
|
123 |
+
return self.task.args["text"]
|
124 |
+
# return {"inputs": self.task.args["text"]}
|
125 |
+
|
126 |
+
def parse_response(self, response):
|
127 |
+
return response.json()
|
128 |
+
|
129 |
+
|
130 |
+
class TokenClassification:
|
131 |
+
def __init__(self, task: Task):
|
132 |
+
self.task = task
|
133 |
+
|
134 |
+
@property
|
135 |
+
def inference_inputs(self):
|
136 |
+
return self.task.args["text"]
|
137 |
+
|
138 |
+
def parse_response(self, response):
|
139 |
+
return response.json()
|
140 |
+
|
141 |
+
|
142 |
+
# CV Tasks
|
143 |
+
class VisualQuestionAnswering:
|
144 |
+
def __init__(self, task: Task):
|
145 |
+
self.task = task
|
146 |
+
|
147 |
+
@property
|
148 |
+
def inference_inputs(self):
|
149 |
+
img_data = encode_image(self.task.args["image"])
|
150 |
+
img_base64 = base64.b64encode(img_data).decode("utf-8")
|
151 |
+
data = {
|
152 |
+
"inputs": {
|
153 |
+
"question": self.task.args["text"],
|
154 |
+
"image": img_base64,
|
155 |
+
}
|
156 |
+
}
|
157 |
+
return json.dumps(data)
|
158 |
+
|
159 |
+
def parse_response(self, response):
|
160 |
+
return response.json()
|
161 |
+
|
162 |
+
|
163 |
+
class DocumentQuestionAnswering:
|
164 |
+
def __init__(self, task: Task):
|
165 |
+
self.task = task
|
166 |
+
|
167 |
+
@property
|
168 |
+
def inference_inputs(self):
|
169 |
+
img_data = encode_image(self.task.args["image"])
|
170 |
+
img_base64 = base64.b64encode(img_data).decode("utf-8")
|
171 |
+
data = {
|
172 |
+
"inputs": {
|
173 |
+
"question": self.task.args["text"],
|
174 |
+
"image": img_base64,
|
175 |
+
}
|
176 |
+
}
|
177 |
+
return json.dumps(data)
|
178 |
+
|
179 |
+
def parse_response(self, response):
|
180 |
+
return response.json()
|
181 |
+
|
182 |
+
|
183 |
+
class TextToImage:
|
184 |
+
def __init__(self, task: Task):
|
185 |
+
self.task = task
|
186 |
+
|
187 |
+
@property
|
188 |
+
def inference_inputs(self):
|
189 |
+
return self.task.args["text"]
|
190 |
+
|
191 |
+
def parse_response(self, response):
|
192 |
+
image = image_from_bytes(response.content)
|
193 |
+
path = save_image(image)
|
194 |
+
return {"generated image": path}
|
195 |
+
|
196 |
+
|
197 |
+
class ImageSegmentation:
|
198 |
+
def __init__(self, task: Task):
|
199 |
+
self.task = task
|
200 |
+
|
201 |
+
@property
|
202 |
+
def inference_inputs(self):
|
203 |
+
return encode_image(self.task.args["image"])
|
204 |
+
|
205 |
+
def parse_response(self, response):
|
206 |
+
image_url = get_resource_url(self.task.args["image"])
|
207 |
+
image = load_image(image_url)
|
208 |
+
colors = []
|
209 |
+
for i in range(len(response.json())):
|
210 |
+
colors.append(
|
211 |
+
(
|
212 |
+
random.randint(100, 255),
|
213 |
+
random.randint(100, 255),
|
214 |
+
random.randint(100, 255),
|
215 |
+
155,
|
216 |
+
)
|
217 |
+
)
|
218 |
+
predicted_results = []
|
219 |
+
for i, pred in enumerate(response.json()):
|
220 |
+
mask = pred.pop("mask").encode("utf-8")
|
221 |
+
mask = base64.b64decode(mask)
|
222 |
+
mask = Image.open(BytesIO(mask), mode="r")
|
223 |
+
mask = mask.convert("L")
|
224 |
+
|
225 |
+
layer = Image.new("RGBA", mask.size, colors[i])
|
226 |
+
image.paste(layer, (0, 0), mask)
|
227 |
+
predicted_results.append(pred)
|
228 |
+
path = save_image(image)
|
229 |
+
return {
|
230 |
+
"generated image with segmentation mask": path,
|
231 |
+
"predicted": predicted_results,
|
232 |
+
}
|
233 |
+
|
234 |
+
|
235 |
+
# Not yet implemented in huggingface inference API
|
236 |
+
class ImageToImage:
|
237 |
+
def __init__(self, task: Task):
|
238 |
+
self.task = task
|
239 |
+
|
240 |
+
@property
|
241 |
+
def inference_inputs(self):
|
242 |
+
img_data = encode_image(self.task.args["image"])
|
243 |
+
img_base64 = base64.b64encode(img_data).decode("utf-8")
|
244 |
+
data = {
|
245 |
+
"inputs": {
|
246 |
+
"image": img_base64,
|
247 |
+
}
|
248 |
+
}
|
249 |
+
if "text" in self.task.args:
|
250 |
+
data["inputs"]["prompt"] = self.task.args["text"]
|
251 |
+
return json.dumps(data)
|
252 |
+
|
253 |
+
def parse_response(self, response):
|
254 |
+
image = image_from_bytes(response.content)
|
255 |
+
path = save_image(image)
|
256 |
+
return {"generated image": path}
|
257 |
+
|
258 |
+
|
259 |
+
class ObjectDetection:
|
260 |
+
def __init__(self, task: Task):
|
261 |
+
self.task = task
|
262 |
+
|
263 |
+
@property
|
264 |
+
def inference_inputs(self):
|
265 |
+
return encode_image(self.task.args["image"])
|
266 |
+
|
267 |
+
def parse_response(self, response):
|
268 |
+
image_url = get_resource_url(self.task.args["image"])
|
269 |
+
image = load_image(image_url)
|
270 |
+
draw = ImageDraw.Draw(image)
|
271 |
+
labels = list(item["label"] for item in response.json())
|
272 |
+
color_map = {}
|
273 |
+
for label in labels:
|
274 |
+
if label not in color_map:
|
275 |
+
color_map[label] = (
|
276 |
+
random.randint(0, 255),
|
277 |
+
random.randint(0, 100),
|
278 |
+
random.randint(0, 255),
|
279 |
+
)
|
280 |
+
for item in response.json():
|
281 |
+
box = item["box"]
|
282 |
+
draw.rectangle(
|
283 |
+
((box["xmin"], box["ymin"]), (box["xmax"], box["ymax"])),
|
284 |
+
outline=color_map[item["label"]],
|
285 |
+
width=2,
|
286 |
+
)
|
287 |
+
draw.text(
|
288 |
+
(box["xmin"] + 5, box["ymin"] - 15),
|
289 |
+
item["label"],
|
290 |
+
fill=color_map[item["label"]],
|
291 |
+
)
|
292 |
+
path = save_image(image)
|
293 |
+
return {
|
294 |
+
"generated image with predicted box": path,
|
295 |
+
"predicted": response.json(),
|
296 |
+
}
|
297 |
+
|
298 |
+
|
299 |
+
# Example added to task-planning-examples.json compared to original paper
|
300 |
+
class ImageClassification:
|
301 |
+
def __init__(self, task: Task):
|
302 |
+
self.task = task
|
303 |
+
|
304 |
+
@property
|
305 |
+
def inference_inputs(self):
|
306 |
+
return encode_image(self.task.args["image"])
|
307 |
+
|
308 |
+
def parse_response(self, response):
|
309 |
+
return response.json()
|
310 |
+
|
311 |
+
|
312 |
+
class ImageToText:
|
313 |
+
def __init__(self, task: Task):
|
314 |
+
self.task = task
|
315 |
+
|
316 |
+
@property
|
317 |
+
def inference_inputs(self):
|
318 |
+
return encode_image(self.task.args["image"])
|
319 |
+
|
320 |
+
def parse_response(self, response):
|
321 |
+
return {"generated text": response.json()[0].get("generated_text", "")}
|
322 |
+
|
323 |
+
|
324 |
+
# Audio Tasks
|
325 |
+
class TextToSpeech:
|
326 |
+
def __init__(self, task: Task):
|
327 |
+
self.task = task
|
328 |
+
|
329 |
+
@property
|
330 |
+
def inference_inputs(self):
|
331 |
+
return self.task.args["text"]
|
332 |
+
|
333 |
+
def parse_response(self, response):
|
334 |
+
audio = audio_from_bytes(response.content)
|
335 |
+
path = save_audio(audio)
|
336 |
+
return {"generated audio": path}
|
337 |
+
|
338 |
+
|
339 |
+
class AudioToAudio:
|
340 |
+
def __init__(self, task: Task):
|
341 |
+
self.task = task
|
342 |
+
|
343 |
+
@property
|
344 |
+
def inference_inputs(self):
|
345 |
+
return encode_audio(self.task.args["audio"])
|
346 |
+
|
347 |
+
def parse_response(self, response):
|
348 |
+
result = response.json()
|
349 |
+
blob = result[0].items()["blob"]
|
350 |
+
content = base64.b64decode(blob.encode("utf-8"))
|
351 |
+
audio = audio_from_bytes(content)
|
352 |
+
path = save_audio(audio)
|
353 |
+
return {"generated audio": path}
|
354 |
+
|
355 |
+
|
356 |
+
class AutomaticSpeechRecognition:
|
357 |
+
def __init__(self, task: Task):
|
358 |
+
self.task = task
|
359 |
+
|
360 |
+
@property
|
361 |
+
def inference_inputs(self):
|
362 |
+
return encode_audio(self.task.args["audio"])
|
363 |
+
|
364 |
+
def parse_response(self, response):
|
365 |
+
return response.json()
|
366 |
+
|
367 |
+
|
368 |
+
class AudioClassification:
|
369 |
+
def __init__(self, task: Task):
|
370 |
+
self.task = task
|
371 |
+
|
372 |
+
@property
|
373 |
+
def inference_inputs(self):
|
374 |
+
return encode_audio(self.task.args["audio"])
|
375 |
+
|
376 |
+
def parse_response(self, response):
|
377 |
+
return response.json()
|
378 |
+
|
379 |
+
|
380 |
+
HUGGINGFACE_TASKS = {
|
381 |
+
"question-answering": QuestionAnswering,
|
382 |
+
"sentence-similarity": SentenceSimilarity,
|
383 |
+
"text-classification": TextClassification,
|
384 |
+
"token-classification": TokenClassification,
|
385 |
+
"visual-question-answering": VisualQuestionAnswering,
|
386 |
+
"document-question-answering": DocumentQuestionAnswering,
|
387 |
+
"text-to-image": TextToImage,
|
388 |
+
"image-segmentation": ImageSegmentation,
|
389 |
+
"image-to-image": ImageToImage,
|
390 |
+
"object-detection": ObjectDetection,
|
391 |
+
"image-classification": ImageClassification,
|
392 |
+
"image-to-text": ImageToText,
|
393 |
+
"text-to-speech": TextToSpeech,
|
394 |
+
"automatic-speech-recognition": AutomaticSpeechRecognition,
|
395 |
+
"audio-to-audio": AudioToAudio,
|
396 |
+
"audio-classification": AudioClassification,
|
397 |
+
}
|
398 |
+
|
399 |
+
|
400 |
+
def create_huggingface_task(task: Task):
|
401 |
+
if task.task in HUGGINGFACE_TASKS:
|
402 |
+
return HUGGINGFACE_TASKS[task.task](task)
|
403 |
+
else:
|
404 |
+
raise NotImplementedError(f"Task {task.task} not supported")
|
405 |
+
|
406 |
+
|
407 |
+
class TaskSummary(BaseModel):
|
408 |
+
task: Task
|
409 |
+
inference_result: Json[Any]
|
410 |
+
model: Model
|
hugginggpt/model_scraper.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import json
|
3 |
+
import logging
|
4 |
+
from collections import defaultdict
|
5 |
+
|
6 |
+
from aiohttp import ClientSession
|
7 |
+
|
8 |
+
from hugginggpt.exceptions import ModelScrapingException, async_wrap_exceptions
|
9 |
+
from hugginggpt.huggingface_api import (HUGGINGFACE_INFERENCE_API_STATUS_URL, get_hf_headers)
|
10 |
+
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
|
14 |
+
def read_huggingface_models_metadata():
|
15 |
+
"""Reads the metadata of all huggingface models from the local models cache file."""
|
16 |
+
with open("resources/huggingface-models-metadata.jsonl") as f:
|
17 |
+
models = [json.loads(line) for line in f]
|
18 |
+
models_map = defaultdict(list)
|
19 |
+
for model in models:
|
20 |
+
models_map[model["task"]].append(model)
|
21 |
+
return models_map
|
22 |
+
|
23 |
+
|
24 |
+
HUGGINGFACE_MODELS_MAP = read_huggingface_models_metadata()
|
25 |
+
|
26 |
+
|
27 |
+
@async_wrap_exceptions(
|
28 |
+
ModelScrapingException,
|
29 |
+
"Failed to find compatible models already loaded in the huggingface inference API.",
|
30 |
+
)
|
31 |
+
async def get_top_k_models(
|
32 |
+
task: str, top_k: int, max_description_length: int, session: ClientSession
|
33 |
+
):
|
34 |
+
"""Returns the best k available huggingface models for a given task, sorted by number of likes."""
|
35 |
+
# Number of potential candidates changed from top 10 to top_k*2
|
36 |
+
candidates = HUGGINGFACE_MODELS_MAP[task][: top_k * 2]
|
37 |
+
logger.debug(f"Task: {task}; All candidate models: {[c['id'] for c in candidates]}")
|
38 |
+
available_models = await filter_available_models(
|
39 |
+
candidates=candidates, session=session
|
40 |
+
)
|
41 |
+
logger.debug(
|
42 |
+
f"Task: {task}; Available models: {[c['id'] for c in available_models]}"
|
43 |
+
)
|
44 |
+
top_k_available_models = available_models[:top_k]
|
45 |
+
if not top_k_available_models:
|
46 |
+
raise Exception(f"No available models for task: {task}")
|
47 |
+
logger.debug(
|
48 |
+
f"Task: {task}; Top {top_k} available models: {[c['id'] for c in top_k_available_models]}"
|
49 |
+
)
|
50 |
+
top_k_models_info = [
|
51 |
+
{
|
52 |
+
"id": model["id"],
|
53 |
+
"likes": model.get("likes"),
|
54 |
+
"description": model.get("description", "")[:max_description_length],
|
55 |
+
"tags": model.get("meta").get("tags") if model.get("meta") else None,
|
56 |
+
}
|
57 |
+
for model in top_k_available_models
|
58 |
+
]
|
59 |
+
return top_k_models_info
|
60 |
+
|
61 |
+
|
62 |
+
async def filter_available_models(candidates, session: ClientSession):
|
63 |
+
"""Filters out models that are not available or loaded in the huggingface inference API.
|
64 |
+
Runs concurrently."""
|
65 |
+
async with asyncio.TaskGroup() as tg:
|
66 |
+
tasks = [
|
67 |
+
tg.create_task(model_status(model_id=c["id"], session=session))
|
68 |
+
for c in candidates
|
69 |
+
]
|
70 |
+
results = await asyncio.gather(*tasks)
|
71 |
+
available_model_ids = [model_id for model_id, status in results if status]
|
72 |
+
return [c for c in candidates if c["id"] in available_model_ids]
|
73 |
+
|
74 |
+
|
75 |
+
async def model_status(model_id: str, session: ClientSession) -> tuple[str, bool]:
|
76 |
+
url = HUGGINGFACE_INFERENCE_API_STATUS_URL + model_id
|
77 |
+
headers = get_hf_headers()
|
78 |
+
r = await session.get(url, headers=headers)
|
79 |
+
status = r.status
|
80 |
+
json_response = await r.json()
|
81 |
+
logger.debug(f"Model {model_id} status: {status}, response: {json_response}")
|
82 |
+
return (
|
83 |
+
(model_id, True)
|
84 |
+
if model_is_available(status=status, json_response=json_response)
|
85 |
+
else (model_id, False)
|
86 |
+
)
|
87 |
+
|
88 |
+
|
89 |
+
def model_is_available(status: int, json_response: dict[str, any]):
|
90 |
+
return status == 200 and "loaded" in json_response and json_response["loaded"]
|
hugginggpt/model_selection.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import json
|
3 |
+
import logging
|
4 |
+
|
5 |
+
import aiohttp
|
6 |
+
from langchain import LLMChain
|
7 |
+
from langchain.llms.base import BaseLLM
|
8 |
+
from langchain.output_parsers import OutputFixingParser, PydanticOutputParser
|
9 |
+
from langchain.prompts import load_prompt
|
10 |
+
from pydantic import BaseModel, Field
|
11 |
+
|
12 |
+
from hugginggpt.exceptions import ModelSelectionException, async_wrap_exceptions
|
13 |
+
from hugginggpt.model_scraper import get_top_k_models
|
14 |
+
from hugginggpt.resources import get_prompt_resource
|
15 |
+
from hugginggpt.task_parsing import Task
|
16 |
+
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
|
19 |
+
|
20 |
+
class Model(BaseModel):
|
21 |
+
id: str = Field(description="ID of the model")
|
22 |
+
reason: str = Field(description="Reason for selecting this model")
|
23 |
+
|
24 |
+
|
25 |
+
async def select_hf_models(
|
26 |
+
user_input: str,
|
27 |
+
tasks: list[Task],
|
28 |
+
model_selection_llm: BaseLLM,
|
29 |
+
output_fixing_llm: BaseLLM,
|
30 |
+
) -> dict[int, Model]:
|
31 |
+
"""Use LLM agent to select the best available HuggingFace model for each task, given model metadata.
|
32 |
+
Runs concurrently."""
|
33 |
+
async with aiohttp.ClientSession() as session:
|
34 |
+
async with asyncio.TaskGroup() as tg:
|
35 |
+
aio_tasks = []
|
36 |
+
for task in tasks:
|
37 |
+
aio_tasks.append(
|
38 |
+
tg.create_task(
|
39 |
+
select_model(
|
40 |
+
user_input=user_input,
|
41 |
+
task=task,
|
42 |
+
model_selection_llm=model_selection_llm,
|
43 |
+
output_fixing_llm=output_fixing_llm,
|
44 |
+
session=session,
|
45 |
+
)
|
46 |
+
)
|
47 |
+
)
|
48 |
+
results = await asyncio.gather(*aio_tasks)
|
49 |
+
return {task_id: model for task_id, model in results}
|
50 |
+
|
51 |
+
|
52 |
+
@async_wrap_exceptions(ModelSelectionException, "Failed to select model")
|
53 |
+
async def select_model(
|
54 |
+
user_input: str,
|
55 |
+
task: Task,
|
56 |
+
model_selection_llm: BaseLLM,
|
57 |
+
output_fixing_llm: BaseLLM,
|
58 |
+
session: aiohttp.ClientSession,
|
59 |
+
) -> (int, Model):
|
60 |
+
logger.info(f"Starting model selection for task: {task.task}")
|
61 |
+
|
62 |
+
top_k_models = await get_top_k_models(
|
63 |
+
task=task.task, top_k=5, max_description_length=100, session=session
|
64 |
+
)
|
65 |
+
|
66 |
+
if task.task in [
|
67 |
+
"summarization",
|
68 |
+
"translation",
|
69 |
+
"conversational",
|
70 |
+
"text-generation",
|
71 |
+
"text2text-generation",
|
72 |
+
]:
|
73 |
+
model = Model(
|
74 |
+
id="openai",
|
75 |
+
reason="Text generation tasks are best handled by OpenAI models",
|
76 |
+
)
|
77 |
+
else:
|
78 |
+
prompt_template = load_prompt(
|
79 |
+
get_prompt_resource("model-selection-prompt.json")
|
80 |
+
)
|
81 |
+
llm_chain = LLMChain(prompt=prompt_template, llm=model_selection_llm)
|
82 |
+
# Need to replace double quotes with single quotes for correct response generation
|
83 |
+
task_str = task.json().replace('"', "'")
|
84 |
+
models_str = json.dumps(top_k_models).replace('"', "'")
|
85 |
+
output = await llm_chain.apredict(
|
86 |
+
user_input=user_input, task=task_str, models=models_str, stop=["<im_end>"]
|
87 |
+
)
|
88 |
+
logger.debug(f"Model selection raw output: {output}")
|
89 |
+
|
90 |
+
parser = PydanticOutputParser(pydantic_object=Model)
|
91 |
+
fixing_parser = OutputFixingParser.from_llm(
|
92 |
+
parser=parser, llm=output_fixing_llm
|
93 |
+
)
|
94 |
+
model = fixing_parser.parse(output)
|
95 |
+
|
96 |
+
logger.info(f"For task: {task.task}, selected model: {model}")
|
97 |
+
return task.id, model
|
hugginggpt/resources.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import uuid
|
3 |
+
from io import BytesIO
|
4 |
+
|
5 |
+
import requests
|
6 |
+
from PIL import Image
|
7 |
+
from diffusers.utils.testing_utils import load_image
|
8 |
+
from pydub import AudioSegment
|
9 |
+
|
10 |
+
RESOURCES_DIR = "resources"
|
11 |
+
PROMPT_TEMPLATES_DIR = "prompt-templates"
|
12 |
+
GENERATED_RESOURCES_DIR = "output"
|
13 |
+
|
14 |
+
|
15 |
+
def get_prompt_resource(prompt_name: str) -> str:
|
16 |
+
return os.path.join(RESOURCES_DIR, PROMPT_TEMPLATES_DIR, prompt_name)
|
17 |
+
|
18 |
+
|
19 |
+
def get_resource_url(resource_arg: str) -> str:
|
20 |
+
if resource_arg.startswith("http"):
|
21 |
+
return resource_arg
|
22 |
+
else:
|
23 |
+
return GENERATED_RESOURCES_DIR + resource_arg
|
24 |
+
|
25 |
+
|
26 |
+
# Images
|
27 |
+
def image_to_bytes(image: Image) -> bytes:
|
28 |
+
image_byte = BytesIO()
|
29 |
+
image.save(image_byte, format="png")
|
30 |
+
image_data = image_byte.getvalue()
|
31 |
+
return image_data
|
32 |
+
|
33 |
+
|
34 |
+
def image_from_bytes(img_data: bytes) -> Image:
|
35 |
+
return Image.open(BytesIO(img_data))
|
36 |
+
|
37 |
+
|
38 |
+
def encode_image(image_arg: str) -> bytes:
|
39 |
+
image_url = get_resource_url(image_arg)
|
40 |
+
image = load_image(image_url)
|
41 |
+
img_data = image_to_bytes(image)
|
42 |
+
return img_data
|
43 |
+
|
44 |
+
|
45 |
+
def save_image(img: Image) -> str:
|
46 |
+
name = str(uuid.uuid4())[:4]
|
47 |
+
path = f"/images/{name}.png"
|
48 |
+
img.save(GENERATED_RESOURCES_DIR + path)
|
49 |
+
return path
|
50 |
+
|
51 |
+
|
52 |
+
# Audios
|
53 |
+
def load_audio(audio_path: str) -> AudioSegment:
|
54 |
+
if audio_path.startswith("http://") or audio_path.startswith("https://"):
|
55 |
+
audio_data = requests.get(audio_path).content
|
56 |
+
audio = AudioSegment.from_file(BytesIO(audio_data))
|
57 |
+
elif os.path.isfile(audio_path):
|
58 |
+
audio = AudioSegment.from_file(audio_path)
|
59 |
+
else:
|
60 |
+
raise ValueError(
|
61 |
+
f"Incorrect path or url, URLs must start with `http://` or `https://`, and {audio_path} is not a valid path"
|
62 |
+
)
|
63 |
+
return audio
|
64 |
+
|
65 |
+
|
66 |
+
def audio_to_bytes(audio: AudioSegment) -> bytes:
|
67 |
+
audio_byte = BytesIO()
|
68 |
+
audio.export(audio_byte, format="flac")
|
69 |
+
audio_data = audio_byte.getvalue()
|
70 |
+
return audio_data
|
71 |
+
|
72 |
+
|
73 |
+
def audio_from_bytes(audio_data: bytes) -> AudioSegment:
|
74 |
+
return AudioSegment.from_file(BytesIO(audio_data))
|
75 |
+
|
76 |
+
|
77 |
+
def encode_audio(audio_arg: str) -> bytes:
|
78 |
+
audio_url = get_resource_url(audio_arg)
|
79 |
+
audio = load_audio(audio_url)
|
80 |
+
audio_data = audio_to_bytes(audio)
|
81 |
+
return audio_data
|
82 |
+
|
83 |
+
|
84 |
+
def save_audio(audio: AudioSegment) -> str:
|
85 |
+
name = str(uuid.uuid4())[:4]
|
86 |
+
path = f"/audios/{name}.flac"
|
87 |
+
with open(GENERATED_RESOURCES_DIR + path, "wb") as f:
|
88 |
+
audio.export(f, format="flac")
|
89 |
+
return path
|
90 |
+
|
91 |
+
|
92 |
+
def prepend_resource_dir(s: str) -> str:
|
93 |
+
"""Prepend the resource dir to all resource paths in the string"""
|
94 |
+
for resource_type in ["images", "audios", "videos"]:
|
95 |
+
s = s.replace(
|
96 |
+
f" /{resource_type}/", f" {GENERATED_RESOURCES_DIR}/{resource_type}/"
|
97 |
+
)
|
98 |
+
return s
|
99 |
+
|
100 |
+
|
101 |
+
def init_resource_dirs():
|
102 |
+
os.makedirs(GENERATED_RESOURCES_DIR + "/images", exist_ok=True)
|
103 |
+
os.makedirs(GENERATED_RESOURCES_DIR + "/audios", exist_ok=True)
|
104 |
+
os.makedirs(GENERATED_RESOURCES_DIR + "/videos", exist_ok=True)
|
hugginggpt/response_generation.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
|
4 |
+
from langchain import LLMChain
|
5 |
+
from langchain.llms.base import BaseLLM
|
6 |
+
from langchain.prompts import load_prompt
|
7 |
+
|
8 |
+
from hugginggpt.exceptions import ResponseGenerationException, wrap_exceptions
|
9 |
+
from hugginggpt.model_inference import TaskSummary
|
10 |
+
from hugginggpt.resources import get_prompt_resource, prepend_resource_dir
|
11 |
+
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
|
15 |
+
@wrap_exceptions(ResponseGenerationException, "Failed to generate assistant response")
|
16 |
+
def generate_response(
|
17 |
+
user_input: str, task_summaries: list[TaskSummary], llm: BaseLLM
|
18 |
+
) -> str:
|
19 |
+
"""Use LLM agent to generate a response to the user's input, given task results."""
|
20 |
+
logger.info("Starting response generation")
|
21 |
+
sorted_task_summaries = sorted(task_summaries, key=lambda ts: ts.task.id)
|
22 |
+
task_results_str = task_summaries_to_json(sorted_task_summaries)
|
23 |
+
prompt_template = load_prompt(
|
24 |
+
get_prompt_resource("response-generation-prompt.json")
|
25 |
+
)
|
26 |
+
llm_chain = LLMChain(prompt=prompt_template, llm=llm)
|
27 |
+
response = llm_chain.predict(
|
28 |
+
user_input=user_input, task_results=task_results_str, stop=["<im_end>"]
|
29 |
+
)
|
30 |
+
logger.info(f"Generated response: {response}")
|
31 |
+
return response
|
32 |
+
|
33 |
+
|
34 |
+
def format_response(response: str) -> str:
|
35 |
+
"""Format the response to be more readable for user."""
|
36 |
+
response = response.strip()
|
37 |
+
response = prepend_resource_dir(response)
|
38 |
+
return response
|
39 |
+
|
40 |
+
|
41 |
+
def task_summaries_to_json(task_summaries: list[TaskSummary]) -> str:
|
42 |
+
dicts = [ts.dict() for ts in task_summaries]
|
43 |
+
return json.dumps(dicts)
|
hugginggpt/task_parsing.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import logging
|
3 |
+
|
4 |
+
from pydantic import BaseModel, Field
|
5 |
+
|
6 |
+
from hugginggpt.exceptions import TaskParsingException, wrap_exceptions
|
7 |
+
|
8 |
+
logger = logging.getLogger(__name__)
|
9 |
+
|
10 |
+
GENERATED_TOKEN = "<GENERATED>"
|
11 |
+
|
12 |
+
|
13 |
+
class Task(BaseModel):
|
14 |
+
# This field is called 'task' and not 'name' to help with prompt engineering
|
15 |
+
task: str = Field(description="Name of the Machine Learning task")
|
16 |
+
id: int = Field(description="ID of the task")
|
17 |
+
dep: list[int] = Field(
|
18 |
+
description="List of IDs of the tasks that this task depends on"
|
19 |
+
)
|
20 |
+
args: dict[str, str] = Field(description="Arguments for the task")
|
21 |
+
|
22 |
+
def depends_on_generated_resources(self) -> bool:
|
23 |
+
"""Returns True if the task args contains <GENERATED> placeholder tokens, False otherwise"""
|
24 |
+
return self.dep != [-1] and any(
|
25 |
+
GENERATED_TOKEN in v for v in self.args.values()
|
26 |
+
)
|
27 |
+
|
28 |
+
@wrap_exceptions(TaskParsingException, "Failed to replace generated resources")
|
29 |
+
def replace_generated_resources(self, task_summaries: list):
|
30 |
+
"""Replaces <GENERATED> placeholder tokens in args with the generated resources from the task summaries"""
|
31 |
+
logger.info("Replacing generated resources")
|
32 |
+
generated_resources = {
|
33 |
+
k: parse_task_id(v) for k, v in self.args.items() if GENERATED_TOKEN in v
|
34 |
+
}
|
35 |
+
logger.info(
|
36 |
+
f"Resources to replace, resource type -> task id: {generated_resources}"
|
37 |
+
)
|
38 |
+
for resource_type, task_id in generated_resources.items():
|
39 |
+
matches = [
|
40 |
+
v
|
41 |
+
for k, v in task_summaries[task_id].inference_result.items()
|
42 |
+
if self.is_matching_generated_resource(k, resource_type)
|
43 |
+
]
|
44 |
+
if len(matches) == 1:
|
45 |
+
logger.info(
|
46 |
+
f"Match for generated {resource_type} in inference result of task {task_id}"
|
47 |
+
)
|
48 |
+
generated_resource = matches[0]
|
49 |
+
logger.info(f"Replacing {resource_type} with {generated_resource}")
|
50 |
+
self.args[resource_type] = generated_resource
|
51 |
+
return self
|
52 |
+
else:
|
53 |
+
raise Exception(
|
54 |
+
f"Cannot find unique required generated {resource_type} in inference result of task {task_id}"
|
55 |
+
)
|
56 |
+
|
57 |
+
def is_matching_generated_resource(self, arg_key: str, resource_type: str) -> bool:
|
58 |
+
"""Returns True if arg_key contains generated resource of the correct type"""
|
59 |
+
# If text, then match all arg keys that contain "text"
|
60 |
+
if resource_type.startswith("text"):
|
61 |
+
return "text" in arg_key
|
62 |
+
# If not text, then arg key must start with "generated" and the correct resource type
|
63 |
+
else:
|
64 |
+
return arg_key.startswith("generated " + resource_type)
|
65 |
+
|
66 |
+
|
67 |
+
class Tasks(BaseModel):
|
68 |
+
__root__: list[Task] = Field(description="List of Machine Learning tasks")
|
69 |
+
|
70 |
+
def __iter__(self):
|
71 |
+
return iter(self.__root__)
|
72 |
+
|
73 |
+
def __getitem__(self, item):
|
74 |
+
return self.__root__[item]
|
75 |
+
|
76 |
+
def __len__(self):
|
77 |
+
return len(self.__root__)
|
78 |
+
|
79 |
+
|
80 |
+
@wrap_exceptions(TaskParsingException, "Failed to parse tasks")
|
81 |
+
def parse_tasks(tasks_str: str) -> list[Task]:
|
82 |
+
"""Parses tasks from task planning json string"""
|
83 |
+
if tasks_str == "[]":
|
84 |
+
raise ValueError("Task string empty, cannot parse")
|
85 |
+
logger.info(f"Parsing tasks string: {tasks_str}")
|
86 |
+
tasks_str = tasks_str.strip()
|
87 |
+
# Cannot use PydanticOutputParser because it fails when parsing top level list JSON string
|
88 |
+
tasks = Tasks.parse_raw(tasks_str)
|
89 |
+
# __root__ extracts list[Task] from Tasks object
|
90 |
+
tasks = unfold(tasks.__root__)
|
91 |
+
tasks = fix_dependencies(tasks)
|
92 |
+
logger.info(f"Parsed tasks: {tasks}")
|
93 |
+
return tasks
|
94 |
+
|
95 |
+
|
96 |
+
def parse_task_id(resource_str: str) -> int:
|
97 |
+
"""Parse task id from generated resource string, e.g. <GENERATED>-4 -> 4"""
|
98 |
+
return int(resource_str.split("-")[1])
|
99 |
+
|
100 |
+
|
101 |
+
def fix_dependencies(tasks: list[Task]) -> list[Task]:
|
102 |
+
"""Ignores parsed tasks dependencies, and instead infers from task arguments"""
|
103 |
+
for task in tasks:
|
104 |
+
task.dep = infer_deps_from_args(task)
|
105 |
+
return tasks
|
106 |
+
|
107 |
+
|
108 |
+
def infer_deps_from_args(task: Task) -> list[int]:
|
109 |
+
"""If GENERATED arg value, add to list of unique deps. If none, deps = [-1]"""
|
110 |
+
deps = [parse_task_id(v) for v in task.args.values() if GENERATED_TOKEN in v]
|
111 |
+
if not deps:
|
112 |
+
deps = [-1]
|
113 |
+
# deduplicate
|
114 |
+
return list(set(deps))
|
115 |
+
|
116 |
+
|
117 |
+
def unfold(tasks: list[Task]) -> list[Task]:
|
118 |
+
"""A folded task has several generated resources folded into a single argument"""
|
119 |
+
unfolded_tasks = []
|
120 |
+
for task in tasks:
|
121 |
+
folded_args = find_folded_args(task)
|
122 |
+
if folded_args:
|
123 |
+
unfolded_tasks.extend(split(task, folded_args))
|
124 |
+
else:
|
125 |
+
unfolded_tasks.append(task)
|
126 |
+
return unfolded_tasks
|
127 |
+
|
128 |
+
|
129 |
+
def split(task: Task, folded_args: tuple[str, str]) -> list[Task]:
|
130 |
+
"""Split folded task into two same tasks, but separated generated resource arguments"""
|
131 |
+
key, value = folded_args
|
132 |
+
generated_items = value.split(",")
|
133 |
+
split_tasks = []
|
134 |
+
for item in generated_items:
|
135 |
+
new_task = copy.deepcopy(task)
|
136 |
+
dep_task_id = parse_task_id(item)
|
137 |
+
new_task.dep = [dep_task_id]
|
138 |
+
new_task.args[key] = item.strip()
|
139 |
+
split_tasks.append(new_task)
|
140 |
+
return split_tasks
|
141 |
+
|
142 |
+
|
143 |
+
def find_folded_args(task: Task) -> tuple[str, str] | None:
|
144 |
+
"""Finds folded args, e.g: 'image': '<GENERATED>-1,<GENERATED>-2'"""
|
145 |
+
for key, value in task.args.items():
|
146 |
+
if value.count(GENERATED_TOKEN) > 1:
|
147 |
+
logger.debug(f"Task {task.id} is folded")
|
148 |
+
return key, value
|
149 |
+
return None
|
hugginggpt/task_planning.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
from langchain import LLMChain
|
4 |
+
from langchain.llms.base import BaseLLM
|
5 |
+
from langchain.prompts import load_prompt
|
6 |
+
|
7 |
+
from hugginggpt.exceptions import TaskPlanningException, wrap_exceptions
|
8 |
+
from hugginggpt.history import ConversationHistory
|
9 |
+
from hugginggpt.llm_factory import LLM_MAX_TOKENS, count_tokens
|
10 |
+
from hugginggpt.resources import get_prompt_resource
|
11 |
+
from hugginggpt.task_parsing import Task, parse_tasks
|
12 |
+
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
MAIN_PROMPT_TOKENS = 800
|
16 |
+
MAX_HISTORY_TOKENS = LLM_MAX_TOKENS - MAIN_PROMPT_TOKENS
|
17 |
+
|
18 |
+
|
19 |
+
@wrap_exceptions(TaskPlanningException, "Failed to plan tasks")
|
20 |
+
def plan_tasks(
|
21 |
+
user_input: str, history: ConversationHistory, llm: BaseLLM
|
22 |
+
) -> list[Task]:
|
23 |
+
"""Use LLM agent to plan tasks in order solve user request."""
|
24 |
+
logger.info("Starting task planning")
|
25 |
+
task_planning_prompt_template = load_prompt(
|
26 |
+
get_prompt_resource("task-planning-few-shot-prompt.json")
|
27 |
+
)
|
28 |
+
llm_chain = LLMChain(prompt=task_planning_prompt_template, llm=llm)
|
29 |
+
history_truncated = truncate_history(history)
|
30 |
+
output = llm_chain.predict(
|
31 |
+
user_input=user_input, history=history_truncated, stop=["<im_end>"]
|
32 |
+
)
|
33 |
+
logger.info(f"Task planning raw output: {output}")
|
34 |
+
tasks = parse_tasks(output)
|
35 |
+
return tasks
|
36 |
+
|
37 |
+
|
38 |
+
def truncate_history(history: ConversationHistory) -> ConversationHistory:
|
39 |
+
"""Truncate history to fit within the max token limit for the task planning LLM"""
|
40 |
+
example_prompt_template = load_prompt(
|
41 |
+
get_prompt_resource("task-planning-example-prompt.json")
|
42 |
+
)
|
43 |
+
token_counter = 0
|
44 |
+
n_messages = 0
|
45 |
+
# Iterate through history backwards in pairs, to ensure most recent messages are kept
|
46 |
+
for i in range(0, len(history), 2):
|
47 |
+
user_message = history[-(i + 2)]
|
48 |
+
assistant_message = history[-(i + 1)]
|
49 |
+
# Turn messages into LLM prompt string
|
50 |
+
history_text = example_prompt_template.format(
|
51 |
+
example_input=user_message["content"],
|
52 |
+
example_output=assistant_message["content"],
|
53 |
+
)
|
54 |
+
n_tokens = count_tokens(history_text)
|
55 |
+
if token_counter + n_tokens <= MAX_HISTORY_TOKENS:
|
56 |
+
n_messages += 2
|
57 |
+
token_counter += n_tokens
|
58 |
+
else:
|
59 |
+
break
|
60 |
+
start = len(history) - n_messages
|
61 |
+
return history[start:]
|
logging-config.toml
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
version = 1
|
2 |
+
|
3 |
+
disable_existing_loggers = false
|
4 |
+
|
5 |
+
[root]
|
6 |
+
level = "DEBUG"
|
7 |
+
handlers = ["debug_file", "errors_file"]
|
8 |
+
|
9 |
+
[formatters.simple]
|
10 |
+
format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
11 |
+
|
12 |
+
[handlers.debug_file]
|
13 |
+
class = "logging.handlers.TimedRotatingFileHandler"
|
14 |
+
level = "DEBUG"
|
15 |
+
formatter = "simple"
|
16 |
+
filename = "logs/debug.log"
|
17 |
+
when = "midnight"
|
18 |
+
encoding = "utf8"
|
19 |
+
|
20 |
+
[handlers.errors_file]
|
21 |
+
class = "logging.handlers.TimedRotatingFileHandler"
|
22 |
+
level = "ERROR"
|
23 |
+
formatter = "simple"
|
24 |
+
filename = "logs/errors.log"
|
25 |
+
when = "midnight"
|
26 |
+
encoding = "utf8"
|
logs/.gitkeep
ADDED
File without changes
|
main.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import json
|
3 |
+
import logging
|
4 |
+
|
5 |
+
import click
|
6 |
+
import requests
|
7 |
+
from dotenv import load_dotenv
|
8 |
+
|
9 |
+
from hugginggpt import generate_response, infer, plan_tasks
|
10 |
+
from hugginggpt.history import ConversationHistory
|
11 |
+
from hugginggpt.llm_factory import LLMs, create_llms
|
12 |
+
from hugginggpt.log import setup_logging
|
13 |
+
from hugginggpt.model_inference import TaskSummary
|
14 |
+
from hugginggpt.model_selection import select_hf_models
|
15 |
+
from hugginggpt.response_generation import format_response
|
16 |
+
|
17 |
+
load_dotenv()
|
18 |
+
setup_logging()
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
|
22 |
+
@click.command()
|
23 |
+
@click.option("-p", "--prompt", type=str, help="Prompt for huggingGPT")
|
24 |
+
def main(prompt):
|
25 |
+
_print_banner()
|
26 |
+
llms = create_llms()
|
27 |
+
if prompt:
|
28 |
+
standalone_mode(user_input=prompt, llms=llms)
|
29 |
+
|
30 |
+
else:
|
31 |
+
interactive_mode(llms=llms)
|
32 |
+
|
33 |
+
|
34 |
+
def standalone_mode(user_input: str, llms: LLMs) -> str:
|
35 |
+
try:
|
36 |
+
response, task_summaries = compute(
|
37 |
+
user_input=user_input,
|
38 |
+
history=ConversationHistory(),
|
39 |
+
llms=llms,
|
40 |
+
)
|
41 |
+
pretty_response = format_response(response)
|
42 |
+
print(pretty_response)
|
43 |
+
return pretty_response
|
44 |
+
except Exception as e:
|
45 |
+
logger.exception("")
|
46 |
+
print(
|
47 |
+
f"Sorry, encountered error: {e}. Please try again. Check logs if problem persists."
|
48 |
+
)
|
49 |
+
|
50 |
+
|
51 |
+
def interactive_mode(llms: LLMs):
|
52 |
+
print("Please enter your request. End the conversation with 'exit'")
|
53 |
+
history = ConversationHistory()
|
54 |
+
while True:
|
55 |
+
try:
|
56 |
+
user_input = click.prompt("User")
|
57 |
+
if user_input.lower() == "exit":
|
58 |
+
break
|
59 |
+
|
60 |
+
logger.info(f"User input: {user_input}")
|
61 |
+
response, task_summaries = compute(
|
62 |
+
user_input=user_input,
|
63 |
+
history=history,
|
64 |
+
llms=llms,
|
65 |
+
)
|
66 |
+
pretty_response = format_response(response)
|
67 |
+
print(f"Assistant:{pretty_response}")
|
68 |
+
|
69 |
+
history.add(role="user", content=user_input)
|
70 |
+
history.add(role="assistant", content=response)
|
71 |
+
except Exception as e:
|
72 |
+
logger.exception("")
|
73 |
+
print(
|
74 |
+
f"Sorry, encountered error: {e}. Please try again. Check logs if problem persists."
|
75 |
+
)
|
76 |
+
|
77 |
+
|
78 |
+
def compute(
|
79 |
+
user_input: str,
|
80 |
+
history: ConversationHistory,
|
81 |
+
llms: LLMs,
|
82 |
+
) -> (str, list[TaskSummary]):
|
83 |
+
tasks = plan_tasks(
|
84 |
+
user_input=user_input, history=history, llm=llms.task_planning_llm
|
85 |
+
)
|
86 |
+
|
87 |
+
sorted(tasks, key=lambda t: max(t.dep))
|
88 |
+
logger.info(f"Sorted tasks: {tasks}")
|
89 |
+
|
90 |
+
hf_models = asyncio.run(
|
91 |
+
select_hf_models(
|
92 |
+
user_input=user_input,
|
93 |
+
tasks=tasks,
|
94 |
+
model_selection_llm=llms.model_selection_llm,
|
95 |
+
output_fixing_llm=llms.output_fixing_llm,
|
96 |
+
)
|
97 |
+
)
|
98 |
+
|
99 |
+
task_summaries = []
|
100 |
+
with requests.Session() as session:
|
101 |
+
for task in tasks:
|
102 |
+
logger.info(f"Starting task: {task}")
|
103 |
+
if task.depends_on_generated_resources():
|
104 |
+
task = task.replace_generated_resources(task_summaries=task_summaries)
|
105 |
+
model = hf_models[task.id]
|
106 |
+
inference_result = infer(
|
107 |
+
task=task,
|
108 |
+
model_id=model.id,
|
109 |
+
llm=llms.model_inference_llm,
|
110 |
+
session=session,
|
111 |
+
)
|
112 |
+
task_summaries.append(
|
113 |
+
TaskSummary(
|
114 |
+
task=task,
|
115 |
+
model=model,
|
116 |
+
inference_result=json.dumps(inference_result),
|
117 |
+
)
|
118 |
+
)
|
119 |
+
logger.info(f"Finished task: {task}")
|
120 |
+
logger.info("Finished all tasks")
|
121 |
+
logger.debug(f"Task summaries: {task_summaries}")
|
122 |
+
|
123 |
+
response = generate_response(
|
124 |
+
user_input=user_input,
|
125 |
+
task_summaries=task_summaries,
|
126 |
+
llm=llms.response_generation_llm,
|
127 |
+
)
|
128 |
+
return response, task_summaries
|
129 |
+
|
130 |
+
|
131 |
+
def _print_banner():
|
132 |
+
with open("resources/banner.txt", "r") as f:
|
133 |
+
banner = f.read()
|
134 |
+
logger.info("\n" + banner)
|
135 |
+
|
136 |
+
|
137 |
+
if __name__ == "__main__":
|
138 |
+
main()
|
output/.gitkeep
ADDED
File without changes
|
output/audios/.gitkeep
ADDED
File without changes
|
output/images/.gitkeep
ADDED
File without changes
|
output/videos/.gitkeep
ADDED
File without changes
|
pdm.lock
ADDED
The diff for this file is too large to render.
See raw diff
|
|
pyproject.toml
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[tool.pdm]
|
2 |
+
[tool.pdm.dev-dependencies]
|
3 |
+
dev = []
|
4 |
+
test = [
|
5 |
+
"pytest>=7.3.0",
|
6 |
+
"pytest-cov>=4.0.0",
|
7 |
+
"pytest-asyncio>=0.21.0",
|
8 |
+
"aioresponses>=0.7.4",
|
9 |
+
"responses>=0.23.1",
|
10 |
+
]
|
11 |
+
ide = [
|
12 |
+
"setuptools>=67.6.1",
|
13 |
+
]
|
14 |
+
|
15 |
+
[tool.pdm.scripts]
|
16 |
+
hugginggpt = "python main.py"
|
17 |
+
|
18 |
+
[tool.pytest]
|
19 |
+
[tool.pytest.ini_options]
|
20 |
+
asyncio_mode = "auto"
|
21 |
+
norecursedirs = "tests/helpers"
|
22 |
+
|
23 |
+
[project]
|
24 |
+
name = "langchain-huggingGPT"
|
25 |
+
version = "0.1.0"
|
26 |
+
description = ""
|
27 |
+
authors = [
|
28 |
+
{name = "camille-vanhoffelen", email = "camille-vanhoffelen@users.noreply.github.com"},
|
29 |
+
]
|
30 |
+
dependencies = [
|
31 |
+
"click>=8.1.3",
|
32 |
+
"python-dotenv>=1.0.0",
|
33 |
+
"langchain>=0.0.137",
|
34 |
+
"openai>=0.27.4",
|
35 |
+
"huggingface-hub>=0.13.4",
|
36 |
+
"tiktoken>=0.3.3",
|
37 |
+
"diffusers>=0.15.1",
|
38 |
+
"Pillow>=9.5.0",
|
39 |
+
"pydub>=0.25.1",
|
40 |
+
"aiohttp>=3.8.4",
|
41 |
+
"aiodns>=3.0.0",
|
42 |
+
"gradio>=3.32.0",
|
43 |
+
]
|
44 |
+
requires-python = ">=3.11"
|
45 |
+
readme = "README.md"
|
46 |
+
license = {text = "MIT"}
|
47 |
+
|
48 |
+
[build-system]
|
49 |
+
requires = ["pdm-backend"]
|
50 |
+
build-backend = "pdm.backend"
|
requirements.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
resources/banner.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__ _ ___
|
2 |
+
| _. ._ _ _ |_ _. o ._ |_| _ _ o ._ _ /__ |_) |
|
3 |
+
| (_| | | (_| (_ | | (_| | | | | | |_| (_| (_| | | | (_| \_| | |
|
4 |
+
_| _| _| _|
|
resources/huggingface-models-metadata.jsonl
ADDED
The diff for this file is too large to render.
See raw diff
|
|
resources/prompt-templates/model-selection-prompt.json
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_type": "prompt",
|
3 |
+
"input_variables": [
|
4 |
+
"user_input",
|
5 |
+
"models",
|
6 |
+
"task"
|
7 |
+
],
|
8 |
+
"template": "#2 Model Selection Stage: Given the user request and the parsed tasks, the AI assistant helps the user to select a suitable model from a list of models to process the user request. The assistant should focus more on the description of the model and find the model that has the most potential to solve requests and tasks. Also, prefer models with local inference endpoints for speed and stability.\n<im_start>user\n{user_input}<im_end>\n<im_start>assistant\n{task}<im_end>\n<im_start>user\nPlease choose the most suitable model from {models} for the task {task}. The output must be in a strict JSON format: {{\"id\": \"id\", \"reason\": \"your detail reasons for the choice\"}}.<im_end>\n<im_start>assistant\n"
|
9 |
+
}
|
resources/prompt-templates/openai-model-inference-prompt.json
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_type": "prompt",
|
3 |
+
"input_variables": [
|
4 |
+
"task",
|
5 |
+
"task_name",
|
6 |
+
"args"
|
7 |
+
],
|
8 |
+
"template": "Model Inference Stage: the AI assistant needs to execute a task for the user.\n<im_start>user\nHere is the task in JSON format {task}. Now you are a {task_name} system, the arguments are {args}. Just help me do {task_name} and give me the result. The result must be in text form without any urls.<im_end>\n<im_start>assistant"
|
9 |
+
}
|
resources/prompt-templates/response-generation-prompt.json
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_type": "prompt",
|
3 |
+
"input_variables": [
|
4 |
+
"user_input",
|
5 |
+
"task_results"
|
6 |
+
],
|
7 |
+
"template": "#4 Response Generation Stage: With the task execution logs, the AI assistant needs to describe the process and inference results.\n<im_start>user\n{user_input}<im_end>\n<im_start>assistant\nBefore give you a response, I want to introduce my workflow for your request, which is shown in the following JSON data: {task_results}. Do you have any demands regarding my response?<im_end>\n<im_start>user\nYes. Please first think carefully and directly answer my request based on the inference results. Some of the inferences may not always turn out to be correct and require you to make careful consideration in making decisions. Then please detail your workflow including the used models and inference results for my request in your friendly tone. Please filter out information that is not relevant to my request. Tell me the complete path or urls of files in inference results. If there is nothing in the results, please tell me you can't make it.<im_end>\n<im_start>assistant"
|
8 |
+
}
|
resources/prompt-templates/task-planning-example-prompt.json
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_type": "prompt",
|
3 |
+
"input_variables": [
|
4 |
+
"example_input",
|
5 |
+
"example_output"
|
6 |
+
],
|
7 |
+
"template": "<im_start>user\n{example_input}<im_end>\n<im_start>assistant\n{example_output}<im_end>"
|
8 |
+
}
|
resources/prompt-templates/task-planning-examples.json
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
{
|
3 |
+
"example_input": "Give you some pictures e1.jpg, e2.png, e3.jpg, help me count the number of sheep?",
|
4 |
+
"example_output": "[{{\"task\": \"image-to-text\", \"id\": 0, \"dep\": [-1], \"args\": {{\"image\": \"e1.jpg\" }}}}, {{\"task\": \"object-detection\", \"id\": 1, \"dep\": [-1], \"args\": {{\"image\": \"e1.jpg\" }}}}, {{\"task\": \"visual-question-answering\", \"id\": 2, \"dep\": [1], \"args\": {{\"image\": \"<GENERATED>-1\", \"text\": \"How many sheep in the picture\"}}}}, {{\"task\": \"image-to-text\", \"id\": 3, \"dep\": [-1], \"args\": {{\"image\": \"e2.png\" }}}}, {{\"task\": \"object-detection\", \"id\": 4, \"dep\": [-1], \"args\": {{\"image\": \"e2.png\" }}}}, {{\"task\": \"visual-question-answering\", \"id\": 5, \"dep\": [4], \"args\": {{\"image\": \"<GENERATED>-4\", \"text\": \"How many sheep in the picture\"}}}}, {{\"task\": \"image-to-text\", \"id\": 6, \"dep\": [-1], \"args\": {{\"image\": \"e3.jpg\" }}}}, {{\"task\": \"object-detection\", \"id\": 7, \"dep\": [-1], \"args\": {{\"image\": \"e3.jpg\" }}}}, {{\"task\": \"visual-question-answering\", \"id\": 8, \"dep\": [7], \"args\": {{\"image\": \"<GENERATED>-7\", \"text\": \"How many sheep in the picture\"}}}}]"
|
5 |
+
},
|
6 |
+
{
|
7 |
+
"example_input":"Look at /e.jpg, can you tell me how many objects in the picture? Give me a picture similar to this one.",
|
8 |
+
"example_output":"[{{\"task\": \"image-to-text\", \"id\": 0, \"dep\": [-1], \"args\": {{\"image\": \"/e.jpg\" }}}}, {{\"task\": \"object-detection\", \"id\": 1, \"dep\": [-1], \"args\": {{\"image\": \"/e.jpg\" }}}}, {{\"task\": \"visual-question-answering\", \"id\": 2, \"dep\": [1], \"args\": {{\"image\": \"<GENERATED>-1\", \"text\": \"how many objects in the picture?\" }}}}, {{\"task\": \"text-to-image\", \"id\": 3, \"dep\": [0], \"args\": {{\"text\": \"<GENERATED-0>\" }}}}, {{\"task\": \"image-to-image\", \"id\": 4, \"dep\": [-1], \"args\": {{\"image\": \"/e.jpg\" }}}}]"
|
9 |
+
},
|
10 |
+
{
|
11 |
+
"example_input":"given a document /images/e.jpeg, answer me what is the student amount? And describe the image with your voice",
|
12 |
+
"example_output":"{{\"task\": \"document-question-answering\", \"id\": 0, \"dep\": [-1], \"args\": {{\"image\": \"/images/e.jpeg\", \"text\": \"what is the student amount?\" }}}}, {{\"task\": \"visual-question-answering\", \"id\": 1, \"dep\": [-1], \"args\": {{\"image\": \"/images/e.jpeg\", \"text\": \"what is the student amount?\" }}}}, {{\"task\": \"image-to-text\", \"id\": 2, \"dep\": [-1], \"args\": {{\"image\": \"/images/e.jpg\" }}}}, {{\"task\": \"text-to-speech\", \"id\": 3, \"dep\": [2], \"args\": {{\"text\": \"<GENERATED>-2\" }}}}]"
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"example_input": "Given an image /example.jpg, generate a new image where instead of reading a book, the girl is using her phone",
|
16 |
+
"example_output": "[{{\"task\": \"image-to-image\", \"id\": 0, \"dep\": [-1], \"args\": {{\"image\": \"/example.jpg\", \"text\": \"instead of reading a book, the girl is using her phone\"}}}}]"
|
17 |
+
},
|
18 |
+
{
|
19 |
+
"example_input": "please show me an image of (based on the text) 'a boy is running' and dub it",
|
20 |
+
"example_output": "[{{\"task\": \"text-to-image\", \"id\": 0, \\\"dep\\\": [-1], \"args\": {{\"text\": \"a boy is running\" }}}}, {{\"task\": \"text-to-speech\", \"id\": 1, \"dep\": [-1], \"args\": {{\"text\": \"a boy is running\" }}}}]"
|
21 |
+
},
|
22 |
+
{
|
23 |
+
"example_input": "please show me a joke and an image of cat",
|
24 |
+
"example_output": "[{{\"task\": \"conversational\", \"id\": 0, \"dep\": [-1], \"args\": {{\"text\": \"please show me a joke of cat\" }}}}, {{\"task\": \"text-to-image\", \"id\": 1, \"dep\": [-1], \"args\": {{\"text\": \"a photo of cat\" }}}}]"
|
25 |
+
},
|
26 |
+
{
|
27 |
+
"example_input": "Please tell me how similar the two following sentences are: I like to surf. Surfing is my passion.",
|
28 |
+
"example_output": "[{{\"task\": \"sentence-similarity\", \"id\": 0, \"dep\": [-1], \"args\": {{\"text1\": \"I like to surf.\", \"text2\": \"Surfing is my passion.\"}}}}]"
|
29 |
+
},
|
30 |
+
{
|
31 |
+
"example_input": "Please tell me if the following sentence is positive or negative: Surfing is the best!",
|
32 |
+
"example_output": "[{{\"task\": \"text-classification\", \"id\": 0, \"dep\": [-1], \"args\": {{\"text\": \"Surfing is the best!\"}}}}]"
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"example_input": "Classify the image found at /images/de78.png.",
|
36 |
+
"example_output": "[{{\"task\": \"image-classification\", \"id\": 0, \"dep\": [-1], \"args\": {{\"image\": \"/images/de78.png\"}}}}]"
|
37 |
+
},
|
38 |
+
{
|
39 |
+
"example_input": "Paris is the capital of France. Berlin is the capital of Germany. Based on the previous facts, answer the following question: What is the capital of France?",
|
40 |
+
"example_output": "[{{\"task\": \"question-answering\", \"id\": 0, \"dep\": [-1], \"args\": {{\"question\": \"What is the capital of France?\", \"context\": \"Paris is the capital of France. Berlin is the capital of Germany.\"}}}}]"
|
41 |
+
}
|
42 |
+
]
|
resources/prompt-templates/task-planning-few-shot-prompt.json
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_type": "few_shot",
|
3 |
+
"input_variables": [
|
4 |
+
"user_input",
|
5 |
+
"history"
|
6 |
+
],
|
7 |
+
"prefix": "#1 Task Planning Stage: The AI assistant can parse user input to several tasks: [{{\"task\": task, \"id\": task_id, \"dep\": dependency_task_id, \"args\": {{\"text\": text or <GENERATED>-dep_id, \"image\": image_url or <GENERATED>-dep_id, \"audio\": audio_url or <GENERATED>-dep_id}}}}]. The special tag \"<GENERATED>-dep_id\" refer to the one generated text/image/audio in the dependency task (Please consider whether the dependency task generates resources of this type.) and \"dep_id\" must be in \"dep\" list. The \"dep\" field denotes the ids of the previous prerequisite tasks which generate a new resource that the current task relies on. The \"args\" field must in [\"text\", \"image\", \"audio\"], nothing else. The task MUST be selected from the following options: \"token-classification\", \"text2text-generation\", \"summarization\", \"translation\", \"question-answering\", \"conversational\", \"text-generation\", \"sentence-similarity\", \"tabular-classification\", \"object-detection\", \"image-classification\", \"image-to-image\", \"image-to-text\", \"text-to-image\", \"visual-question-answering\", \"document-question-answering\", \"image-segmentation\", \"depth-estimation\", \"text-to-speech\", \"automatic-speech-recognition\", \"audio-to-audio\", \"audio-classification\". There may be multiple tasks of the same type. Think step by step about all the tasks needed to resolve the user's request. Parse out as few tasks as possible while ensuring that the user request can be resolved. Pay attention to the dependencies and order among tasks. If the user input can't be parsed, you need to reply empty JSON [].",
|
8 |
+
"example_prompt_path": "resources/prompt-templates/task-planning-example-prompt.json",
|
9 |
+
"examples": "resources/prompt-templates/task-planning-examples.json",
|
10 |
+
"suffix": "<im_start>user\nThe chat log [ {history} ] may contain the resources I mentioned. Now I input {{ {user_input} }}. Pay attention to the input and output types of tasks and the dependencies between tasks.<im_end>\n<im_start>assistant\n"
|
11 |
+
}
|