nxphi47 commited on
Commit
bd0e607
1 Parent(s): f862dea

Upload 41 files

Browse files
Files changed (42) hide show
  1. .gitattributes +6 -0
  2. LICENSE +201 -0
  3. app.py +115 -0
  4. assets/attention_all_you_need.pdf +0 -0
  5. assets/attention_short.pdf +0 -0
  6. assets/doc_gif.gif +3 -0
  7. assets/dog_monalisa.jpeg +0 -0
  8. assets/image_demo.gif +3 -0
  9. assets/image_doc.gif +3 -0
  10. assets/image_doc_rag.gif +3 -0
  11. assets/rag_gif.gif +3 -0
  12. assets/text_completion_gif.gif +3 -0
  13. assets/upload_chat.json +10 -0
  14. assets/upload_few_shot.json +10 -0
  15. llama_cpp_requirements.txt +1 -0
  16. mlx_requirements.txt +2 -0
  17. multipurpose_chatbot/.DS_Store +0 -0
  18. multipurpose_chatbot/__init__.py +0 -0
  19. multipurpose_chatbot/configs.py +110 -0
  20. multipurpose_chatbot/demos/.DS_Store +0 -0
  21. multipurpose_chatbot/demos/__init__.py +8 -0
  22. multipurpose_chatbot/demos/base_demo.py +105 -0
  23. multipurpose_chatbot/demos/batch_inference.py +246 -0
  24. multipurpose_chatbot/demos/chat_interface.py +704 -0
  25. multipurpose_chatbot/demos/multimodal_chat_interface.py +1293 -0
  26. multipurpose_chatbot/demos/rag_chat_interface.py +642 -0
  27. multipurpose_chatbot/demos/text_completion.py +199 -0
  28. multipurpose_chatbot/engines/.DS_Store +0 -0
  29. multipurpose_chatbot/engines/__init__.py +54 -0
  30. multipurpose_chatbot/engines/base_engine.py +46 -0
  31. multipurpose_chatbot/engines/debug_engine.py +49 -0
  32. multipurpose_chatbot/engines/llama_cpp_engine.py +131 -0
  33. multipurpose_chatbot/engines/llava15_transformers_engine.py +230 -0
  34. multipurpose_chatbot/engines/llava_llama_cpp_engine.py +280 -0
  35. multipurpose_chatbot/engines/mlx_engine.py +202 -0
  36. multipurpose_chatbot/engines/transformers_engine.py +452 -0
  37. multipurpose_chatbot/engines/vllm_engine.py +233 -0
  38. multipurpose_chatbot/globals.py +33 -0
  39. pyproject.toml +0 -0
  40. requirements.txt +11 -0
  41. transformers_requirements.txt +1 -0
  42. vllm_requirements.txt +2 -0
.gitattributes CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/doc_gif.gif filter=lfs diff=lfs merge=lfs -text
37
+ assets/image_demo.gif filter=lfs diff=lfs merge=lfs -text
38
+ assets/image_doc_rag.gif filter=lfs diff=lfs merge=lfs -text
39
+ assets/image_doc.gif filter=lfs diff=lfs merge=lfs -text
40
+ assets/rag_gif.gif filter=lfs diff=lfs merge=lfs -text
41
+ assets/text_completion_gif.gif filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
app.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright: DAMO Academy, Alibaba Group
2
+ # By Xuan Phi Nguyen at DAMO Academy, Alibaba Group
3
+
4
+ # Description:
5
+ """
6
+ Demo script to launch Language chat model
7
+ """
8
+
9
+
10
+ import os
11
+ from gradio.themes import ThemeClass as Theme
12
+ import numpy as np
13
+ import argparse
14
+ # import torch
15
+ import gradio as gr
16
+ from typing import Any, Iterator
17
+ from typing import Iterator, List, Optional, Tuple
18
+ import filelock
19
+ import glob
20
+ import json
21
+ import time
22
+ from gradio.routes import Request
23
+ from gradio.utils import SyncToAsyncIterator, async_iteration
24
+ from gradio.helpers import special_args
25
+ import anyio
26
+ from typing import AsyncGenerator, Callable, Literal, Union, cast
27
+
28
+ from gradio_client.documentation import document, set_documentation_group
29
+
30
+ from typing import List, Optional, Union, Dict, Tuple
31
+ from tqdm.auto import tqdm
32
+ from huggingface_hub import snapshot_download
33
+ from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceBgeEmbeddings
34
+ from gradio.components import Button, Component
35
+ from gradio.events import Dependency, EventListenerMethod
36
+
37
+ from multipurpose_chatbot.demos.base_demo import CustomTabbedInterface
38
+
39
+ from multipurpose_chatbot.configs import (
40
+ MODEL_TITLE,
41
+ MODEL_DESC,
42
+ MODEL_INFO,
43
+ CITE_MARKDOWN,
44
+ ALLOWED_PATHS,
45
+ PROXY,
46
+ PORT,
47
+ MODEL_PATH,
48
+ MODEL_NAME,
49
+ BACKEND,
50
+ DEMOS,
51
+ )
52
+
53
+
54
+ demo = None
55
+
56
+
57
+
58
+
59
+
60
+ def launch_demo():
61
+ global demo, MODEL_ENGINE
62
+ model_desc = MODEL_DESC
63
+ model_path = MODEL_PATH
64
+
65
+ print(f'Begin importing models')
66
+ from multipurpose_chatbot.demos import get_demo_class
67
+
68
+ # demos = {
69
+ # k: get_demo_class(k)().create_demo()
70
+ # for k in demo_and_tab_names.keys()
71
+ # }
72
+ print(f'{DEMOS=}')
73
+ demo_class_objects = {
74
+ k: get_demo_class(k)()
75
+ for k in DEMOS
76
+ }
77
+ demos = {
78
+ k: get_demo_class(k)().create_demo()
79
+ for k in DEMOS
80
+ }
81
+ demos_names = [x.tab_name for x in demo_class_objects.values()]
82
+
83
+ descriptions = model_desc
84
+ if MODEL_INFO is not None and MODEL_INFO != "":
85
+ descriptions += (
86
+ f"<br>" +
87
+ MODEL_INFO.format(model_path=model_path)
88
+ )
89
+
90
+ demo = CustomTabbedInterface(
91
+ interface_list=list(demos.values()),
92
+ tab_names=demos_names,
93
+ title=f"{MODEL_TITLE}",
94
+ description=descriptions,
95
+ )
96
+
97
+ demo.title = MODEL_NAME
98
+
99
+ with demo:
100
+ gr.Markdown(CITE_MARKDOWN)
101
+
102
+ demo.queue(api_open=False)
103
+ return demo
104
+
105
+
106
+
107
+ if __name__ == "__main__":
108
+ demo = launch_demo()
109
+ if PROXY is not None and PROXY != "":
110
+ print(f'{PROXY=} {PORT=}')
111
+ demo.launch(server_port=PORT, root_path=PROXY, show_api=False, allowed_paths=ALLOWED_PATHS)
112
+ else:
113
+ demo.launch(server_port=PORT, show_api=False, allowed_paths=ALLOWED_PATHS)
114
+
115
+
assets/attention_all_you_need.pdf ADDED
Binary file (858 kB). View file
 
assets/attention_short.pdf ADDED
Binary file (236 kB). View file
 
assets/doc_gif.gif ADDED

Git LFS Details

  • SHA256: b04ced9f35bec0f27045a895cf991790d112b84e72e279b653cc846447994c9d
  • Pointer size: 132 Bytes
  • Size of remote file: 1.17 MB
assets/dog_monalisa.jpeg ADDED
assets/image_demo.gif ADDED

Git LFS Details

  • SHA256: 6dc4b375bb283cc7486d9134efa256dc9675c29ebef79a7d163b4bba49a5994a
  • Pointer size: 132 Bytes
  • Size of remote file: 1.46 MB
assets/image_doc.gif ADDED

Git LFS Details

  • SHA256: e26a39469ffc5be2d4ca2a24744cea3b2aaefea09a335659529b00d9dec0087a
  • Pointer size: 132 Bytes
  • Size of remote file: 2.62 MB
assets/image_doc_rag.gif ADDED

Git LFS Details

  • SHA256: 0d1bb3ac99fedb5ba0f462b59c735cf5787239edd4ab65451dc14e0387750bce
  • Pointer size: 132 Bytes
  • Size of remote file: 9.9 MB
assets/rag_gif.gif ADDED

Git LFS Details

  • SHA256: e6dc50a2c2ec4e57d3247f9a1233da8ea3c4408d232c6f818cca11f5bfd83cf9
  • Pointer size: 132 Bytes
  • Size of remote file: 7.36 MB
assets/text_completion_gif.gif ADDED

Git LFS Details

  • SHA256: a0f8138f146ac1b784a8eda8413ecaa7e0efbe90d6a47916ae7ee122b849dcf1
  • Pointer size: 132 Bytes
  • Size of remote file: 1.34 MB
assets/upload_chat.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "id": "1",
4
+ "prompt": "Tell me something about AI?"
5
+ },
6
+ {
7
+ "id": "2",
8
+ "prompt": "Who are you?"
9
+ }
10
+ ]
assets/upload_few_shot.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "id": "0",
4
+ "prompt": "Translate Indonesian to English.\nIndonesian: \"Mereka melakukan hal ini dengan cara memancarkan sebuah partikel kecil cahaya kecil yang biasa disebut \"foton\".\"\nEnglish: They do this by emitting a tiny particle of light called a \"photon\".\n\nTranslate Indonesian to English.\nIndonesian: Kami melewati waktu seperti rangkaian peristiwa yang berlalu dari masa depan hingga masa kini lalu ke masa lalu.\nEnglish: We experience time as a series of events passing from the future through the present to the past.\n\nTranslate Indonesian to English.\nIndonesian: Canyoning (atau: canyoneering) adalah segala aktivitas yang terjadi di dasar ngarai, yang kering atau penuh air.\nEnglish: Canyoning (or: canyoneering) is about going in a bottom of a canyon, which is either dry or full of water.\n\nTranslate Indonesian to English.\nIndonesian: Mohon diingat bahwa intinya Anda sedang berkunjung ke situs kuburan massal, serta situs yang maknanya tak terhitung bagi sejumlah populasi dunia yang signifikan.\nEnglish:"
5
+ },
6
+ {
7
+ "id": "1",
8
+ "prompt": "Translate Indonesian to English.\nIndonesian: \"Mereka melakukan hal ini dengan cara memancarkan sebuah partikel kecil cahaya kecil yang biasa disebut \"foton\".\"\nEnglish: They do this by emitting a tiny particle of light called a \"photon\".\n\nTranslate Indonesian to English.\nIndonesian: Kami melewati waktu seperti rangkaian peristiwa yang berlalu dari masa depan hingga masa kini lalu ke masa lalu.\nEnglish: We experience time as a series of events passing from the future through the present to the past.\n\nTranslate Indonesian to English.\nIndonesian: Canyoning (atau: canyoneering) adalah segala aktivitas yang terjadi di dasar ngarai, yang kering atau penuh air.\nEnglish: Canyoning (or: canyoneering) is about going in a bottom of a canyon, which is either dry or full of water.\n\nTranslate Indonesian to English.\nIndonesian: Serangga adalah hewan pertama yang menjelajah angkasa. Kemampuan terbangnya membantu mereka menghindari musuh dengan lebih mudah dan mencari makanan dan pasangan dengan lebih efisien.\nEnglish:"
9
+ }
10
+ ]
llama_cpp_requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ llama-cpp-python
mlx_requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ mlx
2
+ mlx-lm
multipurpose_chatbot/.DS_Store ADDED
Binary file (6.15 kB). View file
 
multipurpose_chatbot/__init__.py ADDED
File without changes
multipurpose_chatbot/configs.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+
4
+ # ! UI Markdown information
5
+
6
+ MODEL_TITLE = "<h1>Multi-Purpose Chatbot</h1>"
7
+
8
+ MODEL_DESC = f"""
9
+ <div style='display:flex; gap: 0.25rem; '>
10
+ <a href='https://github.com/DAMO-NLP-SG/Multipurpose-Chatbot'><img src='https://img.shields.io/badge/Github-Code-success'></a>
11
+ </div>
12
+ <span style="font-size: larger">
13
+ A multi-purpose helpful assistant with multiple functionalities (Chat, text-completion, RAG chat, batch inference).
14
+ </span>
15
+ """.strip()
16
+
17
+
18
+
19
+ MODEL_INFO = """
20
+ <h4>Model Name: {model_path}</h4>
21
+ """
22
+
23
+ CITE_MARKDOWN = """
24
+ ## Citation
25
+ If you find our project useful, hope you can star our repo and cite our repo as follows:
26
+ ```
27
+ @article{multipurpose_chatbot_2024,
28
+ author = {Xuan-Phi Nguyen, },
29
+ title = {Multipurpose Chatbot},
30
+ year = 2024,
31
+ }
32
+ ```
33
+ """
34
+
35
+ USE_PANEL = bool(int(os.environ.get("USE_PANEL", "1")))
36
+ CHATBOT_HEIGHT = int(os.environ.get("CHATBOT_HEIGHT", "500"))
37
+
38
+ ALLOWED_PATHS = []
39
+
40
+
41
+ DEMOS = os.environ.get("DEMOS", "")
42
+
43
+ DEMOS = DEMOS.split(",") if DEMOS.strip() != "" else [
44
+ "DocChatInterfaceDemo",
45
+ "ChatInterfaceDemo",
46
+ "TextCompletionDemo",
47
+ # "RagChatInterfaceDemo",
48
+ # "VisionChatInterfaceDemo",
49
+ # "VisionDocChatInterfaceDemo",
50
+ ]
51
+
52
+ # DEMOS=DocChatInterfaceDemo,ChatInterfaceDemo,RagChatInterfaceDemo,TextCompletionDemo
53
+
54
+
55
+
56
+ # ! server info
57
+
58
+ PORT = int(os.environ.get("PORT", "7860"))
59
+ PROXY = os.environ.get("PROXY", "").strip()
60
+
61
+ # ! backend info
62
+
63
+ BACKEND = os.environ.get("BACKEND", "debug")
64
+
65
+ # ! model information
66
+ # for RAG
67
+ RAG_EMBED_MODEL_NAME = os.environ.get("RAG_EMBED_MODEL_NAME", "sentence-transformers/all-MiniLM-L6-v2")
68
+ CHUNK_SIZE = int(os.environ.get("CHUNK_SIZE", "1024"))
69
+ CHUNK_OVERLAP = int(os.environ.get("CHUNK_SIZE", "50"))
70
+
71
+
72
+ SYSTEM_PROMPT = os.environ.get("SYSTEM_PROMPT", """You are a helpful, respectful, honest and safe AI assistant.""")
73
+
74
+ MAX_TOKENS = int(os.environ.get("MAX_TOKENS", "2048"))
75
+ TEMPERATURE = float(os.environ.get("TEMPERATURE", "0.7"))
76
+ # ! these values currently not used
77
+ FREQUENCE_PENALTY = float(os.environ.get("FREQUENCE_PENALTY", "0.0"))
78
+ PRESENCE_PENALTY = float(os.environ.get("PRESENCE_PENALTY", "0.0"))
79
+
80
+
81
+ # Transformers or vllm
82
+ MODEL_PATH = os.environ.get("MODEL_PATH", "teknium/OpenHermes-2.5-Mistral-7B")
83
+ MODEL_NAME = os.environ.get("MODEL_NAME", "Cool-Chatbot")
84
+ DTYPE = os.environ.get("DTYPE", "bfloat16")
85
+ DEVICE = os.environ.get("DEVICE", "cuda")
86
+
87
+ # VLLM
88
+ GPU_MEMORY_UTILIZATION = float(os.environ.get("GPU_MEMORY_UTILIZATION", "0.9"))
89
+ TENSOR_PARALLEL = int(os.environ.get("TENSOR_PARALLEL", "1"))
90
+ QUANTIZATION = str(os.environ.get("QUANTIZATION", ""))
91
+ STREAM_YIELD_MULTIPLE = int(os.environ.get("STREAM_YIELD_MULTIPLE", "1"))
92
+ # how many iterations to perform safety check on response
93
+ STREAM_CHECK_MULTIPLE = int(os.environ.get("STREAM_CHECK_MULTIPLE", "0"))
94
+
95
+ # llama.cpp
96
+ DEFAULT_CHAT_TEMPLATE = os.environ.get("DEFAULT_CHAT_TEMPLATE", "chatml")
97
+ N_CTX = int(os.environ.get("N_CTX", "4096"))
98
+ N_GPU_LAYERS = int(os.environ.get("N_GPU_LAYERS", "-1"))
99
+
100
+ # llava.llama.cpp
101
+ # ! pending development
102
+
103
+ # Multimodal
104
+ # IMAGE_TOKEN = os.environ.get("IMAGE_TOKEN", "[IMAGE]<|image|>[/IMAGE]")
105
+ IMAGE_TOKEN = os.environ.get("IMAGE_TOKEN", "<image>")
106
+ IMAGE_TOKEN_INTERACTIVE = bool(int(os.environ.get("IMAGE_TOKEN_INTERACTIVE", "0")))
107
+ # ! IMAGE_TOKEN_LENGTH expected embedding lengths of an image to calculate the actual tokens
108
+ IMAGE_TOKEN_LENGTH = int(os.environ.get("IMAGE_TOKEN_LENGTH", "576"))
109
+ # ! Llava1.6 to calculate the maximum number of patches in an image (max=5 for Llava1.6)
110
+ MAX_PACHES = int(os.environ.get("MAX_PACHES", "1"))
multipurpose_chatbot/demos/.DS_Store ADDED
Binary file (6.15 kB). View file
 
multipurpose_chatbot/demos/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+
2
+ from .base_demo import *
3
+
4
+ from .chat_interface import ChatInterfaceDemo
5
+ from .rag_chat_interface import RagChatInterfaceDemo
6
+ from .multimodal_chat_interface import *
7
+ from .text_completion import *
8
+ from .batch_inference import *
multipurpose_chatbot/demos/base_demo.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from gradio.themes import ThemeClass as Theme
3
+ import numpy as np
4
+ import argparse
5
+ import gradio as gr
6
+ from typing import Any, Iterator
7
+ from typing import Iterator, List, Optional, Tuple
8
+ import filelock
9
+ import glob
10
+ import json
11
+ import time
12
+ from gradio.routes import Request
13
+ from gradio.utils import SyncToAsyncIterator, async_iteration
14
+ from gradio.helpers import special_args
15
+ import anyio
16
+ from typing import AsyncGenerator, Callable, Literal, Union, cast
17
+
18
+ from gradio_client.documentation import document, set_documentation_group
19
+ from gradio.components import Button, Component
20
+ from gradio.events import Dependency, EventListenerMethod
21
+ from typing import List, Optional, Union, Dict, Tuple
22
+ from tqdm.auto import tqdm
23
+ from huggingface_hub import snapshot_download
24
+
25
+
26
+ def create_class_func_registry():
27
+ registry = {}
28
+ def register_registry(cls, exist_ok=False):
29
+ assert exist_ok or cls.__name__ not in registry, f'{cls} already in registry: {registry}'
30
+ registry[cls.__name__] = cls
31
+ return cls
32
+
33
+ def get_registry(name):
34
+ assert name in registry, f'{name} not in registry: {registry}'
35
+ return registry[name]
36
+
37
+ return registry, register_registry, get_registry
38
+
39
+ DEMOS, register_demo, get_demo_class = create_class_func_registry()
40
+
41
+
42
+ class BaseDemo(object):
43
+ """
44
+ All demo should be created from BaseDemo and registered with @register_demo
45
+ """
46
+ def __init__(self) -> None:
47
+ pass
48
+
49
+ @property
50
+ def tab_name(self):
51
+ return "Demo"
52
+
53
+ def create_demo(
54
+ self,
55
+ title: Optional[str] = None,
56
+ description: Optional[str] = None,
57
+ **kwargs,
58
+ ) -> gr.Blocks:
59
+ pass
60
+
61
+
62
+ @document()
63
+ class CustomTabbedInterface(gr.Blocks):
64
+ def __init__(
65
+ self,
66
+ interface_list: list[gr.Interface],
67
+ tab_names: Optional[list[str]] = None,
68
+ title: Optional[str] = None,
69
+ description: Optional[str] = None,
70
+ theme: Optional[gr.Theme] = None,
71
+ analytics_enabled: Optional[bool] = None,
72
+ css: Optional[str] = None,
73
+ ):
74
+ """
75
+ Parameters:
76
+ interface_list: a list of interfaces to be rendered in tabs.
77
+ tab_names: a list of tab names. If None, the tab names will be "Tab 1", "Tab 2", etc.
78
+ title: a title for the interface; if provided, appears above the input and output components in large font. Also used as the tab title when opened in a browser window.
79
+ analytics_enabled: whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable or default to True.
80
+ css: custom css or path to custom css file to apply to entire Blocks
81
+ Returns:
82
+ a Gradio Tabbed Interface for the given interfaces
83
+ """
84
+ super().__init__(
85
+ title=title or "Gradio",
86
+ theme=theme,
87
+ analytics_enabled=analytics_enabled,
88
+ mode="tabbed_interface",
89
+ css=css,
90
+ )
91
+ self.description = description
92
+ if tab_names is None:
93
+ tab_names = [f"Tab {i}" for i in range(len(interface_list))]
94
+ with self:
95
+ if title:
96
+ gr.Markdown(
97
+ f"<h1 style='text-align: center; margin-bottom: 1rem'>{title}</h1>"
98
+ )
99
+ if description:
100
+ gr.Markdown(description)
101
+ with gr.Tabs():
102
+ for interface, tab_name in zip(interface_list, tab_names):
103
+ with gr.Tab(label=tab_name):
104
+ interface.render()
105
+
multipurpose_chatbot/demos/batch_inference.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from gradio.themes import ThemeClass as Theme
3
+ import numpy as np
4
+ import argparse
5
+ import gradio as gr
6
+ from typing import Any, Iterator
7
+ from typing import Iterator, List, Optional, Tuple
8
+ import filelock
9
+ import glob
10
+ import json
11
+ import time
12
+ from gradio.routes import Request
13
+ from gradio.utils import SyncToAsyncIterator, async_iteration
14
+ from gradio.helpers import special_args
15
+ import anyio
16
+ from typing import AsyncGenerator, Callable, Literal, Union, cast, Generator
17
+
18
+ from gradio_client.documentation import document, set_documentation_group
19
+ from gradio.components import Button, Component
20
+ from gradio.events import Dependency, EventListenerMethod
21
+ from typing import List, Optional, Union, Dict, Tuple
22
+ from tqdm.auto import tqdm
23
+ from huggingface_hub import snapshot_download
24
+
25
+
26
+ import inspect
27
+ from typing import AsyncGenerator, Callable, Literal, Union, cast
28
+
29
+ import anyio
30
+ from gradio_client import utils as client_utils
31
+ from gradio_client.documentation import document
32
+
33
+ from gradio.blocks import Blocks
34
+ from gradio.components import (
35
+ Button,
36
+ Chatbot,
37
+ Component,
38
+ Markdown,
39
+ State,
40
+ Textbox,
41
+ get_component_instance,
42
+ )
43
+ from gradio.events import Dependency, on
44
+ from gradio.helpers import create_examples as Examples # noqa: N812
45
+ from gradio.helpers import special_args
46
+ from gradio.layouts import Accordion, Group, Row
47
+ from gradio.routes import Request
48
+ from gradio.themes import ThemeClass as Theme
49
+ from gradio.utils import SyncToAsyncIterator, async_iteration
50
+
51
+
52
+ from .base_demo import register_demo, get_demo_class, BaseDemo
53
+ from ..configs import (
54
+ SYSTEM_PROMPT,
55
+ MODEL_NAME,
56
+ MAX_TOKENS,
57
+ TEMPERATURE,
58
+ USE_PANEL,
59
+ CHATBOT_HEIGHT,
60
+ )
61
+
62
+ from ..globals import MODEL_ENGINE
63
+
64
+ from .chat_interface import gradio_history_to_conversation_prompt
65
+
66
+ # Batch inference file upload
67
+ ENABLE_BATCH_INFER = bool(int(os.environ.get("ENABLE_BATCH_INFER", "1")))
68
+ BATCH_INFER_MAX_ITEMS = int(os.environ.get("BATCH_INFER_MAX_ITEMS", "100"))
69
+ BATCH_INFER_MAX_FILE_SIZE = int(os.environ.get("BATCH_INFER_MAX_FILE_SIZE", "500"))
70
+ BATCH_INFER_MAX_PROMPT_TOKENS = int(os.environ.get("BATCH_INFER_MAX_PROMPT_TOKENS", "4000"))
71
+ BATCH_INFER_SAVE_TMP_FILE = os.environ.get("BATCH_INFER_SAVE_TMP_FILE", "./tmp/pred.json")
72
+
73
+
74
+
75
+ FILE_UPLOAD_DESCRIPTION = f"""Upload JSON file as list of dict with < {BATCH_INFER_MAX_ITEMS} items, \
76
+ each item has `prompt` key. We put guardrails to enhance safety, so do not input any harmful content or personal information! Re-upload the file after every submit. See the examples below.
77
+ ```
78
+ [ {{"id": 0, "prompt": "Hello world"}} , {{"id": 1, "prompt": "Hi there?"}}]
79
+ ```
80
+ """
81
+
82
+ def validate_file_item(filename, index, item: Dict[str, str]):
83
+ """
84
+ check safety for items in files
85
+ """
86
+ global MODEL_ENGINE
87
+ message = item['prompt'].strip()
88
+
89
+ if len(message) == 0:
90
+ raise gr.Error(f'Prompt {index} empty')
91
+
92
+ num_tokens = len(MODEL_ENGINE.tokenizer.encode(message))
93
+ if num_tokens >= MODEL_ENGINE.max_position_embeddings - 128:
94
+ raise gr.Error(f"Conversation or prompt is too long ({num_tokens} toks), please clear the chatbox or try shorter input.")
95
+
96
+
97
+ def read_validate_json_files(files: Union[str, List[str]]):
98
+ files = files if isinstance(files, list) else [files]
99
+ filenames = [f.name for f in files]
100
+ all_items = []
101
+ for fname in filenames:
102
+ # check each files
103
+ print(f'Reading {fname}')
104
+ with open(fname, 'r', encoding='utf-8') as f:
105
+ items = json.load(f)
106
+ assert isinstance(items, list), f'Data {fname} not list'
107
+ assert all(isinstance(x, dict) for x in items), f'item in input file not list'
108
+ assert all("prompt" in x for x in items), f'key prompt should be in dict item of input file'
109
+
110
+ for i, x in enumerate(items):
111
+ validate_file_item(fname, i, x)
112
+
113
+ all_items.extend(items)
114
+
115
+ if len(all_items) > BATCH_INFER_MAX_ITEMS:
116
+ raise gr.Error(f"Num samples {len(all_items)} > {BATCH_INFER_MAX_ITEMS} allowed.")
117
+
118
+ return all_items, filenames
119
+
120
+
121
+ def remove_gradio_cache(exclude_names=None):
122
+ """remove gradio cache to avoid flooding"""
123
+ import shutil
124
+ for root, dirs, files in os.walk('/tmp/gradio/'):
125
+ for f in files:
126
+ # if not any(f in ef for ef in except_files):
127
+ if exclude_names is None or not any(ef in f for ef in exclude_names):
128
+ print(f'Remove: {f}')
129
+ os.unlink(os.path.join(root, f))
130
+
131
+
132
+ def free_form_prompt(prompt, history=None, system_prompt=None):
133
+ return prompt
134
+
135
+
136
+
137
+
138
+ def batch_inference_engine(
139
+ files: Union[str, List[str]],
140
+ prompt_mode: str,
141
+ temperature: float,
142
+ max_tokens: int,
143
+ stop_strings: str = "<s>,</s>,<|im_start|>",
144
+ system_prompt: Optional[str] = SYSTEM_PROMPT,
145
+ ):
146
+ global MODEL_ENGINE
147
+ temperature = float(temperature)
148
+ max_tokens = int(max_tokens)
149
+ stop_strings = [x.strip() for x in stop_strings.strip().split(",")]
150
+
151
+ all_items, filenames = read_validate_json_files(files)
152
+
153
+ # remove all items in /tmp/gradio/
154
+ remove_gradio_cache(exclude_names=['upload_chat.json', 'upload_few_shot.json'])
155
+
156
+ if prompt_mode == 'chat':
157
+ prompt_format_fn = gradio_history_to_conversation_prompt
158
+ elif prompt_mode == 'few-shot':
159
+ from functools import partial
160
+ prompt_format_fn = free_form_prompt
161
+ else:
162
+ raise gr.Error(f'Wrong mode {prompt_mode}')
163
+
164
+ full_prompts = [
165
+ prompt_format_fn(
166
+ x['prompt'], [], system_prompt=system_prompt
167
+ )
168
+ for i, x in enumerate(all_items)
169
+ ]
170
+ print(f'{full_prompts[0]}\n')
171
+
172
+ full_num_tokens = [
173
+ len(MODEL_ENGINE.tokenizer.encode(p))
174
+ for p in full_prompts
175
+ ]
176
+ if any(x >= MODEL_ENGINE.max_position_embeddings - 128 for x in full_num_tokens):
177
+ raise gr.Error(f"Some prompt is too long!")
178
+
179
+ # ! batch inference
180
+ responses = MODEL_ENGINE.batch_generate(
181
+ full_prompts,
182
+ temperature=temperature, max_tokens=max_tokens,
183
+ stop_strings=stop_strings,
184
+ )
185
+
186
+ if len(responses) != len(all_items):
187
+ raise gr.Error(f'inconsistent lengths {len(responses)} != {len(all_items)}')
188
+
189
+ for res, item in zip(responses, all_items):
190
+ item['response'] = res
191
+
192
+ save_path = BATCH_INFER_SAVE_TMP_FILE
193
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
194
+ with open(save_path, 'w', encoding='utf-8') as f:
195
+ json.dump(all_items, f, indent=4, ensure_ascii=False)
196
+
197
+ print_items = all_items[:2]
198
+ print(json.dumps(print_items, indent=4, ensure_ascii=False))
199
+ return save_path, print_items
200
+
201
+
202
+ class BatchInferenceDemo(BaseDemo):
203
+ def tab_name(self):
204
+ return "Batch Inference"
205
+
206
+
207
+ def create_demo(
208
+ self,
209
+ title: str | None = None,
210
+ description: str | None = None,
211
+ **kwargs
212
+ ) -> gr.Blocks:
213
+ system_prompt = kwargs.get("system_prompt", SYSTEM_PROMPT)
214
+ max_tokens = kwargs.get("max_tokens", MAX_TOKENS)
215
+ temperature = kwargs.get("temperature", TEMPERATURE)
216
+ model_name = kwargs.get("model_name", MODEL_NAME)
217
+
218
+
219
+ demo_file_upload = gr.Interface(
220
+ batch_inference_engine,
221
+ inputs=[
222
+ gr.File(file_count='single', file_types=['json']),
223
+ gr.Radio(["chat", "few-shot"], value='chat', label="Chat or Few-shot mode", info="Chat's output more user-friendly, Few-shot's output more consistent with few-shot patterns."),
224
+ gr.Number(value=temperature, label='Temperature', info="Higher -> more random"),
225
+ gr.Number(value=max_tokens, label='Max tokens', info='Increase if want more generation'),
226
+ # gr.Number(value=frequence_penalty, label='Frequency penalty', info='> 0 encourage new tokens over repeated tokens'),
227
+ # gr.Number(value=presence_penalty, label='Presence penalty', info='> 0 encourage new tokens, < 0 encourage existing tokens'),
228
+ gr.Textbox(value="<s>,</s>,<|im_start|>", label='Stop strings', info='Comma-separated string to stop generation only in FEW-SHOT mode', lines=1),
229
+ # gr.Number(value=0, label='current_time', visible=False),
230
+ gr.Textbox(value=system_prompt, label='System prompt', lines=4)
231
+ ],
232
+ outputs=[
233
+ # "file",
234
+ gr.File(label="Generated file"),
235
+ # "json"
236
+ gr.JSON(label='Example outputs (display 2 samples)')
237
+ ],
238
+ description=FILE_UPLOAD_DESCRIPTION,
239
+ allow_flagging=False,
240
+ examples=[
241
+ ["upload_chat.json", "chat"],
242
+ ["upload_few_shot.json", "few-shot"],
243
+ ],
244
+ cache_examples=False,
245
+ )
246
+ return demo_file_upload
multipurpose_chatbot/demos/chat_interface.py ADDED
@@ -0,0 +1,704 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from gradio.themes import ThemeClass as Theme
3
+ import numpy as np
4
+ import argparse
5
+ import gradio as gr
6
+ from typing import Any, Iterator
7
+ from typing import Iterator, List, Optional, Tuple
8
+ import filelock
9
+ import glob
10
+ import json
11
+ import time
12
+ from gradio.routes import Request
13
+ from gradio.utils import SyncToAsyncIterator, async_iteration
14
+ from gradio.helpers import special_args
15
+ import anyio
16
+ from typing import AsyncGenerator, Callable, Literal, Union, cast, Generator
17
+
18
+ from gradio_client.documentation import document, set_documentation_group
19
+ from gradio.components import Button, Component
20
+ from gradio.events import Dependency, EventListenerMethod
21
+ from typing import List, Optional, Union, Dict, Tuple
22
+ from tqdm.auto import tqdm
23
+ from huggingface_hub import snapshot_download
24
+
25
+
26
+ import inspect
27
+ from typing import AsyncGenerator, Callable, Literal, Union, cast
28
+
29
+ import anyio
30
+ from gradio_client import utils as client_utils
31
+ from gradio_client.documentation import document
32
+
33
+ from gradio.blocks import Blocks
34
+ from gradio.components import (
35
+ Button,
36
+ Chatbot,
37
+ Component,
38
+ Markdown,
39
+ State,
40
+ Textbox,
41
+ get_component_instance,
42
+ )
43
+ from gradio.events import Dependency, on
44
+ from gradio.helpers import create_examples as Examples # noqa: N812
45
+ from gradio.helpers import special_args
46
+ from gradio.layouts import Accordion, Group, Row
47
+ from gradio.routes import Request
48
+ from gradio.themes import ThemeClass as Theme
49
+ from gradio.utils import SyncToAsyncIterator, async_iteration
50
+
51
+
52
+ from .base_demo import register_demo, get_demo_class, BaseDemo
53
+ from ..configs import (
54
+ SYSTEM_PROMPT,
55
+ MODEL_NAME,
56
+ MAX_TOKENS,
57
+ TEMPERATURE,
58
+ USE_PANEL,
59
+ CHATBOT_HEIGHT,
60
+ )
61
+
62
+ from ..globals import MODEL_ENGINE
63
+
64
+ CHAT_EXAMPLES = [
65
+ ["Explain general relativity."],
66
+ ]
67
+ DATETIME_FORMAT = "Current date time: {cur_datetime}."
68
+
69
+
70
+ def gradio_history_to_openai_conversations(message=None, history=None, system_prompt=None):
71
+ conversations = []
72
+ system_prompt = system_prompt or SYSTEM_PROMPT
73
+ if history is not None and len(history) > 0:
74
+ for i, (prompt, res) in enumerate(history):
75
+ if prompt is not None:
76
+ conversations.append({"role": "user", "content": prompt.strip()})
77
+ if res is not None:
78
+ conversations.append({"role": "assistant", "content": res.strip()})
79
+ if message is not None:
80
+ if len(message.strip()) == 0:
81
+ raise gr.Error("The message cannot be empty!")
82
+ conversations.append({"role": "user", "content": message.strip()})
83
+ if conversations[0]['role'] != 'system':
84
+ conversations = [{"role": "system", "content": system_prompt}] + conversations
85
+ return conversations
86
+
87
+
88
+ def gradio_history_to_conversation_prompt(message=None, history=None, system_prompt=None):
89
+ global MODEL_ENGINE
90
+ full_prompt = MODEL_ENGINE.apply_chat_template(
91
+ gradio_history_to_openai_conversations(
92
+ message, history=history, system_prompt=system_prompt),
93
+ add_generation_prompt=True
94
+ )
95
+ return full_prompt
96
+
97
+
98
+
99
+ def get_datetime_string():
100
+ from datetime import datetime
101
+ now = datetime.now()
102
+ # dd/mm/YY H:M:S
103
+ dt_string = now.strftime("%B %d, %Y, %H:%M:%S")
104
+ return dt_string
105
+
106
+
107
+ def format_conversation(history, system_prompt=None):
108
+ _str = '\n'.join([
109
+ (
110
+ f'<<<User>>> {h[0]}\n'
111
+ f'<<<Asst>>> {h[1]}'
112
+ )
113
+ for h in history
114
+ ])
115
+ _str = ""
116
+ for mes, res in history:
117
+ if mes is not None:
118
+ _str += f'<<<User>>> {mes}\n'
119
+ if res is not None:
120
+ _str += f'<<<Asst>>> {res}\n'
121
+ if system_prompt is not None:
122
+ _str = f"<<<Syst>>> {system_prompt}\n" + _str
123
+ return _str
124
+
125
+
126
+ def chat_response_stream_multiturn_engine(
127
+ message: str,
128
+ history: List[Tuple[str, str]],
129
+ temperature: float,
130
+ max_tokens: int,
131
+ system_prompt: Optional[str] = SYSTEM_PROMPT,
132
+ ):
133
+ global MODEL_ENGINE
134
+ temperature = float(temperature)
135
+ # ! remove frequency_penalty
136
+ # frequency_penalty = float(frequency_penalty)
137
+ max_tokens = int(max_tokens)
138
+ message = message.strip()
139
+ if len(message) == 0:
140
+ raise gr.Error("The message cannot be empty!")
141
+ # ! skip safety
142
+ if DATETIME_FORMAT in system_prompt:
143
+ # ! This sometime works sometimes dont
144
+ system_prompt = system_prompt.format(cur_datetime=get_datetime_string())
145
+ full_prompt = gradio_history_to_conversation_prompt(message.strip(), history=history, system_prompt=system_prompt)
146
+ # ! length checked
147
+ num_tokens = len(MODEL_ENGINE.tokenizer.encode(full_prompt))
148
+ if num_tokens >= MODEL_ENGINE.max_position_embeddings - 128:
149
+ raise gr.Error(f"Conversation or prompt is too long ({num_tokens} toks), please clear the chatbox or try shorter input.")
150
+ print(full_prompt)
151
+ outputs = None
152
+ response = None
153
+ num_tokens = -1
154
+ for j, outputs in enumerate(MODEL_ENGINE.generate_yield_string(
155
+ prompt=full_prompt,
156
+ temperature=temperature,
157
+ max_tokens=max_tokens,
158
+ )):
159
+ if isinstance(outputs, tuple):
160
+ response, num_tokens = outputs
161
+ else:
162
+ response, num_tokens = outputs, -1
163
+ yield response, num_tokens
164
+
165
+ print(format_conversation(history + [[message, response]]))
166
+
167
+ if response is not None:
168
+ yield response, num_tokens
169
+
170
+
171
+ class CustomizedChatInterface(gr.ChatInterface):
172
+ """
173
+ Fixing some issue with chatinterace
174
+ """
175
+
176
+ def __init__(
177
+ self,
178
+ fn: Callable,
179
+ *,
180
+ chatbot: Chatbot | None = None,
181
+ textbox: Textbox | None = None,
182
+ additional_inputs: str | Component | list[str | Component] | None = None,
183
+ additional_inputs_accordion_name: str | None = None,
184
+ additional_inputs_accordion: str | Accordion | None = None,
185
+ examples: list[str] | None = None,
186
+ cache_examples: bool | None = None,
187
+ title: str | None = None,
188
+ description: str | None = None,
189
+ theme: Theme | str | None = None,
190
+ css: str | None = None,
191
+ js: str | None = None,
192
+ head: str | None = None,
193
+ analytics_enabled: bool | None = None,
194
+ submit_btn: str | None | Button = "Submit",
195
+ stop_btn: str | None | Button = "Stop",
196
+ retry_btn: str | None | Button = "🔄 Retry",
197
+ undo_btn: str | None | Button = "↩️ Undo",
198
+ clear_btn: str | None | Button = "🗑️ Clear",
199
+ autofocus: bool = True,
200
+ concurrency_limit: int | None | Literal["default"] = "default",
201
+ fill_height: bool = True,
202
+ ):
203
+ """
204
+ Parameters:
205
+ fn: The function to wrap the chat interface around. Should accept two parameters: a string input message and list of two-element lists of the form [[user_message, bot_message], ...] representing the chat history, and return a string response. See the Chatbot documentation for more information on the chat history format.
206
+ chatbot: An instance of the gr.Chatbot component to use for the chat interface, if you would like to customize the chatbot properties. If not provided, a default gr.Chatbot component will be created.
207
+ textbox: An instance of the gr.Textbox component to use for the chat interface, if you would like to customize the textbox properties. If not provided, a default gr.Textbox component will be created.
208
+ additional_inputs: An instance or list of instances of gradio components (or their string shortcuts) to use as additional inputs to the chatbot. If components are not already rendered in a surrounding Blocks, then the components will be displayed under the chatbot, in an accordion.
209
+ additional_inputs_accordion_name: Deprecated. Will be removed in a future version of Gradio. Use the `additional_inputs_accordion` parameter instead.
210
+ additional_inputs_accordion: If a string is provided, this is the label of the `gr.Accordion` to use to contain additional inputs. A `gr.Accordion` object can be provided as well to configure other properties of the container holding the additional inputs. Defaults to a `gr.Accordion(label="Additional Inputs", open=False)`. This parameter is only used if `additional_inputs` is provided.
211
+ examples: Sample inputs for the function; if provided, appear below the chatbot and can be clicked to populate the chatbot input.
212
+ cache_examples: If True, caches examples in the server for fast runtime in examples. The default option in HuggingFace Spaces is True. The default option elsewhere is False.
213
+ title: a title for the interface; if provided, appears above chatbot in large font. Also used as the tab title when opened in a browser window.
214
+ description: a description for the interface; if provided, appears above the chatbot and beneath the title in regular font. Accepts Markdown and HTML content.
215
+ theme: Theme to use, loaded from gradio.themes.
216
+ css: Custom css as a string or path to a css file. This css will be included in the demo webpage.
217
+ js: Custom js or path to js file to run when demo is first loaded. This javascript will be included in the demo webpage.
218
+ head: Custom html to insert into the head of the demo webpage. This can be used to add custom meta tags, scripts, stylesheets, etc. to the page.
219
+ analytics_enabled: Whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable if defined, or default to True.
220
+ submit_btn: Text to display on the submit button. If None, no button will be displayed. If a Button object, that button will be used.
221
+ stop_btn: Text to display on the stop button, which replaces the submit_btn when the submit_btn or retry_btn is clicked and response is streaming. Clicking on the stop_btn will halt the chatbot response. If set to None, stop button functionality does not appear in the chatbot. If a Button object, that button will be used as the stop button.
222
+ retry_btn: Text to display on the retry button. If None, no button will be displayed. If a Button object, that button will be used.
223
+ undo_btn: Text to display on the delete last button. If None, no button will be displayed. If a Button object, that button will be used.
224
+ clear_btn: Text to display on the clear button. If None, no button will be displayed. If a Button object, that button will be used.
225
+ autofocus: If True, autofocuses to the textbox when the page loads.
226
+ concurrency_limit: If set, this is the maximum number of chatbot submissions that can be running simultaneously. Can be set to None to mean no limit (any number of chatbot submissions can be running simultaneously). Set to "default" to use the default concurrency limit (defined by the `default_concurrency_limit` parameter in `.queue()`, which is 1 by default).
227
+ fill_height: If True, the chat interface will expand to the height of window.
228
+ """
229
+ try:
230
+ super(gr.ChatInterface, self).__init__(
231
+ analytics_enabled=analytics_enabled,
232
+ mode="chat_interface",
233
+ css=css,
234
+ title=title or "Gradio",
235
+ theme=theme,
236
+ js=js,
237
+ head=head,
238
+ fill_height=fill_height,
239
+ )
240
+ except Exception as e:
241
+ # Handling some old gradio version with out fill_height
242
+ super(gr.ChatInterface, self).__init__(
243
+ analytics_enabled=analytics_enabled,
244
+ mode="chat_interface",
245
+ css=css,
246
+ title=title or "Gradio",
247
+ theme=theme,
248
+ js=js,
249
+ head=head,
250
+ # fill_height=fill_height,
251
+ )
252
+ self.concurrency_limit = concurrency_limit
253
+ self.fn = fn
254
+ self.is_async = inspect.iscoroutinefunction(
255
+ self.fn
256
+ ) or inspect.isasyncgenfunction(self.fn)
257
+ self.is_generator = inspect.isgeneratorfunction(
258
+ self.fn
259
+ ) or inspect.isasyncgenfunction(self.fn)
260
+ self.examples = examples
261
+ if self.space_id and cache_examples is None:
262
+ self.cache_examples = True
263
+ else:
264
+ self.cache_examples = cache_examples or False
265
+ self.buttons: list[Button | None] = []
266
+
267
+ if additional_inputs:
268
+ if not isinstance(additional_inputs, list):
269
+ additional_inputs = [additional_inputs]
270
+ self.additional_inputs = [
271
+ get_component_instance(i)
272
+ for i in additional_inputs # type: ignore
273
+ ]
274
+ else:
275
+ self.additional_inputs = []
276
+ if additional_inputs_accordion_name is not None:
277
+ print(
278
+ "The `additional_inputs_accordion_name` parameter is deprecated and will be removed in a future version of Gradio. Use the `additional_inputs_accordion` parameter instead."
279
+ )
280
+ self.additional_inputs_accordion_params = {
281
+ "label": additional_inputs_accordion_name
282
+ }
283
+ if additional_inputs_accordion is None:
284
+ self.additional_inputs_accordion_params = {
285
+ "label": "Additional Inputs",
286
+ "open": False,
287
+ }
288
+ elif isinstance(additional_inputs_accordion, str):
289
+ self.additional_inputs_accordion_params = {
290
+ "label": additional_inputs_accordion
291
+ }
292
+ elif isinstance(additional_inputs_accordion, Accordion):
293
+ self.additional_inputs_accordion_params = (
294
+ additional_inputs_accordion.recover_kwargs(
295
+ additional_inputs_accordion.get_config()
296
+ )
297
+ )
298
+ else:
299
+ raise ValueError(
300
+ f"The `additional_inputs_accordion` parameter must be a string or gr.Accordion, not {type(additional_inputs_accordion)}"
301
+ )
302
+
303
+ with self:
304
+ if title:
305
+ Markdown(
306
+ f"<h1 style='text-align: center; margin-bottom: 1rem'>{self.title}</h1>"
307
+ )
308
+ if description:
309
+ Markdown(description)
310
+
311
+ if chatbot:
312
+ self.chatbot = chatbot.render()
313
+ else:
314
+ self.chatbot = Chatbot(
315
+ label="Chatbot", scale=1, height=200 if fill_height else None
316
+ )
317
+
318
+ with Row():
319
+ for btn in [retry_btn, undo_btn, clear_btn]:
320
+ if btn is not None:
321
+ if isinstance(btn, Button):
322
+ btn.render()
323
+ elif isinstance(btn, str):
324
+ btn = Button(btn, variant="secondary", size="sm")
325
+ else:
326
+ raise ValueError(
327
+ f"All the _btn parameters must be a gr.Button, string, or None, not {type(btn)}"
328
+ )
329
+ self.buttons.append(btn) # type: ignore
330
+
331
+ with Group():
332
+ with Row():
333
+ if textbox:
334
+ textbox.container = False
335
+ textbox.show_label = False
336
+ textbox_ = textbox.render()
337
+ assert isinstance(textbox_, Textbox)
338
+ self.textbox = textbox_
339
+ else:
340
+ self.textbox = Textbox(
341
+ container=False,
342
+ show_label=False,
343
+ label="Message",
344
+ placeholder="Type a message...",
345
+ scale=7,
346
+ autofocus=autofocus,
347
+ )
348
+ if submit_btn is not None:
349
+ if isinstance(submit_btn, Button):
350
+ submit_btn.render()
351
+ elif isinstance(submit_btn, str):
352
+ submit_btn = Button(
353
+ submit_btn,
354
+ variant="primary",
355
+ scale=2,
356
+ min_width=150,
357
+ )
358
+ else:
359
+ raise ValueError(
360
+ f"The submit_btn parameter must be a gr.Button, string, or None, not {type(submit_btn)}"
361
+ )
362
+ if stop_btn is not None:
363
+ if isinstance(stop_btn, Button):
364
+ stop_btn.visible = False
365
+ stop_btn.render()
366
+ elif isinstance(stop_btn, str):
367
+ stop_btn = Button(
368
+ stop_btn,
369
+ variant="stop",
370
+ visible=False,
371
+ scale=2,
372
+ min_width=150,
373
+ )
374
+ else:
375
+ raise ValueError(
376
+ f"The stop_btn parameter must be a gr.Button, string, or None, not {type(stop_btn)}"
377
+ )
378
+ self.num_tokens = Textbox(
379
+ container=False,
380
+ show_label=False,
381
+ label="num_tokens",
382
+ placeholder="0 tokens",
383
+ scale=1,
384
+ interactive=False,
385
+ # autofocus=autofocus,
386
+ min_width=10
387
+ )
388
+ self.buttons.extend([submit_btn, stop_btn]) # type: ignore
389
+
390
+ self.fake_api_btn = Button("Fake API", visible=False)
391
+ self.fake_response_textbox = Textbox(label="Response", visible=False)
392
+ (
393
+ self.retry_btn,
394
+ self.undo_btn,
395
+ self.clear_btn,
396
+ self.submit_btn,
397
+ self.stop_btn,
398
+ ) = self.buttons
399
+
400
+ if examples:
401
+ if self.is_generator:
402
+ examples_fn = self._examples_stream_fn
403
+ else:
404
+ examples_fn = self._examples_fn
405
+
406
+ self.examples_handler = Examples(
407
+ examples=examples,
408
+ inputs=[self.textbox] + self.additional_inputs,
409
+ outputs=self.chatbot,
410
+ fn=examples_fn,
411
+ )
412
+
413
+ any_unrendered_inputs = any(
414
+ not inp.is_rendered for inp in self.additional_inputs
415
+ )
416
+ if self.additional_inputs and any_unrendered_inputs:
417
+ with Accordion(**self.additional_inputs_accordion_params): # type: ignore
418
+ for input_component in self.additional_inputs:
419
+ if not input_component.is_rendered:
420
+ input_component.render()
421
+
422
+ # The example caching must happen after the input components have rendered
423
+ if cache_examples:
424
+ client_utils.synchronize_async(self.examples_handler.cache)
425
+
426
+ self.saved_input = State()
427
+ self.chatbot_state = (
428
+ State(self.chatbot.value) if self.chatbot.value else State([])
429
+ )
430
+
431
+ self._setup_events()
432
+ self._setup_api()
433
+
434
+ # replace events so that submit button is disabled during generation, if stop_btn not found
435
+ # this prevent weird behavior
436
+ def _setup_stop_events(
437
+ self, event_triggers: list[EventListenerMethod], event_to_cancel: Dependency
438
+ ) -> None:
439
+ from gradio.components import State
440
+ event_triggers = event_triggers if isinstance(event_triggers, (list, tuple)) else [event_triggers]
441
+ if self.stop_btn and self.is_generator:
442
+ if self.submit_btn:
443
+ for event_trigger in event_triggers:
444
+ event_trigger(
445
+ lambda: (
446
+ Button(visible=False),
447
+ Button(visible=True),
448
+ ),
449
+ None,
450
+ [self.submit_btn, self.stop_btn],
451
+ api_name=False,
452
+ queue=False,
453
+ )
454
+ event_to_cancel.then(
455
+ lambda: (Button(visible=True), Button(visible=False)),
456
+ None,
457
+ [self.submit_btn, self.stop_btn],
458
+ api_name=False,
459
+ queue=False,
460
+ )
461
+ else:
462
+ for event_trigger in event_triggers:
463
+ event_trigger(
464
+ lambda: Button(visible=True),
465
+ None,
466
+ [self.stop_btn],
467
+ api_name=False,
468
+ queue=False,
469
+ )
470
+ event_to_cancel.then(
471
+ lambda: Button(visible=False),
472
+ None,
473
+ [self.stop_btn],
474
+ api_name=False,
475
+ queue=False,
476
+ )
477
+ self.stop_btn.click(
478
+ None,
479
+ None,
480
+ None,
481
+ cancels=event_to_cancel,
482
+ api_name=False,
483
+ )
484
+ else:
485
+ if self.submit_btn:
486
+ for event_trigger in event_triggers:
487
+ event_trigger(
488
+ lambda: Button(interactive=False),
489
+ None,
490
+ [self.submit_btn],
491
+ api_name=False,
492
+ queue=False,
493
+ )
494
+ event_to_cancel.then(
495
+ lambda: Button(interactive=True),
496
+ None,
497
+ [self.submit_btn],
498
+ api_name=False,
499
+ queue=False,
500
+ )
501
+ # upon clear, cancel the submit event as well
502
+ if self.clear_btn:
503
+ self.clear_btn.click(
504
+ lambda: ([], [], None, Button(interactive=True)),
505
+ None,
506
+ [self.chatbot, self.chatbot_state, self.saved_input, self.submit_btn],
507
+ queue=False,
508
+ api_name=False,
509
+ cancels=event_to_cancel,
510
+ )
511
+
512
+ def _setup_events(self) -> None:
513
+ from gradio.components import State
514
+ has_on = False
515
+ try:
516
+ from gradio.events import Dependency, EventListenerMethod, on
517
+ has_on = True
518
+ except ImportError as ie:
519
+ has_on = False
520
+ submit_fn = self._stream_fn if self.is_generator else self._submit_fn
521
+ if not self.is_generator:
522
+ raise NotImplementedError(f'should use generator')
523
+
524
+ if has_on:
525
+ # new version
526
+ submit_triggers = (
527
+ [self.textbox.submit, self.submit_btn.click]
528
+ if self.submit_btn
529
+ else [self.textbox.submit]
530
+ )
531
+ submit_event = (
532
+ on(
533
+ submit_triggers,
534
+ self._clear_and_save_textbox,
535
+ [self.textbox],
536
+ [self.textbox, self.saved_input],
537
+ api_name=False,
538
+ queue=False,
539
+ )
540
+ .then(
541
+ self._display_input,
542
+ [self.saved_input, self.chatbot_state],
543
+ [self.chatbot, self.chatbot_state],
544
+ api_name=False,
545
+ queue=False,
546
+ )
547
+ .then(
548
+ submit_fn,
549
+ [self.saved_input, self.chatbot_state] + self.additional_inputs,
550
+ [self.chatbot, self.chatbot_state, self.num_tokens],
551
+ api_name=False,
552
+ )
553
+ )
554
+ self._setup_stop_events(submit_triggers, submit_event)
555
+ else:
556
+ raise ValueError(f'Better install new gradio version than 3.44.0')
557
+
558
+ if self.retry_btn:
559
+ retry_event = (
560
+ self.retry_btn.click(
561
+ self._delete_prev_fn,
562
+ [self.chatbot_state],
563
+ [self.chatbot, self.saved_input, self.chatbot_state],
564
+ api_name=False,
565
+ queue=False,
566
+ )
567
+ .then(
568
+ self._display_input,
569
+ [self.saved_input, self.chatbot_state],
570
+ [self.chatbot, self.chatbot_state],
571
+ api_name=False,
572
+ queue=False,
573
+ )
574
+ .then(
575
+ submit_fn,
576
+ [self.saved_input, self.chatbot_state] + self.additional_inputs,
577
+ [self.chatbot, self.chatbot_state, self.num_tokens],
578
+ api_name=False,
579
+ )
580
+ )
581
+ self._setup_stop_events([self.retry_btn.click], retry_event)
582
+
583
+ if self.undo_btn:
584
+ self.undo_btn.click(
585
+ self._delete_prev_fn,
586
+ [self.chatbot_state],
587
+ [self.chatbot, self.saved_input, self.chatbot_state],
588
+ api_name=False,
589
+ queue=False,
590
+ ).then(
591
+ lambda x: x,
592
+ [self.saved_input],
593
+ [self.textbox],
594
+ api_name=False,
595
+ queue=False,
596
+ )
597
+ # Reconfigure clear_btn to stop and clear text box
598
+
599
+ def _clear_and_save_textbox(self, message: str) -> tuple[str, str]:
600
+ return "", message
601
+
602
+ def _display_input(
603
+ self, message: str, history: List[List[Union[str, None]]]
604
+ ) -> Tuple[List[List[Union[str, None]]], List[List[list[Union[str, None]]]]]:
605
+ if message is not None and message.strip() != "":
606
+ history.append([message, None])
607
+ return history, history
608
+
609
+ async def _stream_fn(
610
+ self,
611
+ message: str,
612
+ history_with_input,
613
+ request: Request,
614
+ *args,
615
+ ) -> AsyncGenerator:
616
+ history = history_with_input[:-1]
617
+ inputs, _, _ = special_args(
618
+ self.fn, inputs=[message, history, *args], request=request
619
+ )
620
+
621
+ if self.is_async:
622
+ generator = self.fn(*inputs)
623
+ else:
624
+ generator = await anyio.to_thread.run_sync(
625
+ self.fn, *inputs, limiter=self.limiter
626
+ )
627
+ generator = SyncToAsyncIterator(generator, self.limiter)
628
+
629
+ # ! In case of error, yield the previous history & undo any generation before raising error
630
+ try:
631
+ first_response_pack = await async_iteration(generator)
632
+ if isinstance(first_response_pack, (tuple, list)):
633
+ first_response, num_tokens = first_response_pack
634
+ else:
635
+ first_response, num_tokens = first_response_pack, -1
636
+ update = history + [[message, first_response]]
637
+ yield update, update, f"{num_tokens} toks"
638
+ except StopIteration:
639
+ update = history + [[message, None]]
640
+ yield update, update, "NaN toks"
641
+ except Exception as e:
642
+ yield history, history, "NaN toks"
643
+ raise e
644
+
645
+ try:
646
+ async for response_pack in generator:
647
+ if isinstance(response_pack, (tuple, list)):
648
+ response, num_tokens = response_pack
649
+ else:
650
+ response, num_tokens = response_pack, "NaN toks"
651
+ update = history + [[message, response]]
652
+ yield update, update, f"{num_tokens} toks"
653
+ except Exception as e:
654
+ yield history, history, "NaN toks"
655
+ raise e
656
+
657
+ @register_demo
658
+ class ChatInterfaceDemo(BaseDemo):
659
+ @property
660
+ def tab_name(self):
661
+ return "Chat"
662
+
663
+ def create_demo(
664
+ self,
665
+ title: str | None = None,
666
+ description: str | None = None,
667
+ **kwargs
668
+ ) -> gr.Blocks:
669
+ system_prompt = kwargs.get("system_prompt", SYSTEM_PROMPT)
670
+ max_tokens = kwargs.get("max_tokens", MAX_TOKENS)
671
+ temperature = kwargs.get("temperature", TEMPERATURE)
672
+ model_name = kwargs.get("model_name", MODEL_NAME)
673
+ # frequence_penalty = FREQUENCE_PENALTY
674
+ # presence_penalty = PRESENCE_PENALTY
675
+
676
+ demo_chat = CustomizedChatInterface(
677
+ chat_response_stream_multiturn_engine,
678
+ chatbot=gr.Chatbot(
679
+ label=model_name,
680
+ bubble_full_width=False,
681
+ latex_delimiters=[
682
+ { "left": "$", "right": "$", "display": False},
683
+ { "left": "$$", "right": "$$", "display": True},
684
+ ],
685
+ show_copy_button=True,
686
+ layout="panel" if USE_PANEL else "bubble",
687
+ height=CHATBOT_HEIGHT,
688
+ ),
689
+ textbox=gr.Textbox(placeholder='Type message', lines=1, max_lines=128, min_width=200, scale=8),
690
+ submit_btn=gr.Button(value='Submit', variant="primary", scale=0),
691
+ title=title,
692
+ description=description,
693
+ additional_inputs=[
694
+ gr.Number(value=temperature, label='Temperature (higher -> more random)'),
695
+ gr.Number(value=max_tokens, label='Max generated tokens (increase if want more generation)'),
696
+ # gr.Number(value=frequence_penalty, label='Frequency penalty (> 0 encourage new tokens over repeated tokens)'),
697
+ # gr.Number(value=presence_penalty, label='Presence penalty (> 0 encourage new tokens, < 0 encourage existing tokens)'),
698
+ gr.Textbox(value=system_prompt, label='System prompt', lines=4)
699
+ ],
700
+ examples=CHAT_EXAMPLES,
701
+ cache_examples=False
702
+ )
703
+ return demo_chat
704
+
multipurpose_chatbot/demos/multimodal_chat_interface.py ADDED
@@ -0,0 +1,1293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from gradio.themes import ThemeClass as Theme
3
+ import numpy as np
4
+ import argparse
5
+ import gradio as gr
6
+ from typing import Any, Iterator
7
+ from typing import Iterator, List, Optional, Tuple
8
+ import filelock
9
+ import glob
10
+ import json
11
+ import time
12
+ from gradio.routes import Request
13
+ from gradio.utils import SyncToAsyncIterator, async_iteration
14
+ from gradio.helpers import special_args
15
+ import anyio
16
+ from typing import AsyncGenerator, Callable, Literal, Union, cast, Generator
17
+
18
+ from gradio_client.documentation import document, set_documentation_group
19
+ from gradio.components import Button, Component
20
+ from gradio.events import Dependency, EventListenerMethod
21
+ from typing import List, Optional, Union, Dict, Tuple
22
+ from tqdm.auto import tqdm
23
+ from huggingface_hub import snapshot_download
24
+ from gradio.components.base import Component
25
+
26
+ from .base_demo import register_demo, get_demo_class, BaseDemo
27
+
28
+
29
+ from .chat_interface import (
30
+ SYSTEM_PROMPT,
31
+ MODEL_NAME,
32
+ MAX_TOKENS,
33
+ TEMPERATURE,
34
+ CHAT_EXAMPLES,
35
+ format_conversation,
36
+ gradio_history_to_openai_conversations,
37
+ gradio_history_to_conversation_prompt,
38
+ DATETIME_FORMAT,
39
+ get_datetime_string,
40
+ chat_response_stream_multiturn_engine,
41
+ ChatInterfaceDemo,
42
+ CustomizedChatInterface,
43
+ )
44
+
45
+ from gradio.events import Events
46
+
47
+ import inspect
48
+ from typing import AsyncGenerator, Callable, Literal, Union, cast
49
+
50
+ import anyio
51
+ from gradio_client import utils as client_utils
52
+ from gradio_client.documentation import document
53
+
54
+ from gradio.blocks import Blocks
55
+ from gradio.components import (
56
+ Button,
57
+ Chatbot,
58
+ Component,
59
+ Markdown,
60
+ State,
61
+ Textbox,
62
+ get_component_instance,
63
+ )
64
+ from gradio.events import Dependency, on
65
+ from gradio.helpers import create_examples as Examples # noqa: N812
66
+ from gradio.helpers import special_args
67
+ from gradio.layouts import Accordion, Group, Row
68
+ from gradio.routes import Request
69
+ from gradio.themes import ThemeClass as Theme
70
+ from gradio.utils import SyncToAsyncIterator, async_iteration
71
+
72
+ from ..globals import MODEL_ENGINE
73
+
74
+ from ..configs import (
75
+ USE_PANEL,
76
+ IMAGE_TOKEN,
77
+ IMAGE_TOKEN_INTERACTIVE,
78
+ CHATBOT_HEIGHT,
79
+ )
80
+
81
+
82
+
83
+ CSS = """
84
+ .message-fit {
85
+ min-width: 20em;
86
+ width: fit-content !important;
87
+ }
88
+
89
+ .message.svelte-1lcyrx4.svelte-1lcyrx4.svelte-1lcyrx4 {
90
+ padding-top: 1em;
91
+ padding-bottom: 1em;
92
+ }
93
+ """
94
+
95
+
96
+ DOC_TEMPLATE = """###
97
+ {content}
98
+ ###
99
+
100
+ """
101
+
102
+ DOC_INSTRUCTION = """Answer the following query exclusively based on the information provided in the document above. \
103
+ If the information is not found, please say so instead of making up facts! Remember to answer the question in the same language as the user query!
104
+ """
105
+
106
+
107
+ def undo_history(history):
108
+ if len(history) == 0:
109
+ return history
110
+ if history[-1][-1] is not None:
111
+ if history[-1][0] is not None:
112
+ history[-1][-1] = None
113
+ else:
114
+ history = history[:-1]
115
+ else:
116
+ history = history[:-1]
117
+ return history
118
+
119
+
120
+ def undo_history_until_last_assistant_turn(history):
121
+ history = undo_history(history)
122
+ while len(history) > 0 and history[-1][-1] is None:
123
+ history = undo_history(history)
124
+ return history, history
125
+
126
+
127
+ class MultiModalChatInterface(CustomizedChatInterface):
128
+ def __init__(
129
+ self,
130
+ fn: Callable,
131
+ *,
132
+ chatbot: Chatbot | None = None,
133
+ textbox: Textbox | None = None,
134
+ additional_inputs: str | Component | list[str | Component] | None = None,
135
+ additional_inputs_accordion_name: str | None = None,
136
+ additional_inputs_accordion: str | Accordion | None = None,
137
+ add_multimodal_fn: Callable | None = None,
138
+ render_additional_inputs_fn: Callable | None = None,
139
+ examples: list[str] | None = None,
140
+ cache_examples: bool | None = None,
141
+ title: str | None = None,
142
+ description: str | None = None,
143
+ theme: Theme | str | None = None,
144
+ css: str | None = None,
145
+ js: str | None = None,
146
+ head: str | None = None,
147
+ analytics_enabled: bool | None = None,
148
+ submit_btn: str | None | Button = "Submit",
149
+ stop_btn: str | None | Button = "Stop",
150
+ retry_btn: str | None | Button = "🔄 Retry",
151
+ undo_btn: str | None | Button = "↩️ Undo",
152
+ clear_btn: str | None | Button = "🗑️ Clear",
153
+ autofocus: bool = True,
154
+ concurrency_limit: int | None | Literal["default"] = "default",
155
+ fill_height: bool = True,
156
+ ):
157
+ """
158
+ Parameters:
159
+ fn: The function to wrap the chat interface around. Should accept two parameters: a string input message and list of two-element lists of the form [[user_message, bot_message], ...] representing the chat history, and return a string response. See the Chatbot documentation for more information on the chat history format.
160
+ chatbot: An instance of the gr.Chatbot component to use for the chat interface, if you would like to customize the chatbot properties. If not provided, a default gr.Chatbot component will be created.
161
+ textbox: An instance of the gr.Textbox component to use for the chat interface, if you would like to customize the textbox properties. If not provided, a default gr.Textbox component will be created.
162
+ additional_inputs: An instance or list of instances of gradio components (or their string shortcuts) to use as additional inputs to the chatbot. If components are not already rendered in a surrounding Blocks, then the components will be displayed under the chatbot, in an accordion.
163
+ additional_inputs_accordion_name: Deprecated. Will be removed in a future version of Gradio. Use the `additional_inputs_accordion` parameter instead.
164
+ additional_inputs_accordion: If a string is provided, this is the label of the `gr.Accordion` to use to contain additional inputs. A `gr.Accordion` object can be provided as well to configure other properties of the container holding the additional inputs. Defaults to a `gr.Accordion(label="Additional Inputs", open=False)`. This parameter is only used if `additional_inputs` is provided.
165
+ examples: Sample inputs for the function; if provided, appear below the chatbot and can be clicked to populate the chatbot input.
166
+ cache_examples: If True, caches examples in the server for fast runtime in examples. The default option in HuggingFace Spaces is True. The default option elsewhere is False.
167
+ title: a title for the interface; if provided, appears above chatbot in large font. Also used as the tab title when opened in a browser window.
168
+ description: a description for the interface; if provided, appears above the chatbot and beneath the title in regular font. Accepts Markdown and HTML content.
169
+ theme: Theme to use, loaded from gradio.themes.
170
+ css: Custom css as a string or path to a css file. This css will be included in the demo webpage.
171
+ js: Custom js or path to js file to run when demo is first loaded. This javascript will be included in the demo webpage.
172
+ head: Custom html to insert into the head of the demo webpage. This can be used to add custom meta tags, scripts, stylesheets, etc. to the page.
173
+ analytics_enabled: Whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable if defined, or default to True.
174
+ submit_btn: Text to display on the submit button. If None, no button will be displayed. If a Button object, that button will be used.
175
+ stop_btn: Text to display on the stop button, which replaces the submit_btn when the submit_btn or retry_btn is clicked and response is streaming. Clicking on the stop_btn will halt the chatbot response. If set to None, stop button functionality does not appear in the chatbot. If a Button object, that button will be used as the stop button.
176
+ retry_btn: Text to display on the retry button. If None, no button will be displayed. If a Button object, that button will be used.
177
+ undo_btn: Text to display on the delete last button. If None, no button will be displayed. If a Button object, that button will be used.
178
+ clear_btn: Text to display on the clear button. If None, no button will be displayed. If a Button object, that button will be used.
179
+ autofocus: If True, autofocuses to the textbox when the page loads.
180
+ concurrency_limit: If set, this is the maximum number of chatbot submissions that can be running simultaneously. Can be set to None to mean no limit (any number of chatbot submissions can be running simultaneously). Set to "default" to use the default concurrency limit (defined by the `default_concurrency_limit` parameter in `.queue()`, which is 1 by default).
181
+ fill_height: If True, the chat interface will expand to the height of window.
182
+ """
183
+ try:
184
+ super(gr.ChatInterface, self).__init__(
185
+ analytics_enabled=analytics_enabled,
186
+ mode="chat_interface",
187
+ css=css,
188
+ title=title or "Gradio",
189
+ theme=theme,
190
+ js=js,
191
+ head=head,
192
+ fill_height=fill_height,
193
+ )
194
+ except Exception as e:
195
+ # Handle old gradio versions without fill_height
196
+ super(gr.ChatInterface, self).__init__(
197
+ analytics_enabled=analytics_enabled,
198
+ mode="chat_interface",
199
+ css=css,
200
+ title=title or "Gradio",
201
+ theme=theme,
202
+ js=js,
203
+ head=head,
204
+ # fill_height=fill_height,
205
+ )
206
+
207
+ self.concurrency_limit = concurrency_limit
208
+ self.fn = fn
209
+ self.add_multimodal_fn = add_multimodal_fn
210
+ self.render_additional_inputs_fn = render_additional_inputs_fn
211
+ self.multimodal_inputs = []
212
+ self.is_async = inspect.iscoroutinefunction(
213
+ self.fn
214
+ ) or inspect.isasyncgenfunction(self.fn)
215
+ self.is_generator = inspect.isgeneratorfunction(
216
+ self.fn
217
+ ) or inspect.isasyncgenfunction(self.fn)
218
+ self.examples = examples
219
+ if self.space_id and cache_examples is None:
220
+ self.cache_examples = True
221
+ else:
222
+ self.cache_examples = cache_examples or False
223
+ self.buttons: list[Button | None] = []
224
+
225
+ if additional_inputs:
226
+ if not isinstance(additional_inputs, list):
227
+ additional_inputs = [additional_inputs]
228
+ self.additional_inputs = [
229
+ get_component_instance(i)
230
+ for i in additional_inputs # type: ignore
231
+ ]
232
+ else:
233
+ self.additional_inputs = []
234
+ if additional_inputs_accordion_name is not None:
235
+ print(
236
+ "The `additional_inputs_accordion_name` parameter is deprecated and will be removed in a future version of Gradio. Use the `additional_inputs_accordion` parameter instead."
237
+ )
238
+ self.additional_inputs_accordion_params = {
239
+ "label": additional_inputs_accordion_name
240
+ }
241
+ if additional_inputs_accordion is None:
242
+ self.additional_inputs_accordion_params = {
243
+ "label": "Additional Inputs",
244
+ "open": False,
245
+ }
246
+ elif isinstance(additional_inputs_accordion, str):
247
+ self.additional_inputs_accordion_params = {
248
+ "label": additional_inputs_accordion
249
+ }
250
+ elif isinstance(additional_inputs_accordion, Accordion):
251
+ self.additional_inputs_accordion_params = (
252
+ additional_inputs_accordion.recover_kwargs(
253
+ additional_inputs_accordion.get_config()
254
+ )
255
+ )
256
+ else:
257
+ raise ValueError(
258
+ f"The `additional_inputs_accordion` parameter must be a string or gr.Accordion, not {type(additional_inputs_accordion)}"
259
+ )
260
+
261
+ with self:
262
+ if title:
263
+ Markdown(
264
+ f"<h1 style='text-align: center; margin-bottom: 1rem'>{self.title}</h1>"
265
+ )
266
+ if description:
267
+ Markdown(description)
268
+
269
+ if chatbot:
270
+ self.chatbot = chatbot.render()
271
+ else:
272
+ self.chatbot = Chatbot(
273
+ label="Chatbot", scale=1, height=200 if fill_height else None
274
+ )
275
+
276
+ with Row():
277
+ for btn in [retry_btn, undo_btn, clear_btn]:
278
+ if btn is not None:
279
+ if isinstance(btn, Button):
280
+ btn.render()
281
+ elif isinstance(btn, str):
282
+ btn = Button(btn, variant="secondary", size="sm")
283
+ else:
284
+ raise ValueError(
285
+ f"All the _btn parameters must be a gr.Button, string, or None, not {type(btn)}"
286
+ )
287
+ self.buttons.append(btn) # type: ignore
288
+
289
+ with Group():
290
+ with Row():
291
+ if textbox:
292
+ textbox.container = False
293
+ textbox.show_label = False
294
+ textbox_ = textbox.render()
295
+ assert isinstance(textbox_, Textbox)
296
+ self.textbox = textbox_
297
+ else:
298
+ self.textbox = Textbox(
299
+ container=False,
300
+ show_label=False,
301
+ label="Message",
302
+ placeholder="Type a message...",
303
+ scale=7,
304
+ autofocus=autofocus,
305
+ )
306
+ if submit_btn is not None:
307
+ if isinstance(submit_btn, Button):
308
+ submit_btn.render()
309
+ elif isinstance(submit_btn, str):
310
+ submit_btn = Button(
311
+ submit_btn,
312
+ variant="primary",
313
+ scale=2,
314
+ min_width=150,
315
+ )
316
+ else:
317
+ raise ValueError(
318
+ f"The submit_btn parameter must be a gr.Button, string, or None, not {type(submit_btn)}"
319
+ )
320
+ if stop_btn is not None:
321
+ if isinstance(stop_btn, Button):
322
+ stop_btn.visible = False
323
+ stop_btn.render()
324
+ elif isinstance(stop_btn, str):
325
+ stop_btn = Button(
326
+ stop_btn,
327
+ variant="stop",
328
+ visible=False,
329
+ scale=2,
330
+ min_width=150,
331
+ )
332
+ else:
333
+ raise ValueError(
334
+ f"The stop_btn parameter must be a gr.Button, string, or None, not {type(stop_btn)}"
335
+ )
336
+ self.num_tokens = Textbox(
337
+ container=False,
338
+ show_label=False,
339
+ label="num_tokens",
340
+ placeholder="0 tokens",
341
+ scale=1,
342
+ interactive=False,
343
+ # autofocus=autofocus,
344
+ min_width=10
345
+ )
346
+ self.buttons.extend([submit_btn, stop_btn]) # type: ignore
347
+
348
+ self.fake_api_btn = Button("Fake API", visible=False)
349
+ self.fake_response_textbox = Textbox(label="Response", visible=False)
350
+ (
351
+ self.retry_btn,
352
+ self.undo_btn,
353
+ self.clear_btn,
354
+ self.submit_btn,
355
+ self.stop_btn,
356
+ ) = self.buttons
357
+
358
+
359
+ any_unrendered_inputs = any(
360
+ not inp.is_rendered for inp in self.additional_inputs
361
+ )
362
+ if self.add_multimodal_fn is not None:
363
+ with Row():
364
+ self.multimodal_inputs = self.add_multimodal_fn()
365
+ if self.additional_inputs and any_unrendered_inputs:
366
+ with Accordion(**self.additional_inputs_accordion_params): # type: ignore
367
+ if self.render_additional_inputs_fn is not None:
368
+ self.render_additional_inputs_fn()
369
+ else:
370
+ for input_component in self.additional_inputs:
371
+ if not input_component.is_rendered:
372
+ input_component.render()
373
+ else:
374
+ if self.additional_inputs and any_unrendered_inputs:
375
+ with Accordion(**self.additional_inputs_accordion_params): # type: ignore
376
+ if self.render_additional_inputs_fn is not None:
377
+ self.render_additional_inputs_fn()
378
+ else:
379
+ for input_component in self.additional_inputs:
380
+ if not input_component.is_rendered:
381
+ input_component.render()
382
+
383
+ if examples:
384
+ if self.is_generator:
385
+ examples_fn = self._examples_stream_fn
386
+ else:
387
+ # examples_fn = self._examples_fn
388
+ raise NotImplementedError(f'Not streaming not impl')
389
+
390
+ self.examples_handler = Examples(
391
+ examples=examples,
392
+ inputs=[self.textbox] + self.multimodal_inputs + self.additional_inputs,
393
+ outputs=self.chatbot,
394
+ fn=examples_fn,
395
+ )
396
+
397
+ # The example caching must happen after the input components have rendered
398
+ if cache_examples:
399
+ client_utils.synchronize_async(self.examples_handler.cache)
400
+
401
+ self.saved_input = State()
402
+ self.chatbot_state = (
403
+ State(self.chatbot.value) if self.chatbot.value else State([])
404
+ )
405
+
406
+ self._setup_events()
407
+ self._setup_api()
408
+
409
+ def _clear_and_save_textbox(self, message: str, *multimodal_inputs) -> tuple[str, str]:
410
+ saved_input = [message] + list(multimodal_inputs)
411
+ outputs = [''] + [None] * len(multimodal_inputs)
412
+ return outputs + [saved_input]
413
+
414
+ def _add_inputs_to_history(self, history: List[List[Union[str, None]]], *args):
415
+ message = args[0]
416
+ multimodal_inputs = args[1:1 + len(self.multimodal_inputs)] if len(args) > 1 else None
417
+ if multimodal_inputs is not None:
418
+ is_file_exists = [(x is not None and os.path.exists(x)) for x in multimodal_inputs]
419
+ if any(is_file_exists):
420
+ file_exists = [f for f, ise in zip(multimodal_inputs, is_file_exists) if ise]
421
+ if len(file_exists) > 1:
422
+ raise gr.Error(f"Cannot have more than 1 multimodal input at a time.")
423
+ fname = file_exists[0]
424
+ history.append([(fname,), None])
425
+ if message is not None and message.strip() != "":
426
+ history.append([message, None])
427
+ return history
428
+
429
+
430
+ def _display_input(
431
+ self, saved_input: List[str], history: List[List[Union[str, None]]]
432
+ ) -> Tuple[List[List[Union[str, None]]], List[List[list[Union[str, None]]]]]:
433
+ # message = saved_input[0]
434
+ # multimodal_inputs = saved_input[1:] if len(saved_input) > 1 else None
435
+ # # ! If things wrong, return original history and give warning
436
+ # if multimodal_inputs is not None:
437
+ # is_file_exists = [(x is not None and os.path.exists(x)) for x in multimodal_inputs]
438
+ # if any(is_file_exists):
439
+ # file_exists = [f for f, ise in zip(multimodal_inputs, is_file_exists) if ise]
440
+ # if len(file_exists) > 1:
441
+ # raise gr.Error(f"Cannot have more than 1 multimodal input at a time.")
442
+ # fname = file_exists[0]
443
+ # history.append([(fname,), None])
444
+ # if message is not None and message.strip() != "":
445
+ # history.append([message, None])
446
+ history = self._add_inputs_to_history(history, *saved_input)
447
+ return history, history
448
+
449
+ def _delete_prev_fn(
450
+ self, history: list[list[str | None]]
451
+ ) -> tuple[list[list[str | None]], str, list[list[str | None]]]:
452
+ try:
453
+ message, _ = history.pop()
454
+ except IndexError:
455
+ message = ""
456
+ saved_input = [message or ""] + [None] * len(self.multimodal_inputs)
457
+ return history, saved_input, history
458
+
459
+ def _setup_events(self) -> None:
460
+ from gradio.components import State
461
+ has_on = False
462
+ try:
463
+ from gradio.events import Dependency, EventListenerMethod, on
464
+ has_on = True
465
+ except ImportError as ie:
466
+ has_on = False
467
+ submit_fn = self._stream_fn if self.is_generator else self._submit_fn
468
+ if not self.is_generator:
469
+ raise NotImplementedError(f'should use generator')
470
+
471
+ if has_on:
472
+ # new version
473
+ submit_triggers = (
474
+ [self.textbox.submit, self.submit_btn.click]
475
+ if self.submit_btn
476
+ else [self.textbox.submit]
477
+ )
478
+ submit_event = (
479
+ on(
480
+ submit_triggers,
481
+ self._clear_and_save_textbox,
482
+ [self.textbox] + self.multimodal_inputs,
483
+ [self.textbox] + self.multimodal_inputs + [self.saved_input],
484
+ api_name=False,
485
+ queue=False,
486
+ )
487
+ .then(
488
+ self._display_input,
489
+ [self.saved_input, self.chatbot_state],
490
+ [self.chatbot, self.chatbot_state],
491
+ api_name=False,
492
+ queue=False,
493
+ )
494
+ .success(
495
+ submit_fn,
496
+ [self.chatbot_state] + self.additional_inputs,
497
+ [self.chatbot, self.chatbot_state, self.num_tokens],
498
+ api_name=False,
499
+ )
500
+ )
501
+ self._setup_stop_events(submit_triggers, submit_event)
502
+ else:
503
+ raise ValueError(f'Better install new gradio version than 3.44.0')
504
+
505
+ if self.retry_btn:
506
+ retry_event = (
507
+ self.retry_btn.click(
508
+ self._delete_prev_fn,
509
+ [self.chatbot_state],
510
+ [self.chatbot, self.saved_input, self.chatbot_state],
511
+ api_name=False,
512
+ queue=False,
513
+ )
514
+ .then(
515
+ self._display_input,
516
+ [self.saved_input, self.chatbot_state],
517
+ [self.chatbot, self.chatbot_state],
518
+ api_name=False,
519
+ queue=False,
520
+ )
521
+ .success(
522
+ submit_fn,
523
+ [self.chatbot_state] + self.additional_inputs,
524
+ [self.chatbot, self.chatbot_state, self.num_tokens],
525
+ api_name=False,
526
+ )
527
+ )
528
+ self._setup_stop_events([self.retry_btn.click], retry_event)
529
+
530
+ if self.undo_btn:
531
+ self.undo_btn.click(
532
+ # self._delete_prev_fn,
533
+ # [self.chatbot_state],
534
+ # [self.chatbot, self.saved_input, self.chatbot_state],
535
+ undo_history_until_last_assistant_turn,
536
+ [self.chatbot_state],
537
+ [self.chatbot, self.chatbot_state],
538
+ api_name=False,
539
+ queue=False,
540
+ )
541
+ # .then(
542
+ # lambda x: x,
543
+ # [self.saved_input],
544
+ # [self.textbox],
545
+ # api_name=False,
546
+ # queue=False,
547
+ # )
548
+
549
+ async def _stream_fn(
550
+ self,
551
+ # message: str,
552
+ history_with_input,
553
+ request: Request,
554
+ *args,
555
+ ) -> AsyncGenerator:
556
+ history = history_with_input[:-1]
557
+ message = history_with_input[-1][0]
558
+ inputs, _, _ = special_args(
559
+ self.fn, inputs=[history_with_input, *args], request=request
560
+ )
561
+
562
+ if self.is_async:
563
+ generator = self.fn(*inputs)
564
+ else:
565
+ generator = await anyio.to_thread.run_sync(
566
+ self.fn, *inputs, limiter=self.limiter
567
+ )
568
+ generator = SyncToAsyncIterator(generator, self.limiter)
569
+
570
+ # ! In case of error, yield the previous history & undo any generation before raising error
571
+ try:
572
+ first_response_pack = await async_iteration(generator)
573
+ if isinstance(first_response_pack, (tuple, list)):
574
+ first_response, num_tokens = first_response_pack
575
+ else:
576
+ first_response, num_tokens = first_response_pack, -1
577
+ update = history + [[message, first_response]]
578
+ yield update, update, f"{num_tokens} toks"
579
+ except StopIteration:
580
+ update = history + [[message, None]]
581
+ yield update, update, "NaN toks"
582
+ except Exception as e:
583
+ yield history, history, "NaN toks"
584
+ raise e
585
+
586
+ try:
587
+ async for response_pack in generator:
588
+ if isinstance(response_pack, (tuple, list)):
589
+ response, num_tokens = response_pack
590
+ else:
591
+ response, num_tokens = response_pack, "NaN toks"
592
+ update = history + [[message, response]]
593
+ yield update, update, f"{num_tokens} toks"
594
+ except Exception as e:
595
+ yield history, history, "NaN toks"
596
+ raise e
597
+
598
+ async def _examples_stream_fn(
599
+ self,
600
+ # message: str,
601
+ *args,
602
+ ) -> AsyncGenerator:
603
+ history = []
604
+ input_len = 1 + len(self.multimodal_inputs)
605
+ saved_input = args[:input_len]
606
+ message = saved_input[0]
607
+ additional_inputs = [] if len(args) <= input_len else args[input_len:]
608
+ history = self._add_inputs_to_history(history, *saved_input)
609
+ inputs, _, _ = special_args(self.fn, inputs=[history, *additional_inputs], request=None)
610
+
611
+ if self.is_async:
612
+ generator = self.fn(*inputs)
613
+ else:
614
+ generator = await anyio.to_thread.run_sync(
615
+ self.fn, *inputs, limiter=self.limiter
616
+ )
617
+ generator = SyncToAsyncIterator(generator, self.limiter)
618
+ # async for response in generator:
619
+ # yield [[message, response]]
620
+
621
+ try:
622
+ async for response_pack in generator:
623
+ if isinstance(response_pack, (tuple, list)):
624
+ response, num_tokens = response_pack
625
+ else:
626
+ response, num_tokens = response_pack, "NaN toks"
627
+ update = history + [[message, response]]
628
+ yield update, update, f"{num_tokens} toks"
629
+ except Exception as e:
630
+ yield history, history, "NaN toks"
631
+ raise e
632
+
633
+ async def _examples_fn(self, message: str, *args) -> list[list[str | None]]:
634
+ raise NotImplementedError
635
+ inputs, _, _ = special_args(self.fn, inputs=[message, [], *args], request=None)
636
+
637
+ if self.is_async:
638
+ response = await self.fn(*inputs)
639
+ else:
640
+ response = await anyio.to_thread.run_sync(
641
+ self.fn, *inputs, limiter=self.limiter
642
+ )
643
+ return [[message, response]]
644
+
645
+
646
+
647
+ def gradio_history_to_openai_conversations(message=None, history=None, system_prompt=None):
648
+ conversations = []
649
+ system_prompt = system_prompt or SYSTEM_PROMPT
650
+ if history is not None and len(history) > 0:
651
+ for i, (prompt, res) in enumerate(history):
652
+ if prompt is not None:
653
+ conversations.append({"role": "user", "content": prompt.strip()})
654
+ if res is not None:
655
+ conversations.append({"role": "assistant", "content": res.strip()})
656
+ if message is not None:
657
+ if len(message.strip()) == 0:
658
+ raise gr.Error("The message cannot be empty!")
659
+ conversations.append({"role": "user", "content": message.strip()})
660
+ if conversations[0]['role'] != 'system':
661
+ conversations = [{"role": "system", "content": system_prompt}] + conversations
662
+ return conversations
663
+
664
+
665
+ def gradio_history_to_conversation_prompt(message=None, history=None, system_prompt=None):
666
+ global MODEL_ENGINE
667
+ full_prompt = MODEL_ENGINE.apply_chat_template(
668
+ gradio_history_to_openai_conversations(
669
+ message, history=history, system_prompt=system_prompt),
670
+ add_generation_prompt=True
671
+ )
672
+ return full_prompt
673
+
674
+
675
+ def gradio_history_to_vision_conversation_prompt_paths(
676
+ history, system_prompt=None, image_token=None
677
+ ):
678
+ """
679
+ Aggregate gradio history into openai conversations
680
+ history = [
681
+ ["Hello", "Response"],
682
+ [(file,), None],
683
+ ]
684
+ --->
685
+ [
686
+ {"role": "user", "content": ...}
687
+ ]
688
+ """
689
+ global MODEL_ENGINE
690
+ image_token = image_token or IMAGE_TOKEN
691
+ conversations = []
692
+ image_paths = []
693
+ for i, his in enumerate(history):
694
+ prompt, response = his
695
+ last_turn = conversations[-1] if len(conversations) > 0 else None
696
+ if prompt is not None:
697
+ if isinstance(prompt, tuple):
698
+ image_path = prompt[0]
699
+ if last_turn is not None and last_turn['role'] == 'user':
700
+ last_turn['content'] += f" {image_token}"
701
+ else:
702
+ # last_turn None or last_turn['role'] == 'assistant'
703
+ conversations.append({
704
+ "role": "user",
705
+ "content": f"{image_token}"
706
+ })
707
+ image_paths.append(image_path)
708
+ else:
709
+ assert prompt is not None and isinstance(prompt, str)
710
+ if last_turn is not None and last_turn['role'] == 'user':
711
+ last_turn['content'] += f"\n{prompt}"
712
+ else:
713
+ conversations.append({
714
+ "role": "user",
715
+ "content": prompt,
716
+ })
717
+ if response is not None:
718
+ assert isinstance(response, str)
719
+ conversations.append({
720
+ "role": "assistant",
721
+ "content": response,
722
+ })
723
+
724
+ if conversations[0]['role'] != 'system':
725
+ system_prompt = system_prompt or SYSTEM_PROMPT
726
+ conversations = [{"role": "system", "content": system_prompt}] + conversations
727
+
728
+ # print(f'convo: {json.dumps(conversations, indent=4, ensure_ascii=False)}\n{image_paths=}')
729
+ full_prompt = MODEL_ENGINE.apply_chat_template(
730
+ conversations,
731
+ add_generation_prompt=True
732
+ )
733
+ return full_prompt, image_paths, conversations
734
+
735
+
736
+ def is_doc(file_path):
737
+ is_doc_allowed = file_path.endswith((".pdf", ".docx", ".txt"))
738
+ return is_doc_allowed
739
+
740
+
741
+ def read_doc(file_path):
742
+ from langchain_community.document_loaders import PyPDFLoader, Docx2txtLoader, TextLoader
743
+ if file_path.endswith('.pdf'):
744
+ loader = PyPDFLoader(file_path)
745
+ elif file_path.endswith('.docx'):
746
+ loader = Docx2txtLoader(file_path)
747
+ elif file_path.endswith('.txt'):
748
+ loader = TextLoader(file_path)
749
+ texts = loader.load()
750
+ text = "\n\n".join([t.page_content for t in texts])
751
+ return text
752
+
753
+
754
+ def doc_file_to_instruct_content(file_path, doc_instruction=None):
755
+ doc_instruction = doc_instruction or DOC_INSTRUCTION
756
+ content = doc_instruction.strip() + "\n" + DOC_TEMPLATE.format(content=read_doc(file_path))
757
+ return content
758
+
759
+
760
+ def gradio_history_to_doc_conversation_prompt(
761
+ history, system_prompt=None, doc_instruction=None,
762
+ ):
763
+ """
764
+ Aggregate gradio history into openai conversations
765
+ history = [
766
+ ["Hello", "Response"],
767
+ [(file,), None],
768
+ ]
769
+ --->
770
+ [
771
+ {"role": "user", "content": ...}
772
+ ]
773
+ """
774
+ global MODEL_ENGINE
775
+ # image_token = image_token or IMAGE_TOKEN
776
+ doc_instruction = doc_instruction or DOC_INSTRUCTION
777
+ conversations = []
778
+ image_paths = []
779
+ for i, his in enumerate(history):
780
+ prompt, response = his
781
+ last_turn = conversations[-1] if len(conversations) > 0 else None
782
+ if prompt is not None:
783
+ if isinstance(prompt, tuple):
784
+ file_path = prompt[0]
785
+ if not is_doc(file_path):
786
+ raise gr.Error(f'file not doc {file_path}')
787
+ content = doc_file_to_instruct_content(file_path, doc_instruction)
788
+ if last_turn is not None and last_turn['role'] == 'user':
789
+ last_turn['content'] += f"{content}"
790
+ else:
791
+ # last_turn None or last_turn['role'] == 'assistant'
792
+ conversations.append({
793
+ "role": "user",
794
+ "content": f"{content}"
795
+ })
796
+ else:
797
+ assert prompt is not None and isinstance(prompt, str)
798
+ if last_turn is not None and last_turn['role'] == 'user':
799
+ last_turn['content'] += f"\n{prompt}"
800
+ else:
801
+ conversations.append({
802
+ "role": "user",
803
+ "content": prompt,
804
+ })
805
+ if response is not None:
806
+ assert isinstance(response, str)
807
+ conversations.append({
808
+ "role": "assistant",
809
+ "content": response,
810
+ })
811
+
812
+ if conversations[0]['role'] != 'system':
813
+ system_prompt = system_prompt or SYSTEM_PROMPT
814
+ conversations = [{"role": "system", "content": system_prompt}] + conversations
815
+
816
+ full_prompt = MODEL_ENGINE.apply_chat_template(
817
+ conversations,
818
+ add_generation_prompt=True
819
+ )
820
+ return full_prompt, conversations
821
+
822
+
823
+ def gradio_history_to_vision_doc_conversation_prompt_paths(
824
+ history, system_prompt=None, image_token=None, doc_instruction=None,
825
+ ):
826
+ """
827
+ Aggregate gradio history into openai conversations
828
+ history = [
829
+ ["Hello", "Response"],
830
+ [(file,), None],
831
+ ]
832
+ --->
833
+ [
834
+ {"role": "user", "content": ...}
835
+ ]
836
+ """
837
+ global MODEL_ENGINE
838
+ image_token = image_token or IMAGE_TOKEN
839
+ doc_instruction = doc_instruction or DOC_INSTRUCTION
840
+ conversations = []
841
+ image_paths = []
842
+ for i, his in enumerate(history):
843
+ prompt, response = his
844
+ last_turn = conversations[-1] if len(conversations) > 0 else None
845
+ if prompt is not None:
846
+ if isinstance(prompt, tuple):
847
+ file_path = prompt[0]
848
+ if is_doc(file_path):
849
+ content = doc_file_to_instruct_content(file_path, doc_instruction)
850
+ if last_turn is not None and last_turn['role'] == 'user':
851
+ last_turn['content'] += f"{content}"
852
+ else:
853
+ # last_turn None or last_turn['role'] == 'assistant'
854
+ conversations.append({
855
+ "role": "user",
856
+ "content": f"{content}"
857
+ })
858
+ else:
859
+ if last_turn is not None and last_turn['role'] == 'user':
860
+ last_turn['content'] += f" {image_token}"
861
+ else:
862
+ # last_turn None or last_turn['role'] == 'assistant'
863
+ conversations.append({
864
+ "role": "user",
865
+ "content": f"{image_token}"
866
+ })
867
+ image_paths.append(file_path)
868
+ else:
869
+ assert prompt is not None and isinstance(prompt, str)
870
+ if last_turn is not None and last_turn['role'] == 'user':
871
+ last_turn['content'] += f"\n{prompt}"
872
+ else:
873
+ conversations.append({
874
+ "role": "user",
875
+ "content": prompt,
876
+ })
877
+ if response is not None:
878
+ assert isinstance(response, str)
879
+ conversations.append({
880
+ "role": "assistant",
881
+ "content": response,
882
+ })
883
+
884
+ if conversations[0]['role'] != 'system':
885
+ system_prompt = system_prompt or SYSTEM_PROMPT
886
+ conversations = [{"role": "system", "content": system_prompt}] + conversations
887
+
888
+ full_prompt = MODEL_ENGINE.apply_chat_template(
889
+ conversations,
890
+ add_generation_prompt=True
891
+ )
892
+ return full_prompt, image_paths, conversations
893
+
894
+
895
+ def vision_chat_response_stream_multiturn_engine(
896
+ history: List[Tuple[str, str]],
897
+ temperature: float,
898
+ max_tokens: int,
899
+ system_prompt: Optional[str] = SYSTEM_PROMPT,
900
+ image_token: Optional[str] = IMAGE_TOKEN,
901
+ ):
902
+ global MODEL_ENGINE
903
+ temperature = float(temperature)
904
+ # ! remove frequency_penalty
905
+ # frequency_penalty = float(frequency_penalty)
906
+ max_tokens = int(max_tokens)
907
+ # ! skip safety
908
+ if DATETIME_FORMAT in system_prompt:
909
+ # ! This sometime works sometimes dont
910
+ system_prompt = system_prompt.format(cur_datetime=get_datetime_string())
911
+ # ! history now can have multimodal
912
+
913
+ full_prompt, image_paths, conversations = gradio_history_to_vision_conversation_prompt_paths(
914
+ history=history, system_prompt=system_prompt, image_token=image_token
915
+ )
916
+
917
+ if hasattr(MODEL_ENGINE, "get_multimodal_tokens"):
918
+ num_tokens = MODEL_ENGINE.get_multimodal_tokens(full_prompt, image_paths=image_paths)
919
+ else:
920
+ num_tokens = len(MODEL_ENGINE.tokenizer.encode(full_prompt))
921
+ if num_tokens >= MODEL_ENGINE.max_position_embeddings - 128:
922
+ raise gr.Error(f"Conversation or prompt is too long ({num_tokens} toks), please clear the chatbox or try shorter input.")
923
+
924
+ print(f'{image_paths=}')
925
+ print(full_prompt)
926
+ outputs = None
927
+ response = None
928
+ num_tokens = -1
929
+ for j, outputs in enumerate(MODEL_ENGINE.generate_yield_string(
930
+ prompt=full_prompt,
931
+ temperature=temperature,
932
+ max_tokens=max_tokens,
933
+ image_paths=image_paths,
934
+ )):
935
+ if isinstance(outputs, tuple):
936
+ response, num_tokens = outputs
937
+ else:
938
+ response, num_tokens = outputs, -1
939
+ yield response, num_tokens
940
+
941
+ print(format_conversation(history + [[None, response]]))
942
+
943
+ if response is not None:
944
+ yield response, num_tokens
945
+
946
+
947
+ def doc_chat_response_stream_multiturn_engine(
948
+ history: List[Tuple[str, str]],
949
+ temperature: float,
950
+ max_tokens: int,
951
+ system_prompt: Optional[str] = SYSTEM_PROMPT,
952
+ doc_instruction: Optional[str] = DOC_INSTRUCTION,
953
+ ):
954
+ global MODEL_ENGINE
955
+ temperature = float(temperature)
956
+ # ! remove frequency_penalty
957
+ # frequency_penalty = float(frequency_penalty)
958
+ max_tokens = int(max_tokens)
959
+ # ! skip safety
960
+ if DATETIME_FORMAT in system_prompt:
961
+ # ! This sometime works sometimes dont
962
+ system_prompt = system_prompt.format(cur_datetime=get_datetime_string())
963
+ # ! history now can have multimodal
964
+
965
+ full_prompt, conversations = gradio_history_to_doc_conversation_prompt(
966
+ history=history, system_prompt=system_prompt, doc_instruction=doc_instruction
967
+ )
968
+
969
+ # ! length checked
970
+ num_tokens = len(MODEL_ENGINE.tokenizer.encode(full_prompt))
971
+ if num_tokens >= MODEL_ENGINE.max_position_embeddings - 128:
972
+ raise gr.Error(f"Conversation or prompt is too long ({num_tokens} toks), please clear the chatbox or try shorter input.")
973
+
974
+ print(full_prompt)
975
+ outputs = None
976
+ response = None
977
+ num_tokens = -1
978
+ for j, outputs in enumerate(MODEL_ENGINE.generate_yield_string(
979
+ prompt=full_prompt,
980
+ temperature=temperature,
981
+ max_tokens=max_tokens,
982
+ # image_paths=image_paths,
983
+ )):
984
+ if isinstance(outputs, tuple):
985
+ response, num_tokens = outputs
986
+ else:
987
+ response, num_tokens = outputs, -1
988
+ yield response, num_tokens
989
+
990
+ print(format_conversation(history + [[None, response]]))
991
+
992
+ if response is not None:
993
+ yield response, num_tokens
994
+
995
+
996
+
997
+
998
+ def vision_doc_chat_response_stream_multiturn_engine(
999
+ history: List[Tuple[str, str]],
1000
+ temperature: float,
1001
+ max_tokens: int,
1002
+ system_prompt: Optional[str] = SYSTEM_PROMPT,
1003
+ image_token: Optional[str] = IMAGE_TOKEN,
1004
+ doc_instruction: Optional[str] = DOC_INSTRUCTION,
1005
+ ):
1006
+ global MODEL_ENGINE
1007
+ temperature = float(temperature)
1008
+ # ! remove frequency_penalty
1009
+ # frequency_penalty = float(frequency_penalty)
1010
+ max_tokens = int(max_tokens)
1011
+ # ! skip safety
1012
+ if DATETIME_FORMAT in system_prompt:
1013
+ # ! This sometime works sometimes dont
1014
+ system_prompt = system_prompt.format(cur_datetime=get_datetime_string())
1015
+ # ! history now can have multimodal
1016
+
1017
+ full_prompt, image_paths, conversations = gradio_history_to_vision_doc_conversation_prompt_paths(
1018
+ history=history, system_prompt=system_prompt, image_token=image_token, doc_instruction=doc_instruction
1019
+ )
1020
+
1021
+ # ! length check
1022
+ if hasattr(MODEL_ENGINE, "get_multimodal_tokens"):
1023
+ num_tokens = MODEL_ENGINE.get_multimodal_tokens(full_prompt, image_paths=image_paths)
1024
+ else:
1025
+ num_tokens = len(MODEL_ENGINE.tokenizer.encode(full_prompt))
1026
+ if num_tokens >= MODEL_ENGINE.max_position_embeddings - 128:
1027
+ raise gr.Error(f"Conversation or prompt is too long ({num_tokens} toks), please clear the chatbox or try shorter input.")
1028
+
1029
+ print(full_prompt)
1030
+ print(f'{image_paths=}')
1031
+ outputs = None
1032
+ response = None
1033
+ num_tokens = -1
1034
+ for j, outputs in enumerate(MODEL_ENGINE.generate_yield_string(
1035
+ prompt=full_prompt,
1036
+ temperature=temperature,
1037
+ max_tokens=max_tokens,
1038
+ image_paths=image_paths,
1039
+ )):
1040
+ if isinstance(outputs, tuple):
1041
+ response, num_tokens = outputs
1042
+ else:
1043
+ response, num_tokens = outputs, -1
1044
+ yield response, num_tokens
1045
+
1046
+ print(format_conversation(history + [[None, response]]))
1047
+
1048
+ if response is not None:
1049
+ yield response, num_tokens
1050
+
1051
+
1052
+
1053
+ @register_demo
1054
+ class VisionChatInterfaceDemo(ChatInterfaceDemo):
1055
+ """
1056
+ Accept vision image
1057
+ """
1058
+
1059
+ @property
1060
+ def tab_name(self):
1061
+ return "Vision Chat"
1062
+
1063
+ @property
1064
+ def examples(self):
1065
+ return [
1066
+ ["What's strange about this image?", "assets/dog_monalisa.jpeg",],
1067
+ ["Explain why the sky is blue.", None,],
1068
+ ]
1069
+
1070
+ def create_demo(
1071
+ self,
1072
+ title: str | None = None,
1073
+ description: str | None = None,
1074
+ **kwargs
1075
+ ) -> gr.Blocks:
1076
+ system_prompt = kwargs.get("system_prompt", SYSTEM_PROMPT)
1077
+ max_tokens = kwargs.get("max_tokens", MAX_TOKENS)
1078
+ temperature = kwargs.get("temperature", TEMPERATURE)
1079
+ model_name = kwargs.get("model_name", MODEL_NAME)
1080
+ description = description or """Upload an image to ask question about it."""
1081
+
1082
+ def add_multimodal_fn() -> List[Component]:
1083
+ image_input = gr.Image(label="Input Image", type="filepath", )
1084
+ return [image_input]
1085
+
1086
+ additional_inputs = [
1087
+ gr.Number(value=temperature, label='Temperature', min_width=20),
1088
+ gr.Number(value=max_tokens, label='Max-tokens', min_width=20),
1089
+ gr.Textbox(value=system_prompt, label='System prompt', lines=1),
1090
+ gr.Textbox(value=IMAGE_TOKEN, label='Visual token', lines=1, interactive=IMAGE_TOKEN_INTERACTIVE, min_width=20),
1091
+ ]
1092
+ def render_additional_inputs_fn():
1093
+ with Row():
1094
+ additional_inputs[0].render()
1095
+ additional_inputs[1].render()
1096
+ additional_inputs[3].render()
1097
+ additional_inputs[2].render()
1098
+
1099
+ demo_chat = MultiModalChatInterface(
1100
+ vision_chat_response_stream_multiturn_engine,
1101
+ chatbot=gr.Chatbot(
1102
+ label=model_name,
1103
+ bubble_full_width=False,
1104
+ latex_delimiters=[
1105
+ { "left": "$", "right": "$", "display": False},
1106
+ { "left": "$$", "right": "$$", "display": True},
1107
+ ],
1108
+ show_copy_button=True,
1109
+ layout="panel" if USE_PANEL else "bubble",
1110
+ height=CHATBOT_HEIGHT,
1111
+ ),
1112
+ # textbox=gr.Textbox(placeholder='Type message', lines=4, max_lines=128, min_width=200),
1113
+ textbox=gr.Textbox(placeholder='Type message', lines=1, max_lines=128, min_width=200, scale=8),
1114
+ submit_btn=gr.Button(value='Submit', variant="primary", scale=0),
1115
+ # ! consider preventing the stop button
1116
+ # stop_btn=None,
1117
+ add_multimodal_fn=add_multimodal_fn,
1118
+ title=title,
1119
+ description=description,
1120
+ additional_inputs=additional_inputs,
1121
+ render_additional_inputs_fn=render_additional_inputs_fn,
1122
+ additional_inputs_accordion=gr.Accordion("Additional Inputs", open=True),
1123
+ examples=self.examples,
1124
+ cache_examples=False,
1125
+ css=CSS,
1126
+ )
1127
+ return demo_chat
1128
+
1129
+
1130
+ def add_document_upload():
1131
+ file_input = gr.File(label='Upload pdf, docx, txt', file_count='single', file_types=['pdf', 'docx', 'txt'])
1132
+ # ! Some platform has problems with gr.File, so use uploadbutton instead
1133
+ # with Group():
1134
+ # file_input = gr.Textbox(value=None, label='Document path', lines=1, interactive=False)
1135
+ # upload_button = gr.UploadButton("Click to Upload document", file_types=['pdf', 'docx', 'txt'], file_count="single")
1136
+ # upload_button.upload(lambda x: x.name, upload_button, file_input)
1137
+ return file_input
1138
+
1139
+
1140
+ @register_demo
1141
+ class DocChatInterfaceDemo(ChatInterfaceDemo):
1142
+ """
1143
+ Accept document (full length no RAG)
1144
+ """
1145
+ @property
1146
+ def tab_name(self):
1147
+ return "Doc Chat"
1148
+
1149
+ @property
1150
+ def examples(self):
1151
+ return [
1152
+ ["Summarize the document", "assets/attention_short.pdf",],
1153
+ ["Explain why the sky is blue.", None,],
1154
+ ]
1155
+
1156
+ def create_demo(
1157
+ self,
1158
+ title: str | None = None,
1159
+ description: str | None = None,
1160
+ **kwargs
1161
+ ) -> gr.Blocks:
1162
+ system_prompt = kwargs.get("system_prompt", SYSTEM_PROMPT)
1163
+ max_tokens = kwargs.get("max_tokens", MAX_TOKENS)
1164
+ temperature = kwargs.get("temperature", TEMPERATURE)
1165
+ model_name = kwargs.get("model_name", MODEL_NAME)
1166
+ # frequence_penalty = FREQUENCE_PENALTY
1167
+ # presence_penalty = PRESENCE_PENALTY
1168
+ description = description or """Upload a short document to ask question about it."""
1169
+
1170
+ def add_multimodal_fn() -> List[Component]:
1171
+ file_input = add_document_upload()
1172
+ # image_input = gr.Image(label="Input Image", type="filepath", )
1173
+ return [file_input]
1174
+
1175
+ additional_inputs = [
1176
+ gr.Number(value=temperature, label='Temperature', min_width=20),
1177
+ gr.Number(value=max_tokens, label='Max-tokens', min_width=20),
1178
+ gr.Textbox(value=system_prompt, label='System prompt', lines=1),
1179
+ gr.Textbox(value=DOC_INSTRUCTION, label='Doc instruction', lines=1),
1180
+ ]
1181
+ def render_additional_inputs_fn():
1182
+ with Row():
1183
+ additional_inputs[0].render()
1184
+ additional_inputs[1].render()
1185
+ additional_inputs[2].render()
1186
+ additional_inputs[3].render()
1187
+
1188
+ demo_chat = MultiModalChatInterface(
1189
+ doc_chat_response_stream_multiturn_engine,
1190
+ chatbot=gr.Chatbot(
1191
+ label=model_name,
1192
+ bubble_full_width=False,
1193
+ latex_delimiters=[
1194
+ { "left": "$", "right": "$", "display": False},
1195
+ { "left": "$$", "right": "$$", "display": True},
1196
+ ],
1197
+ show_copy_button=True,
1198
+ layout="panel" if USE_PANEL else "bubble",
1199
+ height=CHATBOT_HEIGHT,
1200
+ ),
1201
+ textbox=gr.Textbox(placeholder='Type message', lines=1, max_lines=128, min_width=200, scale=8),
1202
+ submit_btn=gr.Button(value='Submit', variant="primary", scale=0),
1203
+ # ! consider preventing the stop button
1204
+ add_multimodal_fn=add_multimodal_fn,
1205
+ title=title,
1206
+ description=description,
1207
+ additional_inputs=additional_inputs,
1208
+ render_additional_inputs_fn=render_additional_inputs_fn,
1209
+ additional_inputs_accordion=gr.Accordion("Additional Inputs", open=True),
1210
+ examples=self.examples,
1211
+ cache_examples=False,
1212
+ css=CSS,
1213
+ )
1214
+ return demo_chat
1215
+
1216
+
1217
+ @register_demo
1218
+ class VisionDocChatInterfaceDemo(ChatInterfaceDemo):
1219
+ """
1220
+ Accept either vision image or document (full length no RAG)
1221
+ """
1222
+ @property
1223
+ def tab_name(self):
1224
+ return "Vision Doc Chat"
1225
+
1226
+ @property
1227
+ def examples(self):
1228
+ return [
1229
+ ["What's strange about this image?", None, "assets/dog_monalisa.jpeg",],
1230
+ ["Summarize the document", "assets/attention_short.pdf", None,],
1231
+ ["Explain why the sky is blue.", None, None],
1232
+ ]
1233
+
1234
+ def create_demo(
1235
+ self,
1236
+ title: str | None = None,
1237
+ description: str | None = None,
1238
+ **kwargs
1239
+ ) -> gr.Blocks:
1240
+ system_prompt = kwargs.get("system_prompt", SYSTEM_PROMPT)
1241
+ max_tokens = kwargs.get("max_tokens", MAX_TOKENS)
1242
+ temperature = kwargs.get("temperature", TEMPERATURE)
1243
+ model_name = kwargs.get("model_name", MODEL_NAME)
1244
+ # frequence_penalty = FREQUENCE_PENALTY
1245
+ # presence_penalty = PRESENCE_PENALTY
1246
+ description = description or """Upload either an image or short document to ask question about it."""
1247
+
1248
+ def add_multimodal_fn() -> List[Component]:
1249
+ file_input = add_document_upload()
1250
+ image_input = gr.Image(label="Input Image", type="filepath", )
1251
+ return [file_input, image_input]
1252
+
1253
+ additional_inputs = [
1254
+ gr.Number(value=temperature, label='Temperature', min_width=20),
1255
+ gr.Number(value=max_tokens, label='Max-tokens', min_width=20),
1256
+ gr.Textbox(value=system_prompt, label='System prompt', lines=1),
1257
+ gr.Textbox(value=IMAGE_TOKEN, label='Visual token', lines=1, interactive=IMAGE_TOKEN_INTERACTIVE, min_width=2),
1258
+ gr.Textbox(value=DOC_INSTRUCTION, label='Doc instruction', lines=1),
1259
+ ]
1260
+ def render_additional_inputs_fn():
1261
+ with Row():
1262
+ additional_inputs[0].render()
1263
+ additional_inputs[1].render()
1264
+ additional_inputs[3].render()
1265
+ additional_inputs[2].render()
1266
+ additional_inputs[4].render()
1267
+
1268
+ demo_chat = MultiModalChatInterface(
1269
+ vision_doc_chat_response_stream_multiturn_engine,
1270
+ chatbot=gr.Chatbot(
1271
+ label=MODEL_NAME,
1272
+ bubble_full_width=False,
1273
+ latex_delimiters=[
1274
+ { "left": "$", "right": "$", "display": False},
1275
+ { "left": "$$", "right": "$$", "display": True},
1276
+ ],
1277
+ show_copy_button=True,
1278
+ layout="panel" if USE_PANEL else "bubble",
1279
+ height=CHATBOT_HEIGHT,
1280
+ ),
1281
+ textbox=gr.Textbox(placeholder='Type message', lines=1, max_lines=128, min_width=200, scale=8),
1282
+ submit_btn=gr.Button(value='Submit', variant="primary", scale=0),
1283
+ add_multimodal_fn=add_multimodal_fn,
1284
+ title=title,
1285
+ description=description,
1286
+ additional_inputs=additional_inputs,
1287
+ render_additional_inputs_fn=render_additional_inputs_fn,
1288
+ additional_inputs_accordion=gr.Accordion("Additional Inputs", open=True),
1289
+ examples=self.examples,
1290
+ cache_examples=False,
1291
+ css=CSS,
1292
+ )
1293
+ return demo_chat
multipurpose_chatbot/demos/rag_chat_interface.py ADDED
@@ -0,0 +1,642 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from gradio.themes import ThemeClass as Theme
3
+ import numpy as np
4
+ import argparse
5
+ import gradio as gr
6
+ from typing import Any, Iterator
7
+ from typing import Iterator, List, Optional, Tuple
8
+ import filelock
9
+ import glob
10
+ import json
11
+ import time
12
+ from gradio.routes import Request
13
+ from gradio.utils import SyncToAsyncIterator, async_iteration
14
+ from gradio.helpers import special_args
15
+ import anyio
16
+ from typing import AsyncGenerator, Callable, Literal, Union, cast, Generator
17
+
18
+ from gradio_client.documentation import document, set_documentation_group
19
+ from gradio.components import Button, Component
20
+ from gradio.events import Dependency, EventListenerMethod
21
+ from typing import List, Optional, Union, Dict, Tuple
22
+ from tqdm.auto import tqdm
23
+ from huggingface_hub import snapshot_download
24
+ from gradio.themes import ThemeClass as Theme
25
+
26
+ from .base_demo import register_demo, get_demo_class, BaseDemo
27
+
28
+ import inspect
29
+ from typing import AsyncGenerator, Callable, Literal, Union, cast
30
+
31
+ import anyio
32
+ from gradio_client import utils as client_utils
33
+ from gradio_client.documentation import document
34
+
35
+ from gradio.blocks import Blocks
36
+ from gradio.components import (
37
+ Button,
38
+ Chatbot,
39
+ Component,
40
+ Markdown,
41
+ State,
42
+ Textbox,
43
+ get_component_instance,
44
+ )
45
+ from gradio.events import Dependency, on
46
+ from gradio.helpers import create_examples as Examples # noqa: N812
47
+ from gradio.helpers import special_args
48
+ from gradio.layouts import Accordion, Group, Row
49
+ from gradio.routes import Request
50
+ from gradio.themes import ThemeClass as Theme
51
+ from gradio.utils import SyncToAsyncIterator, async_iteration
52
+
53
+
54
+ from ..globals import MODEL_ENGINE, RAG_CURRENT_FILE, RAG_EMBED, load_embeddings, get_rag_embeddings
55
+
56
+ from .chat_interface import (
57
+ SYSTEM_PROMPT,
58
+ MODEL_NAME,
59
+ MAX_TOKENS,
60
+ TEMPERATURE,
61
+ CHAT_EXAMPLES,
62
+ gradio_history_to_openai_conversations,
63
+ gradio_history_to_conversation_prompt,
64
+ DATETIME_FORMAT,
65
+ get_datetime_string,
66
+ format_conversation,
67
+ chat_response_stream_multiturn_engine,
68
+ ChatInterfaceDemo,
69
+ CustomizedChatInterface,
70
+ )
71
+
72
+ from ..configs import (
73
+ CHUNK_SIZE,
74
+ CHUNK_OVERLAP,
75
+ RAG_EMBED_MODEL_NAME,
76
+ )
77
+
78
+ RAG_CURRENT_VECTORSTORE = None
79
+
80
+
81
+ def load_document_split_vectorstore(file_path):
82
+ global RAG_CURRENT_FILE, RAG_EMBED, RAG_CURRENT_VECTORSTORE
83
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
84
+ from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceBgeEmbeddings
85
+ from langchain_community.vectorstores import Chroma, FAISS
86
+ from langchain_community.document_loaders import PyPDFLoader, Docx2txtLoader, TextLoader
87
+ splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP)
88
+ if file_path.endswith('.pdf'):
89
+ loader = PyPDFLoader(file_path)
90
+ elif file_path.endswith('.docx'):
91
+ loader = Docx2txtLoader(file_path)
92
+ elif file_path.endswith('.txt'):
93
+ loader = TextLoader(file_path)
94
+ splits = loader.load_and_split(splitter)
95
+ RAG_CURRENT_VECTORSTORE = FAISS.from_texts(texts=[s.page_content for s in splits], embedding=get_rag_embeddings())
96
+ return RAG_CURRENT_VECTORSTORE
97
+
98
+ def docs_to_context_content(docs: List[Any]):
99
+ content = "\n".join([d.page_content for d in docs])
100
+ return content
101
+
102
+
103
+ DOC_TEMPLATE = """###
104
+ {content}
105
+ ###
106
+
107
+ """
108
+
109
+ DOC_INSTRUCTION = """Answer the following query exclusively based on the information provided in the document above. \
110
+ If the information is not found, please say so instead of making up facts! Remember to answer the question in the same language as the user query!
111
+ """
112
+
113
+
114
+ def docs_to_rag_context(docs: List[Any], doc_instruction=None):
115
+ doc_instruction = doc_instruction or DOC_INSTRUCTION
116
+ content = docs_to_context_content(docs)
117
+ context = doc_instruction.strip() + "\n" + DOC_TEMPLATE.format(content=content)
118
+ return context
119
+
120
+
121
+ def maybe_get_doc_context(message, file_input, rag_num_docs: Optional[int] = 3):
122
+ doc_context = None
123
+ if file_input is not None:
124
+ if file_input == RAG_CURRENT_FILE:
125
+ # reuse
126
+ vectorstore = RAG_CURRENT_VECTORSTORE
127
+ print(f'Reuse vectorstore: {file_input}')
128
+ else:
129
+ vectorstore = load_document_split_vectorstore(file_input)
130
+ print(f'New vectorstore: {RAG_CURRENT_FILE} {file_input}')
131
+ RAG_CURRENT_FILE = file_input
132
+ docs = vectorstore.similarity_search(message, k=rag_num_docs)
133
+ doc_context = docs_to_rag_context(docs)
134
+ return doc_context
135
+
136
+
137
+ def chat_response_stream_multiturn_doc_engine(
138
+ message: str,
139
+ history: List[Tuple[str, str]],
140
+ file_input: Optional[str] = None,
141
+ temperature: float = 0.7,
142
+ max_tokens: int = 1024,
143
+ system_prompt: Optional[str] = SYSTEM_PROMPT,
144
+ rag_num_docs: Optional[int] = 3,
145
+ doc_instruction: Optional[str] = DOC_INSTRUCTION,
146
+ # profile: Optional[gr.OAuthProfile] = None,
147
+ ):
148
+ global MODEL_ENGINE, RAG_CURRENT_FILE, RAG_EMBED, RAG_CURRENT_VECTORSTORE
149
+ if len(message) == 0:
150
+ raise gr.Error("The message cannot be empty!")
151
+
152
+ rag_num_docs = int(rag_num_docs)
153
+ doc_instruction = doc_instruction or DOC_INSTRUCTION
154
+ doc_context = None
155
+ if file_input is not None:
156
+ if file_input == RAG_CURRENT_FILE:
157
+ # reuse
158
+ vectorstore = RAG_CURRENT_VECTORSTORE
159
+ print(f'Reuse vectorstore: {file_input}')
160
+ else:
161
+ vectorstore = load_document_split_vectorstore(file_input)
162
+ print(f'New vectorstore: {RAG_CURRENT_FILE} {file_input}')
163
+ RAG_CURRENT_FILE = file_input
164
+ docs = vectorstore.similarity_search(message, k=rag_num_docs)
165
+ # doc_context = docs_to_rag_context(docs)
166
+ rag_content = docs_to_context_content(docs)
167
+ doc_context = doc_instruction.strip() + "\n" + DOC_TEMPLATE.format(content=rag_content)
168
+
169
+ if doc_context is not None:
170
+ message = f"{doc_context}\n\n{message}"
171
+
172
+ for response, num_tokens in chat_response_stream_multiturn_engine(
173
+ message, history, temperature, max_tokens, system_prompt
174
+ ):
175
+ # ! yield another content which is doc_context
176
+ yield response, num_tokens, doc_context
177
+
178
+
179
+
180
+ class RagChatInterface(CustomizedChatInterface):
181
+ def __init__(
182
+ self,
183
+ fn: Callable[..., Any],
184
+ *,
185
+ chatbot: gr.Chatbot | None = None,
186
+ textbox: gr.Textbox | None = None,
187
+ additional_inputs: str | Component | list[str | Component] | None = None,
188
+ additional_inputs_accordion_name: str | None = None,
189
+ additional_inputs_accordion: str | gr.Accordion | None = None,
190
+ render_additional_inputs_fn: Callable | None = None,
191
+ examples: list[str] | None = None,
192
+ cache_examples: bool | None = None,
193
+ title: str | None = None,
194
+ description: str | None = None,
195
+ theme: Theme | str | None = None,
196
+ css: str | None = None,
197
+ js: str | None = None,
198
+ head: str | None = None,
199
+ analytics_enabled: bool | None = None,
200
+ submit_btn: str | Button | None = "Submit",
201
+ stop_btn: str | Button | None = "Stop",
202
+ retry_btn: str | Button | None = "🔄 Retry",
203
+ undo_btn: str | Button | None = "↩️ Undo",
204
+ clear_btn: str | Button | None = "🗑️ Clear",
205
+ autofocus: bool = True,
206
+ concurrency_limit: int | Literal['default'] | None = "default",
207
+ fill_height: bool = True
208
+ ):
209
+ try:
210
+ super(gr.ChatInterface, self).__init__(
211
+ analytics_enabled=analytics_enabled,
212
+ mode="chat_interface",
213
+ css=css,
214
+ title=title or "Gradio",
215
+ theme=theme,
216
+ js=js,
217
+ head=head,
218
+ fill_height=fill_height,
219
+ )
220
+ except Exception as e:
221
+ # Handling some old gradio version with out fill_height
222
+ super(gr.ChatInterface, self).__init__(
223
+ analytics_enabled=analytics_enabled,
224
+ mode="chat_interface",
225
+ css=css,
226
+ title=title or "Gradio",
227
+ theme=theme,
228
+ js=js,
229
+ head=head,
230
+ # fill_height=fill_height,
231
+ )
232
+ self.concurrency_limit = concurrency_limit
233
+ self.fn = fn
234
+ self.render_additional_inputs_fn = render_additional_inputs_fn
235
+ self.is_async = inspect.iscoroutinefunction(
236
+ self.fn
237
+ ) or inspect.isasyncgenfunction(self.fn)
238
+ self.is_generator = inspect.isgeneratorfunction(
239
+ self.fn
240
+ ) or inspect.isasyncgenfunction(self.fn)
241
+ self.examples = examples
242
+ if self.space_id and cache_examples is None:
243
+ self.cache_examples = True
244
+ else:
245
+ self.cache_examples = cache_examples or False
246
+ self.buttons: list[Button | None] = []
247
+
248
+ if additional_inputs:
249
+ if not isinstance(additional_inputs, list):
250
+ additional_inputs = [additional_inputs]
251
+ self.additional_inputs = [
252
+ get_component_instance(i)
253
+ for i in additional_inputs # type: ignore
254
+ ]
255
+ else:
256
+ self.additional_inputs = []
257
+ if additional_inputs_accordion_name is not None:
258
+ print(
259
+ "The `additional_inputs_accordion_name` parameter is deprecated and will be removed in a future version of Gradio. Use the `additional_inputs_accordion` parameter instead."
260
+ )
261
+ self.additional_inputs_accordion_params = {
262
+ "label": additional_inputs_accordion_name
263
+ }
264
+ if additional_inputs_accordion is None:
265
+ self.additional_inputs_accordion_params = {
266
+ "label": "Additional Inputs",
267
+ "open": False,
268
+ }
269
+ elif isinstance(additional_inputs_accordion, str):
270
+ self.additional_inputs_accordion_params = {
271
+ "label": additional_inputs_accordion
272
+ }
273
+ elif isinstance(additional_inputs_accordion, Accordion):
274
+ self.additional_inputs_accordion_params = (
275
+ additional_inputs_accordion.recover_kwargs(
276
+ additional_inputs_accordion.get_config()
277
+ )
278
+ )
279
+ else:
280
+ raise ValueError(
281
+ f"The `additional_inputs_accordion` parameter must be a string or gr.Accordion, not {type(additional_inputs_accordion)}"
282
+ )
283
+
284
+ with self:
285
+ if title:
286
+ Markdown(
287
+ f"<h1 style='text-align: center; margin-bottom: 1rem'>{self.title}</h1>"
288
+ )
289
+ if description:
290
+ Markdown(description)
291
+
292
+ if chatbot:
293
+ self.chatbot = chatbot.render()
294
+ else:
295
+ self.chatbot = Chatbot(
296
+ label="Chatbot", scale=1, height=200 if fill_height else None
297
+ )
298
+
299
+ with Row():
300
+ for btn in [retry_btn, undo_btn, clear_btn]:
301
+ if btn is not None:
302
+ if isinstance(btn, Button):
303
+ btn.render()
304
+ elif isinstance(btn, str):
305
+ btn = Button(btn, variant="secondary", size="sm")
306
+ else:
307
+ raise ValueError(
308
+ f"All the _btn parameters must be a gr.Button, string, or None, not {type(btn)}"
309
+ )
310
+ self.buttons.append(btn) # type: ignore
311
+
312
+ with Group():
313
+ with Row():
314
+ if textbox:
315
+ textbox.container = False
316
+ textbox.show_label = False
317
+ textbox_ = textbox.render()
318
+ assert isinstance(textbox_, Textbox)
319
+ self.textbox = textbox_
320
+ else:
321
+ self.textbox = Textbox(
322
+ container=False,
323
+ show_label=False,
324
+ label="Message",
325
+ placeholder="Type a message...",
326
+ scale=7,
327
+ autofocus=autofocus,
328
+ )
329
+ if submit_btn is not None:
330
+ if isinstance(submit_btn, Button):
331
+ submit_btn.render()
332
+ elif isinstance(submit_btn, str):
333
+ submit_btn = Button(
334
+ submit_btn,
335
+ variant="primary",
336
+ scale=2,
337
+ min_width=150,
338
+ )
339
+ else:
340
+ raise ValueError(
341
+ f"The submit_btn parameter must be a gr.Button, string, or None, not {type(submit_btn)}"
342
+ )
343
+ if stop_btn is not None:
344
+ if isinstance(stop_btn, Button):
345
+ stop_btn.visible = False
346
+ stop_btn.render()
347
+ elif isinstance(stop_btn, str):
348
+ stop_btn = Button(
349
+ stop_btn,
350
+ variant="stop",
351
+ visible=False,
352
+ scale=2,
353
+ min_width=150,
354
+ )
355
+ else:
356
+ raise ValueError(
357
+ f"The stop_btn parameter must be a gr.Button, string, or None, not {type(stop_btn)}"
358
+ )
359
+ self.num_tokens = Textbox(
360
+ container=False,
361
+ label="num_tokens",
362
+ placeholder="0 tokens",
363
+ scale=1,
364
+ interactive=False,
365
+ # autofocus=autofocus,
366
+ min_width=10
367
+ )
368
+ self.buttons.extend([submit_btn, stop_btn]) # type: ignore
369
+
370
+ self.fake_api_btn = Button("Fake API", visible=False)
371
+ self.fake_response_textbox = Textbox(label="Response", visible=False)
372
+ (
373
+ self.retry_btn,
374
+ self.undo_btn,
375
+ self.clear_btn,
376
+ self.submit_btn,
377
+ self.stop_btn,
378
+ ) = self.buttons
379
+
380
+ if examples:
381
+ if self.is_generator:
382
+ examples_fn = self._examples_stream_fn
383
+ else:
384
+ examples_fn = self._examples_fn
385
+
386
+ self.examples_handler = Examples(
387
+ examples=examples,
388
+ inputs=[self.textbox] + self.additional_inputs,
389
+ outputs=self.chatbot,
390
+ fn=examples_fn,
391
+ )
392
+
393
+ any_unrendered_inputs = any(
394
+ not inp.is_rendered for inp in self.additional_inputs
395
+ )
396
+ if self.additional_inputs and any_unrendered_inputs:
397
+ with Accordion(**self.additional_inputs_accordion_params): # type: ignore
398
+ if self.render_additional_inputs_fn is not None:
399
+ self.render_additional_inputs_fn()
400
+ else:
401
+ for input_component in self.additional_inputs:
402
+ if not input_component.is_rendered:
403
+ input_component.render()
404
+
405
+ self.rag_content = gr.Textbox(
406
+ scale=4,
407
+ lines=16,
408
+ label='Retrieved RAG context',
409
+ placeholder="Rag context and instrution will show up here",
410
+ interactive=False
411
+ )
412
+
413
+ # The example caching must happen after the input components have rendered
414
+ if cache_examples:
415
+ client_utils.synchronize_async(self.examples_handler.cache)
416
+
417
+ self.saved_input = State()
418
+ self.chatbot_state = (
419
+ State(self.chatbot.value) if self.chatbot.value else State([])
420
+ )
421
+
422
+ self._setup_events()
423
+ self._setup_api()
424
+
425
+ def _setup_events(self) -> None:
426
+ from gradio.components import State
427
+ has_on = False
428
+ try:
429
+ from gradio.events import Dependency, EventListenerMethod, on
430
+ has_on = True
431
+ except ImportError as ie:
432
+ has_on = False
433
+ submit_fn = self._stream_fn if self.is_generator else self._submit_fn
434
+ if not self.is_generator:
435
+ raise NotImplementedError(f'should use generator')
436
+
437
+ if has_on:
438
+ # new version
439
+ submit_triggers = (
440
+ [self.textbox.submit, self.submit_btn.click]
441
+ if self.submit_btn
442
+ else [self.textbox.submit]
443
+ )
444
+ submit_event = (
445
+ on(
446
+ submit_triggers,
447
+ self._clear_and_save_textbox,
448
+ [self.textbox],
449
+ [self.textbox, self.saved_input],
450
+ api_name=False,
451
+ queue=False,
452
+ )
453
+ .then(
454
+ self._display_input,
455
+ [self.saved_input, self.chatbot_state],
456
+ [self.chatbot, self.chatbot_state],
457
+ api_name=False,
458
+ queue=False,
459
+ )
460
+ .then(
461
+ submit_fn,
462
+ [self.saved_input, self.chatbot_state] + self.additional_inputs,
463
+ [self.chatbot, self.chatbot_state, self.num_tokens, self.rag_content],
464
+ api_name=False,
465
+ )
466
+ )
467
+ self._setup_stop_events(submit_triggers, submit_event)
468
+ else:
469
+ raise ValueError(f'Better install new gradio version than 3.44.0')
470
+
471
+ if self.retry_btn:
472
+ retry_event = (
473
+ self.retry_btn.click(
474
+ self._delete_prev_fn,
475
+ [self.chatbot_state],
476
+ [self.chatbot, self.saved_input, self.chatbot_state],
477
+ api_name=False,
478
+ queue=False,
479
+ )
480
+ .then(
481
+ self._display_input,
482
+ [self.saved_input, self.chatbot_state],
483
+ [self.chatbot, self.chatbot_state],
484
+ api_name=False,
485
+ queue=False,
486
+ )
487
+ .then(
488
+ submit_fn,
489
+ [self.saved_input, self.chatbot_state] + self.additional_inputs,
490
+ [self.chatbot, self.chatbot_state, self.num_tokens, self.rag_content],
491
+ api_name=False,
492
+ )
493
+ )
494
+ self._setup_stop_events([self.retry_btn.click], retry_event)
495
+
496
+ if self.undo_btn:
497
+ self.undo_btn.click(
498
+ self._delete_prev_fn,
499
+ [self.chatbot_state],
500
+ [self.chatbot, self.saved_input, self.chatbot_state],
501
+ api_name=False,
502
+ queue=False,
503
+ ).then(
504
+ lambda x: x,
505
+ [self.saved_input],
506
+ [self.textbox],
507
+ api_name=False,
508
+ queue=False,
509
+ )
510
+ # Reconfigure clear_btn to stop and clear text box
511
+
512
+ async def _stream_fn(
513
+ self,
514
+ message: str,
515
+ history_with_input,
516
+ request: Request,
517
+ *args,
518
+ ) -> AsyncGenerator:
519
+ history = history_with_input[:-1]
520
+ inputs, _, _ = special_args(
521
+ self.fn, inputs=[message, history, *args], request=request
522
+ )
523
+
524
+ if self.is_async:
525
+ generator = self.fn(*inputs)
526
+ else:
527
+ generator = await anyio.to_thread.run_sync(
528
+ self.fn, *inputs, limiter=self.limiter
529
+ )
530
+ generator = SyncToAsyncIterator(generator, self.limiter)
531
+
532
+ # ! In case of error, yield the previous history & undo any generation before raising error
533
+ try:
534
+ first_response_pack = await async_iteration(generator)
535
+ if isinstance(first_response_pack, (tuple, list)):
536
+ first_response, num_tokens, rag_content = first_response_pack
537
+ else:
538
+ first_response, num_tokens, rag_content = first_response_pack, -1, ""
539
+ update = history + [[message, first_response]]
540
+ yield update, update, f"{num_tokens} toks", rag_content
541
+ except StopIteration:
542
+ update = history + [[message, None]]
543
+ yield update, update, "NaN toks", ""
544
+ except Exception as e:
545
+ yield history, history, "NaN toks", ""
546
+ raise e
547
+
548
+ try:
549
+ async for response_pack in generator:
550
+ if isinstance(response_pack, (tuple, list)):
551
+ response, num_tokens, rag_content = response_pack
552
+ else:
553
+ response, num_tokens, rag_content = response_pack, "NaN toks", ""
554
+ update = history + [[message, response]]
555
+ yield update, update, f"{num_tokens} toks", rag_content
556
+ except Exception as e:
557
+ yield history, history, "NaN toks", ""
558
+ raise e
559
+
560
+
561
+
562
+ @register_demo
563
+ class RagChatInterfaceDemo(ChatInterfaceDemo):
564
+
565
+ @property
566
+ def examples(self):
567
+ return [
568
+ ["Explain how attention works.", "assets/attention_all_you_need.pdf"],
569
+ ["Explain why the sky is blue.", None],
570
+ ]
571
+
572
+ @property
573
+ def tab_name(self):
574
+ return "RAG Chat"
575
+
576
+ def create_demo(
577
+ self,
578
+ title: str | None = None,
579
+ description: str | None = None,
580
+ **kwargs
581
+ ) -> gr.Blocks:
582
+ load_embeddings()
583
+ global RAG_EMBED
584
+ # assert RAG_EMBED is not None
585
+ print(F'{RAG_EMBED=}')
586
+ system_prompt = kwargs.get("system_prompt", SYSTEM_PROMPT)
587
+ max_tokens = kwargs.get("max_tokens", MAX_TOKENS)
588
+ temperature = kwargs.get("temperature", TEMPERATURE)
589
+ model_name = kwargs.get("model_name", MODEL_NAME)
590
+ rag_num_docs = kwargs.get("rag_num_docs", 3)
591
+
592
+ from ..configs import RAG_EMBED_MODEL_NAME
593
+
594
+ description = (
595
+ description or
596
+ f"""Upload a long document to ask question with RAG. Check at the bottom the retrieved RAG text segment.
597
+ Control `RAG instruction to fit your language`. Embedding model {RAG_EMBED_MODEL_NAME}."""
598
+ )
599
+
600
+ additional_inputs = [
601
+ gr.File(label='Upload Document', file_count='single', file_types=['pdf', 'docx', 'txt']),
602
+ gr.Number(value=temperature, label='Temperature', min_width=20),
603
+ gr.Number(value=max_tokens, label='Max tokens', min_width=20),
604
+ gr.Textbox(value=system_prompt, label='System prompt', lines=2),
605
+ gr.Number(value=rag_num_docs, label='RAG Top-K', min_width=20),
606
+ gr.Textbox(value=DOC_INSTRUCTION, label='RAG instruction'),
607
+ ]
608
+ def render_additional_inputs_fn():
609
+ additional_inputs[0].render()
610
+ with Row():
611
+ additional_inputs[1].render()
612
+ additional_inputs[2].render()
613
+ additional_inputs[4].render()
614
+ additional_inputs[3].render()
615
+ additional_inputs[5].render()
616
+
617
+ demo_chat = RagChatInterface(
618
+ chat_response_stream_multiturn_doc_engine,
619
+ chatbot=gr.Chatbot(
620
+ label=model_name,
621
+ bubble_full_width=False,
622
+ latex_delimiters=[
623
+ { "left": "$", "right": "$", "display": False},
624
+ { "left": "$$", "right": "$$", "display": True},
625
+ ],
626
+ show_copy_button=True,
627
+ ),
628
+ textbox=gr.Textbox(placeholder='Type message', lines=1, max_lines=128, min_width=200, scale=8),
629
+ submit_btn=gr.Button(value='Submit', variant="primary", scale=0),
630
+ # ! consider preventing the stop button
631
+ # stop_btn=None,
632
+ title=title,
633
+ description=description,
634
+ additional_inputs=additional_inputs,
635
+ render_additional_inputs_fn=render_additional_inputs_fn,
636
+ additional_inputs_accordion=gr.Accordion("Additional Inputs", open=True),
637
+ examples=self.examples,
638
+ cache_examples=False,
639
+ )
640
+ return demo_chat
641
+
642
+
multipurpose_chatbot/demos/text_completion.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from gradio.themes import ThemeClass as Theme
3
+ import numpy as np
4
+ import argparse
5
+ import gradio as gr
6
+ from typing import Any, Iterator
7
+ from typing import Iterator, List, Optional, Tuple
8
+ import filelock
9
+ import glob
10
+ import json
11
+ import time
12
+ from gradio.routes import Request
13
+ from gradio.utils import SyncToAsyncIterator, async_iteration
14
+ from gradio.helpers import special_args
15
+ import anyio
16
+ from typing import AsyncGenerator, Callable, Literal, Union, cast, Generator
17
+
18
+ from gradio_client.documentation import document, set_documentation_group
19
+ from gradio.components import Button, Component
20
+ from gradio.events import Dependency, EventListenerMethod
21
+ from typing import List, Optional, Union, Dict, Tuple
22
+ from tqdm.auto import tqdm
23
+ from huggingface_hub import snapshot_download
24
+
25
+
26
+ import inspect
27
+ from typing import AsyncGenerator, Callable, Literal, Union, cast
28
+
29
+ import anyio
30
+ from gradio_client import utils as client_utils
31
+ from gradio_client.documentation import document
32
+
33
+ from gradio.blocks import Blocks
34
+ from gradio.components import (
35
+ Button,
36
+ Chatbot,
37
+ Component,
38
+ Markdown,
39
+ State,
40
+ Textbox,
41
+ get_component_instance,
42
+ )
43
+ from gradio.events import Dependency, on
44
+ from gradio.helpers import create_examples as Examples # noqa: N812
45
+ from gradio.helpers import special_args
46
+ from gradio.layouts import Accordion, Group, Row
47
+ from gradio.routes import Request
48
+ from gradio.themes import ThemeClass as Theme
49
+ from gradio.utils import SyncToAsyncIterator, async_iteration
50
+
51
+
52
+ from .base_demo import register_demo, get_demo_class, BaseDemo
53
+
54
+
55
+ from ..configs import (
56
+ SYSTEM_PROMPT,
57
+ MODEL_NAME,
58
+ MAX_TOKENS,
59
+ TEMPERATURE,
60
+ )
61
+
62
+ from ..globals import MODEL_ENGINE
63
+
64
+
65
+ def generate_text_completion_stream_engine(
66
+ message: str,
67
+ temperature: float,
68
+ max_tokens: int,
69
+ stop_strings: str = '<s>,</s>,<|im_start|>,<|im_end|>',
70
+ ):
71
+ global MODEL_ENGINE
72
+ temperature = float(temperature)
73
+ # ! remove frequency_penalty
74
+ # frequency_penalty = float(frequency_penalty)
75
+ max_tokens = int(max_tokens)
76
+ # message = message.strip()
77
+ stop_strings = [x.strip() for x in stop_strings.strip().split(",")]
78
+ stop_strings = list(set(stop_strings + ['</s>', '<|im_start|>', '<|im_end|>']))
79
+ if message.strip() != message:
80
+ gr.Warning(f'There are preceding/trailing spaces in the message, may lead to unexpected behavior')
81
+ if len(message) == 0:
82
+ raise gr.Error("The message cannot be empty!")
83
+ num_tokens = len(MODEL_ENGINE.tokenizer.encode(message))
84
+ if num_tokens >= MODEL_ENGINE.max_position_embeddings - 128:
85
+ raise gr.Error(f"Conversation or prompt is too long ({num_tokens} toks), please clear the chatbox or try shorter input.")
86
+
87
+ outputs = None
88
+ response = None
89
+ num_tokens = -1
90
+ for j, outputs in enumerate(MODEL_ENGINE.generate_yield_string(
91
+ prompt=message,
92
+ temperature=temperature,
93
+ max_tokens=max_tokens,
94
+ stop_strings=stop_strings,
95
+ )):
96
+ if isinstance(outputs, tuple):
97
+ response, num_tokens = outputs
98
+ else:
99
+ response, num_tokens = outputs, -1
100
+ yield message + response, f"{num_tokens} tokens"
101
+
102
+ if response is not None:
103
+ yield message + response, f"{num_tokens} tokens"
104
+
105
+
106
+ @register_demo
107
+ class TextCompletionDemo(BaseDemo):
108
+ @property
109
+ def tab_name(self):
110
+ return "Text Completion"
111
+
112
+ def create_demo(
113
+ self,
114
+ title: str | None = None,
115
+ description: str | None = None,
116
+ **kwargs
117
+ ) -> gr.Blocks:
118
+ system_prompt = kwargs.get("system_prompt", SYSTEM_PROMPT)
119
+ max_tokens = kwargs.get("max_tokens", MAX_TOKENS)
120
+ temperature = kwargs.get("temperature", TEMPERATURE)
121
+ model_name = kwargs.get("model_name", MODEL_NAME)
122
+ # frequence_penalty = FREQUENCE_PENALTY
123
+ # presence_penalty = PRESENCE_PENALTY
124
+ max_tokens = max_tokens // 2
125
+
126
+ description = description or f"""Put any context string (like few-shot prompts)"""
127
+
128
+ with gr.Blocks() as demo_text_completion:
129
+ if title:
130
+ gr.Markdown(title)
131
+ if description:
132
+ gr.Markdown(description)
133
+ with gr.Row():
134
+ txt = gr.Textbox(
135
+ scale=4,
136
+ lines=16,
137
+ show_label=False,
138
+ placeholder="Enter any free form text and submit",
139
+ container=False,
140
+ )
141
+ with gr.Row():
142
+ submit_button = gr.Button('Submit', variant='primary', scale=9)
143
+ stop_button = gr.Button('Stop', variant='stop', scale=9, visible=False)
144
+ num_tokens = Textbox(
145
+ container=False,
146
+ show_label=False,
147
+ label="num_tokens",
148
+ placeholder="0 tokens",
149
+ scale=1,
150
+ interactive=False,
151
+ min_width=10
152
+ )
153
+ with gr.Row():
154
+ temp_input = gr.Number(value=temperature, label='Temperature', info="Higher -> more random")
155
+ length_input = gr.Number(value=max_tokens, label='Max tokens', info='Increase if want more generation')
156
+ stop_strings = gr.Textbox(value="<s>,</s>,<|im_start|>,<|im_end|>", label='Stop strings', info='Comma-separated string to stop generation only in FEW-SHOT mode', lines=1)
157
+ examples = gr.Examples(
158
+ examples=[
159
+ ["The following is the recite the declaration of independence:",]
160
+ ],
161
+ inputs=[txt, temp_input, length_input, stop_strings],
162
+ # outputs=[txt]
163
+ )
164
+ # ! Handle stop button
165
+ submit_trigger = submit_button.click
166
+ submit_event = submit_button.click(
167
+ # submit_trigger,
168
+ generate_text_completion_stream_engine,
169
+ [txt, temp_input, length_input, stop_strings],
170
+ [txt, num_tokens],
171
+ # api_name=False,
172
+ # queue=False,
173
+ )
174
+
175
+ submit_trigger(
176
+ lambda: (
177
+ Button(visible=False), Button(visible=True),
178
+ ),
179
+ None,
180
+ [submit_button, stop_button],
181
+ api_name=False,
182
+ queue=False,
183
+ )
184
+ submit_event.then(
185
+ lambda: (Button(visible=True), Button(visible=False)),
186
+ None,
187
+ [submit_button, stop_button],
188
+ api_name=False,
189
+ queue=False,
190
+ )
191
+ stop_button.click(
192
+ None,
193
+ None,
194
+ None,
195
+ cancels=submit_event,
196
+ api_name=False,
197
+ )
198
+
199
+ return demo_text_completion
multipurpose_chatbot/engines/.DS_Store ADDED
Binary file (6.15 kB). View file
 
multipurpose_chatbot/engines/__init__.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from .base_engine import BaseEngine
3
+
4
+ BACKENDS = [
5
+ "mlx",
6
+ "vllm",
7
+ "transformers",
8
+ "llava15_transformers",
9
+ "llama_cpp",
10
+ # "llava_llama_cpp",
11
+ "debug",
12
+ ]
13
+
14
+ ENGINE_LOADED = False
15
+
16
+
17
+ def load_multipurpose_chatbot_engine(backend: str):
18
+ # ! lazy import other engines
19
+ global ENGINE_LOADED
20
+ assert backend in BACKENDS, f'{backend} not in {BACKENDS}'
21
+ if ENGINE_LOADED:
22
+ raise RuntimeError(f'{ENGINE_LOADED=} this means load_multipurpose_chatbot_engine has already been called! Check your codes.')
23
+ print(f'Load model from {backend}')
24
+ if backend == "mlx":
25
+ from .mlx_engine import MlxEngine
26
+ model_engine = MlxEngine()
27
+ elif backend == 'vllm':
28
+ from .vllm_engine import VllmEngine
29
+ model_engine = VllmEngine()
30
+ elif backend == 'transformers':
31
+ from .transformers_engine import TransformersEngine
32
+ model_engine = TransformersEngine()
33
+ elif backend == 'llava15_transformers':
34
+ from .llava15_transformers_engine import Llava15TransformersEngine
35
+ model_engine = Llava15TransformersEngine()
36
+ elif backend == 'llama_cpp':
37
+ from .llama_cpp_engine import LlamaCppEngine
38
+ model_engine = LlamaCppEngine()
39
+ # ! llava_llama_cpp currently not done due to bugs
40
+ # elif backend == 'llava_llama_cpp':
41
+ # from .llava_llama_cpp_engine import LlavaLlamaCppEngine
42
+ # model_engine = LlavaLlamaCppEngine()
43
+ elif backend == 'debug':
44
+ from .debug_engine import DebugEngine
45
+ model_engine = DebugEngine()
46
+ else:
47
+ raise ValueError(f'backend invalid: {BACKENDS} vs {backend}')
48
+
49
+ model_engine.load_model()
50
+ ENGINE_LOADED = True
51
+ return model_engine
52
+ # ! add more llama.cpp engine here.
53
+
54
+
multipurpose_chatbot/engines/base_engine.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from huggingface_hub import snapshot_download
4
+ # ! Avoid importing transformers
5
+ # from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
6
+ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
7
+ import time
8
+
9
+
10
+ class BaseEngine(object):
11
+ def __init__(self, **kwargs) -> None:
12
+ pass
13
+
14
+ @property
15
+ def max_position_embeddings(self) -> int:
16
+ return 10000
17
+
18
+ @property
19
+ def tokenizer(self):
20
+ raise NotImplementedError
21
+
22
+ @property
23
+ def processor(self):
24
+ raise NotImplementedError
25
+
26
+ def load_model(self, ):
27
+ raise NotImplementedError
28
+
29
+ def apply_chat_template(self, conversations, add_generation_prompt: bool, add_special_tokens=False, **kwargs) -> str:
30
+ """
31
+ return string convo, add_special_tokens should be added later
32
+ """
33
+ bos_token = self.tokenizer.bos_token
34
+ eos_token = self.tokenizer.eos_token
35
+ if not add_special_tokens:
36
+ # prevent bos being added to string
37
+ self.tokenizer.bos_token = ""
38
+ self.tokenizer.eos_token = ""
39
+ full_prompt = self.tokenizer.apply_chat_template(
40
+ conversations, add_generation_prompt=add_generation_prompt,
41
+ tokenize=False,
42
+ )
43
+ self.tokenizer.bos_token = bos_token
44
+ self.tokenizer.eos_token = eos_token
45
+ return full_prompt
46
+
multipurpose_chatbot/engines/debug_engine.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from huggingface_hub import snapshot_download
4
+ from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
5
+ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
6
+ import time
7
+
8
+ from .base_engine import BaseEngine
9
+
10
+ from ..configs import (
11
+ MODEL_PATH,
12
+ )
13
+
14
+ FAKE_MODEL_PATH = os.environ.get("FAKE_MODEL_PATH", MODEL_PATH)
15
+ FAKE_RESPONSE = "Wow that's very very cool, please try again."
16
+
17
+
18
+ class DebugEngine(BaseEngine):
19
+ """
20
+ It will always yield FAKE_RESPONSE
21
+ """
22
+
23
+ def __init__(self, **kwargs) -> None:
24
+ super().__init__(**kwargs)
25
+ self._model = None
26
+ self._tokenizer = None
27
+
28
+ @property
29
+ def tokenizer(self) -> PreTrainedTokenizer:
30
+ if self._tokenizer is None:
31
+ self._tokenizer = AutoTokenizer.from_pretrained(FAKE_MODEL_PATH, trust_remote_code=True)
32
+ return self._tokenizer
33
+
34
+ def load_model(self):
35
+ print(f"Load fake model with tokenizer: {self.tokenizer}")
36
+
37
+ def generate_yield_string(self, prompt, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs):
38
+
39
+ num_tokens = len(self.tokenizer.encode(prompt))
40
+ response = FAKE_RESPONSE
41
+ for i in range(len(response)):
42
+ time.sleep(0.01)
43
+ yield response[:i], num_tokens
44
+
45
+ num_tokens = len(self.tokenizer.encode(prompt + response))
46
+ yield response, num_tokens
47
+
48
+ def batch_generate(self, prompts, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs):
49
+ return [p + " -- Test" for p in prompts]
multipurpose_chatbot/engines/llama_cpp_engine.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import argparse
4
+ import gradio as gr
5
+ from typing import Any, Iterator
6
+ from typing import Iterator, List, Optional, Tuple
7
+ import filelock
8
+ import glob
9
+ import json
10
+ import time
11
+ from gradio.routes import Request
12
+ from gradio.utils import SyncToAsyncIterator, async_iteration
13
+ from gradio.helpers import special_args
14
+ import anyio
15
+ from typing import AsyncGenerator, Callable, Literal, Union, cast
16
+
17
+ from gradio_client.documentation import document, set_documentation_group
18
+
19
+ from typing import List, Optional, Union, Dict, Tuple
20
+ from tqdm.auto import tqdm
21
+ from huggingface_hub import snapshot_download
22
+ import types
23
+
24
+ from gradio.components import Button
25
+ from gradio.events import Dependency, EventListenerMethod
26
+
27
+ import types
28
+ import sys
29
+
30
+ from .base_engine import BaseEngine
31
+
32
+ # ! Remember to use static cache
33
+
34
+ from ..configs import (
35
+ MODEL_PATH,
36
+ DEFAULT_CHAT_TEMPLATE,
37
+ N_CTX,
38
+ N_GPU_LAYERS,
39
+ )
40
+
41
+
42
+
43
+ def encode_tokenize(self, prompt: str, **kwargs):
44
+ """Mimic behavior of transformers tokenizer"""
45
+ prompt_tokens: List[int] = (
46
+ (
47
+ self.tokenize(prompt.encode("utf-8"), special=True)
48
+ if prompt != ""
49
+ else [self.token_bos()]
50
+ )
51
+ if isinstance(prompt, str)
52
+ else prompt
53
+ )
54
+ return prompt_tokens
55
+
56
+
57
+ conversations = [
58
+ {"role": "system", "content": "You are good."},
59
+ {"role": "user", "content": "Hello."},
60
+ {"role": "assistant", "content": "Hi."},
61
+ ]
62
+
63
+
64
+ class LlamaCppEngine(BaseEngine):
65
+ """
66
+ need to create an engine.tokenizer.encode(text) method
67
+ """
68
+ @property
69
+ def max_position_embeddings(self) -> int:
70
+ # raise ValueError
71
+ return self._model.context_params.n_ctx
72
+
73
+ def apply_chat_template(self, conversations, add_generation_prompt: bool, add_special_tokens=False, **kwargs) -> str:
74
+ """
75
+ return string convo, add_special_tokens should be added later
76
+ remember to remove <s> if any,
77
+ """
78
+ from llama_cpp.llama_chat_format import Jinja2ChatFormatter
79
+
80
+ formatter = Jinja2ChatFormatter(
81
+ template=self._model.metadata['tokenizer.chat_template'],
82
+ # bos_token=self._model._model.token_get_text(self._model.token_bos()),
83
+ bos_token="",
84
+ eos_token=self._model._model.token_get_text(self._model.token_eos()),
85
+ add_generation_prompt=add_generation_prompt,
86
+ )
87
+
88
+ full_prompt = formatter(messages=conversations).prompt
89
+ # ! it may has bos
90
+ return full_prompt
91
+
92
+ @property
93
+ def tokenizer(self):
94
+ return self._model
95
+
96
+ def load_model(self):
97
+ # from transformers import AutoTokenizer, AutoModelForCausalLM
98
+
99
+ from llama_cpp import Llama
100
+ self.model_path = MODEL_PATH
101
+ self._model = Llama(
102
+ model_path=self.model_path,
103
+ n_gpu_layers=N_GPU_LAYERS, # Uncomment to use GPU acceleration
104
+ # seed=1337, # Uncomment to set a specific seed
105
+ n_ctx=N_CTX, # Uncomment to increase the context window
106
+ )
107
+ self._tokenizer = self._model
108
+ self._model.encode = types.MethodType(encode_tokenize, self._model)
109
+ print(f'Load model: {self.model_path=} | {N_GPU_LAYERS=} | {N_CTX=}')
110
+
111
+ def generate_yield_string(self, prompt, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs):
112
+ stop_strings = list(stop_strings) if stop_strings is not None else []
113
+ stop_strings = list(set(stop_strings + ["</s>", "<|im_end|>"]))
114
+ generator = self._model(
115
+ prompt,
116
+ max_tokens=max_tokens, # Generate up to 32 tokens, set to None to generate up to the end of the context window
117
+ temperature=temperature,
118
+ stop=stop_strings, # Stop generating just before the model would generate a new question
119
+ stream=True,
120
+ )
121
+ response = ""
122
+ num_tokens = len(self.tokenizer.encode(prompt))
123
+ for g in generator:
124
+ response += g['choices'][0]['text']
125
+ yield response, num_tokens
126
+
127
+ if response is not None and len(response) > 0:
128
+ num_tokens = len(self.tokenizer.encode(prompt + response))
129
+ yield response, num_tokens
130
+
131
+
multipurpose_chatbot/engines/llava15_transformers_engine.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import numpy as np
4
+ import argparse
5
+ import torch
6
+ import gradio as gr
7
+ from typing import Any, Iterator
8
+ from typing import Iterator, List, Optional, Tuple
9
+ import filelock
10
+ import glob
11
+ import json
12
+ import time
13
+ from gradio.routes import Request
14
+ from gradio.utils import SyncToAsyncIterator, async_iteration
15
+ from gradio.helpers import special_args
16
+ import anyio
17
+ from typing import AsyncGenerator, Callable, Literal, Union, cast
18
+
19
+ from gradio_client.documentation import document, set_documentation_group
20
+
21
+ from typing import List, Optional, Union, Dict, Tuple
22
+ from tqdm.auto import tqdm
23
+ from huggingface_hub import snapshot_download
24
+
25
+ from gradio.components import Button
26
+ from gradio.events import Dependency, EventListenerMethod
27
+ from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
28
+ import types
29
+ import sys
30
+ from .base_engine import BaseEngine
31
+ from .transformers_engine import TransformersEngine, NewGenerationMixin
32
+
33
+ from ..configs import (
34
+ STREAM_CHECK_MULTIPLE,
35
+ STREAM_YIELD_MULTIPLE,
36
+ )
37
+
38
+ from ..configs import (
39
+ STREAM_CHECK_MULTIPLE,
40
+ STREAM_YIELD_MULTIPLE,
41
+ IMAGE_TOKEN,
42
+ IMAGE_TOKEN_INTERACTIVE,
43
+ IMAGE_TOKEN_LENGTH,
44
+ MAX_PACHES,
45
+ DTYPE,
46
+ DEVICE,
47
+ )
48
+
49
+ CODE_PATH = os.environ.get("CODE_PATH", "")
50
+ MODEL_PATH = os.environ.get("MODEL_PATH", "")
51
+
52
+ # IMAGE_TOKEN = "<image"
53
+
54
+ # IMAGE_LENGTH = 576
55
+ # MAX_PACHES = 1
56
+
57
+
58
+ # ! Still working on it....
59
+ # Should only do with
60
+
61
+ """
62
+ This is a chat between an inquisitive human and an AI assistant. Assume the role of the AI assistant. Read all the images carefully, and respond to the human's questions with informative, helpful, detailed and polite answers. 这是一个好奇的人类和一个人工智能助手之间的对话。假设你扮演这个AI助手的角色。仔细阅读所有的图像,并对人类的问题做出信息丰富、有帮助、详细的和礼貌的回答。
63
+
64
+ ### Human: <image_placeholder>
65
+ Describe the cats and what they are doing in detail.
66
+ ### Assistant:
67
+ """
68
+
69
+ # prompt = "USER: <image>\nWhat are these?\nASSISTANT:"
70
+ # image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
71
+
72
+ # conv_llava_llama_2 = Conversation(
73
+ # system="You are a helpful language and vision assistant. "
74
+ # "You are able to understand the visual content that the user provides, "
75
+ # "and assist the user with a variety of tasks using natural language.",
76
+ # roles=("USER", "ASSISTANT"),
77
+ # version="llama_v2",
78
+ # messages=(),
79
+ # offset=0,
80
+ # sep_style=SeparatorStyle.LLAMA_2,
81
+ # sep="<s>",
82
+ # sep2="</s>",
83
+ # )
84
+
85
+
86
+ LLAVA_CHAT_TEMPLATE = """"""
87
+
88
+
89
+ # "chat_template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '</s>'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
90
+
91
+
92
+ if IMAGE_TOKEN != "<image>":
93
+ print(f'WARNING!!!! {IMAGE_TOKEN=} is not <image>, this can lead to problems')
94
+
95
+
96
+ class Llava15TransformersEngine(TransformersEngine):
97
+ """
98
+ Llava 1.5 hardcoded
99
+ """
100
+ @property
101
+ def image_token(self):
102
+ return IMAGE_TOKEN
103
+
104
+ @property
105
+ def max_position_embeddings(self) -> int:
106
+ return self._model.config.text_config.max_position_embeddings
107
+
108
+ @property
109
+ def tokenizer(self):
110
+ return self._tokenizer
111
+
112
+ @property
113
+ def processor(self):
114
+ return self._processor
115
+
116
+
117
+ def apply_chat_template(self, conversations, add_generation_prompt: bool, add_special_tokens=False, **kwargs) -> str:
118
+ """
119
+ return string convo, add_special_tokens should be added later
120
+ """
121
+ prompt = ""
122
+ for turn in conversations:
123
+ if turn['role'] == 'system':
124
+ prompt += turn['content'] + "\n\n"
125
+ elif turn['role'] == 'user':
126
+ prompt += f"USER: {turn['content']}\n"
127
+ elif turn['role'] == 'assistant':
128
+ prompt += f"ASSISTANT: {turn['content']}\n"
129
+ if add_generation_prompt:
130
+ prompt += f"ASSISTANT:"
131
+ return prompt
132
+
133
+
134
+ def load_model(self):
135
+ import requests
136
+ from PIL import Image
137
+ from transformers import AutoProcessor, LlavaForConditionalGeneration
138
+
139
+ self.model_path = model_path = MODEL_PATH
140
+ self.torch_dtype = torch.bfloat16 if DTYPE == 'bfloat16' else torch.float16
141
+ self.device_map = DEVICE
142
+ print(f'Loading model from {model_path} on {self.device_map} with {self.torch_dtype} | LlavaForConditionalGeneration')
143
+
144
+ self._processor = AutoProcessor.from_pretrained(self.model_path)
145
+ self._model = LlavaForConditionalGeneration.from_pretrained(
146
+ MODEL_PATH,
147
+ torch_dtype=self.torch_dtype, device_map=self.device_map, trust_remote_code=True
148
+ ).eval()
149
+ self._model.sample_old = self._model.sample
150
+ # self._model.sample = types.MethodType(NewGenerationMixin.sample_stream, self._model)
151
+ self._model._sample = types.MethodType(NewGenerationMixin.sample_stream, self._model)
152
+
153
+ self._tokenizer = self._processor.tokenizer
154
+ print(self._model)
155
+ print(f"{self.max_position_embeddings=}")
156
+
157
+ def get_multimodal_tokens(self, full_prompt, image_paths=None):
158
+ num_tokens = len(self.tokenizer.encode(full_prompt))
159
+ for image_path in image_paths:
160
+ num_tokens += IMAGE_TOKEN_LENGTH * MAX_PACHES
161
+ return num_tokens
162
+
163
+ def generate_yield_string(self, prompt, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs):
164
+ from transformers.generation.utils import GenerationConfig
165
+ from PIL import Image
166
+ image_paths = kwargs.get("image_paths", None)
167
+ image_paths = image_paths or []
168
+
169
+ images = [Image.open(x) for x in image_paths] if len(image_paths) > 0 else None
170
+
171
+ with torch.no_grad():
172
+ inputs = self.processor(prompt, images, return_tensors='pt')
173
+ # inputs = inputs.to("cuda", torch.bfloat16)
174
+ inputs = {k: v.to(self.device_map) for k, v in inputs.items() if v is not None}
175
+ num_tokens = self.get_multimodal_tokens(prompt, image_paths)
176
+ # non-streaming generation
177
+ # output = self._model.generate(
178
+ # **inputs,
179
+ # do_sample=True,
180
+ # temperature=temperature,
181
+ # max_new_tokens=max_tokens,
182
+ # pad_token_id=self.processor.tokenizer.pad_token_id,
183
+ # )
184
+ # # response = self.processor.tokenizer.decode(output[0][-inputs.input_ids.size(-1):], skip_special_tokens=True)
185
+ # full_output_text = self.processor.decode(output[0], skip_special_tokens=True)
186
+ # response = full_output_text.split("<|im_start|>assistant\n")[-1]
187
+ # num_tokens = self.get_multimodal_tokens(prompt + response, image_paths)
188
+ # print(prompt)
189
+ # print(response)
190
+ # print(num_tokens)
191
+ # yield response, num_tokens
192
+
193
+ # if i % 4 == 0 and i > 1:
194
+ # message_safety = safety_check(response)
195
+ # if message_safety is not None:
196
+ # history = undo_history(history)
197
+ # yield history, "", None
198
+ # raise gr.Error(message_safety)
199
+
200
+ # # ! streaming
201
+ generator = self._model.generate(
202
+ **inputs,
203
+ do_sample=True,
204
+ temperature=temperature,
205
+ max_new_tokens=max_tokens,
206
+ pad_token_id=self.processor.tokenizer.pad_token_id,
207
+ )
208
+
209
+ out_tokens = []
210
+ response = None
211
+ for index, token in enumerate(generator):
212
+ out_tokens.append(token.item())
213
+ response = self.processor.tokenizer.decode(out_tokens)
214
+
215
+ yield response, num_tokens
216
+
217
+ del generator
218
+
219
+ if response is not None:
220
+
221
+ full_text = prompt + response
222
+ num_tokens = self.get_multimodal_tokens(full_text, image_paths)
223
+ yield response, num_tokens
224
+
225
+ # raw_image = Image.open(requests.get(image_file, stream=True).raw)
226
+ # inputs = processor(prompt, raw_image, return_tensors='pt').to(0, torch.float16)
227
+
228
+
229
+
230
+
multipurpose_chatbot/engines/llava_llama_cpp_engine.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import argparse
4
+ import gradio as gr
5
+ from typing import Any, Iterator
6
+ from typing import Iterator, List, Optional, Tuple
7
+ import filelock
8
+ import glob
9
+ import json
10
+ import time
11
+ from gradio.routes import Request
12
+ from gradio.utils import SyncToAsyncIterator, async_iteration
13
+ from gradio.helpers import special_args
14
+ import anyio
15
+ from typing import AsyncGenerator, Callable, Literal, Union, cast
16
+
17
+ from gradio_client.documentation import document, set_documentation_group
18
+
19
+ from typing import List, Optional, Union, Dict, Tuple
20
+ from tqdm.auto import tqdm
21
+ from huggingface_hub import snapshot_download
22
+ import types
23
+
24
+ from gradio.components import Button
25
+ from gradio.events import Dependency, EventListenerMethod
26
+
27
+ import types
28
+ import sys
29
+
30
+ from .base_engine import BaseEngine
31
+
32
+ # ! Remember to use static cache
33
+
34
+ from ..configs import (
35
+ MODEL_PATH,
36
+ DEFAULT_CHAT_TEMPLATE,
37
+ N_CTX,
38
+ N_GPU_LAYERS,
39
+ IMAGE_TOKEN,
40
+ IMAGE_TOKEN_INTERACTIVE,
41
+ IMAGE_TOKEN_LENGTH,
42
+ MAX_PACHES,
43
+ )
44
+
45
+ from .llama_cpp_engine import (
46
+ encode_tokenize,
47
+ LlamaCppEngine,
48
+ )
49
+
50
+
51
+
52
+ # resource: https://llama-cpp-python.readthedocs.io/en/latest/#multi-modal-models
53
+
54
+ import base64
55
+
56
+ def image_to_base64_data_uri(file_path):
57
+ with open(file_path, "rb") as img_file:
58
+ base64_data = base64.b64encode(img_file.read()).decode('utf-8')
59
+ return f"data:image/png;base64,{base64_data}"
60
+
61
+
62
+ # file_path = 'file_path.png'
63
+ # data_uri = image_to_base64_data_uri(file_path)
64
+
65
+ # data_uri = image_to_base64_data_uri(file_path)
66
+
67
+ # messages = [
68
+ # {"role": "system", "content": "You are an assistant who perfectly describes images."},
69
+ # {
70
+ # "role": "user",
71
+ # "content": [
72
+ # {"type": "image_url", "image_url": {"url": data_uri }},
73
+ # {"type" : "text", "text": "Describe this image in detail please."}
74
+ # ]
75
+ # }
76
+ # ]
77
+
78
+
79
+ def llava_15_chat_handler_call(
80
+ self,
81
+ *,
82
+ llama: Any,
83
+ # messages: List[Any],
84
+ prompt: Union[str, List[int]],
85
+ image_data_uris: Optional[List[Any]] = None,
86
+ image_token: str = None,
87
+ functions: Optional[List[Any]] = None,
88
+ function_call: Optional[Any] = None,
89
+ tools: Optional[List[Any]] = None,
90
+ tool_choice: Optional[Any] = None,
91
+ temperature: float = 0.2,
92
+ top_p: float = 0.95,
93
+ top_k: int = 40,
94
+ min_p: float = 0.05,
95
+ typical_p: float = 1.0,
96
+ stream: bool = False,
97
+ stop: Optional[Union[str, List[str]]] = [],
98
+ response_format: Optional[
99
+ Any
100
+ ] = None,
101
+ max_tokens: Optional[int] = None,
102
+ presence_penalty: float = 0.0,
103
+ frequency_penalty: float = 0.0,
104
+ repeat_penalty: float = 1.1,
105
+ tfs_z: float = 1.0,
106
+ mirostat_mode: int = 0,
107
+ mirostat_tau: float = 5.0,
108
+ mirostat_eta: float = 0.1,
109
+ model: Optional[str] = None,
110
+ logits_processor: Optional[Any] = None,
111
+ grammar: Optional[Any] = None,
112
+ **kwargs, # type: ignore
113
+ ):
114
+ from llama_cpp.llama_chat_format import (
115
+ ctypes,
116
+ suppress_stdout_stderr,
117
+ )
118
+ assert (
119
+ llama.context_params.logits_all is True
120
+ ) # BUG: logits_all=True is required for llava
121
+ assert self.clip_ctx is not None
122
+ # ! split prompt into different parts
123
+ assert image_token is not None
124
+ prompt_parts = prompt.split(image_token)
125
+ # assert len(prompt_parts)
126
+ assert len(prompt_parts) == len(image_data_uris) + 1, f'invalid {len(prompt_parts)=} != {len(image_data_uris)=}'
127
+ llama.reset()
128
+ prefix = prompt_parts[0]
129
+ remaining_texts = prompt_parts[1:]
130
+ llama.reset()
131
+ llama.eval(llama.tokenize(prefix.encode("utf8"), add_bos=True))
132
+ for index, (image_uri, prompt_p) in enumerate(zip(image_data_uris, remaining_texts)):
133
+ image_bytes = self.load_image(image_uri)
134
+ import array
135
+ data_array = array.array("B", image_bytes)
136
+ c_ubyte_ptr = (
137
+ ctypes.c_ubyte * len(data_array)
138
+ ).from_buffer(data_array)
139
+ with suppress_stdout_stderr(disable=self.verbose):
140
+ embed = (
141
+ self._llava_cpp.llava_image_embed_make_with_bytes(
142
+ self.clip_ctx,
143
+ llama.context_params.n_threads,
144
+ c_ubyte_ptr,
145
+ len(image_bytes),
146
+ )
147
+ )
148
+ try:
149
+ n_past = ctypes.c_int(llama.n_tokens)
150
+ n_past_p = ctypes.pointer(n_past)
151
+ with suppress_stdout_stderr(disable=self.verbose):
152
+ self._llava_cpp.llava_eval_image_embed(
153
+ llama.ctx,
154
+ embed,
155
+ llama.n_batch,
156
+ n_past_p,
157
+ )
158
+ assert llama.n_ctx() >= n_past.value
159
+ llama.n_tokens = n_past.value
160
+ finally:
161
+ with suppress_stdout_stderr(disable=self.verbose):
162
+ self._llava_cpp.llava_image_embed_free(embed)
163
+
164
+ llama.eval(llama.tokenize(prompt_p.encode("utf8"), add_bos=False))
165
+ assert llama.n_ctx() >= llama.n_tokens
166
+
167
+ prompt = llama.input_ids[: llama.n_tokens].tolist()
168
+ # from llava-1.5
169
+ return llama.create_completion(
170
+ prompt=prompt,
171
+ temperature=temperature,
172
+ top_p=top_p,
173
+ top_k=top_k,
174
+ min_p=min_p,
175
+ typical_p=typical_p,
176
+ stream=stream,
177
+ stop=stop,
178
+ max_tokens=max_tokens,
179
+ presence_penalty=presence_penalty,
180
+ frequency_penalty=frequency_penalty,
181
+ repeat_penalty=repeat_penalty,
182
+ tfs_z=tfs_z,
183
+ mirostat_mode=mirostat_mode,
184
+ mirostat_tau=mirostat_tau,
185
+ mirostat_eta=mirostat_eta,
186
+ model=model,
187
+ logits_processor=logits_processor,
188
+ grammar=grammar,
189
+ )
190
+
191
+
192
+
193
+ class LlavaLlamaCppEngine(LlamaCppEngine):
194
+ """
195
+ Still in development, expect BUGS
196
+
197
+ ERROR: could not know why
198
+ objc[61055]: Class GGMLMetalClass is implemented in both miniconda3/envs/native/lib/python3.12/site-packages/llama_cpp/libllama.dylib (0x12cb40290) and miniconda3/envs/native/lib/python3.12/site-packages/llama_cpp/libllava.dylib (0x12d9c8290). One of the two will be used. Which one is undefined.
199
+
200
+ """
201
+ @property
202
+ def image_token(self):
203
+ return IMAGE_TOKEN
204
+
205
+ def get_multimodal_tokens(self, full_prompt, image_paths=None):
206
+ num_tokens = len(self.tokenizer.encode(full_prompt))
207
+ for image_path in image_paths:
208
+ num_tokens += IMAGE_TOKEN_LENGTH * MAX_PACHES
209
+ return num_tokens
210
+
211
+ def load_model(self):
212
+ # from transformers import AutoTokenizer, AutoModelForCausalLM
213
+ from llama_cpp import Llama
214
+ from llama_cpp.llama_chat_format import Llava15ChatHandler
215
+ model_dir = os.path.dirname(MODEL_PATH)
216
+ self.chat_handler = Llava15ChatHandler(clip_model_path=os.path.join(model_dir, "mmproj.bin"))
217
+
218
+ self.chat_handler.__call__ = types.MethodType(llava_15_chat_handler_call, self.chat_handler)
219
+
220
+ self.model_path = MODEL_PATH
221
+ self._model = Llama(
222
+ model_path=self.model_path,
223
+ n_gpu_layers=N_GPU_LAYERS, # Uncomment to use GPU acceleration
224
+ # seed=1337, # Uncomment to set a specific seed
225
+ chat_handler=self.chat_handler,
226
+ n_ctx=N_CTX, # Uncomment to increase the context window
227
+ logits_all=True, # needed to make llava work
228
+ )
229
+ self._tokenizer = self._model
230
+ self._model.encode = types.MethodType(encode_tokenize, self._model)
231
+ print(f'Load model: {self.model_path=} | {N_GPU_LAYERS=} | {N_CTX=}')
232
+
233
+ def generate_yield_string(self, prompt, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs):
234
+ image_paths = kwargs.get("image_paths", [])
235
+
236
+ image_data_uris = [
237
+ image_to_base64_data_uri(ip)
238
+ for ip in image_paths
239
+ ]
240
+
241
+ stop_strings = list(stop_strings) if stop_strings is not None else []
242
+ stop_strings = list(set(stop_strings + ["</s>", "<|im_end|>"]))
243
+ # generator = self._model(
244
+ generator = self.chat_handler(
245
+ prompt=prompt,
246
+ image_data_uris=image_data_uris,
247
+ image_token=self.image_token,
248
+ max_tokens=max_tokens, # Generate up to 32 tokens, set to None to generate up to the end of the context window
249
+ temperature=temperature,
250
+ stop=stop_strings, # Stop generating just before the model would generate a new question
251
+ stream=True,
252
+ )
253
+ response = ""
254
+ num_tokens = len(self.tokenizer.encode(prompt))
255
+ for g in generator:
256
+ response += g['choices'][0]['text']
257
+ yield response, num_tokens
258
+
259
+ if response is not None and len(response) > 0:
260
+ num_tokens = len(self.tokenizer.encode(prompt + response))
261
+ yield response, num_tokens
262
+
263
+
264
+ """
265
+
266
+ export MODEL_PATH
267
+ BACKEND=llama_cpp
268
+ MODEL_PATH=/Users/nguyenxuanphi/Desktop/projects/cache/seallms/SeaLLMs/SeaLLM-7B-v2-gguf/seallm-v2.chatml.Q4_K_M.gguf
269
+ N_CTX=4096
270
+ python app.py
271
+
272
+
273
+ export BACKEND=llava_llama_cpp
274
+ export MODEL_PATH=/Users/nguyenxuanphi/Desktop/projects/cache/llava/llava-1.5/ggml-model-q4_k.gguf
275
+ export N_CTX=4096
276
+ export IMAGE_TOKEN="<image>"
277
+ python app.py
278
+
279
+
280
+ """
multipurpose_chatbot/engines/mlx_engine.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+ from huggingface_hub import snapshot_download
6
+ from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
7
+ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
8
+ import time
9
+ from mlx_lm import load, generate
10
+ from mlx_lm.utils import generate_step
11
+
12
+ from .base_engine import BaseEngine
13
+
14
+ from ..configs import (
15
+ MODEL_PATH,
16
+ )
17
+
18
+ def generate_string(
19
+ model: nn.Module,
20
+ tokenizer: PreTrainedTokenizer,
21
+ prompt: str,
22
+ temp: float = 0.0,
23
+ max_tokens: int = 100,
24
+ verbose: bool = False,
25
+ formatter: Callable = None,
26
+ repetition_penalty: Optional[float] = None,
27
+ repetition_context_size: Optional[int] = None,
28
+ stop_strings: Optional[Tuple[str]] = None
29
+ ):
30
+ prompt_tokens = mx.array(tokenizer.encode(prompt))
31
+ stop_strings = stop_strings if stop_strings is None or isinstance(stop_strings, tuple) else tuple(stop_strings)
32
+ assert stop_strings is None or isinstance(stop_strings, tuple), f'invalid {stop_strings}'
33
+
34
+ tic = time.perf_counter()
35
+ tokens = []
36
+ skip = 0
37
+ REPLACEMENT_CHAR = "\ufffd"
38
+
39
+ for (token, prob), n in zip(
40
+ generate_step(
41
+ prompt_tokens,
42
+ model,
43
+ temp,
44
+ repetition_penalty,
45
+ repetition_context_size,
46
+ ),
47
+ range(max_tokens),
48
+ ):
49
+ if token == tokenizer.eos_token_id:
50
+ break
51
+ if n == 0:
52
+ prompt_time = time.perf_counter() - tic
53
+ tic = time.perf_counter()
54
+ tokens.append(token.item())
55
+ if stop_strings is not None:
56
+ token_string = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "")
57
+ if token_string.strip().endswith(stop_strings):
58
+ break
59
+ token_string = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "")
60
+ return token_string
61
+
62
+
63
+
64
+ def generate_yield_string(
65
+ model: nn.Module,
66
+ tokenizer: PreTrainedTokenizer,
67
+ prompt: str,
68
+ temp: float = 0.0,
69
+ max_tokens: int = 100,
70
+ verbose: bool = False,
71
+ formatter: Callable = None,
72
+ repetition_penalty: Optional[float] = None,
73
+ repetition_context_size: Optional[int] = None,
74
+ stop_strings: Optional[Tuple[str]] = None
75
+ ):
76
+ """
77
+ Generate text from the model.
78
+ Args:
79
+ model (nn.Module): The language model.
80
+ tokenizer (PreTrainedTokenizer): The tokenizer.
81
+ prompt (str): The string prompt.
82
+ temp (float): The temperature for sampling (default 0).
83
+ max_tokens (int): The maximum number of tokens (default 100).
84
+ verbose (bool): If ``True``, print tokens and timing information
85
+ (default ``False``).
86
+ formatter (Optional[Callable]): A function which takes a token and a
87
+ probability and displays it.
88
+ repetition_penalty (float, optional): The penalty factor for repeating tokens.
89
+ repetition_context_size (int, optional): The number of tokens to consider for repetition penalty.
90
+ """
91
+ if verbose:
92
+ print("=" * 10)
93
+ print("Prompt:", prompt)
94
+ stop_strings = stop_strings if stop_strings is None or isinstance(stop_strings, tuple) else tuple(stop_strings)
95
+ assert stop_strings is None or isinstance(stop_strings, tuple), f'invalid {stop_strings}'
96
+ prompt_tokens = mx.array(tokenizer.encode(prompt))
97
+ tic = time.perf_counter()
98
+ tokens = []
99
+ skip = 0
100
+ REPLACEMENT_CHAR = "\ufffd"
101
+ for (token, prob), n in zip(
102
+ generate_step(
103
+ prompt_tokens,
104
+ model,
105
+ temp,
106
+ repetition_penalty,
107
+ repetition_context_size,
108
+ ),
109
+ range(max_tokens),
110
+ ):
111
+ if token == tokenizer.eos_token_id:
112
+ break
113
+ # if n == 0:
114
+ # prompt_time = time.perf_counter() - tic
115
+ # tic = time.perf_counter()
116
+ tokens.append(token.item())
117
+ # if verbose:
118
+ # s = tokenizer.decode(tokens)
119
+ # if formatter:
120
+ # formatter(s[skip:], prob.item())
121
+ # skip = len(s)
122
+ # elif REPLACEMENT_CHAR not in s:
123
+ # print(s[skip:], end="", flush=True)
124
+ # skip = len(s)
125
+ token_string = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "")
126
+ yield token_string
127
+ if stop_strings is not None and token_string.strip().endswith(stop_strings):
128
+ break
129
+
130
+ # token_count = len(tokens)
131
+ # token_string = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "")
132
+
133
+ # if verbose:
134
+ # print(token_string[skip:], flush=True)
135
+ # gen_time = time.perf_counter() - tic
136
+ # print("=" * 10)
137
+ # if token_count == 0:
138
+ # print("No tokens generated for this prompt")
139
+ # return
140
+ # prompt_tps = prompt_tokens.size / prompt_time
141
+ # gen_tps = (token_count - 1) / gen_time
142
+ # print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
143
+ # print(f"Generation: {gen_tps:.3f} tokens-per-sec")
144
+
145
+ # return token_string
146
+
147
+
148
+ class MlxEngine(BaseEngine):
149
+
150
+ def __init__(self, **kwargs) -> None:
151
+ super().__init__(**kwargs)
152
+ self._model = None
153
+ self._tokenizer = None
154
+
155
+ @property
156
+ def tokenizer(self) -> PreTrainedTokenizer:
157
+ return self._tokenizer
158
+
159
+ def load_model(self, ):
160
+ model_path = MODEL_PATH
161
+ self._model, self._tokenizer = load(model_path)
162
+ self.model_path = model_path
163
+ print(f'Load MLX model from {model_path}')
164
+
165
+
166
+ def generate_yield_string(self, prompt, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs):
167
+ num_tokens = len(self.tokenizer.encode(prompt))
168
+ response = None
169
+ for response in generate_yield_string(
170
+ self._model, self._tokenizer,
171
+ prompt, temp=temperature, max_tokens=max_tokens,
172
+ repetition_penalty=kwargs.get("repetition_penalty", None),
173
+ stop_strings=stop_strings,
174
+ ):
175
+ yield response, num_tokens
176
+ if response is not None:
177
+ full_text = prompt + response
178
+ num_tokens = len(self.tokenizer.encode(full_text))
179
+ yield response, num_tokens
180
+
181
+ def batch_generate(self, prompts, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs):
182
+ """
183
+ ! MLX does not support
184
+ """
185
+ responses = [
186
+ generate_string(
187
+ self._model, self._tokenizer,
188
+ s, temp=temperature, max_tokens=max_tokens,
189
+ repetition_penalty=kwargs.get("repetition_penalty", None),
190
+ stop_strings=stop_strings,
191
+ )
192
+ for s in prompts
193
+ ]
194
+ return responses
195
+
196
+
197
+
198
+
199
+
200
+
201
+
202
+
multipurpose_chatbot/engines/transformers_engine.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import numpy as np
4
+ import argparse
5
+ import torch
6
+ import gradio as gr
7
+ from typing import Any, Iterator
8
+ from typing import Iterator, List, Optional, Tuple
9
+ import filelock
10
+ import glob
11
+ import json
12
+ import time
13
+ from gradio.routes import Request
14
+ from gradio.utils import SyncToAsyncIterator, async_iteration
15
+ from gradio.helpers import special_args
16
+ import anyio
17
+ from typing import AsyncGenerator, Callable, Literal, Union, cast
18
+
19
+ from gradio_client.documentation import document, set_documentation_group
20
+
21
+ from typing import List, Optional, Union, Dict, Tuple
22
+ from tqdm.auto import tqdm
23
+ from huggingface_hub import snapshot_download
24
+ import types
25
+
26
+ from gradio.components import Button
27
+ from gradio.events import Dependency, EventListenerMethod
28
+
29
+ from .base_engine import BaseEngine
30
+
31
+ # ! Remember to use static cache
32
+
33
+ from transformers import (
34
+ GenerationConfig,
35
+ GenerationMixin,
36
+ LogitsProcessorList,
37
+ StoppingCriteriaList,
38
+ DisjunctiveConstraint,
39
+ BeamSearchScorer,
40
+ PhrasalConstraint,
41
+ ConstrainedBeamSearchScorer,
42
+ PreTrainedModel,
43
+ )
44
+ import numpy as np
45
+ import random
46
+ import warnings
47
+ import inspect
48
+ from transformers.generation.utils import GenerateOutput, SampleOutput, logger
49
+ import torch
50
+ from typing import Callable, List, Optional, Union
51
+ from torch import nn
52
+ import torch.distributed as dist
53
+ import copy
54
+
55
+ from ..configs import (
56
+ MODEL_PATH,
57
+ DTYPE,
58
+ DEVICE,
59
+ )
60
+
61
+
62
+ def setup_seed(seed):
63
+ if seed == -1:
64
+ return
65
+ torch.manual_seed(seed)
66
+ if torch.cuda.is_available():
67
+ torch.cuda.manual_seed_all(seed)
68
+ np.random.seed(seed)
69
+ random.seed(seed)
70
+ torch.backends.cudnn.deterministic = True
71
+
72
+
73
+ class NewGenerationMixin(GenerationMixin):
74
+ """
75
+ Allow generator sampling
76
+
77
+ """
78
+
79
+ # ! Copy from transformers.generation.utils -> GenerationMixin
80
+ # Change sample function to sample_stream
81
+ @torch.no_grad()
82
+ def sample_stream(
83
+ self,
84
+ input_ids: torch.LongTensor,
85
+ logits_processor: Optional[LogitsProcessorList] = None,
86
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
87
+ logits_warper: Optional[LogitsProcessorList] = None,
88
+ max_length: Optional[int] = None,
89
+ pad_token_id: Optional[int] = None,
90
+ eos_token_id: Optional[Union[int, List[int]]] = None,
91
+ output_attentions: Optional[bool] = None,
92
+ output_hidden_states: Optional[bool] = None,
93
+ output_scores: Optional[bool] = None,
94
+ output_logits: Optional[bool] = None,
95
+ return_dict_in_generate: Optional[bool] = None,
96
+ synced_gpus: bool = False,
97
+ streamer: Optional["BaseStreamer"] = None,
98
+ **model_kwargs,
99
+ ):
100
+ r"""
101
+ Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
102
+ can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
103
+
104
+ <Tip warning={true}>
105
+
106
+ In most cases, you do not need to call [`~generation.GenerationMixin.sample`] directly. Use generate() instead.
107
+ For an overview of generation strategies and code examples, check the [following
108
+ guide](../generation_strategies).
109
+
110
+ </Tip>
111
+
112
+ Parameters:
113
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
114
+ The sequence used as a prompt for the generation.
115
+ logits_processor (`LogitsProcessorList`, *optional*):
116
+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
117
+ used to modify the prediction scores of the language modeling head applied at each generation step.
118
+ stopping_criteria (`StoppingCriteriaList`, *optional*):
119
+ An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
120
+ used to tell if the generation loop should stop.
121
+ logits_warper (`LogitsProcessorList`, *optional*):
122
+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
123
+ to warp the prediction score distribution of the language modeling head applied before multinomial
124
+ sampling at each generation step.
125
+ max_length (`int`, *optional*, defaults to 20):
126
+ **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
127
+ tokens. The maximum length of the sequence to be generated.
128
+ pad_token_id (`int`, *optional*):
129
+ The id of the *padding* token.
130
+ eos_token_id (`Union[int, List[int]]`, *optional*):
131
+ The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
132
+ output_attentions (`bool`, *optional*, defaults to `False`):
133
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
134
+ returned tensors for more details.
135
+ output_hidden_states (`bool`, *optional*, defaults to `False`):
136
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
137
+ for more details.
138
+ output_scores (`bool`, *optional*, defaults to `False`):
139
+ Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
140
+ output_logits (`bool`, *optional*, defaults to `False`):
141
+ Whether or not to return the raw prediction logit scores. See `logits` under returned tensors for
142
+ more details.
143
+ return_dict_in_generate (`bool`, *optional*, defaults to `False`):
144
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
145
+ synced_gpus (`bool`, *optional*, defaults to `False`):
146
+ Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
147
+ streamer (`BaseStreamer`, *optional*):
148
+ Streamer object that will be used to stream the generated sequences. Generated tokens are passed
149
+ through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
150
+ model_kwargs:
151
+ Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
152
+ an encoder-decoder model the kwargs should include `encoder_outputs`.
153
+
154
+ Return:
155
+ [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`:
156
+ A `torch.LongTensor` containing the generated tokens (default behaviour) or a
157
+ [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
158
+ `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
159
+ `model.config.is_encoder_decoder=True`.
160
+
161
+ Examples:
162
+
163
+ ```python
164
+ >>> from transformers import (
165
+ ... AutoTokenizer,
166
+ ... AutoModelForCausalLM,
167
+ ... LogitsProcessorList,
168
+ ... MinLengthLogitsProcessor,
169
+ ... TopKLogitsWarper,
170
+ ... TemperatureLogitsWarper,
171
+ ... StoppingCriteriaList,
172
+ ... MaxLengthCriteria,
173
+ ... )
174
+ >>> import torch
175
+
176
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
177
+ >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
178
+
179
+ >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token
180
+ >>> model.config.pad_token_id = model.config.eos_token_id
181
+ >>> model.generation_config.pad_token_id = model.config.eos_token_id
182
+
183
+ >>> input_prompt = "Today is a beautiful day, and"
184
+ >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids
185
+
186
+ >>> # instantiate logits processors
187
+ >>> logits_processor = LogitsProcessorList(
188
+ ... [
189
+ ... MinLengthLogitsProcessor(15, eos_token_id=model.generation_config.eos_token_id),
190
+ ... ]
191
+ ... )
192
+ >>> # instantiate logits processors
193
+ >>> logits_warper = LogitsProcessorList(
194
+ ... [
195
+ ... TopKLogitsWarper(50),
196
+ ... TemperatureLogitsWarper(0.7),
197
+ ... ]
198
+ ... )
199
+
200
+ >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
201
+
202
+ >>> torch.manual_seed(0) # doctest: +IGNORE_RESULT
203
+ >>> outputs = model.sample(
204
+ ... input_ids,
205
+ ... logits_processor=logits_processor,
206
+ ... logits_warper=logits_warper,
207
+ ... stopping_criteria=stopping_criteria,
208
+ ... )
209
+
210
+ >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
211
+ ['Today is a beautiful day, and we must do everything possible to make it a day of celebration.']
212
+ ```"""
213
+ # init values
214
+ from transformers.generation.utils import (
215
+ validate_stopping_criteria, GenerateEncoderDecoderOutput, GenerateDecoderOnlyOutput
216
+ )
217
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
218
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
219
+ if max_length is not None:
220
+ warnings.warn(
221
+ "`max_length` is deprecated in this function, use"
222
+ " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
223
+ UserWarning,
224
+ )
225
+ stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
226
+ logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
227
+ pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
228
+ eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
229
+ if isinstance(eos_token_id, int):
230
+ eos_token_id = [eos_token_id]
231
+ eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
232
+ output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
233
+ output_logits = output_logits if output_logits is not None else self.generation_config.output_logits
234
+ output_attentions = (
235
+ output_attentions if output_attentions is not None else self.generation_config.output_attentions
236
+ )
237
+ output_hidden_states = (
238
+ output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
239
+ )
240
+ return_dict_in_generate = (
241
+ return_dict_in_generate
242
+ if return_dict_in_generate is not None
243
+ else self.generation_config.return_dict_in_generate
244
+ )
245
+
246
+ # init attention / hidden states / scores tuples
247
+ scores = () if (return_dict_in_generate and output_scores) else None
248
+ raw_logits = () if (return_dict_in_generate and output_logits) else None
249
+ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
250
+ cross_attentions = () if (return_dict_in_generate and output_attentions) else None
251
+ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
252
+
253
+ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
254
+ if return_dict_in_generate and self.config.is_encoder_decoder:
255
+ encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
256
+ encoder_hidden_states = (
257
+ model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
258
+ )
259
+ # keep track of which sequences are already finished
260
+ unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
261
+
262
+ this_peer_finished = False # used by synced_gpus only
263
+ # auto-regressive generation
264
+ while True:
265
+ if synced_gpus:
266
+ # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
267
+ # The following logic allows an early break if all peers finished generating their sequence
268
+ this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
269
+ # send 0.0 if we finished, 1.0 otherwise
270
+ dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
271
+ # did all peers finish? the reduced sum will be 0.0 then
272
+ if this_peer_finished_flag.item() == 0.0:
273
+ break
274
+
275
+ # prepare model inputs
276
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
277
+
278
+ # forward pass to get next token
279
+ outputs = self(
280
+ **model_inputs,
281
+ return_dict=True,
282
+ output_attentions=output_attentions,
283
+ output_hidden_states=output_hidden_states,
284
+ )
285
+
286
+ if synced_gpus and this_peer_finished:
287
+ continue # don't waste resources running the code we don't need
288
+
289
+ next_token_logits = outputs.logits[:, -1, :]
290
+
291
+ # pre-process distribution
292
+ next_token_scores = logits_processor(input_ids, next_token_logits)
293
+ next_token_scores = logits_warper(input_ids, next_token_scores)
294
+
295
+ # Store scores, attentions and hidden_states when required
296
+ if return_dict_in_generate:
297
+ if output_scores:
298
+ scores += (next_token_scores,)
299
+ if output_logits:
300
+ raw_logits += (next_token_logits,)
301
+ if output_attentions:
302
+ decoder_attentions += (
303
+ (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
304
+ )
305
+ if self.config.is_encoder_decoder:
306
+ cross_attentions += (outputs.cross_attentions,)
307
+
308
+ if output_hidden_states:
309
+ decoder_hidden_states += (
310
+ (outputs.decoder_hidden_states,)
311
+ if self.config.is_encoder_decoder
312
+ else (outputs.hidden_states,)
313
+ )
314
+
315
+ # sample
316
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
317
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
318
+
319
+ # finished sentences should have their next token be a padding token
320
+ if eos_token_id is not None:
321
+ if pad_token_id is None:
322
+ raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
323
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
324
+
325
+ yield next_tokens.cpu()
326
+
327
+ # update generated ids, model inputs, and length for next step
328
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
329
+ if streamer is not None:
330
+ streamer.put(next_tokens.cpu())
331
+
332
+ next_model_inputs = {}
333
+ if "cache_position" in model_inputs:
334
+ next_model_inputs['cache_position'] = model_inputs['cache_position']
335
+
336
+ try:
337
+ model_kwargs = self._update_model_kwargs_for_generation(
338
+ outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder,
339
+ # model_inputs=model_inputs
340
+ model_inputs=next_model_inputs,
341
+ )
342
+ except Exception as e:
343
+ # Older version dont have model_inputs
344
+ model_kwargs = self._update_model_kwargs_for_generation(
345
+ outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder,
346
+ )
347
+
348
+
349
+ # if eos_token was found in one sentence, set sentence to finished
350
+ if eos_token_id_tensor is not None:
351
+ unfinished_sequences = unfinished_sequences.mul(
352
+ next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
353
+ )
354
+
355
+ # stop when each sentence is finished
356
+ if unfinished_sequences.max() == 0:
357
+ this_peer_finished = True
358
+
359
+ # stop if we exceed the maximum length
360
+ if stopping_criteria(input_ids, scores):
361
+ this_peer_finished = True
362
+
363
+ if this_peer_finished and not synced_gpus:
364
+ break
365
+
366
+ if streamer is not None:
367
+ streamer.end()
368
+
369
+ # if return_dict_in_generate:
370
+ # if self.config.is_encoder_decoder:
371
+ # return GenerateEncoderDecoderOutput(
372
+ # sequences=input_ids,
373
+ # scores=scores,
374
+ # logits=raw_logits,
375
+ # encoder_attentions=encoder_attentions,
376
+ # encoder_hidden_states=encoder_hidden_states,
377
+ # decoder_attentions=decoder_attentions,
378
+ # cross_attentions=cross_attentions,
379
+ # decoder_hidden_states=decoder_hidden_states,
380
+ # past_key_values=model_kwargs.get("past_key_values"),
381
+ # )
382
+ # else:
383
+ # return GenerateDecoderOnlyOutput(
384
+ # sequences=input_ids,
385
+ # scores=scores,
386
+ # logits=raw_logits,
387
+ # attentions=decoder_attentions,
388
+ # hidden_states=decoder_hidden_states,
389
+ # past_key_values=model_kwargs.get("past_key_values"),
390
+ # )
391
+ # else:
392
+ # return input_ids
393
+
394
+
395
+
396
+ class TransformersEngine(BaseEngine):
397
+ @property
398
+ def max_position_embeddings(self) -> int:
399
+ return self._model.config.max_position_embeddings
400
+
401
+ @property
402
+ def tokenizer(self):
403
+ return self._tokenizer
404
+
405
+ def load_model(self):
406
+ from transformers import AutoTokenizer, AutoModelForCausalLM
407
+ import sys
408
+ # caution: path[0] is reserved for script path (or '' in REPL)
409
+ # sys.path.append(CODE_PATH)
410
+ self.model_path = model_path = MODEL_PATH
411
+ self.torch_dtype = torch.bfloat16 if DTYPE == 'bfloat16' else torch.float16
412
+ self.device_map = DEVICE
413
+ print(f'Loading model from {model_path} on {self.device_map} with {self.torch_dtype}')
414
+
415
+ self._tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
416
+ assert self._tokenizer.chat_template is not None and self._tokenizer.chat_template != "", f"{self._tokenizer.chat_template=} not found!"
417
+ self._model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=self.torch_dtype, device_map=self.device_map, trust_remote_code=True).eval()
418
+ self._model.sample_old = self._model.sample
419
+ self._model._sample = types.MethodType(NewGenerationMixin.sample_stream, self._model)
420
+ print(self._model)
421
+ print(f"{self.max_position_embeddings=}")
422
+
423
+ def generate_yield_string(self, prompt, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs):
424
+
425
+ # ! MUST PUT INSIDE torch.no_grad() otherwise it will overflow OOM
426
+ with torch.no_grad():
427
+ inputs = self.tokenizer(prompt, return_tensors='pt')
428
+ num_tokens = inputs.input_ids.size(1)
429
+
430
+ inputs = {k: v.to(self.device_map) for k, v in inputs.items() if v is not None}
431
+ generator = self._model.generate(
432
+ **inputs,
433
+ do_sample=True,
434
+ temperature=temperature,
435
+ max_new_tokens=max_tokens,
436
+ pad_token_id=self.processor.tokenizer.pad_token_id,
437
+ )
438
+
439
+ out_tokens = []
440
+ response = None
441
+ for token in generator:
442
+ out_tokens.append(token.item())
443
+ response = self.processor.tokenizer.decode(out_tokens)
444
+ num_tokens += 1
445
+ # print(f"{num_tokens=}", end='\r')
446
+ # sys.stdout.flush()
447
+ yield response, num_tokens
448
+
449
+ if response is not None:
450
+ full_text = prompt + response
451
+ num_tokens = len(self.tokenizer.encode(full_text))
452
+ yield response, num_tokens
multipurpose_chatbot/engines/vllm_engine.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import argparse
4
+ import gradio as gr
5
+ from typing import Any, Iterator
6
+ from typing import Iterator, List, Optional, Tuple
7
+ import filelock
8
+ import glob
9
+ import json
10
+ import time
11
+ from gradio.routes import Request
12
+ from gradio.utils import SyncToAsyncIterator, async_iteration
13
+ from gradio.helpers import special_args
14
+ import anyio
15
+ from typing import AsyncGenerator, Callable, Literal, Union, cast
16
+
17
+ from gradio_client.documentation import document, set_documentation_group
18
+
19
+ from typing import List, Optional, Union, Dict, Tuple
20
+ from tqdm.auto import tqdm
21
+ from huggingface_hub import snapshot_download
22
+
23
+ from gradio.components import Button
24
+ from gradio.events import Dependency, EventListenerMethod
25
+
26
+ from .base_engine import BaseEngine
27
+ # @@ environments ================
28
+
29
+ from ..configs import (
30
+ DTYPE,
31
+ TENSOR_PARALLEL,
32
+ MODEL_PATH,
33
+ QUANTIZATION,
34
+ MAX_TOKENS,
35
+ TEMPERATURE,
36
+ FREQUENCE_PENALTY,
37
+ PRESENCE_PENALTY,
38
+ GPU_MEMORY_UTILIZATION,
39
+ STREAM_CHECK_MULTIPLE,
40
+ STREAM_YIELD_MULTIPLE,
41
+
42
+ )
43
+
44
+
45
+ llm = None
46
+ demo = None
47
+
48
+
49
+
50
+ def vllm_abort(self):
51
+ sh = self.llm_engine.scheduler
52
+ for g in (sh.waiting + sh.running + sh.swapped):
53
+ sh.abort_seq_group(g.request_id)
54
+ from vllm.sequence import SequenceStatus
55
+ scheduler = self.llm_engine.scheduler
56
+ for state_queue in [scheduler.waiting, scheduler.running, scheduler.swapped]:
57
+ for seq_group in state_queue:
58
+ # if seq_group.request_id == request_id:
59
+ # Remove the sequence group from the state queue.
60
+ state_queue.remove(seq_group)
61
+ for seq in seq_group.seqs:
62
+ if seq.is_finished():
63
+ continue
64
+ scheduler.free_seq(seq, SequenceStatus.FINISHED_ABORTED)
65
+
66
+
67
+ def _vllm_run_engine(self: Any, use_tqdm: bool = False) -> Dict[str, Any]:
68
+ from vllm.outputs import RequestOutput
69
+ # Initialize tqdm.
70
+ if use_tqdm:
71
+ num_requests = self.llm_engine.get_num_unfinished_requests()
72
+ pbar = tqdm(total=num_requests, desc="Processed prompts")
73
+ # Run the engine.
74
+ outputs: Dict[str, RequestOutput] = {}
75
+ while self.llm_engine.has_unfinished_requests():
76
+ step_outputs = self.llm_engine.step()
77
+ for output in step_outputs:
78
+ outputs[output.request_id] = output
79
+ if len(outputs) > 0:
80
+ yield outputs
81
+
82
+
83
+ def vllm_generate_stream(
84
+ self: Any,
85
+ prompts: Optional[Union[str, List[str]]] = None,
86
+ sampling_params: Optional[Any] = None,
87
+ prompt_token_ids: Optional[List[List[int]]] = None,
88
+ use_tqdm: bool = False,
89
+ ) -> Dict[str, Any]:
90
+ """Generates the completions for the input prompts.
91
+
92
+ NOTE: This class automatically batches the given prompts, considering
93
+ the memory constraint. For the best performance, put all of your prompts
94
+ into a single list and pass it to this method.
95
+
96
+ Args:
97
+ prompts: A list of prompts to generate completions for.
98
+ sampling_params: The sampling parameters for text generation. If
99
+ None, we use the default sampling parameters.
100
+ prompt_token_ids: A list of token IDs for the prompts. If None, we
101
+ use the tokenizer to convert the prompts to token IDs.
102
+ use_tqdm: Whether to use tqdm to display the progress bar.
103
+
104
+ Returns:
105
+ A list of `RequestOutput` objects containing the generated
106
+ completions in the same order as the input prompts.
107
+ """
108
+ from vllm import LLM, SamplingParams
109
+ if prompts is None and prompt_token_ids is None:
110
+ raise ValueError("Either prompts or prompt_token_ids must be "
111
+ "provided.")
112
+ if isinstance(prompts, str):
113
+ # Convert a single prompt to a list.
114
+ prompts = [prompts]
115
+ if prompts is not None and prompt_token_ids is not None:
116
+ if len(prompts) != len(prompt_token_ids):
117
+ raise ValueError("The lengths of prompts and prompt_token_ids "
118
+ "must be the same.")
119
+ if sampling_params is None:
120
+ # Use default sampling params.
121
+ sampling_params = SamplingParams()
122
+ # Add requests to the engine.
123
+ if prompts is not None:
124
+ num_requests = len(prompts)
125
+ else:
126
+ num_requests = len(prompt_token_ids)
127
+ for i in range(num_requests):
128
+ prompt = prompts[i] if prompts is not None else None
129
+ if prompt_token_ids is None:
130
+ token_ids = None
131
+ else:
132
+ token_ids = prompt_token_ids[i]
133
+ self._add_request(prompt, sampling_params, token_ids)
134
+ # return self._run_engine(use_tqdm)
135
+ yield from _vllm_run_engine(self, use_tqdm)
136
+
137
+
138
+
139
+ class VllmEngine(BaseEngine):
140
+ def __init__(self, **kwargs) -> None:
141
+ super().__init__(**kwargs)
142
+
143
+ @property
144
+ def tokenizer(self):
145
+ return self._model.get_tokenizer()
146
+
147
+ def load_model(self, ):
148
+ import torch
149
+ try:
150
+ compute_capability = torch.cuda.get_device_capability()
151
+ print(f'Torch CUDA compute_capability: {compute_capability}')
152
+ except Exception as e:
153
+ print(f'Failed to print compute_capability version: {e}')
154
+
155
+ import vllm
156
+ from vllm import LLM
157
+
158
+ print(f'VLLM: {vllm.__version__=}')
159
+
160
+ if QUANTIZATION == 'awq':
161
+ print(F'Load model in int4 quantization')
162
+ llm = LLM(
163
+ model=MODEL_PATH,
164
+ dtype="float16",
165
+ tensor_parallel_size=TENSOR_PARALLEL,
166
+ gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
167
+ quantization="awq",
168
+ max_model_len=MAX_TOKENS
169
+ )
170
+ else:
171
+ llm = LLM(
172
+ model=MODEL_PATH,
173
+ dtype=DTYPE,
174
+ tensor_parallel_size=TENSOR_PARALLEL,
175
+ gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
176
+ max_model_len=MAX_TOKENS
177
+ )
178
+
179
+ try:
180
+ print(llm.llm_engine.workers[0].model)
181
+ except Exception as e:
182
+ print(f'Cannot print model worker: {e}')
183
+
184
+ try:
185
+ llm.llm_engine.scheduler_config.max_model_len = MAX_TOKENS
186
+ llm.llm_engine.scheduler_config.max_num_batched_tokens = MAX_TOKENS
187
+ except Exception as e:
188
+ print(f'Cannot set parameters: {e}')
189
+
190
+ self._model = llm
191
+
192
+ def generate_yield_string(self, prompt, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs):
193
+ from vllm import SamplingParams
194
+ # ! must abort previous ones
195
+ vllm_abort(llm)
196
+ sampling_params = SamplingParams(
197
+ temperature=temperature,
198
+ max_tokens=max_tokens,
199
+ # frequency_penalty=frequency_penalty,
200
+ # presence_penalty=presence_penalty,
201
+ stop=stop_strings,
202
+ )
203
+ cur_out = None
204
+ num_tokens = len(self.tokenizer.encode(prompt))
205
+ for j, gen in enumerate(vllm_generate_stream(llm, prompt, sampling_params)):
206
+ if cur_out is not None and (STREAM_YIELD_MULTIPLE < 1 or j % STREAM_YIELD_MULTIPLE == 0) and j > 0:
207
+ yield cur_out, num_tokens
208
+ assert len(gen) == 1, f'{gen}'
209
+ item = next(iter(gen.values()))
210
+ cur_out = item.outputs[0].text
211
+
212
+ if cur_out is not None:
213
+ full_text = prompt + cur_out
214
+ num_tokens = len(self.tokenizer.encode(full_text))
215
+ yield cur_out, num_tokens
216
+
217
+ def batch_generate(self, prompts, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs):
218
+ """
219
+ Only vllm should support this, the other engines is only batch=1 only
220
+ """
221
+ from vllm import SamplingParams
222
+ # ! must abort previous ones
223
+ vllm_abort(llm)
224
+ sampling_params = SamplingParams(
225
+ temperature=temperature,
226
+ max_tokens=max_tokens,
227
+ # frequency_penalty=frequency_penalty,
228
+ # presence_penalty=presence_penalty,
229
+ stop=stop_strings,
230
+ )
231
+ generated = llm.generate(prompts, sampling_params, use_tqdm=False)
232
+ responses = [g.outputs[0].text for g in generated]
233
+ return responses
multipurpose_chatbot/globals.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ global MODEL_ENGINE
4
+
5
+ from multipurpose_chatbot.engines import load_multipurpose_chatbot_engine
6
+ from multipurpose_chatbot.demos import get_demo_class
7
+
8
+ from .configs import (
9
+ BACKEND,
10
+ RAG_EMBED_MODEL_NAME,
11
+ )
12
+
13
+ MODEL_ENGINE = load_multipurpose_chatbot_engine(BACKEND)
14
+
15
+
16
+ RAG_CURRENT_FILE, RAG_EMBED, RAG_CURRENT_VECTORSTORE = None, None, None
17
+
18
+
19
+ def load_embeddings():
20
+ global RAG_EMBED
21
+ if RAG_EMBED is None:
22
+ from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceBgeEmbeddings
23
+ print(f'LOading embeddings: {RAG_EMBED_MODEL_NAME}')
24
+ RAG_EMBED = HuggingFaceEmbeddings(model_name=RAG_EMBED_MODEL_NAME, model_kwargs={'trust_remote_code':True, "device": "cpu"})
25
+ else:
26
+ print(f'RAG_EMBED ALREADY EXIST: {RAG_EMBED_MODEL_NAME}: {RAG_EMBED=}')
27
+ return RAG_EMBED
28
+
29
+
30
+ def get_rag_embeddings():
31
+ return load_embeddings()
32
+
33
+
pyproject.toml ADDED
File without changes
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ gradio
3
+ tiktoken
4
+ openai
5
+ transformers
6
+ langchain
7
+ langchain-community
8
+ langchain-core
9
+ chromadb
10
+ pypdf
11
+ docx2txt
transformers_requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ transformers
vllm_requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ transformers
2
+ vllm