BMukhtar commited on
Commit
0370cf5
β€’
1 Parent(s): 63fe45c

initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ logs/
2
+ .idea/
3
+ models
4
+ public/*
5
+ *.pyc
6
+ !public/examples
7
+
8
+ .idea/
9
+
10
+ # ALL
11
+ *.dev
12
+
13
+ # for server
14
+ server/models/*
15
+ !server/models/download.sh
16
+ !server/models/download.ps1
17
+ server/logs/
18
+ server/models_dev
19
+ server/public/*
20
+ !server/public/examples/
21
+ server/public/examples/*
22
+ !server/public/examples/a.jpg
23
+ !server/public/examples/b.jpg
24
+ !server/public/examples/c.jpg
25
+ !server/public/examples/d.jpg
26
+ !server/public/examples/e.jpg
27
+ !server/public/examples/f.jpg
28
+ !server/public/examples/g.jpg
29
+
30
+ # docker
31
+ Dockerfile
32
+ docker-compose.yml
33
+
34
+ # for gradio
35
+ # server/run_gradio.py
36
+
37
+ # for web
38
+ web/node_modules
39
+ web/package-lock.json
40
+ web/dist
41
+ web/electron-dist
42
+ web/yarn.lock
43
+
44
+ # Byte-compiled / optimized / DLL files
45
+ __pycache__/
46
+ *.py[cod]
47
+ *$py.class
48
+
49
+ # C extensions
50
+ *.so
51
+
52
+ # Distribution / packaging
53
+ .Python
54
+ build/
55
+ develop-eggs/
56
+ dist/
57
+ downloads/
58
+ eggs/
59
+ .eggs/
60
+ lib/
61
+ lib64/
62
+ parts/
63
+ sdist/
64
+ var/
65
+ wheels/
66
+ pip-wheel-metadata/
67
+ share/python-wheels/
68
+ *.egg-info/
69
+ .installed.cfg
70
+ *.egg
71
+ MANIFEST
72
+
73
+ # PyInstaller
74
+ # Usually these files are written by a python script from a template
75
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
76
+ *.manifest
77
+ *.spec
78
+
79
+ # Installer logs
80
+ pip-log.txt
81
+ pip-delete-this-directory.txt
82
+
83
+ # Unit test / coverage reports
84
+ htmlcov/
85
+ .tox/
86
+ .nox/
87
+ .coverage
88
+ .coverage.*
89
+ .cache
90
+ nosetests.xml
91
+ coverage.xml
92
+ *.cover
93
+ *.py,cover
94
+ .hypothesis/
95
+ .pytest_cache/
96
+
97
+ # Translations
98
+ *.mo
99
+ *.pot
100
+
101
+ # Django stuff:
102
+ *.log
103
+ local_settings.py
104
+ db.sqlite3
105
+ db.sqlite3-journal
106
+
107
+ # Flask stuff:
108
+ instance/
109
+ .webassets-cache
110
+
111
+ # Scrapy stuff:
112
+ .scrapy
113
+
114
+ # Sphinx documentation
115
+ docs/_build/
116
+
117
+ # PyBuilder
118
+ target/
119
+
120
+ # Jupyter Notebook
121
+ .ipynb_checkpoints
122
+
123
+ # IPython
124
+ profile_default/
125
+ ipython_config.py
126
+
127
+ # pyenv
128
+ .python-version
129
+
130
+ # pipenv
131
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
132
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
133
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
134
+ # install all needed dependencies.
135
+ #Pipfile.lock
136
+
137
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
138
+ __pypackages__/
139
+
140
+ # Celery stuff
141
+ celerybeat-schedule
142
+ celerybeat.pid
143
+
144
+ # SageMath parsed files
145
+ *.sage.py
146
+
147
+ # Environments
148
+ .env
149
+ .venv
150
+ env/
151
+ venv/
152
+ ENV/
153
+ env.bak/
154
+ venv.bak/
155
+
156
+ # Spyder project settings
157
+ .spyderproject
158
+ .spyproject
159
+
160
+ # Rope project settings
161
+ .ropeproject
162
+
163
+ # mkdocs documentation
164
+ /site
165
+
166
+ # mypy
167
+ .mypy_cache/
168
+ .dmypy.json
169
+ dmypy.json
170
+
171
+ # Pyre type checker
172
+ .pyre/
173
+
app.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+ import gradio as gr
3
+ import re
4
+ from diffusers.utils import load_image
5
+ import requests
6
+ from awesome_chat import chat_huggingface
7
+ import os
8
+
9
+ os.makedirs("public/images", exist_ok=True)
10
+ os.makedirs("public/audios", exist_ok=True)
11
+ os.makedirs("public/videos", exist_ok=True)
12
+
13
+ class Client:
14
+ def __init__(self) -> None:
15
+ self.OPENAI_KEY = ""
16
+ self.HUGGINGFACE_TOKEN = ""
17
+ self.all_messages = []
18
+
19
+ def set_key(self, openai_key):
20
+ self.OPENAI_KEY = openai_key
21
+ if len(self.HUGGINGFACE_TOKEN)>0:
22
+ gr.update(visible = True)
23
+ return self.OPENAI_KEY
24
+
25
+ def set_token(self, huggingface_token):
26
+ self.HUGGINGFACE_TOKEN = huggingface_token
27
+ if len(self.OPENAI_KEY)>0:
28
+ gr.update(visible = True)
29
+ return self.HUGGINGFACE_TOKEN
30
+
31
+ def add_message(self, content, role):
32
+ message = {"role":role, "content":content}
33
+ self.all_messages.append(message)
34
+
35
+ def extract_medias(self, message):
36
+ # url_pattern = re.compile(r"(http(s?):|\/)?([\.\/_\w:-])*?")
37
+ urls = []
38
+ # for match in url_pattern.finditer(message):
39
+ # if match.group(0) not in urls:
40
+ # urls.append(match.group(0))
41
+
42
+ image_pattern = re.compile(r"(http(s?):|\/)?([\.\/_\w:-])*?\.(jpg|jpeg|tiff|gif|png)")
43
+ image_urls = []
44
+ for match in image_pattern.finditer(message):
45
+ if match.group(0) not in image_urls:
46
+ image_urls.append(match.group(0))
47
+
48
+ audio_pattern = re.compile(r"(http(s?):|\/)?([\.\/_\w:-])*?\.(flac|wav)")
49
+ audio_urls = []
50
+ for match in audio_pattern.finditer(message):
51
+ if match.group(0) not in audio_urls:
52
+ audio_urls.append(match.group(0))
53
+
54
+ video_pattern = re.compile(r"(http(s?):|\/)?([\.\/_\w:-])*?\.(mp4)")
55
+ video_urls = []
56
+ for match in video_pattern.finditer(message):
57
+ if match.group(0) not in video_urls:
58
+ video_urls.append(match.group(0))
59
+
60
+ return urls, image_urls, audio_urls, video_urls
61
+
62
+ def add_text(self, messages, message):
63
+ if len(self.OPENAI_KEY) == 0 or not self.OPENAI_KEY.startswith("sk-") or len(self.HUGGINGFACE_TOKEN) == 0 or not self.HUGGINGFACE_TOKEN.startswith("hf_"):
64
+ return messages, "Please set your OpenAI API key and Hugging Face token first!!!"
65
+ self.add_message(message, "user")
66
+ messages = messages + [(message, None)]
67
+ urls, image_urls, audio_urls, video_urls = self.extract_medias(message)
68
+
69
+ for image_url in image_urls:
70
+ if not image_url.startswith("http") and not image_url.startswith("public"):
71
+ image_url = "public/" + image_url
72
+ image = load_image(image_url)
73
+ name = f"public/images/{str(uuid.uuid4())[:4]}.jpg"
74
+ image.save(name)
75
+ messages = messages + [((f"{name}",), None)]
76
+ for audio_url in audio_urls and not audio_url.startswith("public"):
77
+ if not audio_url.startswith("http"):
78
+ audio_url = "public/" + audio_url
79
+ ext = audio_url.split(".")[-1]
80
+ name = f"public/audios/{str(uuid.uuid4()[:4])}.{ext}"
81
+ response = requests.get(audio_url)
82
+ with open(name, "wb") as f:
83
+ f.write(response.content)
84
+ messages = messages + [((f"{name}",), None)]
85
+ for video_url in video_urls and not video_url.startswith("public"):
86
+ if not video_url.startswith("http"):
87
+ video_url = "public/" + video_url
88
+ ext = video_url.split(".")[-1]
89
+ name = f"public/audios/{str(uuid.uuid4()[:4])}.{ext}"
90
+ response = requests.get(video_url)
91
+ with open(name, "wb") as f:
92
+ f.write(response.content)
93
+ messages = messages + [((f"{name}",), None)]
94
+ return messages, ""
95
+
96
+ def bot(self, messages):
97
+ if len(self.OPENAI_KEY) == 0 or not self.OPENAI_KEY.startswith("sk-") or len(self.HUGGINGFACE_TOKEN) == 0 or not self.HUGGINGFACE_TOKEN.startswith("hf_"):
98
+ return messages, {}
99
+ message, results = chat_huggingface(self.all_messages, self.OPENAI_KEY, self.HUGGINGFACE_TOKEN)
100
+ urls, image_urls, audio_urls, video_urls = self.extract_medias(message)
101
+ self.add_message(message, "assistant")
102
+ messages[-1][1] = message
103
+ for image_url in image_urls:
104
+ if not image_url.startswith("http"):
105
+ image_url = image_url.replace("public/", "")
106
+ messages = messages + [((None, (f"public/{image_url}",)))]
107
+ # else:
108
+ # messages = messages + [((None, (f"{image_url}",)))]
109
+ for audio_url in audio_urls:
110
+ if not audio_url.startswith("http"):
111
+ audio_url = audio_url.replace("public/", "")
112
+ messages = messages + [((None, (f"public/{audio_url}",)))]
113
+ # else:
114
+ # messages = messages + [((None, (f"{audio_url}",)))]
115
+ for video_url in video_urls:
116
+ if not video_url.startswith("http"):
117
+ video_url = video_url.replace("public/", "")
118
+ messages = messages + [((None, (f"public/{video_url}",)))]
119
+ # else:
120
+ # messages = messages + [((None, (f"{video_url}",)))]
121
+ # replace int key to string key
122
+ results = {str(k): v for k, v in results.items()}
123
+ return messages, results
124
+
125
+ css = ".json {height: 527px; overflow: scroll;} .json-holder {height: 527px; overflow: scroll;}"
126
+ with gr.Blocks(css=css) as demo:
127
+ state = gr.State(value={"client": Client()})
128
+ gr.Markdown("<h1><center>HuggingGPT</center></h1>")
129
+ gr.Markdown("<p align='center'><img src='https://i.ibb.co/qNH3Jym/logo.png' height='25' width='95'></p>")
130
+ gr.Markdown("<p align='center' style='font-size: 20px;'>A system to connect LLMs with ML community. See our <a href='https://github.com/microsoft/JARVIS'>Project</a> and <a href='http://arxiv.org/abs/2303.17580'>Paper</a>.</p>")
131
+ gr.HTML('''<center><a href="https://huggingface.co/spaces/microsoft/HuggingGPT?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>Duplicate the Space and run securely with your OpenAI API Key and Hugging Face Token</center>''')
132
+ with gr.Row().style():
133
+ with gr.Column(scale=0.85):
134
+ openai_api_key = gr.Textbox(
135
+ show_label=False,
136
+ placeholder="Set your OpenAI API key here and press Enter",
137
+ lines=1,
138
+ type="password"
139
+ ).style(container=False)
140
+ with gr.Column(scale=0.15, min_width=0):
141
+ btn1 = gr.Button("Submit").style(full_height=True)
142
+
143
+ with gr.Row().style():
144
+ with gr.Column(scale=0.85):
145
+ hugging_face_token = gr.Textbox(
146
+ show_label=False,
147
+ placeholder="Set your Hugging Face Token here and press Enter",
148
+ lines=1,
149
+ type="password"
150
+ ).style(container=False)
151
+ with gr.Column(scale=0.15, min_width=0):
152
+ btn3 = gr.Button("Submit").style(full_height=True)
153
+
154
+
155
+ with gr.Row().style():
156
+ with gr.Column(scale=0.6):
157
+ chatbot = gr.Chatbot([], elem_id="chatbot").style(height=500)
158
+ with gr.Column(scale=0.4):
159
+ results = gr.JSON(elem_classes="json")
160
+
161
+
162
+ with gr.Row().style():
163
+ with gr.Column(scale=0.85):
164
+ txt = gr.Textbox(
165
+ show_label=False,
166
+ placeholder="Enter text and press enter. The url of the multimedia resource must contain the extension name.",
167
+ lines=1,
168
+ ).style(container=False)
169
+ with gr.Column(scale=0.15, min_width=0):
170
+ btn2 = gr.Button("Send").style(full_height=True)
171
+
172
+ def set_key(state, openai_api_key):
173
+ return state["client"].set_key(openai_api_key)
174
+
175
+ def add_text(state, chatbot, txt):
176
+ return state["client"].add_text(chatbot, txt)
177
+
178
+ def set_token(state, hugging_face_token):
179
+ return state["client"].set_token(hugging_face_token)
180
+
181
+ def bot(state, chatbot):
182
+ return state["client"].bot(chatbot)
183
+
184
+ openai_api_key.submit(set_key, [state, openai_api_key], [openai_api_key])
185
+ txt.submit(add_text, [state, chatbot, txt], [chatbot, txt]).then(bot, [state, chatbot], [chatbot, results])
186
+ hugging_face_token.submit(set_token, [state, hugging_face_token], [hugging_face_token])
187
+ btn1.click(set_key, [state, openai_api_key], [openai_api_key])
188
+ btn2.click(add_text, [state, chatbot, txt], [chatbot, txt]).then(bot, [state, chatbot], [chatbot, results])
189
+ btn3.click(set_token, [state, hugging_face_token], [hugging_face_token])
190
+
191
+ gr.Examples(
192
+ examples=["Given a collection of image A: /examples/a.jpg, B: /examples/b.jpg, C: /examples/c.jpg, please tell me how many zebras in these picture?",
193
+ "Please generate a canny image based on /examples/f.jpg",
194
+ "show me a joke and an image of cat",
195
+ "what is in the examples/a.jpg",
196
+ "based on the /examples/a.jpg, please generate a video and audio",
197
+ "based on pose of /examples/d.jpg and content of /examples/e.jpg, please show me a new image",
198
+ ],
199
+ inputs=txt
200
+ )
201
+
202
+ demo.launch()
awesome_chat.py ADDED
@@ -0,0 +1,920 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import copy
3
+ from io import BytesIO
4
+ import io
5
+ import os
6
+ import random
7
+ import time
8
+ import traceback
9
+ import uuid
10
+ import requests
11
+ import re
12
+ import json
13
+ import logging
14
+ import argparse
15
+ import yaml
16
+ from PIL import Image, ImageDraw
17
+ from diffusers.utils import load_image
18
+ from pydub import AudioSegment
19
+ import threading
20
+ from queue import Queue
21
+ from get_token_ids import get_token_ids_for_task_parsing, get_token_ids_for_choose_model, count_tokens, get_max_context_length
22
+ from huggingface_hub.inference_api import InferenceApi
23
+ from huggingface_hub.inference_api import ALL_TASKS
24
+ from models_server import models, status
25
+ from functools import partial
26
+
27
+ parser = argparse.ArgumentParser()
28
+ parser.add_argument("--config", type=str, default="config.yaml.dev")
29
+ parser.add_argument("--mode", type=str, default="cli")
30
+ args = parser.parse_args()
31
+
32
+ if __name__ != "__main__":
33
+ args.config = "config.gradio.yaml"
34
+
35
+ config = yaml.load(open(args.config, "r"), Loader=yaml.FullLoader)
36
+
37
+ if not os.path.exists("logs"):
38
+ os.mkdir("logs")
39
+
40
+ logger = logging.getLogger(__name__)
41
+ logger.setLevel(logging.DEBUG)
42
+
43
+ handler = logging.StreamHandler()
44
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
45
+ handler.setFormatter(formatter)
46
+ if not config["debug"]:
47
+ handler.setLevel(logging.INFO)
48
+ logger.addHandler(handler)
49
+
50
+ log_file = config["log_file"]
51
+ if log_file:
52
+ filehandler = logging.FileHandler(log_file)
53
+ filehandler.setLevel(logging.DEBUG)
54
+ filehandler.setFormatter(formatter)
55
+ logger.addHandler(filehandler)
56
+
57
+ LLM = config["model"]
58
+ use_completion = config["use_completion"]
59
+
60
+ # consistent: wrong msra model name
61
+ LLM_encoding = LLM
62
+ if LLM == "gpt-3.5-turbo":
63
+ LLM_encoding = "text-davinci-003"
64
+ task_parsing_highlight_ids = get_token_ids_for_task_parsing(LLM_encoding)
65
+ choose_model_highlight_ids = get_token_ids_for_choose_model(LLM_encoding)
66
+
67
+ # ENDPOINT MODEL NAME
68
+ # /v1/chat/completions gpt-4, gpt-4-0314, gpt-4-32k, gpt-4-32k-0314, gpt-3.5-turbo, gpt-3.5-turbo-0301
69
+ # /v1/completions text-davinci-003, text-davinci-002, text-curie-001, text-babbage-001, text-ada-001, davinci, curie, babbage, ada
70
+
71
+ if use_completion:
72
+ api_name = "completions"
73
+ else:
74
+ api_name = "chat/completions"
75
+
76
+ if not config["dev"]:
77
+ if not config["openai"]["key"].startswith("sk-") and not config["openai"]["key"]=="gradio":
78
+ raise ValueError("Incrorrect OpenAI key. Please check your config.yaml file.")
79
+ OPENAI_KEY = config["openai"]["key"]
80
+ endpoint = f"https://api.openai.com/v1/{api_name}"
81
+ if OPENAI_KEY.startswith("sk-"):
82
+ HEADER = {
83
+ "Authorization": f"Bearer {OPENAI_KEY}"
84
+ }
85
+ else:
86
+ HEADER = None
87
+ else:
88
+ endpoint = f"{config['local']['endpoint']}/v1/{api_name}"
89
+ HEADER = None
90
+
91
+ PROXY = None
92
+ if config["proxy"]:
93
+ PROXY = {
94
+ "https": config["proxy"],
95
+ }
96
+
97
+ inference_mode = config["inference_mode"]
98
+
99
+ parse_task_demos_or_presteps = open(config["demos_or_presteps"]["parse_task"], "r").read()
100
+ choose_model_demos_or_presteps = open(config["demos_or_presteps"]["choose_model"], "r").read()
101
+ response_results_demos_or_presteps = open(config["demos_or_presteps"]["response_results"], "r").read()
102
+
103
+ parse_task_prompt = config["prompt"]["parse_task"]
104
+ choose_model_prompt = config["prompt"]["choose_model"]
105
+ response_results_prompt = config["prompt"]["response_results"]
106
+
107
+ parse_task_tprompt = config["tprompt"]["parse_task"]
108
+ choose_model_tprompt = config["tprompt"]["choose_model"]
109
+ response_results_tprompt = config["tprompt"]["response_results"]
110
+
111
+ MODELS = [json.loads(line) for line in open("data/p0_models.jsonl", "r").readlines()]
112
+ MODELS_MAP = {}
113
+ for model in MODELS:
114
+ tag = model["task"]
115
+ if tag not in MODELS_MAP:
116
+ MODELS_MAP[tag] = []
117
+ MODELS_MAP[tag].append(model)
118
+ METADATAS = {}
119
+ for model in MODELS:
120
+ METADATAS[model["id"]] = model
121
+
122
+ def convert_chat_to_completion(data):
123
+ messages = data.pop('messages', [])
124
+ tprompt = ""
125
+ if messages[0]['role'] == "system":
126
+ tprompt = messages[0]['content']
127
+ messages = messages[1:]
128
+ final_prompt = ""
129
+ for message in messages:
130
+ if message['role'] == "user":
131
+ final_prompt += ("<im_start>"+ "user" + "\n" + message['content'] + "<im_end>\n")
132
+ elif message['role'] == "assistant":
133
+ final_prompt += ("<im_start>"+ "assistant" + "\n" + message['content'] + "<im_end>\n")
134
+ else:
135
+ final_prompt += ("<im_start>"+ "system" + "\n" + message['content'] + "<im_end>\n")
136
+ final_prompt = tprompt + final_prompt
137
+ final_prompt = final_prompt + "<im_start>assistant"
138
+ data["prompt"] = final_prompt
139
+ data['stop'] = data.get('stop', ["<im_end>"])
140
+ data['max_tokens'] = data.get('max_tokens', max(get_max_context_length(LLM) - count_tokens(LLM_encoding, final_prompt), 1))
141
+ return data
142
+
143
+ def send_request(data):
144
+ global HEADER
145
+ openaikey = data.pop("openaikey")
146
+ if use_completion:
147
+ data = convert_chat_to_completion(data)
148
+ if openaikey and openaikey.startswith("sk-"):
149
+ HEADER = {
150
+ "Authorization": f"Bearer {openaikey}"
151
+ }
152
+
153
+ response = requests.post(endpoint, json=data, headers=HEADER, proxies=PROXY)
154
+ logger.debug(response.text.strip())
155
+ if "choices" not in response.json():
156
+ return response.json()
157
+ if use_completion:
158
+ return response.json()["choices"][0]["text"].strip()
159
+ else:
160
+ return response.json()["choices"][0]["message"]["content"].strip()
161
+
162
+ def replace_slot(text, entries):
163
+ for key, value in entries.items():
164
+ if not isinstance(value, str):
165
+ value = str(value)
166
+ text = text.replace("{{" + key +"}}", value.replace('"', "'").replace('\n', ""))
167
+ return text
168
+
169
+ def find_json(s):
170
+ s = s.replace("\'", "\"")
171
+ start = s.find("{")
172
+ end = s.rfind("}")
173
+ res = s[start:end+1]
174
+ res = res.replace("\n", "")
175
+ return res
176
+
177
+ def field_extract(s, field):
178
+ try:
179
+ field_rep = re.compile(f'{field}.*?:.*?"(.*?)"', re.IGNORECASE)
180
+ extracted = field_rep.search(s).group(1).replace("\"", "\'")
181
+ except:
182
+ field_rep = re.compile(f'{field}:\ *"(.*?)"', re.IGNORECASE)
183
+ extracted = field_rep.search(s).group(1).replace("\"", "\'")
184
+ return extracted
185
+
186
+ def get_id_reason(choose_str):
187
+ reason = field_extract(choose_str, "reason")
188
+ id = field_extract(choose_str, "id")
189
+ choose = {"id": id, "reason": reason}
190
+ return id.strip(), reason.strip(), choose
191
+
192
+ def record_case(success, **args):
193
+ if success:
194
+ f = open("logs/log_success.jsonl", "a")
195
+ else:
196
+ f = open("logs/log_fail.jsonl", "a")
197
+ log = args
198
+ f.write(json.dumps(log) + "\n")
199
+ f.close()
200
+
201
+ def image_to_bytes(img_url):
202
+ img_byte = io.BytesIO()
203
+ type = img_url.split(".")[-1]
204
+ load_image(img_url).save(img_byte, format="png")
205
+ img_data = img_byte.getvalue()
206
+ return img_data
207
+
208
+ def resource_has_dep(command):
209
+ args = command["args"]
210
+ for _, v in args.items():
211
+ if "<GENERATED>" in v:
212
+ return True
213
+ return False
214
+
215
+ def fix_dep(tasks):
216
+ for task in tasks:
217
+ args = task["args"]
218
+ task["dep"] = []
219
+ for k, v in args.items():
220
+ if "<GENERATED>" in v:
221
+ dep_task_id = int(v.split("-")[1])
222
+ if dep_task_id not in task["dep"]:
223
+ task["dep"].append(dep_task_id)
224
+ if len(task["dep"]) == 0:
225
+ task["dep"] = [-1]
226
+ return tasks
227
+
228
+ def unfold(tasks):
229
+ flag_unfold_task = False
230
+ try:
231
+ for task in tasks:
232
+ for key, value in task["args"].items():
233
+ if "<GENERATED>" in value:
234
+ generated_items = value.split(",")
235
+ if len(generated_items) > 1:
236
+ flag_unfold_task = True
237
+ for item in generated_items:
238
+ new_task = copy.deepcopy(task)
239
+ dep_task_id = int(item.split("-")[1])
240
+ new_task["dep"] = [dep_task_id]
241
+ new_task["args"][key] = item
242
+ tasks.append(new_task)
243
+ tasks.remove(task)
244
+ except Exception as e:
245
+ print(e)
246
+ traceback.print_exc()
247
+ logger.debug("unfold task failed.")
248
+
249
+ if flag_unfold_task:
250
+ logger.debug(f"unfold tasks: {tasks}")
251
+
252
+ return tasks
253
+
254
+ def chitchat(messages, openaikey=None):
255
+ data = {
256
+ "model": LLM,
257
+ "messages": messages,
258
+ "openaikey": openaikey
259
+ }
260
+ return send_request(data)
261
+
262
+ def parse_task(context, input, openaikey=None):
263
+ demos_or_presteps = parse_task_demos_or_presteps
264
+ messages = json.loads(demos_or_presteps)
265
+ messages.insert(0, {"role": "system", "content": parse_task_tprompt})
266
+
267
+ # cut chat logs
268
+ start = 0
269
+ while start <= len(context):
270
+ history = context[start:]
271
+ prompt = replace_slot(parse_task_prompt, {
272
+ "input": input,
273
+ "context": history
274
+ })
275
+ messages.append({"role": "user", "content": prompt})
276
+ history_text = "<im_end>\nuser<im_start>".join([m["content"] for m in messages])
277
+ num = count_tokens(LLM_encoding, history_text)
278
+ if get_max_context_length(LLM) - num > 800:
279
+ break
280
+ messages.pop()
281
+ start += 2
282
+
283
+ logger.debug(messages)
284
+ data = {
285
+ "model": LLM,
286
+ "messages": messages,
287
+ "temperature": 0,
288
+ "logit_bias": {item: config["logit_bias"]["parse_task"] for item in task_parsing_highlight_ids},
289
+ "openaikey": openaikey
290
+ }
291
+ return send_request(data)
292
+
293
+ def choose_model(input, task, metas, openaikey = None):
294
+ prompt = replace_slot(choose_model_prompt, {
295
+ "input": input,
296
+ "task": task,
297
+ "metas": metas,
298
+ })
299
+ demos_or_presteps = replace_slot(choose_model_demos_or_presteps, {
300
+ "input": input,
301
+ "task": task,
302
+ "metas": metas
303
+ })
304
+ messages = json.loads(demos_or_presteps)
305
+ messages.insert(0, {"role": "system", "content": choose_model_tprompt})
306
+ messages.append({"role": "user", "content": prompt})
307
+ logger.debug(messages)
308
+ data = {
309
+ "model": LLM,
310
+ "messages": messages,
311
+ "temperature": 0,
312
+ "logit_bias": {item: config["logit_bias"]["choose_model"] for item in choose_model_highlight_ids}, # 5
313
+ "openaikey": openaikey
314
+ }
315
+ return send_request(data)
316
+
317
+
318
+ def response_results(input, results, openaikey=None):
319
+ results = [v for k, v in sorted(results.items(), key=lambda item: item[0])]
320
+ prompt = replace_slot(response_results_prompt, {
321
+ "input": input,
322
+ })
323
+ demos_or_presteps = replace_slot(response_results_demos_or_presteps, {
324
+ "input": input,
325
+ "processes": results
326
+ })
327
+ messages = json.loads(demos_or_presteps)
328
+ messages.insert(0, {"role": "system", "content": response_results_tprompt})
329
+ messages.append({"role": "user", "content": prompt})
330
+ logger.debug(messages)
331
+ data = {
332
+ "model": LLM,
333
+ "messages": messages,
334
+ "temperature": 0,
335
+ "openaikey": openaikey
336
+ }
337
+ return send_request(data)
338
+
339
+ def huggingface_model_inference(model_id, data, task, huggingfacetoken=None):
340
+ if huggingfacetoken is None:
341
+ HUGGINGFACE_HEADERS = {}
342
+ else:
343
+ HUGGINGFACE_HEADERS = {
344
+ "Authorization": f"Bearer {huggingfacetoken}",
345
+ }
346
+ task_url = f"https://api-inference.huggingface.co/models/{model_id}" # InferenceApi does not yet support some tasks
347
+ inference = InferenceApi(repo_id=model_id, token=huggingfacetoken)
348
+
349
+ # NLP tasks
350
+ if task == "question-answering":
351
+ inputs = {"question": data["text"], "context": (data["context"] if "context" in data else "" )}
352
+ result = inference(inputs)
353
+ if task == "sentence-similarity":
354
+ inputs = {"source_sentence": data["text1"], "target_sentence": data["text2"]}
355
+ result = inference(inputs)
356
+ if task in ["text-classification", "token-classification", "text2text-generation", "summarization", "translation", "conversational", "text-generation"]:
357
+ inputs = data["text"]
358
+ result = inference(inputs)
359
+
360
+ # CV tasks
361
+ if task == "visual-question-answering" or task == "document-question-answering":
362
+ img_url = data["image"]
363
+ text = data["text"]
364
+ img_data = image_to_bytes(img_url)
365
+ img_base64 = base64.b64encode(img_data).decode("utf-8")
366
+ json_data = {}
367
+ json_data["inputs"] = {}
368
+ json_data["inputs"]["question"] = text
369
+ json_data["inputs"]["image"] = img_base64
370
+ result = requests.post(task_url, headers=HUGGINGFACE_HEADERS, json=json_data).json()
371
+ # result = inference(inputs) # not support
372
+
373
+ if task == "image-to-image":
374
+ img_url = data["image"]
375
+ img_data = image_to_bytes(img_url)
376
+ # result = inference(data=img_data) # not support
377
+ HUGGINGFACE_HEADERS["Content-Length"] = str(len(img_data))
378
+ r = requests.post(task_url, headers=HUGGINGFACE_HEADERS, data=img_data)
379
+ result = r.json()
380
+ if "path" in result:
381
+ result["generated image"] = result.pop("path")
382
+
383
+ if task == "text-to-image":
384
+ inputs = data["text"]
385
+ img = inference(inputs)
386
+ name = str(uuid.uuid4())[:4]
387
+ img.save(f"public/images/{name}.png")
388
+ result = {}
389
+ result["generated image"] = f"/images/{name}.png"
390
+
391
+ if task == "image-segmentation":
392
+ img_url = data["image"]
393
+ img_data = image_to_bytes(img_url)
394
+ image = Image.open(BytesIO(img_data))
395
+ predicted = inference(data=img_data)
396
+ colors = []
397
+ for i in range(len(predicted)):
398
+ colors.append((random.randint(100, 255), random.randint(100, 255), random.randint(100, 255), 155))
399
+ for i, pred in enumerate(predicted):
400
+ label = pred["label"]
401
+ mask = pred.pop("mask").encode("utf-8")
402
+ mask = base64.b64decode(mask)
403
+ mask = Image.open(BytesIO(mask), mode='r')
404
+ mask = mask.convert('L')
405
+
406
+ layer = Image.new('RGBA', mask.size, colors[i])
407
+ image.paste(layer, (0, 0), mask)
408
+ name = str(uuid.uuid4())[:4]
409
+ image.save(f"public/images/{name}.jpg")
410
+ result = {}
411
+ result["generated image with segmentation mask"] = f"/images/{name}.jpg"
412
+ result["predicted"] = predicted
413
+
414
+ if task == "object-detection":
415
+ img_url = data["image"]
416
+ img_data = image_to_bytes(img_url)
417
+ predicted = inference(data=img_data)
418
+ image = Image.open(BytesIO(img_data))
419
+ draw = ImageDraw.Draw(image)
420
+ labels = list(item['label'] for item in predicted)
421
+ color_map = {}
422
+ for label in labels:
423
+ if label not in color_map:
424
+ color_map[label] = (random.randint(0, 255), random.randint(0, 100), random.randint(0, 255))
425
+ for label in predicted:
426
+ box = label["box"]
427
+ draw.rectangle(((box["xmin"], box["ymin"]), (box["xmax"], box["ymax"])), outline=color_map[label["label"]], width=2)
428
+ draw.text((box["xmin"]+5, box["ymin"]-15), label["label"], fill=color_map[label["label"]])
429
+ name = str(uuid.uuid4())[:4]
430
+ image.save(f"public/images/{name}.jpg")
431
+ result = {}
432
+ result["generated image with predicted box"] = f"/images/{name}.jpg"
433
+ result["predicted"] = predicted
434
+
435
+ if task in ["image-classification"]:
436
+ img_url = data["image"]
437
+ img_data = image_to_bytes(img_url)
438
+ result = inference(data=img_data)
439
+
440
+ if task == "image-to-text":
441
+ img_url = data["image"]
442
+ img_data = image_to_bytes(img_url)
443
+ HUGGINGFACE_HEADERS["Content-Length"] = str(len(img_data))
444
+ r = requests.post(task_url, headers=HUGGINGFACE_HEADERS, data=img_data)
445
+ result = {}
446
+ if "generated_text" in r.json()[0]:
447
+ result["generated text"] = r.json()[0].pop("generated_text")
448
+
449
+ # AUDIO tasks
450
+ if task == "text-to-speech":
451
+ inputs = data["text"]
452
+ response = inference(inputs, raw_response=True)
453
+ # response = requests.post(task_url, headers=HUGGINGFACE_HEADERS, json={"inputs": text})
454
+ name = str(uuid.uuid4())[:4]
455
+ with open(f"public/audios/{name}.flac", "wb") as f:
456
+ f.write(response.content)
457
+ result = {"generated audio": f"/audios/{name}.flac"}
458
+ if task in ["automatic-speech-recognition", "audio-to-audio", "audio-classification"]:
459
+ audio_url = data["audio"]
460
+ audio_data = requests.get(audio_url, timeout=10).content
461
+ response = inference(data=audio_data, raw_response=True)
462
+ result = response.json()
463
+ if task == "audio-to-audio":
464
+ content = None
465
+ type = None
466
+ for k, v in result[0].items():
467
+ if k == "blob":
468
+ content = base64.b64decode(v.encode("utf-8"))
469
+ if k == "content-type":
470
+ type = "audio/flac".split("/")[-1]
471
+ audio = AudioSegment.from_file(BytesIO(content))
472
+ name = str(uuid.uuid4())[:4]
473
+ audio.export(f"public/audios/{name}.{type}", format=type)
474
+ result = {"generated audio": f"/audios/{name}.{type}"}
475
+ return result
476
+
477
+ def local_model_inference(model_id, data, task):
478
+ inference = partial(models, model_id)
479
+ # contronlet
480
+ if model_id.startswith("lllyasviel/sd-controlnet-"):
481
+ img_url = data["image"]
482
+ text = data["text"]
483
+ results = inference({"img_url": img_url, "text": text})
484
+ if "path" in results:
485
+ results["generated image"] = results.pop("path")
486
+ return results
487
+ if model_id.endswith("-control"):
488
+ img_url = data["image"]
489
+ results = inference({"img_url": img_url})
490
+ if "path" in results:
491
+ results["generated image"] = results.pop("path")
492
+ return results
493
+
494
+ if task == "text-to-video":
495
+ results = inference(data)
496
+ if "path" in results:
497
+ results["generated video"] = results.pop("path")
498
+ return results
499
+
500
+ # NLP tasks
501
+ if task == "question-answering" or task == "sentence-similarity":
502
+ results = inference(json=data)
503
+ return results
504
+ if task in ["text-classification", "token-classification", "text2text-generation", "summarization", "translation", "conversational", "text-generation"]:
505
+ results = inference(json=data)
506
+ return results
507
+
508
+ # CV tasks
509
+ if task == "depth-estimation":
510
+ img_url = data["image"]
511
+ results = inference({"img_url": img_url})
512
+ if "path" in results:
513
+ results["generated depth image"] = results.pop("path")
514
+ return results
515
+ if task == "image-segmentation":
516
+ img_url = data["image"]
517
+ results = inference({"img_url": img_url})
518
+ results["generated image with segmentation mask"] = results.pop("path")
519
+ return results
520
+ if task == "image-to-image":
521
+ img_url = data["image"]
522
+ results = inference({"img_url": img_url})
523
+ if "path" in results:
524
+ results["generated image"] = results.pop("path")
525
+ return results
526
+ if task == "text-to-image":
527
+ results = inference(data)
528
+ if "path" in results:
529
+ results["generated image"] = results.pop("path")
530
+ return results
531
+ if task == "object-detection":
532
+ img_url = data["image"]
533
+ predicted = inference({"img_url": img_url})
534
+ if "error" in predicted:
535
+ return predicted
536
+ image = load_image(img_url)
537
+ draw = ImageDraw.Draw(image)
538
+ labels = list(item['label'] for item in predicted)
539
+ color_map = {}
540
+ for label in labels:
541
+ if label not in color_map:
542
+ color_map[label] = (random.randint(0, 255), random.randint(0, 100), random.randint(0, 255))
543
+ for label in predicted:
544
+ box = label["box"]
545
+ draw.rectangle(((box["xmin"], box["ymin"]), (box["xmax"], box["ymax"])), outline=color_map[label["label"]], width=2)
546
+ draw.text((box["xmin"]+5, box["ymin"]-15), label["label"], fill=color_map[label["label"]])
547
+ name = str(uuid.uuid4())[:4]
548
+ image.save(f"public/images/{name}.jpg")
549
+ results = {}
550
+ results["generated image with predicted box"] = f"/images/{name}.jpg"
551
+ results["predicted"] = predicted
552
+ return results
553
+ if task in ["image-classification", "image-to-text", "document-question-answering", "visual-question-answering"]:
554
+ img_url = data["image"]
555
+ text = None
556
+ if "text" in data:
557
+ text = data["text"]
558
+ results = inference({"img_url": img_url, "text": text})
559
+ return results
560
+ # AUDIO tasks
561
+ if task == "text-to-speech":
562
+ results = inference(data)
563
+ if "path" in results:
564
+ results["generated audio"] = results.pop("path")
565
+ return results
566
+ if task in ["automatic-speech-recognition", "audio-to-audio", "audio-classification"]:
567
+ audio_url = data["audio"]
568
+ results = inference({"audio_url": audio_url})
569
+ return results
570
+
571
+
572
+ def model_inference(model_id, data, hosted_on, task, huggingfacetoken=None):
573
+ if huggingfacetoken:
574
+ HUGGINGFACE_HEADERS = {
575
+ "Authorization": f"Bearer {huggingfacetoken}",
576
+ }
577
+ else:
578
+ HUGGINGFACE_HEADERS = None
579
+ if hosted_on == "unknown":
580
+ r = status(model_id)
581
+ logger.debug("Local Server Status: " + str(r))
582
+ if "loaded" in r and r["loaded"]:
583
+ hosted_on = "local"
584
+ else:
585
+ huggingfaceStatusUrl = f"https://api-inference.huggingface.co/status/{model_id}"
586
+ r = requests.get(huggingfaceStatusUrl, headers=HUGGINGFACE_HEADERS, proxies=PROXY)
587
+ logger.debug("Huggingface Status: " + str(r.json()))
588
+ if "loaded" in r and r["loaded"]:
589
+ hosted_on = "huggingface"
590
+ try:
591
+ if hosted_on == "local":
592
+ inference_result = local_model_inference(model_id, data, task)
593
+ elif hosted_on == "huggingface":
594
+ inference_result = huggingface_model_inference(model_id, data, task, huggingfacetoken)
595
+ except Exception as e:
596
+ print(e)
597
+ traceback.print_exc()
598
+ inference_result = {"error":{"message": str(e)}}
599
+ return inference_result
600
+
601
+
602
+ def get_model_status(model_id, url, headers, queue = None):
603
+ endpoint_type = "huggingface" if "huggingface" in url else "local"
604
+ if "huggingface" in url:
605
+ r = requests.get(url, headers=headers, proxies=PROXY)
606
+ else:
607
+ r = status(model_id)
608
+ if "loaded" in r and r["loaded"]:
609
+ if queue:
610
+ queue.put((model_id, True, endpoint_type))
611
+ return True
612
+ else:
613
+ if queue:
614
+ queue.put((model_id, False, None))
615
+ return False
616
+
617
+ def get_avaliable_models(candidates, topk=10, huggingfacetoken = None):
618
+ all_available_models = {"local": [], "huggingface": []}
619
+ threads = []
620
+ result_queue = Queue()
621
+ HUGGINGFACE_HEADERS = {
622
+ "Authorization": f"Bearer {huggingfacetoken}",
623
+ }
624
+ for candidate in candidates:
625
+ model_id = candidate["id"]
626
+
627
+ if inference_mode != "local":
628
+ huggingfaceStatusUrl = f"https://api-inference.huggingface.co/status/{model_id}"
629
+ thread = threading.Thread(target=get_model_status, args=(model_id, huggingfaceStatusUrl, HUGGINGFACE_HEADERS, result_queue))
630
+ threads.append(thread)
631
+ thread.start()
632
+
633
+ if inference_mode != "huggingface" and config["local_deployment"] != "minimal":
634
+ thread = threading.Thread(target=get_model_status, args=(model_id, "", {}, result_queue))
635
+ threads.append(thread)
636
+ thread.start()
637
+
638
+ result_count = len(threads)
639
+ while result_count:
640
+ model_id, status, endpoint_type = result_queue.get()
641
+ if status and model_id not in all_available_models:
642
+ all_available_models[endpoint_type].append(model_id)
643
+ if len(all_available_models["local"] + all_available_models["huggingface"]) >= topk:
644
+ break
645
+ result_count -= 1
646
+
647
+ for thread in threads:
648
+ thread.join()
649
+
650
+ return all_available_models
651
+
652
+ def collect_result(command, choose, inference_result):
653
+ result = {"task": command}
654
+ result["inference result"] = inference_result
655
+ result["choose model result"] = choose
656
+ logger.debug(f"inference result: {inference_result}")
657
+ return result
658
+
659
+
660
+ def run_task(input, command, results, openaikey = None, huggingfacetoken = None):
661
+ id = command["id"]
662
+ args = command["args"]
663
+ task = command["task"]
664
+ deps = command["dep"]
665
+ if deps[0] != -1:
666
+ dep_tasks = [results[dep] for dep in deps]
667
+ else:
668
+ dep_tasks = []
669
+
670
+ logger.debug(f"Run task: {id} - {task}")
671
+ logger.debug("Deps: " + json.dumps(dep_tasks))
672
+
673
+ if deps[0] != -1:
674
+ if "image" in args and "<GENERATED>-" in args["image"]:
675
+ resource_id = int(args["image"].split("-")[1])
676
+ if "generated image" in results[resource_id]["inference result"]:
677
+ args["image"] = results[resource_id]["inference result"]["generated image"]
678
+ if "audio" in args and "<GENERATED>-" in args["audio"]:
679
+ resource_id = int(args["audio"].split("-")[1])
680
+ if "generated audio" in results[resource_id]["inference result"]:
681
+ args["audio"] = results[resource_id]["inference result"]["generated audio"]
682
+ if "text" in args and "<GENERATED>-" in args["text"]:
683
+ resource_id = int(args["text"].split("-")[1])
684
+ if "generated text" in results[resource_id]["inference result"]:
685
+ args["text"] = results[resource_id]["inference result"]["generated text"]
686
+
687
+ text = image = audio = None
688
+ for dep_task in dep_tasks:
689
+ if "generated text" in dep_task["inference result"]:
690
+ text = dep_task["inference result"]["generated text"]
691
+ logger.debug("Detect the generated text of dependency task (from results):" + text)
692
+ elif "text" in dep_task["task"]["args"]:
693
+ text = dep_task["task"]["args"]["text"]
694
+ logger.debug("Detect the text of dependency task (from args): " + text)
695
+ if "generated image" in dep_task["inference result"]:
696
+ image = dep_task["inference result"]["generated image"]
697
+ logger.debug("Detect the generated image of dependency task (from results): " + image)
698
+ elif "image" in dep_task["task"]["args"]:
699
+ image = dep_task["task"]["args"]["image"]
700
+ logger.debug("Detect the image of dependency task (from args): " + image)
701
+ if "generated audio" in dep_task["inference result"]:
702
+ audio = dep_task["inference result"]["generated audio"]
703
+ logger.debug("Detect the generated audio of dependency task (from results): " + audio)
704
+ elif "audio" in dep_task["task"]["args"]:
705
+ audio = dep_task["task"]["args"]["audio"]
706
+ logger.debug("Detect the audio of dependency task (from args): " + audio)
707
+
708
+ if "image" in args and "<GENERATED>" in args["image"]:
709
+ if image:
710
+ args["image"] = image
711
+ if "audio" in args and "<GENERATED>" in args["audio"]:
712
+ if audio:
713
+ args["audio"] = audio
714
+ if "text" in args and "<GENERATED>" in args["text"]:
715
+ if text:
716
+ args["text"] = text
717
+
718
+ for resource in ["image", "audio"]:
719
+ if resource in args and not args[resource].startswith("public/") and len(args[resource]) > 0 and not args[resource].startswith("http"):
720
+ args[resource] = f"public/{args[resource]}"
721
+
722
+ if "-text-to-image" in command['task'] and "text" not in args:
723
+ logger.debug("control-text-to-image task, but text is empty, so we use control-generation instead.")
724
+ control = task.split("-")[0]
725
+
726
+ if control == "seg":
727
+ task = "image-segmentation"
728
+ command['task'] = task
729
+ elif control == "depth":
730
+ task = "depth-estimation"
731
+ command['task'] = task
732
+ else:
733
+ task = f"{control}-control"
734
+
735
+ command["args"] = args
736
+ logger.debug(f"parsed task: {command}")
737
+
738
+ if task.endswith("-text-to-image") or task.endswith("-control"):
739
+ if inference_mode != "huggingface":
740
+ if task.endswith("-text-to-image"):
741
+ control = task.split("-")[0]
742
+ best_model_id = f"lllyasviel/sd-controlnet-{control}"
743
+ else:
744
+ best_model_id = task
745
+ hosted_on = "local"
746
+ reason = "ControlNet is the best model for this task."
747
+ choose = {"id": best_model_id, "reason": reason}
748
+ logger.debug(f"chosen model: {choose}")
749
+ else:
750
+ logger.warning(f"Task {command['task']} is not available. ControlNet need to be deployed locally.")
751
+ record_case(success=False, **{"input": input, "task": command, "reason": f"Task {command['task']} is not available. ControlNet need to be deployed locally.", "op":"message"})
752
+ inference_result = {"error": f"service related to ControlNet is not available."}
753
+ results[id] = collect_result(command, "", inference_result)
754
+ return False
755
+ elif task in ["summarization", "translation", "conversational", "text-generation", "text2text-generation"]: # ChatGPT Can do
756
+ best_model_id = "ChatGPT"
757
+ reason = "ChatGPT performs well on some NLP tasks as well."
758
+ choose = {"id": best_model_id, "reason": reason}
759
+ messages = [{
760
+ "role": "user",
761
+ "content": f"[ {input} ] contains a task in JSON format {command}, 'task' indicates the task type and 'args' indicates the arguments required for the task. Don't explain the task to me, just help me do it and give me the result. The result must be in text form without any urls."
762
+ }]
763
+ response = chitchat(messages, openaikey)
764
+ results[id] = collect_result(command, choose, {"response": response})
765
+ return True
766
+ else:
767
+ if task not in MODELS_MAP:
768
+ logger.warning(f"no available models on {task} task.")
769
+ record_case(success=False, **{"input": input, "task": command, "reason": f"task not support: {command['task']}", "op":"message"})
770
+ inference_result = {"error": f"{command['task']} not found in available tasks."}
771
+ results[id] = collect_result(command, "", inference_result)
772
+ return False
773
+
774
+ candidates = MODELS_MAP[task][:20]
775
+ all_avaliable_models = get_avaliable_models(candidates, config["num_candidate_models"], huggingfacetoken)
776
+ all_avaliable_model_ids = all_avaliable_models["local"] + all_avaliable_models["huggingface"]
777
+ logger.debug(f"avaliable models on {command['task']}: {all_avaliable_models}")
778
+
779
+ if len(all_avaliable_model_ids) == 0:
780
+ logger.warning(f"no available models on {command['task']}")
781
+ record_case(success=False, **{"input": input, "task": command, "reason": f"no available models: {command['task']}", "op":"message"})
782
+ inference_result = {"error": f"no available models on {command['task']} task."}
783
+ results[id] = collect_result(command, "", inference_result)
784
+ return False
785
+
786
+ if len(all_avaliable_model_ids) == 1:
787
+ best_model_id = all_avaliable_model_ids[0]
788
+ hosted_on = "local" if best_model_id in all_avaliable_models["local"] else "huggingface"
789
+ reason = "Only one model available."
790
+ choose = {"id": best_model_id, "reason": reason}
791
+ logger.debug(f"chosen model: {choose}")
792
+ else:
793
+ cand_models_info = [
794
+ {
795
+ "id": model["id"],
796
+ "inference endpoint": all_avaliable_models.get(
797
+ "local" if model["id"] in all_avaliable_models["local"] else "huggingface"
798
+ ),
799
+ "likes": model.get("likes"),
800
+ "description": model.get("description", "")[:config["max_description_length"]],
801
+ "language": model.get("language"),
802
+ "tags": model.get("tags"),
803
+ }
804
+ for model in candidates
805
+ if model["id"] in all_avaliable_model_ids
806
+ ]
807
+
808
+ choose_str = choose_model(input, command, cand_models_info, openaikey)
809
+ logger.debug(f"chosen model: {choose_str}")
810
+ try:
811
+ choose = json.loads(choose_str)
812
+ reason = choose["reason"]
813
+ best_model_id = choose["id"]
814
+ hosted_on = "local" if best_model_id in all_avaliable_models["local"] else "huggingface"
815
+ except Exception as e:
816
+ logger.warning(f"the response [ {choose_str} ] is not a valid JSON, try to find the model id and reason in the response.")
817
+ choose_str = find_json(choose_str)
818
+ best_model_id, reason, choose = get_id_reason(choose_str)
819
+ hosted_on = "local" if best_model_id in all_avaliable_models["local"] else "huggingface"
820
+ inference_result = model_inference(best_model_id, args, hosted_on, command['task'], huggingfacetoken)
821
+
822
+ if "error" in inference_result:
823
+ logger.warning(f"Inference error: {inference_result['error']}")
824
+ record_case(success=False, **{"input": input, "task": command, "reason": f"inference error: {inference_result['error']}", "op":"message"})
825
+ results[id] = collect_result(command, choose, inference_result)
826
+ return False
827
+
828
+ results[id] = collect_result(command, choose, inference_result)
829
+ return True
830
+
831
+ def chat_huggingface(messages, openaikey = None, huggingfacetoken = None, return_planning = False, return_results = False):
832
+ start = time.time()
833
+ context = messages[:-1]
834
+ input = messages[-1]["content"]
835
+ logger.info("*"*80)
836
+ logger.info(f"input: {input}")
837
+
838
+ task_str = parse_task(context, input, openaikey)
839
+ logger.info(task_str)
840
+
841
+ if "error" in task_str:
842
+ return str(task_str), {}
843
+ else:
844
+ task_str = task_str.strip()
845
+
846
+ try:
847
+ tasks = json.loads(task_str)
848
+ except Exception as e:
849
+ logger.debug(e)
850
+ response = chitchat(messages, openaikey)
851
+ record_case(success=False, **{"input": input, "task": task_str, "reason": "task parsing fail", "op":"chitchat"})
852
+ return response, {}
853
+
854
+ if task_str == "[]": # using LLM response for empty task
855
+ record_case(success=False, **{"input": input, "task": [], "reason": "task parsing fail: empty", "op": "chitchat"})
856
+ response = chitchat(messages, openaikey)
857
+ return response, {}
858
+
859
+ if len(tasks)==1 and tasks[0]["task"] in ["summarization", "translation", "conversational", "text-generation", "text2text-generation"]:
860
+ record_case(success=True, **{"input": input, "task": tasks, "reason": "task parsing fail: empty", "op": "chitchat"})
861
+ response = chitchat(messages, openaikey)
862
+ best_model_id = "ChatGPT"
863
+ reason = "ChatGPT performs well on some NLP tasks as well."
864
+ choose = {"id": best_model_id, "reason": reason}
865
+ return response, collect_result(tasks[0], choose, {"response": response})
866
+
867
+
868
+ tasks = unfold(tasks)
869
+ tasks = fix_dep(tasks)
870
+ logger.debug(tasks)
871
+
872
+ if return_planning:
873
+ return tasks
874
+
875
+ results = {}
876
+ threads = []
877
+ tasks = tasks[:]
878
+ d = dict()
879
+ retry = 0
880
+ while True:
881
+ num_threads = len(threads)
882
+ for task in tasks:
883
+ dep = task["dep"]
884
+ # logger.debug(f"d.keys(): {d.keys()}, dep: {dep}")
885
+ for dep_id in dep:
886
+ if dep_id >= task["id"]:
887
+ task["dep"] = [-1]
888
+ dep = [-1]
889
+ break
890
+ if len(list(set(dep).intersection(d.keys()))) == len(dep) or dep[0] == -1:
891
+ tasks.remove(task)
892
+ thread = threading.Thread(target=run_task, args=(input, task, d, openaikey, huggingfacetoken))
893
+ thread.start()
894
+ threads.append(thread)
895
+ if num_threads == len(threads):
896
+ time.sleep(0.5)
897
+ retry += 1
898
+ if retry > 160:
899
+ logger.debug("User has waited too long, Loop break.")
900
+ break
901
+ if len(tasks) == 0:
902
+ break
903
+ for thread in threads:
904
+ thread.join()
905
+
906
+ results = d.copy()
907
+
908
+ logger.debug(results)
909
+ if return_results:
910
+ return results
911
+
912
+ response = response_results(input, results, openaikey).strip()
913
+
914
+ end = time.time()
915
+ during = end - start
916
+
917
+ answer = {"message": response}
918
+ record_case(success=True, **{"input": input, "task": task_str, "results": results, "response": response, "during": during, "op":"response"})
919
+ logger.info(f"response: {response}")
920
+ return response, results
config.gradio.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ openai:
2
+ key: gradio # "gradio" (set when request) or your_personal_key
3
+ huggingface:
4
+ token: # required: huggingface token @ https://huggingface.co/settings/tokens
5
+ dev: false
6
+ debug: true
7
+ log_file: logs/debug.log
8
+ model: text-davinci-003 # text-davinci-003
9
+ use_completion: true
10
+ inference_mode: hybrid # local, huggingface or hybrid
11
+ local_deployment: standard # minimal, standard or full
12
+ num_candidate_models: 5
13
+ max_description_length: 100
14
+ proxy:
15
+ logit_bias:
16
+ parse_task: 0.5
17
+ choose_model: 5
18
+ tprompt:
19
+ parse_task: >-
20
+ #1 Task Planning Stage: The AI assistant can parse user input to several tasks: [{"task": task, "id": task_id, "dep": dependency_task_id, "args": {"text": text or <GENERATED>-dep_id, "image": image_url or <GENERATED>-dep_id, "audio": audio_url or <GENERATED>-dep_id}}]. The special tag "<GENERATED>-dep_id" refer to the one genereted text/image/audio in the dependency task (Please consider whether the dependency task generates resources of this type.) and "dep_id" must be in "dep" list. The "dep" field denotes the ids of the previous prerequisite tasks which generate a new resource that the current task relies on. The "args" field must in ["text", "image", "audio"], nothing else. The task MUST be selected from the following options: "token-classification", "text2text-generation", "summarization", "translation", "question-answering", "conversational", "text-generation", "sentence-similarity", "tabular-classification", "object-detection", "image-classification", "image-to-image", "image-to-text", "text-to-image", "text-to-video", "visual-question-answering", "document-question-answering", "image-segmentation", "depth-estimation", "text-to-speech", "automatic-speech-recognition", "audio-to-audio", "audio-classification", "canny-control", "hed-control", "mlsd-control", "normal-control", "openpose-control", "canny-text-to-image", "depth-text-to-image", "hed-text-to-image", "mlsd-text-to-image", "normal-text-to-image", "openpose-text-to-image", "seg-text-to-image". There may be multiple tasks of the same type. Think step by step about all the tasks needed to resolve the user's request. Parse out as few tasks as possible while ensuring that the user request can be resolved. Pay attention to the dependencies and order among tasks. If the user input can't be parsed, you need to reply empty JSON [].
21
+ choose_model: >-
22
+ #2 Model Selection Stage: Given the user request and the parsed tasks, the AI assistant helps the user to select a suitable model from a list of models to process the user request. The assistant should focus more on the description of the model and find the model that has the most potential to solve requests and tasks. Also, prefer models with local inference endpoints for speed and stability.
23
+ response_results: >-
24
+ #4 Response Generation Stage: With the task execution logs, the AI assistant needs to describe the process and inference results.
25
+ demos_or_presteps:
26
+ parse_task: demos/demo_parse_task.json
27
+ choose_model: demos/demo_choose_model.json
28
+ response_results: demos/demo_response_results.json
29
+ prompt:
30
+ parse_task: The chat log [ {{context}} ] may contain the resources I mentioned. Now I input { {{input}} }. Pay attention to the input and output types of tasks and the dependencies between tasks.
31
+ choose_model: >-
32
+ Please choose the most suitable model from {{metas}} for the task {{task}}. The output must be in a strict JSON format: {"id": "id", "reason": "your detail reasons for the choice"}.
33
+ response_results: >-
34
+ Yes. Please first think carefully and directly answer my request based on the inference results. Some of the inferences may not always turn out to be correct and require you to make careful consideration in making decisions. Then please detail your workflow including the used models and inference results for my request in your friendly tone. Please filter out information that is not relevant to my request. Tell me the complete path or urls of files in inference results. If there is nothing in the results, please tell me you can't make it. }
data/p0_models.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
demos/demo_choose_model.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "role": "user",
4
+ "content": "{{input}}"
5
+ },
6
+ {
7
+ "role": "assistant",
8
+ "content": "{{task}}"
9
+ }
10
+ ]
demos/demo_parse_task.json ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "role": "user",
4
+ "content": "Give you some pictures e1.jpg, e2.png, e3.jpg, help me count the number of sheep?"
5
+ },
6
+ {
7
+ "role": "assistant",
8
+ "content": "[{\"task\": \"image-to-text\", \"id\": 0, \"dep\": [-1], \"args\": {\"image\": \"e1.jpg\" }}, {\"task\": \"object-detection\", \"id\": 1, \"dep\": [-1], \"args\": {\"image\": \"e1.jpg\" }}, {\"task\": \"visual-question-answering\", \"id\": 2, \"dep\": [1], \"args\": {\"image\": \"<GENERATED>-1\", \"text\": \"How many sheep in the picture\"}} }}, {\"task\": \"image-to-text\", \"id\": 3, \"dep\": [-1], \"args\": {\"image\": \"e2.png\" }}, {\"task\": \"object-detection\", \"id\": 4, \"dep\": [-1], \"args\": {\"image\": \"e2.png\" }}, {\"task\": \"visual-question-answering\", \"id\": 5, \"dep\": [4], \"args\": {\"image\": \"<GENERATED>-4\", \"text\": \"How many sheep in the picture\"}} }}, {\"task\": \"image-to-text\", \"id\": 6, \"dep\": [-1], \"args\": {\"image\": \"e3.jpg\" }}, {\"task\": \"object-detection\", \"id\": 7, \"dep\": [-1], \"args\": {\"image\": \"e3.jpg\" }}, {\"task\": \"visual-question-answering\", \"id\": 8, \"dep\": [7], \"args\": {\"image\": \"<GENERATED>-7\", \"text\": \"How many sheep in the picture\"}}]"
9
+ },
10
+
11
+ {
12
+ "role":"user",
13
+ "content":"Look at /e.jpg, can you tell me how many objects in the picture? Give me a picture and video similar to this one."
14
+ },
15
+ {
16
+ "role":"assistant",
17
+ "content":"[{\"task\": \"image-to-text\", \"id\": 0, \"dep\": [-1], \"args\": {\"image\": \"/e.jpg\" }}, {\"task\": \"object-detection\", \"id\": 1, \"dep\": [-1], \"args\": {\"image\": \"/e.jpg\" }}, {\"task\": \"visual-question-answering\", \"id\": 2, \"dep\": [1], \"args\": {\"image\": \"<GENERATED>-1\", \"text\": \"how many objects in the picture?\" }}, {\"task\": \"text-to-image\", \"id\": 3, \"dep\": [0], \"args\": {\"text\": \"<GENERATED-0>\" }}, {\"task\": \"image-to-image\", \"id\": 4, \"dep\": [-1], \"args\": {\"image\": \"/e.jpg\" }}, {\"task\": \"text-to-video\", \"id\": 5, \"dep\": [0], \"args\": {\"text\": \"<GENERATED-0>\" }}]"
18
+ },
19
+
20
+ {
21
+ "role":"user",
22
+ "content":"given a document /images/e.jpeg, answer me what is the student amount? And describe the image with your voice"
23
+ },
24
+ {
25
+ "role":"assistant",
26
+ "content":"{\"task\": \"document-question-answering\", \"id\": 0, \"dep\": [-1], \"args\": {\"image\": \"/images/e.jpeg\", \"text\": \"what is the student amount?\" }}, {\"task\": \"visual-question-answering\", \"id\": 1, \"dep\": [-1], \"args\": {\"image\": \"/images/e.jpeg\", \"text\": \"what is the student amount?\" }}, {\"task\": \"image-to-text\", \"id\": 2, \"dep\": [-1], \"args\": {\"image\": \"/images/e.jpg\" }}, {\"task\": \"text-to-speech\", \"id\": 3, \"dep\": [2], \"args\": {\"text\": \"<GENERATED>-2\" }}]"
27
+ },
28
+
29
+ {
30
+ "role": "user",
31
+ "content": "Given an image /example.jpg, first generate a hed image, then based on the hed image generate a new image where a girl is reading a book"
32
+ },
33
+ {
34
+ "role": "assistant",
35
+ "content": "[{\"task\": \"openpose-control\", \"id\": 0, \"dep\": [-1], \"args\": {\"image\": \"/example.jpg\" }}, {\"task\": \"openpose-text-to-image\", \"id\": 1, \"dep\": [0], \"args\": {\"text\": \"a girl is reading a book\", \"image\": \"<GENERATED>-0\" }}]"
36
+ },
37
+
38
+ {
39
+ "role": "user",
40
+ "content": "please show me a video and an image of (based on the text) 'a boy is running' and dub it"
41
+ },
42
+ {
43
+ "role": "assistant",
44
+ "content": "[{\"task\": \"text-to-video\", \"id\": 0, \"dep\": [-1], \"args\": {\"text\": \"a boy is running\" }}, {\"task\": \"text-to-speech\", \"id\": 1, \"dep\": [-1], \"args\": {\"text\": \"a boy is running\" }}, {\"task\": \"text-to-image\", \"id\": 2, \"dep\": [-1], \"args\": {\"text\": \"a boy is running\" }}]"
45
+ },
46
+
47
+
48
+ {
49
+ "role": "user",
50
+ "content": "please show me a joke and an image of cat"
51
+ },
52
+ {
53
+ "role": "assistant",
54
+ "content": "[{\"task\": \"conversational\", \"id\": 0, \"dep\": [-1], \"args\": {\"text\": \"please show me a joke of cat\" }}, {\"task\": \"text-to-image\", \"id\": 1, \"dep\": [-1], \"args\": {\"text\": \"a photo of cat\" }}]"
55
+ },
56
+
57
+ {
58
+ "role": "user",
59
+ "content": "give me a picture about a cut dog, then describe the image to me and tell a story about it"
60
+ },
61
+ {
62
+ "role": "assistant",
63
+ "content": "[{\"task\": \"text-to-image\", \"id\": 0, \"dep\": [-1], \"args\": {\"text\": \"a picture of a cut dog\" }}, {\"task\": \"image-to-text\", \"id\": 1, \"dep\": [0], \"args\": {\"image\": \"<GENERATED>-0\" }}, {\"task\": \"text-generation\", \"id\": 2, \"dep\": [1], \"args\": {\"text\": \"<GENERATED>-1\" }}, {\"task\": \"text-to-speech\", \"id\": 3, \"dep\": [2], \"args\": {\"text\": \"<GENERATED>-2\" }}]"
64
+ }
65
+ ]
demos/demo_response_results.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "role": "user",
4
+ "content": "{{input}}"
5
+ },
6
+ {
7
+ "role": "assistant",
8
+ "content": "Before give you a response, I want to introduce my workflow for your request, which is shown in the following JSON data: {{processes}}. Do you have any demands regarding my response?"
9
+ }
10
+ ]
get_token_ids.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tiktoken
2
+
3
+ encodings = {
4
+ "gpt-3.5-turbo": tiktoken.get_encoding("cl100k_base"),
5
+ "gpt-3.5-turbo-0301": tiktoken.get_encoding("cl100k_base"),
6
+ "text-davinci-003": tiktoken.get_encoding("p50k_base"),
7
+ "text-davinci-002": tiktoken.get_encoding("p50k_base"),
8
+ "text-davinci-001": tiktoken.get_encoding("r50k_base"),
9
+ "text-curie-001": tiktoken.get_encoding("r50k_base"),
10
+ "text-babbage-001": tiktoken.get_encoding("r50k_base"),
11
+ "text-ada-001": tiktoken.get_encoding("r50k_base"),
12
+ "davinci": tiktoken.get_encoding("r50k_base"),
13
+ "curie": tiktoken.get_encoding("r50k_base"),
14
+ "babbage": tiktoken.get_encoding("r50k_base"),
15
+ "ada": tiktoken.get_encoding("r50k_base"),
16
+ }
17
+
18
+ max_length = {
19
+ "gpt-3.5-turbo": 4096,
20
+ "gpt-3.5-turbo-0301": 4096,
21
+ "text-davinci-003": 4096,
22
+ "text-davinci-002": 4096,
23
+ "text-davinci-001": 2049,
24
+ "text-curie-001": 2049,
25
+ "text-babbage-001": 2049,
26
+ "text-ada-001": 2049,
27
+ "davinci": 2049,
28
+ "curie": 2049,
29
+ "babbage": 2049,
30
+ "ada": 2049
31
+ }
32
+
33
+ def count_tokens(model_name, text):
34
+ return len(encodings[model_name].encode(text))
35
+
36
+ def get_max_context_length(model_name):
37
+ return max_length[model_name]
38
+
39
+ def get_token_ids_for_task_parsing(model_name):
40
+ text = '''{"task": "text-classification", "token-classification", "text2text-generation", "summarization", "translation", "question-answering", "conversational", "text-generation", "sentence-similarity", "tabular-classification", "object-detection", "image-classification", "image-to-image", "image-to-text", "text-to-image", "visual-question-answering", "document-question-answering", "image-segmentation", "text-to-speech", "text-to-video", "automatic-speech-recognition", "audio-to-audio", "audio-classification", "canny-control", "hed-control", "mlsd-control", "normal-control", "openpose-control", "canny-text-to-image", "depth-text-to-image", "hed-text-to-image", "mlsd-text-to-image", "normal-text-to-image", "openpose-text-to-image", "seg-text-to-image", "args", "text", "path", "dep", "id", "<GENERATED>-"}'''
41
+ res = encodings[model_name].encode(text)
42
+ res = list(set(res))
43
+ return res
44
+
45
+ def get_token_ids_for_choose_model(model_name):
46
+ text = '''{"id": "reason"}'''
47
+ res = encodings[model_name].encode(text)
48
+ res = list(set(res))
49
+ return res
models_server.py ADDED
@@ -0,0 +1,618 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import random
4
+ import uuid
5
+ import numpy as np
6
+ from transformers import pipeline
7
+ from diffusers import DiffusionPipeline, StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
8
+ from diffusers.utils import load_image
9
+ from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
10
+ from diffusers.utils import export_to_video
11
+ from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5ForSpeechToSpeech
12
+ from transformers import BlipProcessor, BlipForConditionalGeneration
13
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
14
+ from datasets import load_dataset
15
+ from PIL import Image
16
+ import io
17
+ from torchvision import transforms
18
+ import torch
19
+ import torchaudio
20
+ from speechbrain.pretrained import WaveformEnhancement
21
+ import joblib
22
+ from huggingface_hub import hf_hub_url, cached_download
23
+ from transformers import AutoImageProcessor, TimesformerForVideoClassification
24
+ from transformers import MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation, AutoFeatureExtractor
25
+ from controlnet_aux import OpenposeDetector, MLSDdetector, HEDdetector, CannyDetector, MidasDetector
26
+ from controlnet_aux.open_pose.body import Body
27
+ from controlnet_aux.mlsd.models.mbv2_mlsd_large import MobileV2_MLSD_Large
28
+ from controlnet_aux.hed import Network
29
+ from transformers import DPTForDepthEstimation, DPTFeatureExtractor
30
+ import warnings
31
+ import time
32
+ from espnet2.bin.tts_inference import Text2Speech
33
+ import soundfile as sf
34
+ from asteroid.models import BaseModel
35
+ import traceback
36
+ import os
37
+ import yaml
38
+
39
+ warnings.filterwarnings("ignore")
40
+
41
+ parser = argparse.ArgumentParser()
42
+ parser.add_argument("--config", type=str, default="config.yaml")
43
+ args = parser.parse_args()
44
+
45
+ if __name__ != "__main__":
46
+ args.config = "config.gradio.yaml"
47
+
48
+ logger = logging.getLogger(__name__)
49
+ logger.setLevel(logging.INFO)
50
+ handler = logging.StreamHandler()
51
+ handler.setLevel(logging.INFO)
52
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
53
+ handler.setFormatter(formatter)
54
+ logger.addHandler(handler)
55
+
56
+ config = yaml.load(open(args.config, "r"), Loader=yaml.FullLoader)
57
+
58
+ local_deployment = config["local_deployment"]
59
+ if config["inference_mode"] == "huggingface":
60
+ local_deployment = "none"
61
+
62
+ PROXY = None
63
+ if config["proxy"]:
64
+ PROXY = {
65
+ "https": config["proxy"],
66
+ }
67
+
68
+ start = time.time()
69
+
70
+ # local_models = "models/"
71
+ local_models = ""
72
+
73
+
74
+ def load_pipes(local_deployment):
75
+ other_pipes = {}
76
+ standard_pipes = {}
77
+ controlnet_sd_pipes = {}
78
+ if local_deployment in ["full"]:
79
+ other_pipes = {
80
+
81
+ # "Salesforce/blip-image-captioning-large": {
82
+ # "model": BlipForConditionalGeneration.from_pretrained(f"Salesforce/blip-image-captioning-large"),
83
+ # "processor": BlipProcessor.from_pretrained(f"Salesforce/blip-image-captioning-large"),
84
+ # "device": "cuda:0"
85
+ # },
86
+ "damo-vilab/text-to-video-ms-1.7b": {
87
+ "model": DiffusionPipeline.from_pretrained(f"{local_models}damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16"),
88
+ "device": "cuda:0"
89
+ },
90
+ # "facebook/maskformer-swin-large-ade": {
91
+ # "model": MaskFormerForInstanceSegmentation.from_pretrained(f"facebook/maskformer-swin-large-ade"),
92
+ # "feature_extractor" : AutoFeatureExtractor.from_pretrained("facebook/maskformer-swin-large-ade"),
93
+ # "device": "cuda:0"
94
+ # },
95
+ # "microsoft/trocr-base-printed": {
96
+ # "processor": TrOCRProcessor.from_pretrained(f"microsoft/trocr-base-printed"),
97
+ # "model": VisionEncoderDecoderModel.from_pretrained(f"microsoft/trocr-base-printed"),
98
+ # "device": "cuda:0"
99
+ # },
100
+ # "microsoft/trocr-base-handwritten": {
101
+ # "processor": TrOCRProcessor.from_pretrained(f"microsoft/trocr-base-handwritten"),
102
+ # "model": VisionEncoderDecoderModel.from_pretrained(f"microsoft/trocr-base-handwritten"),
103
+ # "device": "cuda:0"
104
+ # },
105
+ "JorisCos/DCCRNet_Libri1Mix_enhsingle_16k": {
106
+ "model": BaseModel.from_pretrained("JorisCos/DCCRNet_Libri1Mix_enhsingle_16k"),
107
+ "device": "cuda:0"
108
+ },
109
+
110
+ # "CompVis/stable-diffusion-v1-4": {
111
+ # "model": DiffusionPipeline.from_pretrained(f"CompVis/stable-diffusion-v1-4"),
112
+ # "device": "cuda:0"
113
+ # },
114
+ # "stabilityai/stable-diffusion-2-1": {
115
+ # "model": DiffusionPipeline.from_pretrained(f"stabilityai/stable-diffusion-2-1"),
116
+ # "device": "cuda:0"
117
+ # },
118
+
119
+ # "microsoft/speecht5_tts":{
120
+ # "processor": SpeechT5Processor.from_pretrained(f"microsoft/speecht5_tts"),
121
+ # "model": SpeechT5ForTextToSpeech.from_pretrained(f"microsoft/speecht5_tts"),
122
+ # "vocoder": SpeechT5HifiGan.from_pretrained(f"microsoft/speecht5_hifigan"),
123
+ # "embeddings_dataset": load_dataset(f"Matthijs/cmu-arctic-xvectors", split="validation"),
124
+ # "device": "cuda:0"
125
+ # },
126
+ # "speechbrain/mtl-mimic-voicebank": {
127
+ # "model": WaveformEnhancement.from_hparams(source="speechbrain/mtl-mimic-voicebank", savedir="models/mtl-mimic-voicebank"),
128
+ # "device": "cuda:0"
129
+ # },
130
+ "microsoft/speecht5_vc":{
131
+ "processor": SpeechT5Processor.from_pretrained(f"{local_models}microsoft/speecht5_vc"),
132
+ "model": SpeechT5ForSpeechToSpeech.from_pretrained(f"{local_models}microsoft/speecht5_vc"),
133
+ "vocoder": SpeechT5HifiGan.from_pretrained(f"{local_models}microsoft/speecht5_hifigan"),
134
+ "embeddings_dataset": load_dataset(f"{local_models}Matthijs/cmu-arctic-xvectors", split="validation"),
135
+ "device": "cuda:0"
136
+ },
137
+ # "julien-c/wine-quality": {
138
+ # "model": joblib.load(cached_download(hf_hub_url("julien-c/wine-quality", "sklearn_model.joblib")))
139
+ # },
140
+ # "facebook/timesformer-base-finetuned-k400": {
141
+ # "processor": AutoImageProcessor.from_pretrained(f"facebook/timesformer-base-finetuned-k400"),
142
+ # "model": TimesformerForVideoClassification.from_pretrained(f"facebook/timesformer-base-finetuned-k400"),
143
+ # "device": "cuda:0"
144
+ # },
145
+ "facebook/maskformer-swin-base-coco": {
146
+ "feature_extractor": MaskFormerFeatureExtractor.from_pretrained(f"{local_models}facebook/maskformer-swin-base-coco"),
147
+ "model": MaskFormerForInstanceSegmentation.from_pretrained(f"{local_models}facebook/maskformer-swin-base-coco"),
148
+ "device": "cuda:0"
149
+ },
150
+ "Intel/dpt-hybrid-midas": {
151
+ "model": DPTForDepthEstimation.from_pretrained(f"{local_models}Intel/dpt-hybrid-midas", low_cpu_mem_usage=True),
152
+ "feature_extractor": DPTFeatureExtractor.from_pretrained(f"{local_models}Intel/dpt-hybrid-midas"),
153
+ "device": "cuda:0"
154
+ }
155
+ }
156
+
157
+ if local_deployment in ["full", "standard"]:
158
+ standard_pipes = {
159
+ # "nlpconnect/vit-gpt2-image-captioning":{
160
+ # "model": VisionEncoderDecoderModel.from_pretrained(f"{local_models}nlpconnect/vit-gpt2-image-captioning"),
161
+ # "feature_extractor": ViTImageProcessor.from_pretrained(f"{local_models}nlpconnect/vit-gpt2-image-captioning"),
162
+ # "tokenizer": AutoTokenizer.from_pretrained(f"{local_models}nlpconnect/vit-gpt2-image-captioning"),
163
+ # "device": "cuda:0"
164
+ # },
165
+ "espnet/kan-bayashi_ljspeech_vits": {
166
+ "model": Text2Speech.from_pretrained("espnet/kan-bayashi_ljspeech_vits"),
167
+ "device": "cuda:0"
168
+ },
169
+ # "lambdalabs/sd-image-variations-diffusers": {
170
+ # "model": DiffusionPipeline.from_pretrained(f"{local_models}lambdalabs/sd-image-variations-diffusers"), #torch_dtype=torch.float16
171
+ # "device": "cuda:0"
172
+ # },
173
+ "runwayml/stable-diffusion-v1-5": {
174
+ "model": DiffusionPipeline.from_pretrained(f"{local_models}runwayml/stable-diffusion-v1-5"),
175
+ "device": "cuda:0"
176
+ },
177
+ # "superb/wav2vec2-base-superb-ks": {
178
+ # "model": pipeline(task="audio-classification", model=f"superb/wav2vec2-base-superb-ks"),
179
+ # "device": "cuda:0"
180
+ # },
181
+ "openai/whisper-base": {
182
+ "model": pipeline(task="automatic-speech-recognition", model=f"{local_models}openai/whisper-base"),
183
+ "device": "cuda:0"
184
+ },
185
+ # "microsoft/speecht5_asr": {
186
+ # "model": pipeline(task="automatic-speech-recognition", model=f"{local_models}microsoft/speecht5_asr"),
187
+ # "device": "cuda:0"
188
+ # },
189
+ "Intel/dpt-large": {
190
+ "model": pipeline(task="depth-estimation", model=f"{local_models}Intel/dpt-large"),
191
+ "device": "cuda:0"
192
+ },
193
+ # "microsoft/beit-base-patch16-224-pt22k-ft22k": {
194
+ # "model": pipeline(task="image-classification", model=f"microsoft/beit-base-patch16-224-pt22k-ft22k"),
195
+ # "device": "cuda:0"
196
+ # },
197
+ "facebook/detr-resnet-50-panoptic": {
198
+ "model": pipeline(task="image-segmentation", model=f"{local_models}facebook/detr-resnet-50-panoptic"),
199
+ "device": "cuda:0"
200
+ },
201
+ "facebook/detr-resnet-101": {
202
+ "model": pipeline(task="object-detection", model=f"{local_models}facebook/detr-resnet-101"),
203
+ "device": "cuda:0"
204
+ },
205
+ # "openai/clip-vit-large-patch14": {
206
+ # "model": pipeline(task="zero-shot-image-classification", model=f"openai/clip-vit-large-patch14"),
207
+ # "device": "cuda:0"
208
+ # },
209
+ # "google/owlvit-base-patch32": {
210
+ # "model": pipeline(task="zero-shot-object-detection", model=f"{local_models}google/owlvit-base-patch32"),
211
+ # "device": "cuda:0"
212
+ # },
213
+ # "microsoft/DialoGPT-medium": {
214
+ # "model": pipeline(task="conversational", model=f"microsoft/DialoGPT-medium"),
215
+ # "device": "cuda:0"
216
+ # },
217
+ # "bert-base-uncased": {
218
+ # "model": pipeline(task="fill-mask", model=f"bert-base-uncased"),
219
+ # "device": "cuda:0"
220
+ # },
221
+ # "deepset/roberta-base-squad2": {
222
+ # "model": pipeline(task = "question-answering", model=f"deepset/roberta-base-squad2"),
223
+ # "device": "cuda:0"
224
+ # },
225
+ # "facebook/bart-large-cnn": {
226
+ # "model": pipeline(task="summarization", model=f"facebook/bart-large-cnn"),
227
+ # "device": "cuda:0"
228
+ # },
229
+ # "google/tapas-base-finetuned-wtq": {
230
+ # "model": pipeline(task="table-question-answering", model=f"google/tapas-base-finetuned-wtq"),
231
+ # "device": "cuda:0"
232
+ # },
233
+ # "distilbert-base-uncased-finetuned-sst-2-english": {
234
+ # "model": pipeline(task="text-classification", model=f"distilbert-base-uncased-finetuned-sst-2-english"),
235
+ # "device": "cuda:0"
236
+ # },
237
+ # "gpt2": {
238
+ # "model": pipeline(task="text-generation", model="gpt2"),
239
+ # "device": "cuda:0"
240
+ # },
241
+ # "mrm8488/t5-base-finetuned-question-generation-ap": {
242
+ # "model": pipeline(task="text2text-generation", model=f"mrm8488/t5-base-finetuned-question-generation-ap"),
243
+ # "device": "cuda:0"
244
+ # },
245
+ # "Jean-Baptiste/camembert-ner": {
246
+ # "model": pipeline(task="token-classification", model=f"Jean-Baptiste/camembert-ner", aggregation_strategy="simple"),
247
+ # "device": "cuda:0"
248
+ # },
249
+ # "t5-base": {
250
+ # "model": pipeline(task="translation", model=f"t5-base"),
251
+ # "device": "cuda:0"
252
+ # },
253
+ "impira/layoutlm-document-qa": {
254
+ "model": pipeline(task="document-question-answering", model=f"{local_models}impira/layoutlm-document-qa"),
255
+ "device": "cuda:0"
256
+ },
257
+ "ydshieh/vit-gpt2-coco-en": {
258
+ "model": pipeline(task="image-to-text", model=f"{local_models}ydshieh/vit-gpt2-coco-en"),
259
+ "device": "cuda:0"
260
+ },
261
+ "dandelin/vilt-b32-finetuned-vqa": {
262
+ "model": pipeline(task="visual-question-answering", model=f"{local_models}dandelin/vilt-b32-finetuned-vqa"),
263
+ "device": "cuda:0"
264
+ }
265
+ }
266
+
267
+ if local_deployment in ["full", "standard", "minimal"]:
268
+
269
+ controlnet = ControlNetModel.from_pretrained(f"{local_models}lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
270
+ controlnetpipe = StableDiffusionControlNetPipeline.from_pretrained(
271
+ f"{local_models}runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
272
+ )
273
+
274
+
275
+ hed_network = HEDdetector.from_pretrained('lllyasviel/ControlNet')
276
+
277
+ controlnet_sd_pipes = {
278
+ "openpose-control": {
279
+ "model": OpenposeDetector.from_pretrained('lllyasviel/ControlNet')
280
+ },
281
+ "mlsd-control": {
282
+ "model": MLSDdetector.from_pretrained('lllyasviel/ControlNet')
283
+ },
284
+ "hed-control": {
285
+ "model": hed_network
286
+ },
287
+ "scribble-control": {
288
+ "model": hed_network
289
+ },
290
+ "midas-control": {
291
+ "model": MidasDetector.from_pretrained('lllyasviel/ControlNet')
292
+ },
293
+ "canny-control": {
294
+ "model": CannyDetector()
295
+ },
296
+ "lllyasviel/sd-controlnet-canny":{
297
+ "control": controlnet,
298
+ "model": controlnetpipe,
299
+ "device": "cuda:0"
300
+ },
301
+ "lllyasviel/sd-controlnet-depth":{
302
+ "control": ControlNetModel.from_pretrained(f"{local_models}lllyasviel/sd-controlnet-depth", torch_dtype=torch.float16),
303
+ "model": controlnetpipe,
304
+ "device": "cuda:0"
305
+ },
306
+ "lllyasviel/sd-controlnet-hed":{
307
+ "control": ControlNetModel.from_pretrained(f"{local_models}lllyasviel/sd-controlnet-hed", torch_dtype=torch.float16),
308
+ "model": controlnetpipe,
309
+ "device": "cuda:0"
310
+ },
311
+ "lllyasviel/sd-controlnet-mlsd":{
312
+ "control": ControlNetModel.from_pretrained(f"{local_models}lllyasviel/sd-controlnet-mlsd", torch_dtype=torch.float16),
313
+ "model": controlnetpipe,
314
+ "device": "cuda:0"
315
+ },
316
+ "lllyasviel/sd-controlnet-openpose":{
317
+ "control": ControlNetModel.from_pretrained(f"{local_models}lllyasviel/sd-controlnet-openpose", torch_dtype=torch.float16),
318
+ "model": controlnetpipe,
319
+ "device": "cuda:0"
320
+ },
321
+ "lllyasviel/sd-controlnet-scribble":{
322
+ "control": ControlNetModel.from_pretrained(f"{local_models}lllyasviel/sd-controlnet-scribble", torch_dtype=torch.float16),
323
+ "model": controlnetpipe,
324
+ "device": "cuda:0"
325
+ },
326
+ "lllyasviel/sd-controlnet-seg":{
327
+ "control": ControlNetModel.from_pretrained(f"{local_models}lllyasviel/sd-controlnet-seg", torch_dtype=torch.float16),
328
+ "model": controlnetpipe,
329
+ "device": "cuda:0"
330
+ }
331
+ }
332
+ pipes = {**standard_pipes, **other_pipes, **controlnet_sd_pipes}
333
+ return pipes
334
+
335
+ pipes = load_pipes(local_deployment)
336
+
337
+ end = time.time()
338
+ during = end - start
339
+
340
+ print(f"[ ready ] {during}s")
341
+
342
+ def running():
343
+ return {"running": True}
344
+
345
+ def status(model_id):
346
+ disabled_models = ["microsoft/trocr-base-printed", "microsoft/trocr-base-handwritten"]
347
+ if model_id in pipes.keys() and model_id not in disabled_models:
348
+ print(f"[ check {model_id} ] success")
349
+ return {"loaded": True}
350
+ else:
351
+ print(f"[ check {model_id} ] failed")
352
+ return {"loaded": False}
353
+
354
+ def models(model_id, data):
355
+ while "using" in pipes[model_id] and pipes[model_id]["using"]:
356
+ print(f"[ inference {model_id} ] waiting")
357
+ time.sleep(0.1)
358
+ pipes[model_id]["using"] = True
359
+ print(f"[ inference {model_id} ] start")
360
+
361
+ start = time.time()
362
+
363
+ pipe = pipes[model_id]["model"]
364
+
365
+ if "device" in pipes[model_id]:
366
+ try:
367
+ pipe.to(pipes[model_id]["device"])
368
+ except:
369
+ pipe.device = torch.device(pipes[model_id]["device"])
370
+ pipe.model.to(pipes[model_id]["device"])
371
+
372
+ result = None
373
+ try:
374
+ # text to video
375
+ if model_id == "damo-vilab/text-to-video-ms-1.7b":
376
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
377
+ # pipe.enable_model_cpu_offload()
378
+ prompt = data["text"]
379
+ video_frames = pipe(prompt, num_inference_steps=50, num_frames=40).frames
380
+ file_name = str(uuid.uuid4())[:4]
381
+ video_path = export_to_video(video_frames, f"public/videos/{file_name}.mp4")
382
+
383
+ new_file_name = str(uuid.uuid4())[:4]
384
+ os.system(f"ffmpeg -i {video_path} -vcodec libx264 public/videos/{new_file_name}.mp4")
385
+
386
+ if os.path.exists(f"public/videos/{new_file_name}.mp4"):
387
+ result = {"path": f"/videos/{new_file_name}.mp4"}
388
+ else:
389
+ result = {"path": f"/videos/{file_name}.mp4"}
390
+
391
+ # controlnet
392
+ if model_id.startswith("lllyasviel/sd-controlnet-"):
393
+ pipe.controlnet.to('cpu')
394
+ pipe.controlnet = pipes[model_id]["control"].to(pipes[model_id]["device"])
395
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
396
+ control_image = load_image(data["img_url"])
397
+ # generator = torch.manual_seed(66)
398
+ out_image: Image = pipe(data["text"], num_inference_steps=20, image=control_image).images[0]
399
+ file_name = str(uuid.uuid4())[:4]
400
+ out_image.save(f"public/images/{file_name}.png")
401
+ result = {"path": f"/images/{file_name}.png"}
402
+
403
+ if model_id.endswith("-control"):
404
+ image = load_image(data["img_url"])
405
+ if "scribble" in model_id:
406
+ control = pipe(image, scribble = True)
407
+ elif "canny" in model_id:
408
+ control = pipe(image, low_threshold=100, high_threshold=200)
409
+ else:
410
+ control = pipe(image)
411
+ file_name = str(uuid.uuid4())[:4]
412
+ control.save(f"public/images/{file_name}.png")
413
+ result = {"path": f"/images/{file_name}.png"}
414
+
415
+ # image to image
416
+ if model_id == "lambdalabs/sd-image-variations-diffusers":
417
+ im = load_image(data["img_url"])
418
+ file_name = str(uuid.uuid4())[:4]
419
+ with open(f"public/images/{file_name}.png", "wb") as f:
420
+ f.write(data)
421
+ tform = transforms.Compose([
422
+ transforms.ToTensor(),
423
+ transforms.Resize(
424
+ (224, 224),
425
+ interpolation=transforms.InterpolationMode.BICUBIC,
426
+ antialias=False,
427
+ ),
428
+ transforms.Normalize(
429
+ [0.48145466, 0.4578275, 0.40821073],
430
+ [0.26862954, 0.26130258, 0.27577711]),
431
+ ])
432
+ inp = tform(im).to(pipes[model_id]["device"]).unsqueeze(0)
433
+ out = pipe(inp, guidance_scale=3)
434
+ out["images"][0].save(f"public/images/{file_name}.jpg")
435
+ result = {"path": f"/images/{file_name}.jpg"}
436
+
437
+ # image to text
438
+ if model_id == "Salesforce/blip-image-captioning-large":
439
+ raw_image = load_image(data["img_url"]).convert('RGB')
440
+ text = data["text"]
441
+ inputs = pipes[model_id]["processor"](raw_image, return_tensors="pt").to(pipes[model_id]["device"])
442
+ out = pipe.generate(**inputs)
443
+ caption = pipes[model_id]["processor"].decode(out[0], skip_special_tokens=True)
444
+ result = {"generated text": caption}
445
+ if model_id == "ydshieh/vit-gpt2-coco-en":
446
+ img_url = data["img_url"]
447
+ generated_text = pipe(img_url)[0]['generated_text']
448
+ result = {"generated text": generated_text}
449
+ if model_id == "nlpconnect/vit-gpt2-image-captioning":
450
+ image = load_image(data["img_url"]).convert("RGB")
451
+ pixel_values = pipes[model_id]["feature_extractor"](images=image, return_tensors="pt").pixel_values
452
+ pixel_values = pixel_values.to(pipes[model_id]["device"])
453
+ generated_ids = pipe.generate(pixel_values, **{"max_length": 200, "num_beams": 1})
454
+ generated_text = pipes[model_id]["tokenizer"].batch_decode(generated_ids, skip_special_tokens=True)[0]
455
+ result = {"generated text": generated_text}
456
+ # image to text: OCR
457
+ if model_id == "microsoft/trocr-base-printed" or model_id == "microsoft/trocr-base-handwritten":
458
+ image = load_image(data["img_url"]).convert("RGB")
459
+ pixel_values = pipes[model_id]["processor"](image, return_tensors="pt").pixel_values
460
+ pixel_values = pixel_values.to(pipes[model_id]["device"])
461
+ generated_ids = pipe.generate(pixel_values)
462
+ generated_text = pipes[model_id]["processor"].batch_decode(generated_ids, skip_special_tokens=True)[0]
463
+ result = {"generated text": generated_text}
464
+
465
+ # text to image
466
+ if model_id == "runwayml/stable-diffusion-v1-5":
467
+ file_name = str(uuid.uuid4())[:4]
468
+ text = data["text"]
469
+ out = pipe(prompt=text)
470
+ out["images"][0].save(f"public/images/{file_name}.jpg")
471
+ result = {"path": f"/images/{file_name}.jpg"}
472
+
473
+ # object detection
474
+ if model_id == "google/owlvit-base-patch32" or model_id == "facebook/detr-resnet-101":
475
+ img_url = data["img_url"]
476
+ open_types = ["cat", "couch", "person", "car", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird"]
477
+ result = pipe(img_url, candidate_labels=open_types)
478
+
479
+ # VQA
480
+ if model_id == "dandelin/vilt-b32-finetuned-vqa":
481
+ question = data["text"]
482
+ img_url = data["img_url"]
483
+ result = pipe(question=question, image=img_url)
484
+
485
+ #DQA
486
+ if model_id == "impira/layoutlm-document-qa":
487
+ question = data["text"]
488
+ img_url = data["img_url"]
489
+ result = pipe(img_url, question)
490
+
491
+ # depth-estimation
492
+ if model_id == "Intel/dpt-large":
493
+ output = pipe(data["img_url"])
494
+ image = output['depth']
495
+ name = str(uuid.uuid4())[:4]
496
+ image.save(f"public/images/{name}.jpg")
497
+ result = {"path": f"/images/{name}.jpg"}
498
+
499
+ if model_id == "Intel/dpt-hybrid-midas" and model_id == "Intel/dpt-large":
500
+ image = load_image(data["img_url"])
501
+ inputs = pipes[model_id]["feature_extractor"](images=image, return_tensors="pt")
502
+ with torch.no_grad():
503
+ outputs = pipe(**inputs)
504
+ predicted_depth = outputs.predicted_depth
505
+ prediction = torch.nn.functional.interpolate(
506
+ predicted_depth.unsqueeze(1),
507
+ size=image.size[::-1],
508
+ mode="bicubic",
509
+ align_corners=False,
510
+ )
511
+ output = prediction.squeeze().cpu().numpy()
512
+ formatted = (output * 255 / np.max(output)).astype("uint8")
513
+ image = Image.fromarray(formatted)
514
+ name = str(uuid.uuid4())[:4]
515
+ image.save(f"public/images/{name}.jpg")
516
+ result = {"path": f"/images/{name}.jpg"}
517
+
518
+ # TTS
519
+ if model_id == "espnet/kan-bayashi_ljspeech_vits":
520
+ text = data["text"]
521
+ wav = pipe(text)["wav"]
522
+ name = str(uuid.uuid4())[:4]
523
+ sf.write(f"public/audios/{name}.wav", wav.cpu().numpy(), pipe.fs, "PCM_16")
524
+ result = {"path": f"/audios/{name}.wav"}
525
+
526
+ if model_id == "microsoft/speecht5_tts":
527
+ text = data["text"]
528
+ inputs = pipes[model_id]["processor"](text=text, return_tensors="pt")
529
+ embeddings_dataset = pipes[model_id]["embeddings_dataset"]
530
+ speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0).to(pipes[model_id]["device"])
531
+ pipes[model_id]["vocoder"].to(pipes[model_id]["device"])
532
+ speech = pipe.generate_speech(inputs["input_ids"].to(pipes[model_id]["device"]), speaker_embeddings, vocoder=pipes[model_id]["vocoder"])
533
+ name = str(uuid.uuid4())[:4]
534
+ sf.write(f"public/audios/{name}.wav", speech.cpu().numpy(), samplerate=16000)
535
+ result = {"path": f"/audios/{name}.wav"}
536
+
537
+ # ASR
538
+ if model_id == "openai/whisper-base" or model_id == "microsoft/speecht5_asr":
539
+ audio_url = data["audio_url"]
540
+ result = { "text": pipe(audio_url)["text"]}
541
+
542
+ # audio to audio
543
+ if model_id == "JorisCos/DCCRNet_Libri1Mix_enhsingle_16k":
544
+ audio_url = data["audio_url"]
545
+ wav, sr = torchaudio.load(audio_url)
546
+ with torch.no_grad():
547
+ result_wav = pipe(wav.to(pipes[model_id]["device"]))
548
+ name = str(uuid.uuid4())[:4]
549
+ sf.write(f"public/audios/{name}.wav", result_wav.cpu().squeeze().numpy(), sr)
550
+ result = {"path": f"/audios/{name}.wav"}
551
+
552
+ if model_id == "microsoft/speecht5_vc":
553
+ audio_url = data["audio_url"]
554
+ wav, sr = torchaudio.load(audio_url)
555
+ inputs = pipes[model_id]["processor"](audio=wav, sampling_rate=sr, return_tensors="pt")
556
+ embeddings_dataset = pipes[model_id]["embeddings_dataset"]
557
+ speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
558
+ pipes[model_id]["vocoder"].to(pipes[model_id]["device"])
559
+ speech = pipe.generate_speech(inputs["input_ids"].to(pipes[model_id]["device"]), speaker_embeddings, vocoder=pipes[model_id]["vocoder"])
560
+ name = str(uuid.uuid4())[:4]
561
+ sf.write(f"public/audios/{name}.wav", speech.cpu().numpy(), samplerate=16000)
562
+ result = {"path": f"/audios/{name}.wav"}
563
+
564
+ # segmentation
565
+ if model_id == "facebook/detr-resnet-50-panoptic":
566
+ result = []
567
+ segments = pipe(data["img_url"])
568
+ image = load_image(data["img_url"])
569
+
570
+ colors = []
571
+ for i in range(len(segments)):
572
+ colors.append((random.randint(100, 255), random.randint(100, 255), random.randint(100, 255), 50))
573
+
574
+ for segment in segments:
575
+ mask = segment["mask"]
576
+ mask = mask.convert('L')
577
+ layer = Image.new('RGBA', mask.size, colors[i])
578
+ image.paste(layer, (0, 0), mask)
579
+ name = str(uuid.uuid4())[:4]
580
+ image.save(f"public/images/{name}.jpg")
581
+ result = {"path": f"/images/{name}.jpg"}
582
+
583
+ if model_id == "facebook/maskformer-swin-base-coco" or model_id == "facebook/maskformer-swin-large-ade":
584
+ image = load_image(data["img_url"])
585
+ inputs = pipes[model_id]["feature_extractor"](images=image, return_tensors="pt").to(pipes[model_id]["device"])
586
+ outputs = pipe(**inputs)
587
+ result = pipes[model_id]["feature_extractor"].post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
588
+ predicted_panoptic_map = result["segmentation"].cpu().numpy()
589
+ predicted_panoptic_map = Image.fromarray(predicted_panoptic_map.astype(np.uint8))
590
+ name = str(uuid.uuid4())[:4]
591
+ predicted_panoptic_map.save(f"public/images/{name}.jpg")
592
+ result = {"path": f"/images/{name}.jpg"}
593
+
594
+ except Exception as e:
595
+ print(e)
596
+ traceback.print_exc()
597
+ result = {"error": {"message": "Error when running the model inference."}}
598
+
599
+ if "device" in pipes[model_id]:
600
+ try:
601
+ pipe.to("cpu")
602
+ torch.cuda.empty_cache()
603
+ except:
604
+ pipe.device = torch.device("cpu")
605
+ pipe.model.to("cpu")
606
+ torch.cuda.empty_cache()
607
+
608
+ pipes[model_id]["using"] = False
609
+
610
+ if result is None:
611
+ result = {"error": {"message": "model not found"}}
612
+
613
+ end = time.time()
614
+ during = end - start
615
+ print(f"[ complete {model_id} ] {during}s")
616
+ print(f"[ result {model_id} ] {result}")
617
+
618
+ return result
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ tesseract-ocr
public/examples/a.jpg ADDED
public/examples/b.jpg ADDED
public/examples/c.jpg ADDED
public/examples/d.jpg ADDED
public/examples/e.jpg ADDED
public/examples/f.jpg ADDED
public/examples/g.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ git+https://github.com/huggingface/diffusers.git@8c530fc2f6a76a2aefb6b285dce6df1675092ac6#egg=diffusers
2
+ git+https://github.com/huggingface/transformers@c612628045822f909020f7eb6784c79700813eda#egg=transformers
3
+ git+https://github.com/patrickvonplaten/controlnet_aux@78efc716868a7f5669c288233d65b471f542ce40#egg=controlnet_aux
4
+ tiktoken==0.3.3
5
+ pydub==0.25.1
6
+ espnet==202301
7
+ espnet_model_zoo==0.1.7
8
+ flask==2.2.3
9
+ flask_cors==3.0.10
10
+ waitress==2.1.2
11
+ datasets==2.11.0
12
+ asteroid==0.6.0
13
+ speechbrain==0.5.14
14
+ timm==0.6.13
15
+ typeguard==2.13.3
16
+ accelerate==0.18.0
17
+ pytesseract==0.3.10
18
+ basicsr==1.4.2