muxingyin commited on
Commit
f6086aa
1 Parent(s): 8e4b175

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ examples/1.jpeg filter=lfs diff=lfs merge=lfs -text
.github/workflows/update_space.yml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Run Python script
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+
8
+ jobs:
9
+ build:
10
+ runs-on: ubuntu-latest
11
+
12
+ steps:
13
+ - name: Checkout
14
+ uses: actions/checkout@v2
15
+
16
+ - name: Set up Python
17
+ uses: actions/setup-python@v2
18
+ with:
19
+ python-version: '3.9'
20
+
21
+ - name: Install Gradio
22
+ run: python -m pip install gradio
23
+
24
+ - name: Log in to Hugging Face
25
+ run: python -c 'import huggingface_hub; huggingface_hub.login(token="${{ secrets.hf_token }}")'
26
+
27
+ - name: Deploy to Spaces
28
+ run: gradio deploy
.gitignore ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ checkpoints/
2
+ runs/
3
+ model/__pycache__/
4
+
5
+ # Byte-compiled / optimized / DLL files
6
+ __pycache__/
7
+ *.py[cod]
8
+ *$py.class
9
+
10
+ # C extensions
11
+ *.so
12
+
13
+ # Distribution / packaging
14
+ .Python
15
+ build/
16
+ develop-eggs/
17
+ dist/
18
+ downloads/
19
+ eggs/
20
+ .eggs/
21
+ lib/
22
+ lib64/
23
+ parts/
24
+ sdist/
25
+ var/
26
+ wheels/
27
+ share/python-wheels/
28
+ *.egg-info/
29
+ .installed.cfg
30
+ *.egg
31
+ MANIFEST
32
+
33
+ # PyInstaller
34
+ # Usually these files are written by a python script from a template
35
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
36
+ *.manifest
37
+ *.spec
38
+
39
+ # Installer logs
40
+ pip-log.txt
41
+ pip-delete-this-directory.txt
42
+
43
+ # Unit test / coverage reports
44
+ htmlcov/
45
+ .tox/
46
+ .nox/
47
+ .coverage
48
+ .coverage.*
49
+ .cache
50
+ nosetests.xml
51
+ coverage.xml
52
+ *.cover
53
+ *.py,cover
54
+ .hypothesis/
55
+ .pytest_cache/
56
+ cover/
57
+
58
+ # Translations
59
+ *.mo
60
+ *.pot
61
+
62
+ # Django stuff:
63
+ *.log
64
+ local_settings.py
65
+ db.sqlite3
66
+ db.sqlite3-journal
67
+
68
+ # Flask stuff:
69
+ instance/
70
+ .webassets-cache
71
+
72
+ # Scrapy stuff:
73
+ .scrapy
74
+
75
+ # Sphinx documentation
76
+ docs/_build/
77
+
78
+ # PyBuilder
79
+ .pybuilder/
80
+ target/
81
+
82
+ # Jupyter Notebook
83
+ .ipynb_checkpoints
84
+
85
+ # IPython
86
+ profile_default/
87
+ ipython_config.py
88
+
89
+ # pyenv
90
+ # For a library or package, you might want to ignore these files since the code is
91
+ # intended to run in multiple environments; otherwise, check them in:
92
+ # .python-version
93
+
94
+ # pipenv
95
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
97
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
98
+ # install all needed dependencies.
99
+ #Pipfile.lock
100
+
101
+ # poetry
102
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
103
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
104
+ # commonly ignored for libraries.
105
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
106
+ #poetry.lock
107
+
108
+ # pdm
109
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
110
+ #pdm.lock
111
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
112
+ # in version control.
113
+ # https://pdm.fming.dev/#use-with-ide
114
+ .pdm.toml
115
+
116
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
117
+ __pypackages__/
118
+
119
+ # Celery stuff
120
+ celerybeat-schedule
121
+ celerybeat.pid
122
+
123
+ # SageMath parsed files
124
+ *.sage.py
125
+
126
+ # Environments
127
+ .env
128
+ .venv
129
+ env/
130
+ venv/
131
+ ENV/
132
+ env.bak/
133
+ venv.bak/
134
+
135
+ # Spyder project settings
136
+ .spyderproject
137
+ .spyproject
138
+
139
+ # Rope project settings
140
+ .ropeproject
141
+
142
+ # mkdocs documentation
143
+ /site
144
+
145
+ # mypy
146
+ .mypy_cache/
147
+ .dmypy.json
148
+ dmypy.json
149
+
150
+ # Pyre type checker
151
+ .pyre/
152
+
153
+ # pytype static type analyzer
154
+ .pytype/
155
+
156
+ # Cython debug symbols
157
+ cython_debug/
158
+
159
+ # PyCharm
160
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
161
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
162
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
163
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
164
+ #.idea/
BigModel/.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
BigModel/README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: BigModel
3
+ emoji: 🌍
4
+ colorFrom: yellow
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 3.33.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: openrail
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
LICENSE.txt ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright Zhengxiao Du
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
MODEL_LICENSE.txt ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ The VisualGLM-6B License
2
+
3
+ 1. Definitions
4
+
5
+ “Licensor” means the VisualGLM-6B Model Team that distributes its Software.
6
+
7
+ “Software” means the VisualGLM-6B model parameters made available under this license.
8
+
9
+ 2. License Grant
10
+
11
+ Subject to the terms and conditions of this License, the Licensor hereby grants to you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty-free copyright license to use the Software solely for your non-commercial research purposes.
12
+
13
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
14
+
15
+ 3. Restriction
16
+
17
+ You will not use, copy, modify, merge, publish, distribute, reproduce, or create derivative works of the Software, in whole or in part, for any commercial, military, or illegal purposes.
18
+
19
+ You will not use the Software for any act that may undermine China's national security and national unity, harm the public interest of society, or infringe upon the rights and interests of human beings.
20
+
21
+ 4. Disclaimer
22
+
23
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
24
+
25
+ 5. Limitation of Liability
26
+
27
+ EXCEPT TO THE EXTENT PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER BASED IN TORT, NEGLIGENCE, CONTRACT, LIABILITY, OR OTHERWISE WILL ANY LICENSOR BE LIABLE TO YOU FOR ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES, OR ANY OTHER COMMERCIAL LOSSES, EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
28
+
29
+ 6. Dispute Resolution
30
+
31
+ This license shall be governed and construed in accordance with the laws of People’s Republic of China. Any dispute arising from or in connection with this License shall be submitted to Haidian District People's Court in Beijing.
32
+
33
+ Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us.
README.md CHANGED
@@ -1,12 +1,329 @@
1
  ---
2
- title: VisualGLM 6B
3
- emoji: 🌖
4
- colorFrom: green
5
- colorTo: pink
6
  sdk: gradio
7
- sdk_version: 3.33.1
8
- app_file: app.py
9
- pinned: false
10
  ---
 
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: VisualGLM-6B
3
+ app_file: web_demo_hf.py
 
 
4
  sdk: gradio
5
+ sdk_version: 3.33.0
 
 
6
  ---
7
+ # VisualGLM-6B
8
 
9
+ <p align="center">
10
+ 🤗 <a href="https://huggingface.co/THUDM/visualglm-6b" target="_blank">HF Repo</a> • ⚒️ <a href="https://github.com/THUDM/SwissArmyTransformer" target="_blank">SwissArmyTransformer (sat)</a> • 🐦 <a href="https://twitter.com/thukeg" target="_blank">Twitter</a>
11
+ </p>
12
+ <p align="center">
13
+ • 📃 <a href="https://arxiv.org/abs/2105.13290" target="_blank">[CogView@NeurIPS 21]</a> <a href="https://github.com/THUDM/CogView" target="_blank">[GitHub]</a> • 📃 <a href="https://arxiv.org/abs/2103.10360" target="_blank">[GLM@ACL 22]</a> <a href="https://github.com/THUDM/GLM" target="_blank">[GitHub]</a> <br>
14
+ </p>
15
+ <p align="center">
16
+ 👋 加入我们的 <a href="https://join.slack.com/t/chatglm/shared_invite/zt-1th2q5u69-7tURzFuOPanmuHy9hsZnKA" target="_blank">Slack</a> 和 <a href="resources/WECHAT.md" target="_blank">WeChat</a>
17
+ </p>
18
+ <!-- <p align="center">
19
+ 🤖<a href="https://huggingface.co/spaces/THUDM/visualglm-6b" target="_blank">VisualGLM-6B在线演示网站</a>
20
+ </p> -->
21
+
22
+ ## 介绍
23
+
24
+ VisualGLM-6B is an open-source, multi-modal dialog language model that supports **images, Chinese, and English**. The language model is based on [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B) with 6.2 billion parameters; the image part builds a bridge between the visual model and the language model through the training of [BLIP2-Qformer](https://arxiv.org/abs/2301.12597), with the total model comprising 7.8 billion parameters. **[Click here for English version.](README_en.md)**
25
+
26
+ VisualGLM-6B 是一个开源的,支持**图像、中文和英文**的多模态对话语言模型,语言模型基于 [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B),具有 62 亿参数;图像部分通过训练 [BLIP2-Qformer](https://arxiv.org/abs/2301.12597) 构建起视觉模型与语言模型的桥梁,整体模型共78亿参数。
27
+
28
+ VisualGLM-6B 依靠来自于 [CogView](https://arxiv.org/abs/2105.13290) 数据集的30M高质量中文图文对,与300M经过筛选的英文图文对进行预训练,中英文权重相同。该训练方式较好地将视觉信息对齐到ChatGLM的语义空间;之后的微调阶段,模型在长视觉问答数据上训练,以生成符合人类偏好的答案。
29
+
30
+ VisualGLM-6B 由 [SwissArmyTransformer](https://github.com/THUDM/SwissArmyTransformer)(简称`sat`) 库训练,这是一个支持Transformer灵活修改、训练的工具库,支持Lora、P-tuning等参数高效微调方法。本项目提供了符合用户习惯的huggingface接口,也提供了基于sat的接口。
31
+
32
+ 结合模型量化技术,用户可以在消费级的显卡上进行本地部署(INT4量化级别下最低只需8.7G显存)。
33
+
34
+ -----
35
+
36
+ VisualGLM-6B 开源模型旨在与开源社区一起推动大模型技术发展,恳请开发者和大家遵守开源协议,勿将该开源模型和代码及基于该开源项目产生的衍生物用于任何可能给国家和社会带来危害的用途以及用于任何未经过安全评估和备案的服务。目前,本项目官方未基于 VisualGLM-6B 开发任何应用,包括网站、安卓App、苹果 iOS应用及 Windows App 等。
37
+
38
+ 由于 VisualGLM-6B 仍处于v1版本,目前已知其具有相当多的[**局限性**](README.md#局限性),如图像描述事实性/模型幻觉问题,图像细节信息捕捉不足,以及一些来自语言模型的局限性。尽管模型在训练的各个阶段都尽力确保数据的合规性和准确性,但由于 VisualGLM-6B 模型规模较小,且模型受概率随机性因素影响,无法保证输出内容的准确性,且模型易被误导(详见局限性部分)。在VisualGLM之后的版本中,将会着力对此类问题进行优化。本项目不承担开源模型和代码导致的数据安全、舆情风险或发生任何模型被误导、滥用、传播、不当利用而产生的风险和责任。
39
+
40
+ ## 样例
41
+ VisualGLM-6B 可以进行图像的描述的相关知识的问答。
42
+ ![泰坦尼克号样例](examples/chat_example1.png)
43
+
44
+ <details>
45
+ <summary>也能结合常识或提出有趣的观点,点击展开/折叠更多样例</summary>
46
+
47
+ ![出租车熨衣服样例](examples/chat_example2.png)
48
+ ![蒙娜丽莎狗样例](examples/chat_example3.png)
49
+
50
+ </details>
51
+
52
+ ## 友情链接
53
+
54
+ * [XrayGLM](https://github.com/WangRongsheng/XrayGLM) 是基于visualGLM-6B在X光诊断数据集上微调的X光诊断问答的项目,能根据X光片回答医学相关询问。
55
+ <details>
56
+ <summary>点击查看样例</summary>
57
+
58
+ ![样例](https://github.com/WangRongsheng/XrayGLM/raw/main/assets/images/xrayglm.png)
59
+ </details>
60
+
61
+ ## 使用
62
+
63
+ ### 模型推理
64
+
65
+ 使用pip安装依赖
66
+ ```
67
+ pip install -i https://pypi.org/simple -r requirements.txt
68
+ # 国内请使用aliyun镜像,TUNA等镜像同步最近出现问题,命令如下
69
+ pip install -i https://mirrors.aliyun.com/pypi/simple/ -r requirements.txt
70
+ ```
71
+ 此时默认会安装`deepspeed`库(支持`sat`库训练),此库对于模型推理并非必要,同时部分Windows环境安装此库时会遇到问题。
72
+ 如果想绕过`deepspeed`安装,我们可以将命令改为
73
+ ```
74
+ pip install -i https://mirrors.aliyun.com/pypi/simple/ -r requirements_wo_ds.txt
75
+ pip install -i https://mirrors.aliyun.com/pypi/simple/ --no-deps "SwissArmyTransformer>=0.3.6"
76
+ ```
77
+
78
+ 如果使用Huggingface transformers库调用模型(**也需要安装上述依赖包!**),可以通过如下代码(其中图像路径为本地路径):
79
+ ```python
80
+ from transformers import AutoTokenizer, AutoModel
81
+ tokenizer = AutoTokenizer.from_pretrained("THUDM/visualglm-6b", trust_remote_code=True)
82
+ model = AutoModel.from_pretrained("THUDM/visualglm-6b", trust_remote_code=True).half().cuda()
83
+ image_path = "your image path"
84
+ response, history = model.chat(tokenizer, image_path, "描述这张图片。", history=[])
85
+ print(response)
86
+ response, history = model.chat(tokenizer, image_path, "这张图片可能是在什么场所拍摄的?", history=history)
87
+ print(response)
88
+ ```
89
+ 以上代码会由 `transformers` 自动下载模型实现和参数。完整的模型实现可以在 [Hugging Face Hub](https://huggingface.co/THUDM/visualglm-6b)。如果你从 Hugging Face Hub 上下载模型参数的速度较慢,可以从[这里](https://cloud.tsinghua.edu.cn/d/43ffb021ca5f4897b56a/)手动下载模型参数文件,并从本地加载模型。具体做法请参考[从本地加载模型](https://github.com/THUDM/ChatGLM-6B#%E4%BB%8E%E6%9C%AC%E5%9C%B0%E5%8A%A0%E8%BD%BD%E6%A8%A1%E5%9E%8B)。关于基于 transformers 库模型的量化、CPU推理、Mac MPS 后端加速等内容,请参考 [ChatGLM-6B 的低成本部署](https://github.com/THUDM/ChatGLM-6B#%E4%BD%8E%E6%88%90%E6%9C%AC%E9%83%A8%E7%BD%B2)。
90
+
91
+ 如果使用SwissArmyTransformer库调用模型,方法类似,可以使用环境变量`SAT_HOME`决定模型下载位置。在本仓库目录下:
92
+ ```python
93
+ import argparse
94
+ from transformers import AutoTokenizer
95
+ tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
96
+ from model import chat, VisualGLMModel
97
+ model, model_args = VisualGLMModel.from_pretrained('visualglm-6b', args=argparse.Namespace(fp16=True, skip_init=True))
98
+ from sat.model.mixins import CachedAutoregressiveMixin
99
+ model.add_mixin('auto-regressive', CachedAutoregressiveMixin())
100
+ image_path = "your image path or URL"
101
+ response, history, cache_image = chat(image_path, model, tokenizer, "描述这张图片。", history=[])
102
+ print(response)
103
+ response, history, cache_image = chat(None, model, tokenizer, "这张图片可能是在什么场所拍摄的?", history=history, image=cache_image)
104
+ print(response)
105
+ ```
106
+ 使用`sat`库也可以轻松进行进行参数高效微调。<!-- TODO 具体代码 -->
107
+
108
+ ## 模型微调
109
+
110
+ 多模态任务分布广、种类多,预训练往往不能面面俱到。
111
+ 这里我们提供了一个小样本微调的例子,使用20张标注图增强模型回答“背景”问题的能力。
112
+
113
+ 解压`fewshot-data.zip`以后运行如下命令:
114
+ ```
115
+ bash finetune/finetune_visualglm.sh
116
+ ```
117
+
118
+ 目前支持三种方式的微调:
119
+
120
+ * LoRA:样例中为ChatGLM模型的第0层和第14层加入了rank=10的LoRA微调,可以根据具体情景和数据量调整`--layer_range`和`--lora_rank`参数。
121
+ * QLoRA:如果资源有限,可以考虑使用`bash finetune/finetune_visualglm_qlora.sh`,QLoRA将ChatGLM的线性层进行了4-bit量化,只需要9.8GB显存即可微调。
122
+ * P-tuning:可以将`--use_lora`替换为`--use_ptuning`,不过不推荐使用,除非模型应用场景非常固定。
123
+
124
+ 训练好以后可以使用如下命令推理:
125
+
126
+ ```
127
+ python cli_demo.py --from_pretrained your_checkpoint_path --prompt_zh 这张图片的背景里有什么内容?
128
+ ```
129
+
130
+ <details>
131
+ <summary>微调前后的效果对比</summary>
132
+
133
+ 微调前:
134
+
135
+ ```
136
+ 欢迎使用 VisualGLM-6B 模型,输入图像URL或本地路径读图,继续输入内容对话,clear 重新开始,stop 终止程序
137
+ 请输入图像路径或URL(回车进入纯文本对话): fewshot-data/2p.png
138
+ 用户:这张图片的背景里有什么内容?
139
+ VisualGLM-6B:人们总是被生活琐事所困扰,无法摆脱束缚.
140
+ 用户:clear
141
+ 请输入图像路径或URL(回车进入纯文本对话): fewshot-data/meme.png
142
+ 用户:这张图片的背景里有什么内容?
143
+ VisualGLM-6B:这张图片的背景是一张狗的图片,它躺在木地板上。
144
+ 用户:clear
145
+ 请输入图像路径或URL(回车进入纯文本对话): fewshot-data/passport.png
146
+ 用户:这张图片的背景里有什么内容?
147
+ VisualGLM-6B:这张图片的背景是一台电脑键盘和两张护照。
148
+ 用户:clear
149
+ 请输入图像路径或URL(回车进入纯文本对话):
150
+ ```
151
+
152
+ 微调后:
153
+
154
+ ```
155
+ 欢迎使用 VisualGLM-6B 模型��输入图像URL或本地路径读图,继续输入内容对话,clear 重新开始,stop 终止程序
156
+ 请输入图像路径或URL(回车进入纯文本对话): fewshot-data/2p.png
157
+ 用户:这张图片的背景里有什么内容?
158
+ VisualGLM-6B:这张图片的背景是雨天,
159
+ 用户:描述一下这张图片。
160
+ VisualGLM-6B:这张图片描绘了雨天的情景,雨水淅淅沥沥地落在地面上。人们穿着雨衣或雨伞,走在雨中,似乎在寻找着什么.
161
+ 用户:clear
162
+ 请输入图像路径或URL(回车进入纯文本对话): fewshot-data/meme.png
163
+ 用户:这张图片的背景里有什么内容?
164
+ VisualGLM-6B:这张图片的背景是木质地板。
165
+ 用户:什么颜色的地板
166
+ VisualGLM-6B:这张图片的背景是蓝色的。
167
+ 用户:描述一下这张图片。
168
+ VisualGLM-6B:这张图片描绘了一只棕色的小狗躺在蓝色的木质地板上。
169
+ 用户:clear
170
+ 请输入图像路径或URL(回车进入纯文本对话): fewshot-data/passport.png
171
+ 用户:这张图片的背景里有什么内容?
172
+ VisualGLM-6B:这张图片的背景是电脑键盘。
173
+ 用户:描述一下这张图片。
174
+ VisualGLM-6B:这张图片的特点是电脑键盘和护照。
175
+ 用户:有几张护照
176
+ VisualGLM-6B:两张护照。
177
+ 用户:clear
178
+ 请输入图像路径或URL(回车进入纯文本对话):
179
+ ```
180
+
181
+ </details>
182
+
183
+ 如果希望把LoRA部分的参数合并到原始的权重,可以调用`merge_lora()`,例如:
184
+
185
+ ```python
186
+ from finetune_visualglm import FineTuneVisualGLMModel
187
+ import argparse
188
+
189
+ model, args = FineTuneVisualGLMModel.from_pretrained('checkpoints/finetune-visualglm-6b-05-19-07-36',
190
+ args=argparse.Namespace(
191
+ fp16=True,
192
+ skip_init=True,
193
+ use_gpu_initialization=True,
194
+ device='cuda',
195
+ ))
196
+ model.get_mixin('lora').merge_lora()
197
+ args.layer_range = []
198
+ args.save = 'merge_lora'
199
+ args.mode = 'inference'
200
+ from sat.training.model_io import save_checkpoint
201
+ save_checkpoint(1, model, None, None, args)
202
+ ```
203
+
204
+ 微调需要安装`deepspeed`库,目前本流程仅支持linux系统,更多的样例说明和Windows系统的流程说明将在近期完成。
205
+
206
+ ## 部署工具
207
+
208
+ ### 命令行 Demo
209
+
210
+ ```shell
211
+ python cli_demo.py
212
+ ```
213
+ 程序会自动下载sat模型,并在命令行中进行交互式的对话,输入指示并回车即可生成回复,输入 clear 可以清空对话历史,输入 stop 终止程序。
214
+
215
+ ![cli_demo](examples/thu.png)
216
+ 程序提供如下超参数控制生成过程与量化精度:
217
+ ```
218
+ usage: cli_demo.py [-h] [--max_length MAX_LENGTH] [--top_p TOP_P] [--top_k TOP_K] [--temperature TEMPERATURE] [--english] [--quant {8,4}]
219
+
220
+ optional arguments:
221
+ -h, --help show this help message and exit
222
+ --max_length MAX_LENGTH
223
+ max length of the total sequence
224
+ --top_p TOP_P top p for nucleus sampling
225
+ --top_k TOP_K top k for top k sampling
226
+ --temperature TEMPERATURE
227
+ temperature for sampling
228
+ --english only output English
229
+ --quant {8,4} quantization bits
230
+ ```
231
+ 需要注意的是,在训练时英文问答对的提示词为`Q: A:`,而中文为`问:答:`,在网页demo中采取了中文的提示,因此英文回复会差一些且夹杂中文;如果需要英文回复,请使用`cli_demo.py`中的`--english`选项。
232
+
233
+ 我们也提供了继承自`ChatGLM-6B`的打字机效果命令行工具,此工具使用Huggingface模型:
234
+ ```shell
235
+ python cli_demo_hf.py
236
+ ```
237
+
238
+ ### 网页版 Demo
239
+ ![web_demo](examples/web_demo.png)
240
+
241
+ 我们提供了一个基于 [Gradio](https://gradio.app) 的网页版 Demo,首先安装 Gradio:`pip install gradio`。
242
+ 然后下载并进入本仓库运行`web_demo.py`:
243
+
244
+ ```
245
+ git clone https://github.com/THUDM/VisualGLM-6B
246
+ cd VisualGLM-6B
247
+ python web_demo.py
248
+ ```
249
+ 程序会自动下载 sat 模型,并运行一个 Web Server,并输出地址。在浏览器中打开输出的地址即可使用。
250
+
251
+
252
+ 我们也提供了继承自`ChatGLM-6B`的打字机效果网页版工具,此工具使用 Huggingface 模型,启动后将运行在`:8080`端口上:
253
+ ```shell
254
+ python web_demo_hf.py
255
+ ```
256
+
257
+ 两种网页版 demo 均接受命令行参数`--share`以生成 gradio 公开链接,接受`--quant 4`和`--quant 8`以分别使用4比特量化/8比特量化减少显存占用。
258
+
259
+ ### API部署
260
+ 首先需要安装额外的依赖 `pip install fastapi uvicorn`,然后运行仓库中的 [api.py](api.py):
261
+ ```shell
262
+ python api.py
263
+ ```
264
+ 程序会自动下载 sat 模型,默认部署在本地的 8080 端口,通过 POST 方法进行调用。下面是用`curl`请求的例子,一般而言可以也可以使用代码方法进行POST。
265
+ ```shell
266
+ echo "{\"image\":\"$(base64 path/to/example.jpg)\",\"text\":\"描述这张图片\",\"history\":[]}" > temp.json
267
+ curl -X POST -H "Content-Type: application/json" -d @temp.json http://127.0.0.1:8080
268
+ ```
269
+ 得到的返回值为
270
+ ```
271
+ {
272
+ "response":"这张图片展现了一只可爱的卡通羊驼,它站在��个透明的背景上。这只羊驼长着一张毛茸茸的耳朵和一双大大的眼睛,它的身体是白色的,带有棕色斑点。",
273
+ "history":[('描述这张图片', '这张图片展现了一只可爱的卡通羊驼,它站在一个透明的背景上。这只羊驼长着一张毛茸茸的耳朵和一双大大的眼睛,它的身体是白色的,带有棕色斑点。')],
274
+ "status":200,
275
+ "time":"2023-05-16 20:20:10"
276
+ }
277
+ ```
278
+
279
+ 我们也提供了使用Huggingface模型的 [api_hf.py](api_hf.py),用法和sat模型的api一致:
280
+ ```shell
281
+ python api_hf.py
282
+ ```
283
+
284
+
285
+ ## 模型量化
286
+ 在Huggingface实现中,模型默认以 FP16 精度加载,运行上述代码需要大概 15GB 显存。如果你的 GPU 显存有限,可以尝试以量化方式加载模型。
287
+ 使用方法如下:
288
+ ```python
289
+ # 按需修改,目前只支持 4/8 bit 量化。下面将只量化ChatGLM,ViT 量化时误差较大
290
+ model = AutoModel.from_pretrained("THUDM/visualglm-6b", trust_remote_code=True).quantize(8).half().cuda()
291
+ ```
292
+
293
+ 在sat实现中,需先传参将加载位置改为`cpu`,再进行量化。方法如下,详见`cli_demo.py`:
294
+ ```python
295
+ from sat.quantization.kernels import quantize
296
+ model = quantize(model.transformer, args.quant).cuda()
297
+ # 指定 model.transformer 只量化 ChatGLM,ViT 量化时误差较大
298
+ ```
299
+
300
+ ## 局限性
301
+ 本项目正处于V1版本视觉和语言模型的参数、计算量都较小,我们总结了如下主要存在的改进方向:
302
+ - 图像描述事实性/模型幻觉问题。在生成图像长描述的时候,距离图像较远时,语言模型的将占主导,有一定可能根据上下文生成并不存在于图像的内容。
303
+ - 属性错配问题。在多物体的场景中,部分物体的某些属性,经常被错误安插到其他物体上。
304
+ - 分辨率问题。本项目使用了224*224的分辨率,也是视觉模型中最为常用的尺寸;然而为了进行更细粒度的理解,更大的分辨率和计算量是必要的。
305
+ - 由于数据等方面原因,模型暂时不具有中文ocr的能力(英文ocr能力有一些),我们会在后续版本中增加这个能力。
306
+ ## 协议
307
+
308
+ 本仓库的代码依照 [Apache-2.0](LICENSE.txt) 协议开源,VisualGLM-6B 模型的权重的使用则需要遵循 [Model License](MODEL_LICENSE.txt)。
309
+
310
+ ## 引用与致谢
311
+ 如果你觉得我们的工作有帮助的话,请考虑引用下列论文
312
+ ```
313
+ @inproceedings{du2022glm,
314
+ title={GLM: General Language Model Pretraining with Autoregressive Blank Infilling},
315
+ author={Du, Zhengxiao and Qian, Yujie and Liu, Xiao and Ding, Ming and Qiu, Jiezhong and Yang, Zhilin and Tang, Jie},
316
+ booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)},
317
+ pages={320--335},
318
+ year={2022}
319
+ }
320
+ @article{ding2021cogview,
321
+ title={Cogview: Mastering text-to-image generation via transformers},
322
+ author={Ding, Ming and Yang, Zhuoyi and Hong, Wenyi and Zheng, Wendi and Zhou, Chang and Yin, Da and Lin, Junyang and Zou, Xu and Shao, Zhou and Yang, Hongxia and others},
323
+ journal={Advances in Neural Information Processing Systems},
324
+ volume={34},
325
+ pages={19822--19835},
326
+ year={2021}
327
+ }
328
+ ```
329
+ 在VisualGLM-6B的指令微调阶段的数据集中,包含了来自[MiniGPT-4](https://github.com/Vision-CAIR/MiniGPT-4)和[LLAVA](https://github.com/haotian-liu/LLaVA)项目的一部分英文图文数据,以及许多经典的跨模态工作数据集,衷心感谢他们的贡献。
README_en.md ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # VisualGLM-6B
2
+ <p align="center">
3
+ 🤗 <a href="https://huggingface.co/THUDM/visualglm-6b" target="_blank">HF Repo</a> • ⚒️ <a href="https://github.com/THUDM/SwissArmyTransformer" target="_blank">SwissArmyTransformer (sat)</a> • 🐦 <a href="https://twitter.com/thukeg" target="_blank">Twitter</a>
4
+ </p>
5
+ <p align="center">
6
+ • 📃 <a href="https://arxiv.org/abs/2105.13290" target="_blank">[CogView@NeurIPS 21]</a> <a href="https://github.com/THUDM/CogView" target="_blank">[GitHub]</a> • 📃 <a href="https://arxiv.org/abs/2103.10360" target="_blank">[GLM@ACL 22]</a> <a href="https://github.com/THUDM/GLM" target="_blank">[GitHub]</a> <br>
7
+ </p>
8
+ <p align="center">
9
+ 👋 Join us on <a href="https://join.slack.com/t/chatglm/shared_invite/zt-1th2q5u69-7tURzFuOPanmuHy9hsZnKA" target="_blank">Slack</a> and <a href="resources/WECHAT.md" target="_blank">WeChat</a>
10
+ </p>
11
+ <!-- <p align="center">
12
+ 🤖<a href="https://huggingface.co/spaces/THUDM/visualglm-6b" target="_blank">VisualGLM-6B Online Demo Website</a>
13
+ </p> -->
14
+
15
+ ## Introduction
16
+ VisualGLM-6B is an open-source, multi-modal dialog language model that supports **images, Chinese, and English**. The language model is based on [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B) with 6.2 billion parameters; the image part builds a bridge between the visual model and the language model through the training of [BLIP2-Qformer](https://arxiv.org/abs/2301.12597), with the total model comprising 7.8 billion parameters.
17
+
18
+ VisualGLM-6B relies on 30M high-quality Chinese image-text pairs from the [CogView](https://arxiv.org/abs/2105.13290) dataset and 300M filtered English image-text pairs for pre-training, with equal weight for Chinese and English. This training method aligns visual information well to the semantic space of ChatGLM. In the subsequent fine-tuning phase, the model is trained on long visual question answering data to generate answers that align with human preferences.
19
+
20
+ VisualGLM-6B is trained using the [SwissArmyTransformer](https://github.com/THUDM/SwissArmyTransformer) (abbreviated as sat) library, a utility library for flexible modification and training of Transformer, supporting efficient fine-tuning methods like Lora and P-tuning. This project provides a user-friendly huggingface interface, as well as an interface based on sat.
21
+
22
+ However, as VisualGLM-6B is still at the v1 stage, it is known to have quite a few [**limitations**](#Limitations), such as factual inaccuracy/model hallucination in image description, lack of capturing image detail information, and some limitations from the language model. Please be aware of these issues and evaluate the potential risks before using. In future versions of VisualGLM, we will strive to optimize these issues.
23
+
24
+ With model quantization technology, users can deploy locally on consumer-grade graphics cards (requiring as little as 8.7G memory under INT4 quantization level).
25
+
26
+ ## Examples
27
+ VisualGLM-6B can answer questions related to image description.
28
+ ![Titanic example](examples/chat_example1.png)
29
+
30
+ <details>
31
+ <summary>It can also combine common sense or propose interesting views. Click to expand/collapse more examples</summary>
32
+
33
+ ![Ironing shirt taxi example](examples/chat_example2.png)
34
+ ![Mona Lisa dog example](examples/chat_example3.png)
35
+
36
+ </details>
37
+
38
+
39
+ ## Usage
40
+
41
+ ### Model Inference
42
+
43
+ Install dependencies with pip
44
+ ```
45
+ pip install -i https://pypi.org/simple -r requirements.txt
46
+ pip install -i https://mirrors.aliyun.com/pypi/simple/ -r requirements.txt
47
+ ```
48
+ This will default to installing the deepspeed library (which supports the sat library training). This library is not necessary for model inference and can cause problems when installed in some Windows environments.
49
+ If you want to bypass deepspeed installation, you can change the command to:
50
+ ```
51
+ pip install -i https://mirrors.aliyun.com/pypi/simple/ -r requirements_wo_ds.txt
52
+ pip install -i https://mirrors.aliyun.com/pypi/simple/ --no-deps "SwissArmyTransformer>=0.3.6"
53
+ ```
54
+
55
+ If you are calling the model using the Huggingface transformers library (you also need to install the above dependency packages!), you can use the following code (where the image path is the local path):
56
+ ```python
57
+ from transformers import AutoTokenizer, AutoModel
58
+ tokenizer = AutoTokenizer.from_pretrained("THUDM/visualglm-6b", trust_remote_code=True)
59
+ model = AutoModel.from_pretrained("THUDM/visualglm-6b", trust_remote_code=True).half().cuda()
60
+ image_path = "your image path"
61
+ response, history = model.chat(tokenizer, image_path, "描述这张图片。", history=[])
62
+ print(response)
63
+ response, history = model.chat(tokenizer, image_path, "这张图片可能是在什么场所拍摄的?", history=history)
64
+ print(response)
65
+ ```
66
+
67
+ If you use the SwissArmyTransformer library to call the model, the method is similar, and you can use the environment variable SAT_HOME to determine the model download location. In the directory of this repository:
68
+ ```python
69
+ import argparse
70
+ from transformers import AutoTokenizer
71
+ tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
72
+ from model import chat, VisualGLMModel
73
+ model, model_args = VisualGLMModel.from_pretrained('visualglm-6b', args=argparse.Namespace(fp16=True, skip_init=True))
74
+ from sat.model.mixins import CachedAutoregressiveMixin
75
+ model.add_mixin('auto-regressive', CachedAutoregressiveMixin())
76
+ image_path = "your image path or URL"
77
+ response, history, cache_image = chat(image_path, model, tokenizer, "Describe this picture.", history=[])
78
+ print(response)
79
+ response, history, cache_image = chat(None, model, tokenizer, "Where could this picture possibly have been taken?", history=history, image=cache_image)
80
+ print(response)
81
+ ```
82
+
83
+ Using the `sat` library can also easily carry out efficient parameter fine-tuning. <!-- TODO specific code -->
84
+
85
+ Please note that the Huggingface model implementation is located in the [Huggingface repository](https://huggingface.co/THUDM/visualglm-6b), and the `sat` model implementation is included in this repository.
86
+
87
+ ## Model Fine-tuning
88
+
89
+ Multimodal tasks are wide-ranging and diverse, and pre-training often cannot cover all bases.
90
+ Here we provide an example of small sample fine-tuning, using 20 labeled images to enhance the model's ability to answer "background" questions.
91
+
92
+ After unzipping fewshot-data.zip, run the following command:
93
+ ```
94
+ bash finetune/finetune_visualglm.sh
95
+ ```
96
+
97
+ Currently we support three types of (parameter-efficient) fine-tuning:
98
+
99
+ * LoRA: In the given example, we add rank=10 LoRA for layer 0 and layer 14 in ChatGLM. You can adjust `--layer_range` and `--lora_rank` to fit your application and data amount.
100
+ * QLoRA: If your resource is limited, consider using `bash finetune/finetune_visualglm_qlora.sh`, which do 4-bit quantization for ChatGLM Linear layers, reducing the required GPU memory to 9.8 GB.
101
+ * P-tuning: You can replace `--use_lora` to `--use_ptuning`, but not recommended, unless your application has a relatively fixed input and output template.
102
+
103
+ After training, you can use the following command for inference:
104
+
105
+ ```
106
+ python cli_demo.py --from_pretrained your_checkpoint_path --prompt_zh 这张图片的背景里有什么内容?
107
+ ```
108
+
109
+ Fine-tuning requires the installation of the deepspeed library, and currently this process only supports the Linux system. More examples and instructions for the Windows system will be completed in the near future.
110
+
111
+ If you want to merge LoRA weights into original weights, just call `merge_lora()`:
112
+
113
+ ```python
114
+ from finetune_visualglm import FineTuneVisualGLMModel
115
+ import argparse
116
+
117
+ model, args = FineTuneVisualGLMModel.from_pretrained('checkpoints/finetune-visualglm-6b-05-19-07-36',
118
+ args=argparse.Namespace(
119
+ fp16=True,
120
+ skip_init=True,
121
+ use_gpu_initialization=True,
122
+ device='cuda',
123
+ ))
124
+ model.get_mixin('lora').merge_lora()
125
+ args.layer_range = []
126
+ args.save = 'merge_lora'
127
+ args.mode = 'inference'
128
+ from sat.training.model_io import save_checkpoint
129
+ save_checkpoint(1, model, None, None, args)
130
+ ```
131
+
132
+ ## Deployment Tools
133
+
134
+ ### Command Line Demo
135
+
136
+ ```shell
137
+ python cli_demo.py
138
+ ```
139
+ The program will automatically download the sat model and interact in the command line. You can generate replies by entering instructions and pressing enter. Enter 'clear' to clear the conversation history and 'stop' to stop the program.
140
+
141
+ ![cli_demo](examples/thu.png)
142
+ The program provides the following hyperparameters to control the generation process and quantization accuracy:
143
+ ```
144
+ usage: cli_demo.py [-h] [--max_length MAX_LENGTH] [--top_p TOP_P] [--top_k TOP_K] [--temperature TEMPERATURE] [--english] [--quant {8,4}]
145
+
146
+ optional arguments:
147
+ -h, --help show this help message and exit
148
+ --max_length MAX_LENGTH
149
+ max length of the total sequence
150
+ --top_p TOP_P top p for nucleus sampling
151
+ --top_k TOP_K top k for top k sampling
152
+ --temperature TEMPERATURE
153
+ temperature for sampling
154
+ --english only output English
155
+ --quant {8,4} quantization bits
156
+ ```
157
+ Note that during training, the prompt words for English Q&A pairs are 'Q: A:', while in Chinese they are '问:答:'. The web demo uses Chinese prompts, so the English replies will be worse and interspersed with Chinese; if you need English replies, please use the --english option in cli_demo.py.
158
+
159
+ We also provide a typewriter effect command line tool inherited from ChatGLM-6B, which uses the Huggingface model:
160
+ ```shell
161
+ python cli_demo_hf.py
162
+ ```
163
+
164
+ ### Web Demo
165
+ ![web_demo](examples/web_demo.png)
166
+
167
+ We provide a web demo based on [Gradio](https://gradio.app). First, install Gradio: `pip install gradio`.
168
+ Then download and enter this repository and run `web_demo.py`:
169
+
170
+ ```
171
+ git clone https://github.com/THUDM/VisualGLM-6B
172
+ cd VisualGLM-6B
173
+ python web_demo.py
174
+ ```
175
+ The program will automatically download the sat model and run a Web Server, outputting the address. Open the output address in your browser to use it.
176
+
177
+ We also provide a web tool with a typewriter effect inherited from ChatGLM-6B, which uses the Huggingface model and will run on port :8080 after starting:
178
+ ```shell
179
+ python web_demo_hf.py
180
+ ```
181
+
182
+ Both web demos accept the command line parameter --share to generate a public link for gradio, and accept --quant 4 and --quant 8 to use 4-bit quantization/8-bit quantization to reduce GPU memory usage.
183
+
184
+ ### API Deployment
185
+ First, you need to install additional dependencies pip install fastapi uvicorn, then run the api.py in the repository:
186
+ ```shell
187
+ python api.py
188
+ ```
189
+ The program will automatically download the sat model, and by default it will be deployed on local port 8080 and called through the POST method. Below is an example of a request with curl, but in general you can also use a code method to POST.
190
+ ```shell
191
+ echo "{\"image\":\"$(base64 path/to/example.jpg)\",\"text\":\"Describe this picture\",\"history\":[]}" > temp.json
192
+ curl -X POST -H "Content-Type: application/json" -d @temp.json http://127.0.0.1:8080
193
+ ```
194
+
195
+ We also provide an api_hf.py that uses the Huggingface model, which works the same way as the sat model's api:
196
+ ```shell
197
+ python api_hf.py
198
+ ```
199
+
200
+
201
+ ## Model Quantization
202
+ In the Huggingface implementation, the model is loaded with FP16 precision by default, and running the above code requires about 15GB of GPU memory. If your GPU memory is limited, you can try loading the model in a quantized manner.
203
+ Here's how:
204
+ ```python
205
+ # Modify as needed, currently only 4/8 bit quantization is supported. The following will only quantize ChatGLM, as the error is larger when quantizing ViT
206
+ model = AutoModel.from_pretrained("THUDM/visualglm-6b", trust_remote_code=True).quantize(8).half().cuda()
207
+ ```
208
+
209
+ In the sat implementation, you need to change the loading location to 'cpu' first, and then perform quantization. Here's how, see cli_demo.py for details:
210
+ ```python
211
+ from sat.quantization.kernels import quantize
212
+ model = quantize(model.transformer, args.quant).cuda()
213
+ # Specify model.transformer to only quantize ChatGLM, as the error is larger when quantizing ViT
214
+ ```
215
+
216
+ ## Limitations
217
+ This project is currently at V1 version of the visual and language model parameters, the amount of calculation is small, we have summarized the following main improvements:
218
+
219
+ - Image description factuality/model hallucination problem. When generating long descriptions of images, as the distance from the image increases, the language model will dominate, and there is a certain possibility of generating content that does not exist in the image based on the context.
220
+ - Attribute mismatch problem. In scenes with multiple objects, some attributes of some objects are often incorrectly inserted onto other objects.
221
+ - Resolution issue. This project uses a resolution of 224*224, which is the most commonly used size in visual models; however, for more fine-grained understanding, larger resolution and computation are necessary.
222
+ - Due to data and other reasons, the model currently does not have the ability to perform Chinese OCR (some ability for English OCR), we will add this ability in future versions.
223
+ ## License
224
+
225
+ The code in this repository is open source under the Apache-2.0 license, while the use of the VisualGLM-6B model weights must comply with the Model License.
226
+
227
+ ## Citation & Acknowledgements
228
+ If you find our work helpful, please consider citing the following papers
229
+ ```
230
+ @inproceedings{du2022glm,
231
+ title={GLM: General Language Model Pretraining with Autoregressive Blank Infilling},
232
+ author={Du, Zhengxiao and Qian, Yujie and Liu, Xiao and Ding, Ming and Qiu, Jiezhong and Yang, Zhilin and Tang, Jie},
233
+ booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)},
234
+ pages={320--335},
235
+ year={2022}
236
+ }
237
+ @article{ding2021cogview,
238
+ title={Cogview: Mastering text-to-image generation via transformers},
239
+ author={Ding, Ming and Yang, Zhuoyi and Hong, Wenyi and Zheng, Wendi and Zhou, Chang and Yin, Da and Lin, Junyang and Zou, Xu and Shao, Zhou and Yang, Hongxia and others},
240
+ journal={Advances in Neural Information Processing Systems},
241
+ volume={34},
242
+ pages={19822--19835},
243
+ year={2021}
244
+ }
245
+ ```
246
+ In the instruction fine-tuning phase of the VisualGLM-6B dataset, there are some English image-text data from the [MiniGPT-4](https://github.com/Vision-CAIR/MiniGPT-4) and [LLAVA](https://github.com/haotian-liu/LLaVA) projects, as well as many classic cross-modal work datasets. We sincerely thank them for their contributions.
__pycache__/finetune_visualglm.cpython-310.pyc ADDED
Binary file (7 kB). View file
 
__pycache__/lora_mixin.cpython-310.pyc ADDED
Binary file (10.6 kB). View file
 
api.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import uvicorn
4
+ from fastapi import FastAPI, Request
5
+ from model import is_chinese, get_infer_setting, generate_input, chat
6
+ import datetime
7
+ import torch
8
+
9
+ gpu_number = 0
10
+ model, tokenizer = get_infer_setting(gpu_device=gpu_number)
11
+
12
+ app = FastAPI()
13
+ @app.post('/')
14
+ async def visual_glm(request: Request):
15
+ json_post_raw = await request.json()
16
+ print("Start to process request")
17
+
18
+ json_post = json.dumps(json_post_raw)
19
+ request_data = json.loads(json_post)
20
+ input_text, input_image_encoded, history = request_data['text'], request_data['image'], request_data['history']
21
+ input_para = {
22
+ "max_length": 2048,
23
+ "min_length": 50,
24
+ "temperature": 0.8,
25
+ "top_p": 0.4,
26
+ "top_k": 100,
27
+ "repetition_penalty": 1.2
28
+ }
29
+ input_para.update(request_data)
30
+
31
+ is_zh = is_chinese(input_text)
32
+ input_data = generate_input(input_text, input_image_encoded, history, input_para)
33
+ input_image, gen_kwargs = input_data['input_image'], input_data['gen_kwargs']
34
+ with torch.no_grad():
35
+ answer, history, _ = chat(None, model, tokenizer, input_text, history=history, image=input_image, \
36
+ max_length=gen_kwargs['max_length'], top_p=gen_kwargs['top_p'], \
37
+ top_k = gen_kwargs['top_k'], temperature=gen_kwargs['temperature'], english=not is_zh)
38
+
39
+ now = datetime.datetime.now()
40
+ time = now.strftime("%Y-%m-%d %H:%M:%S")
41
+ response = {
42
+ "result": answer,
43
+ "history": history,
44
+ "status": 200,
45
+ "time": time
46
+ }
47
+ return response
48
+
49
+
50
+ if __name__ == '__main__':
51
+ uvicorn.run(app, host='0.0.0.0', port=8080, workers=1)
api_hf.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from transformers import AutoTokenizer, AutoModel
4
+ import uvicorn
5
+ from fastapi import FastAPI, Request
6
+ import datetime
7
+ from model import process_image
8
+ import torch
9
+
10
+ tokenizer = AutoTokenizer.from_pretrained("THUDM/visualglm-6b", trust_remote_code=True)
11
+ model = AutoModel.from_pretrained("THUDM/visualglm-6b", trust_remote_code=True).half().cuda()
12
+
13
+
14
+ app = FastAPI()
15
+ @app.post('/')
16
+ async def visual_glm(request: Request):
17
+ json_post_raw = await request.json()
18
+ print("Start to process request")
19
+
20
+ json_post = json.dumps(json_post_raw)
21
+ request_data = json.loads(json_post)
22
+
23
+ history = request_data.get("history")
24
+ image_encoded = request_data.get("image")
25
+ query = request_data.get("text")
26
+ image_path = process_image(image_encoded)
27
+
28
+ with torch.no_grad():
29
+ result = model.stream_chat(tokenizer, image_path, query, history=history)
30
+ last_result = None
31
+ for value in result:
32
+ last_result = value
33
+ answer = last_result[0]
34
+
35
+ if os.path.isfile(image_path):
36
+ os.remove(image_path)
37
+ now = datetime.datetime.now()
38
+ time = now.strftime("%Y-%m-%d %H:%M:%S")
39
+ response = {
40
+ "result": answer,
41
+ "history": history,
42
+ "status": 200,
43
+ "time": time
44
+ }
45
+ return response
46
+
47
+
48
+ if __name__ == "__main__":
49
+ uvicorn.run(app, host='0.0.0.0', port=8080, workers=1)
cli_demo.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+
3
+ import os
4
+ import sys
5
+ import torch
6
+ import argparse
7
+ from transformers import AutoTokenizer
8
+ from sat.model.mixins import CachedAutoregressiveMixin
9
+ from sat.quantization.kernels import quantize
10
+
11
+ from model import VisualGLMModel, chat
12
+ from finetune_visualglm import FineTuneVisualGLMModel
13
+ from sat.model import AutoModel
14
+
15
+
16
+ def main():
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument("--max_length", type=int, default=2048, help='max length of the total sequence')
19
+ parser.add_argument("--top_p", type=float, default=0.4, help='top p for nucleus sampling')
20
+ parser.add_argument("--top_k", type=int, default=100, help='top k for top k sampling')
21
+ parser.add_argument("--temperature", type=float, default=.8, help='temperature for sampling')
22
+ parser.add_argument("--english", action='store_true', help='only output English')
23
+ parser.add_argument("--quant", choices=[8, 4], type=int, default=None, help='quantization bits')
24
+ parser.add_argument("--from_pretrained", type=str, default="visualglm-6b", help='pretrained ckpt')
25
+ parser.add_argument("--prompt_zh", type=str, default="描述这张图片。", help='Chinese prompt for the first round')
26
+ parser.add_argument("--prompt_en", type=str, default="Describe the image.", help='English prompt for the first round')
27
+ args = parser.parse_args()
28
+
29
+ # load model
30
+ model, model_args = AutoModel.from_pretrained(
31
+ args.from_pretrained,
32
+ args=argparse.Namespace(
33
+ fp16=True,
34
+ skip_init=True,
35
+ use_gpu_initialization=True if (torch.cuda.is_available() and args.quant is None) else False,
36
+ device='cuda' if (torch.cuda.is_available() and args.quant is None) else 'cpu',
37
+ ))
38
+ model = model.eval()
39
+
40
+ if args.quant:
41
+ quantize(model.transformer, args.quant)
42
+ if torch.cuda.is_available():
43
+ model = model.cuda()
44
+
45
+ model.add_mixin('auto-regressive', CachedAutoregressiveMixin())
46
+
47
+ tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
48
+ if not args.english:
49
+ print('欢迎使用 VisualGLM-6B 模型,输入图像URL或本地路径读图,继续输入内容对话,clear 重新开始,stop 终止程序')
50
+ else:
51
+ print('Welcome to VisualGLM-6B model. Enter an image URL or local file path to load an image. Continue inputting text to engage in a conversation. Type "clear" to start over, or "stop" to end the program.')
52
+ with torch.no_grad():
53
+ while True:
54
+ history = None
55
+ cache_image = None
56
+ if not args.english:
57
+ image_path = input("请输入图像路径或URL(回车进入纯文本对话): ")
58
+ else:
59
+ image_path = input("Please enter the image path or URL (press Enter for plain text conversation): ")
60
+
61
+ if image_path == 'stop':
62
+ break
63
+ if len(image_path) > 0:
64
+ query = args.prompt_en if args.english else args.prompt_zh
65
+ else:
66
+ if not args.english:
67
+ query = input("用户:")
68
+ else:
69
+ query = input("User: ")
70
+ while True:
71
+ if query == "clear":
72
+ break
73
+ if query == "stop":
74
+ sys.exit(0)
75
+ try:
76
+ response, history, cache_image = chat(
77
+ image_path,
78
+ model,
79
+ tokenizer,
80
+ query,
81
+ history=history,
82
+ image=cache_image,
83
+ max_length=args.max_length,
84
+ top_p=args.top_p,
85
+ temperature=args.temperature,
86
+ top_k=args.top_k,
87
+ english=args.english,
88
+ invalid_slices=[slice(63823, 130000)] if args.english else []
89
+ )
90
+ except Exception as e:
91
+ print(e)
92
+ break
93
+ sep = 'A:' if args.english else '答:'
94
+ print("VisualGLM-6B:"+response.split(sep)[-1].strip())
95
+ image_path = None
96
+ if not args.english:
97
+ query = input("用户:")
98
+ else:
99
+ query = input("User: ")
100
+
101
+
102
+ if __name__ == "__main__":
103
+ main()
cli_demo_hf.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import platform
3
+ import signal
4
+ from transformers import AutoTokenizer, AutoModel
5
+ import torch
6
+
7
+ tokenizer = AutoTokenizer.from_pretrained("THUDM/visualglm-6b", trust_remote_code=True)
8
+ model = AutoModel.from_pretrained("THUDM/visualglm-6b", trust_remote_code=True).half().cuda()
9
+ model = model.eval()
10
+
11
+ os_name = platform.system()
12
+ clear_command = 'cls' if os_name == 'Windows' else 'clear'
13
+ stop_stream = False
14
+
15
+
16
+ def build_prompt(history, prefix):
17
+ prompt = prefix
18
+ for query, response in history:
19
+ prompt += f"\n\n用户:{query}"
20
+ prompt += f"\n\nVisualGLM-6B:{response}"
21
+ return prompt
22
+
23
+
24
+ def signal_handler(signal, frame):
25
+ global stop_stream
26
+ stop_stream = True
27
+
28
+
29
+ def main():
30
+ global stop_stream
31
+ while True:
32
+ history = []
33
+ prefix = "欢迎使用 VisualGLM-6B 模型,输入图片路径和内容即可进行对话,clear 清空对话历史,stop 终止程序"
34
+ print(prefix)
35
+ image_path = input("\n请输入图片路径:")
36
+ if image_path == "stop":
37
+ break
38
+ prefix = prefix + "\n" + image_path
39
+ query = "描述这张图片。"
40
+ while True:
41
+ count = 0
42
+ with torch.no_grad():
43
+ for response, history in model.stream_chat(tokenizer, image_path, query, history=history):
44
+ if stop_stream:
45
+ stop_stream = False
46
+ break
47
+ else:
48
+ count += 1
49
+ if count % 8 == 0:
50
+ os.system(clear_command)
51
+ print(build_prompt(history, prefix), flush=True)
52
+ signal.signal(signal.SIGINT, signal_handler)
53
+ os.system(clear_command)
54
+ print(build_prompt(history, prefix), flush=True)
55
+ query = input("\n用户:")
56
+ if query.strip() == "clear":
57
+ break
58
+ if query.strip() == "stop":
59
+ stop_stream = True
60
+ exit(0)
61
+ # if query.strip() == "clear":
62
+ # history = []
63
+ # os.system(clear_command)
64
+ # print(prefix)
65
+ # continue
66
+
67
+
68
+ if __name__ == "__main__":
69
+ main()
examples/1.jpeg ADDED

Git LFS Details

  • SHA256: 821ef1b9aaf417f2438bf5395f43201b7774b482b8fd74bfd07893ab4ef03694
  • Pointer size: 132 Bytes
  • Size of remote file: 1.02 MB
examples/2.jpeg ADDED
examples/3.jpeg ADDED
examples/chat_example1.png ADDED
examples/chat_example2.png ADDED
examples/chat_example3.png ADDED
examples/example_inputs.jsonl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {"id":1, "text": "描述一下这个场景", "image": "examples/1.jpeg"}
2
+ {"id":2, "text": "这是什么东西", "image": "examples/2.jpeg"}
3
+ {"id":3, "text": "这张图片描述了什么", "image": "examples/3.jpeg"}
examples/thu.png ADDED
examples/web_demo.png ADDED
fewshot-data.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e96484834c7d9bda898f8da5f658ea63268ebd3aa10ac7f0da3b3dc40a86e1b7
3
+ size 6695260
finetune/finetune_visualglm.sh ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /bin/bash
2
+ NUM_WORKERS=1
3
+ NUM_GPUS_PER_WORKER=8
4
+ MP_SIZE=1
5
+
6
+ script_path=$(realpath $0)
7
+ script_dir=$(dirname $script_path)
8
+ main_dir=$(dirname $script_dir)
9
+ MODEL_TYPE="visualglm-6b"
10
+ MODEL_ARGS="--max_source_length 64 \
11
+ --max_target_length 256 \
12
+ --lora_rank 10 \
13
+ --layer_range 0 14 \
14
+ --pre_seq_len 4"
15
+
16
+ # OPTIONS_SAT="SAT_HOME=$1" #"SAT_HOME=/raid/dm/sat_models"
17
+ OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2"
18
+ HOST_FILE_PATH="hostfile"
19
+ HOST_FILE_PATH="hostfile_single"
20
+
21
+ train_data="./fewshot-data/dataset.json"
22
+ eval_data="./fewshot-data/dataset.json"
23
+
24
+
25
+ gpt_options=" \
26
+ --experiment-name finetune-$MODEL_TYPE \
27
+ --model-parallel-size ${MP_SIZE} \
28
+ --mode finetune \
29
+ --train-iters 300 \
30
+ --resume-dataloader \
31
+ $MODEL_ARGS \
32
+ --train-data ${train_data} \
33
+ --valid-data ${eval_data} \
34
+ --distributed-backend nccl \
35
+ --lr-decay-style cosine \
36
+ --warmup .02 \
37
+ --checkpoint-activations \
38
+ --save-interval 300 \
39
+ --eval-interval 10000 \
40
+ --save "./checkpoints" \
41
+ --split 1 \
42
+ --eval-iters 10 \
43
+ --eval-batch-size 8 \
44
+ --zero-stage 1 \
45
+ --lr 0.0001 \
46
+ --batch-size 4 \
47
+ --skip-init \
48
+ --fp16 \
49
+ --use_lora
50
+ "
51
+
52
+
53
+
54
+ run_cmd="${OPTIONS_NCCL} ${OPTIONS_SAT} deepspeed --master_port 16666 --hostfile ${HOST_FILE_PATH} finetune_visualglm.py ${gpt_options}"
55
+ echo ${run_cmd}
56
+ eval ${run_cmd}
57
+
58
+ set +x
finetune/finetune_visualglm_qlora.sh ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /bin/bash
2
+ NUM_WORKERS=1
3
+ NUM_GPUS_PER_WORKER=8
4
+ MP_SIZE=1
5
+
6
+ script_path=$(realpath $0)
7
+ script_dir=$(dirname $script_path)
8
+ main_dir=$(dirname $script_dir)
9
+ MODEL_TYPE="visualglm-6b"
10
+ MODEL_ARGS="--max_source_length 64 \
11
+ --max_target_length 256 \
12
+ --lora_rank 10 \
13
+ --layer_range 0 14 \
14
+ --pre_seq_len 4"
15
+
16
+ # OPTIONS_SAT="SAT_HOME=$1" #"SAT_HOME=/raid/dm/sat_models"
17
+ OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2"
18
+ HOST_FILE_PATH="hostfile"
19
+ HOST_FILE_PATH="hostfile_single"
20
+
21
+ train_data="./fewshot-data/dataset.json"
22
+ eval_data="./fewshot-data/dataset.json"
23
+
24
+
25
+ gpt_options=" \
26
+ --experiment-name finetune-$MODEL_TYPE \
27
+ --model-parallel-size ${MP_SIZE} \
28
+ --mode finetune \
29
+ --train-iters 300 \
30
+ --resume-dataloader \
31
+ $MODEL_ARGS \
32
+ --train-data ${train_data} \
33
+ --valid-data ${eval_data} \
34
+ --distributed-backend nccl \
35
+ --lr-decay-style cosine \
36
+ --warmup .02 \
37
+ --checkpoint-activations \
38
+ --save-interval 300 \
39
+ --eval-interval 10000 \
40
+ --save "./checkpoints" \
41
+ --split 1 \
42
+ --eval-iters 10 \
43
+ --eval-batch-size 8 \
44
+ --zero-stage 1 \
45
+ --lr 0.0001 \
46
+ --batch-size 1 \
47
+ --gradient-accumulation-steps 4 \
48
+ --skip-init \
49
+ --fp16 \
50
+ --use_qlora
51
+ "
52
+
53
+
54
+
55
+ run_cmd="${OPTIONS_NCCL} ${OPTIONS_SAT} deepspeed --master_port 16666 --include localhost:0 --hostfile ${HOST_FILE_PATH} finetune_visualglm.py ${gpt_options}"
56
+ echo ${run_cmd}
57
+ eval ${run_cmd}
58
+
59
+ set +x
finetune_visualglm.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import argparse
4
+
5
+ from sat import mpu, get_args, get_tokenizer
6
+ from sat.training.deepspeed_training import training_main
7
+ from model import VisualGLMModel
8
+ from sat.model.finetune import PTuningV2Mixin
9
+ from lora_mixin import LoraMixin
10
+
11
+ class FineTuneVisualGLMModel(VisualGLMModel):
12
+ def __init__(self, args, transformer=None, parallel_output=True, **kw_args):
13
+ super().__init__(args, transformer=transformer, parallel_output=parallel_output, **kw_args)
14
+ if args.use_ptuning:
15
+ self.add_mixin("ptuning", PTuningV2Mixin(args.num_layers, args.hidden_size // args.num_attention_heads, args.num_attention_heads, args.pre_seq_len))
16
+ if args.use_lora:
17
+ # If you use lora on other "normal" Transformer, just use it with head_first=False (by default)
18
+ self.add_mixin("lora", LoraMixin(args.num_layers, args.lora_rank, head_first=True, num_attention_heads=args.num_attention_heads, hidden_size_per_attention_head=args.hidden_size // args.num_attention_heads, layer_range=args.layer_range), reinit=True)
19
+ # self.get_mixin("eva").model.glm_proj = replace_linear_with_lora(self.get_mixin("eva").model.glm_proj, LoraLinear, args.lora_rank)
20
+ elif args.use_qlora:
21
+ self.add_mixin("lora", LoraMixin(args.num_layers, args.lora_rank, head_first=True, num_attention_heads=args.num_attention_heads, hidden_size_per_attention_head=args.hidden_size // args.num_attention_heads, layer_range=args.layer_range, qlora=True), reinit=True)
22
+ self.args = args
23
+
24
+ @classmethod
25
+ def add_model_specific_args(cls, parser):
26
+ group = parser.add_argument_group('VisualGLM-finetune', 'VisualGLM finetune Configurations')
27
+ group.add_argument('--pre_seq_len', type=int, default=8)
28
+ group.add_argument('--lora_rank', type=int, default=10)
29
+ group.add_argument('--use_ptuning', action="store_true")
30
+ group.add_argument('--use_lora', action="store_true")
31
+ group.add_argument('--use_qlora', action="store_true")
32
+ group.add_argument('--layer_range', nargs='+', type=int, default=None)
33
+ return super().add_model_specific_args(parser)
34
+
35
+ def disable_untrainable_params(self):
36
+ enable = []
37
+ if self.args.use_ptuning:
38
+ enable.extend(['ptuning'])
39
+ if self.args.use_lora or self.args.use_qlora:
40
+ enable.extend(['matrix_A', 'matrix_B'])
41
+ for n, p in self.named_parameters():
42
+ flag = False
43
+ for e in enable:
44
+ if e.lower() in n.lower():
45
+ flag = True
46
+ break
47
+ if not flag:
48
+ p.requires_grad_(False)
49
+ else:
50
+ print(n)
51
+
52
+
53
+ def get_batch(data_iterator, args, timers):
54
+ # Items and their type.
55
+ keys = ['input_ids', 'labels']
56
+ datatype = torch.int64
57
+
58
+ # Broadcast data.
59
+ timers('data loader').start()
60
+ if data_iterator is not None:
61
+ data = next(data_iterator)
62
+ else:
63
+ data = None
64
+ timers('data loader').stop()
65
+ data_b = mpu.broadcast_data(keys, data, datatype)
66
+ data_i = mpu.broadcast_data(['image'], data, torch.float32)
67
+ # Unpack.
68
+ tokens = data_b['input_ids'].long()
69
+ labels = data_b['labels'].long()
70
+ img = data_i['image']
71
+ if args.fp16:
72
+ img = img.half()
73
+
74
+ return tokens, labels, img, data['pre_image']
75
+
76
+
77
+ from torch.nn import CrossEntropyLoss
78
+
79
+ def forward_step(data_iterator, model, args, timers):
80
+ """Forward step."""
81
+
82
+ # Get the batch.
83
+ timers('batch generator').start()
84
+ tokens, labels, image, pre_image = get_batch(
85
+ data_iterator, args, timers)
86
+ timers('batch generator').stop()
87
+
88
+ logits = model(input_ids=tokens, image=image, pre_image=pre_image)[0]
89
+ dtype = logits.dtype
90
+ lm_logits = logits.to(torch.float32)
91
+
92
+ # Shift so that tokens < n predict n
93
+ shift_logits = lm_logits[..., :-1, :].contiguous()
94
+ shift_labels = labels[..., 1:].contiguous()
95
+ # Flatten the tokens
96
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
97
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
98
+
99
+ lm_logits = lm_logits.to(dtype)
100
+ loss = loss.to(dtype)
101
+ return loss, {'loss': loss}
102
+
103
+
104
+ from model.blip2 import BlipImageEvalProcessor
105
+ from torch.utils.data import Dataset
106
+ import json
107
+ from PIL import Image
108
+
109
+ class FewShotDataset(Dataset):
110
+ def __init__(self, path, processor, tokenizer, args):
111
+ max_seq_length = args.max_source_length + args.max_target_length
112
+ with open(path, 'r', encoding='utf-8') as f:
113
+ data = json.load(f)
114
+ self.images = []
115
+ self.input_ids = []
116
+ self.labels = []
117
+ for item in data:
118
+ image = processor(Image.open(item['img']).convert('RGB'))
119
+ input0 = tokenizer.encode("<img>", add_special_tokens=False)
120
+ input1 = [tokenizer.pad_token_id] * args.image_length
121
+ input2 = tokenizer.encode("</img>问:"+item['prompt']+"\n答��", add_special_tokens=False)
122
+ a_ids = sum([input0, input1, input2], [])
123
+ b_ids = tokenizer.encode(text=item['label'], add_special_tokens=False)
124
+ if len(a_ids) > args.max_source_length - 1:
125
+ a_ids = a_ids[: args.max_source_length - 1]
126
+ if len(b_ids) > args.max_target_length - 2:
127
+ b_ids = b_ids[: args.max_target_length - 2]
128
+ pre_image = len(input0)
129
+ input_ids = tokenizer.build_inputs_with_special_tokens(a_ids, b_ids)
130
+
131
+ context_length = input_ids.index(tokenizer.bos_token_id)
132
+ mask_position = context_length - 1
133
+ labels = [-100] * context_length + input_ids[mask_position+1:]
134
+
135
+ pad_len = max_seq_length - len(input_ids)
136
+ input_ids = input_ids + [tokenizer.pad_token_id] * pad_len
137
+ labels = labels + [tokenizer.pad_token_id] * pad_len
138
+ if args.ignore_pad_token_for_loss:
139
+ labels = [(l if l != tokenizer.pad_token_id else -100) for l in labels]
140
+ self.images.append(image)
141
+ self.input_ids.append(input_ids)
142
+ self.labels.append(labels)
143
+ self.pre_image = pre_image
144
+
145
+ def __len__(self):
146
+ return len(self.images)
147
+
148
+ def __getitem__(self, idx):
149
+ return {
150
+ "image": self.images[idx],
151
+ "input_ids": self.input_ids[idx],
152
+ "labels": self.labels[idx],
153
+ "pre_image": self.pre_image
154
+ }
155
+
156
+
157
+ def create_dataset_function(path, args):
158
+ tokenizer = get_tokenizer(args)
159
+ image_processor = BlipImageEvalProcessor(224)
160
+
161
+ dataset = FewShotDataset(path, image_processor, tokenizer, args)
162
+ return dataset
163
+
164
+
165
+ if __name__ == '__main__':
166
+ py_parser = argparse.ArgumentParser(add_help=False)
167
+ py_parser.add_argument('--max_source_length', type=int)
168
+ py_parser.add_argument('--max_target_length', type=int)
169
+ py_parser.add_argument('--ignore_pad_token_for_loss', type=bool, default=True)
170
+ # py_parser.add_argument('--old_checkpoint', action="store_true")
171
+ py_parser.add_argument('--source_prefix', type=str, default="")
172
+ py_parser = FineTuneVisualGLMModel.add_model_specific_args(py_parser)
173
+ known, args_list = py_parser.parse_known_args()
174
+ args = get_args(args_list)
175
+ args = argparse.Namespace(**vars(args), **vars(known))
176
+ args.device = 'cpu'
177
+
178
+ model_type = 'visualglm-6b'
179
+ model, args = FineTuneVisualGLMModel.from_pretrained(model_type, args)
180
+ if torch.cuda.is_available():
181
+ model = model.to('cuda')
182
+ tokenizer = get_tokenizer(args)
183
+ label_pad_token_id = -100 if args.ignore_pad_token_for_loss else tokenizer.pad_token_id
184
+ def data_collator(examples):
185
+ for example in examples:
186
+ example['input_ids'] = torch.tensor(example['input_ids'], dtype=torch.long)
187
+ example['labels'] = torch.tensor(example['labels'], dtype=torch.long)
188
+ ret = {
189
+ 'input_ids': torch.stack([example['input_ids'] for example in examples]),
190
+ 'labels': torch.stack([example['labels'] for example in examples]),
191
+ 'image': torch.stack([example['image'] for example in examples]),
192
+ 'pre_image': example['pre_image']
193
+ }
194
+ return ret
195
+ training_main(args, model_cls=model, forward_step_function=forward_step, create_dataset_function=create_dataset_function, collate_fn=data_collator)
lora_mixin.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ In this mixin, I use a different implementation than sat/model/finetune/lora.py
3
+ I just use a fake linear layer to replace any model with lora mixin.
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from sat.model.base_model import BaseMixin
9
+ import math
10
+ from sat.helpers import print_all
11
+ from sat.model.transformer import RowParallelLinear, ColumnParallelLinear
12
+
13
+ class HackLinear(nn.Linear):
14
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
15
+ if prefix + 'weight' in state_dict:
16
+ self.weight.data.copy_(state_dict[prefix+'weight'])
17
+ if prefix + 'bias' in state_dict:
18
+ self.bias.data.copy_(state_dict[prefix+'bias'])
19
+
20
+ class HackRowParallelLinear(RowParallelLinear):
21
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
22
+ if prefix + 'weight' in state_dict:
23
+ self.weight.data.copy_(state_dict[prefix+'weight'])
24
+ if prefix + 'bias' in state_dict:
25
+ self.bias.data.copy_(state_dict[prefix+'bias'])
26
+
27
+ class HackColumnParallelLinear(ColumnParallelLinear):
28
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
29
+ if prefix + 'weight' in state_dict:
30
+ self.weight.data.copy_(state_dict[prefix+'weight'])
31
+ if prefix + 'bias' in state_dict:
32
+ self.bias.data.copy_(state_dict[prefix+'bias'])
33
+
34
+ try:
35
+ from bitsandbytes.nn import LinearNF4
36
+ def copy_nested_list(src, dst):
37
+ for i in range(len(dst)):
38
+ if type(dst[i]) is torch.Tensor:
39
+ dst[i].copy_(src[i])
40
+ elif type(dst[i]) is list:
41
+ copy_nested_list(src[i], dst[i])
42
+ else:
43
+ dst[i] = src[i]
44
+ class HackLinearNF4(LinearNF4):
45
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
46
+ if prefix + 'weight' in state_dict:
47
+ self.weight.data.copy_(state_dict[prefix+'weight'])
48
+ if self.weight.data.dtype == torch.uint8:
49
+ copy_nested_list(state_dict[prefix+'quant_state'], self.weight.quant_state)
50
+ if prefix + 'bias' in state_dict:
51
+ self.bias.data.copy_(state_dict[prefix+'bias'])
52
+ def _save_to_state_dict(self, destination, prefix, keep_vars):
53
+ super()._save_to_state_dict(destination, prefix, keep_vars)
54
+ destination[prefix+'quant_state'] = self.weight.quant_state
55
+ except Exception as exception:
56
+ print_all("Failed to load bitsandbytes:" + str(exception), level='WARNING')
57
+
58
+
59
+ class HackParameterList(nn.ParameterList):
60
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
61
+ for i in range(len(self)):
62
+ if prefix + str(i) in state_dict:
63
+ self[i].data.copy_(state_dict[prefix+str(i)])
64
+
65
+ map_cls = {
66
+ nn.Linear: (HackLinear, {}),
67
+ ColumnParallelLinear: (HackColumnParallelLinear, {'gather_output': False}),
68
+ RowParallelLinear: (HackRowParallelLinear, {'input_is_parallel': True})
69
+ }
70
+
71
+ class LoraLinear(nn.Module):
72
+ def __init__(self, original_cls, partition, in_dim, out_dim, r, lora_alpha=1., lora_dropout=0., head_first=False, num_attention_heads=None, hidden_size_per_attention_head=None, qlora=False):
73
+ """
74
+ You can use safely with this layer, ONLY WHEN query_key_value output is query_key_value order.
75
+ If you use a different order like ChatGLM
76
+ """
77
+ super().__init__()
78
+ if lora_dropout and lora_dropout > 0:
79
+ self.lora_dropout = nn.Dropout(p=lora_dropout)
80
+ else:
81
+ self.lora_dropout = lambda x: x
82
+ self.r = r
83
+ self.lora_alpha = lora_alpha
84
+ self.scaling = self.lora_alpha / self.r
85
+ if qlora:
86
+ try:
87
+ self.original = HackLinearNF4(in_dim, out_dim)
88
+ except:
89
+ raise Exception('Build 4bit layer failed. You need to install the latest bitsandbytes. Try `pip install bitsandbytes`. If you still meet error after installation, try running `from bitsandbytes.nn import LinearNF4` with python and fix the error.')
90
+ else:
91
+ base_cls, kwargs = map_cls[original_cls]
92
+ self.original = base_cls(in_dim, out_dim, **kwargs)
93
+ self.matrix_A = HackParameterList([nn.Parameter(torch.empty((r, in_dim))) for _ in range(partition)])
94
+ self.matrix_B = HackParameterList([nn.Parameter(torch.empty((out_dim // partition, r))) for _ in range(partition)])
95
+ for i in range(partition):
96
+ nn.init.kaiming_uniform_(self.matrix_A[i], a=math.sqrt(5))
97
+ nn.init.zeros_(self.matrix_B[i])
98
+ self.head_first = head_first
99
+ self.partition = partition
100
+ if head_first:
101
+ assert num_attention_heads is not None and hidden_size_per_attention_head is not None, "You should set num_attention_heads and hidden_size_per_attention_head if you use head_first=True!"
102
+ self.num_attention_heads = num_attention_heads
103
+ self.hidden_size_per_attention_head = hidden_size_per_attention_head
104
+
105
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
106
+ # This is not a perfect version, becuase it doesn't handle errors and unexpected keys.
107
+ if prefix + 'weight' in state_dict:
108
+ # load from normal Linear
109
+ self.original._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
110
+ else:
111
+ # load from LoraLinear
112
+ super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
113
+
114
+ def forward(self, x):
115
+ mixed_raw_layer = self.original(x)
116
+ lora_outputs = []
117
+ for i in range(self.partition):
118
+ lora_outputs.append((self.lora_dropout(x) @ self.matrix_A[i].T @ self.matrix_B[i].T) * self.scaling)
119
+ if self.head_first:
120
+ new_tensor_shape = lora_outputs[0].size()[:-1] + (
121
+ self.num_attention_heads,
122
+ self.hidden_size_per_attention_head,
123
+ )
124
+ for i in range(self.partition):
125
+ lora_outputs[i] = lora_outputs[i].view(*new_tensor_shape)
126
+ mixed_raw_layer = mixed_raw_layer + torch.cat(lora_outputs, -1).view(*mixed_raw_layer.size())
127
+ else:
128
+ mixed_raw_layer = mixed_raw_layer + torch.cat(lora_outputs, -1)
129
+
130
+ return mixed_raw_layer
131
+
132
+
133
+ def replace_linear_with_lora(lin, partition, r, *args, **kw_args):
134
+ # not supported for linear without bias for now
135
+ out_dim, in_dim = lin.weight.shape
136
+ original_cls = type(lin)
137
+ del lin
138
+ return LoraLinear(original_cls, partition, in_dim, out_dim, r, *args, **kw_args)
139
+
140
+ def merge_linear_lora(lin):
141
+ if lin.original.weight.data.dtype is not torch.uint8:
142
+ weight = lin.original.weight
143
+ out_dim, in_dim = weight.shape
144
+ new_lin = nn.Linear(in_dim, out_dim)
145
+ else:
146
+ import bitsandbytes.functional as F
147
+ weight = F.dequantize_fp4(lin.original.weight.data, lin.original.weight.quant_state).to(lin.original.bias.data.dtype)
148
+ out_dim, in_dim = weight.shape
149
+ new_lin = HackLinearNF4(in_dim, out_dim)
150
+ new_lin.bias.data = lin.original.bias.data
151
+ new_qkv = []
152
+ for i in range(lin.partition):
153
+ new_qkv.append(lin.matrix_A[i].data.T.float() @ lin.matrix_B[i].data.T.float() * lin.scaling)
154
+ if lin.head_first:
155
+ ini_shape = new_qkv[0].shape
156
+ new_qkv = [x.view(ini_shape[0], lin.num_attention_heads, -1) for x in new_qkv]
157
+ new_qkv = torch.cat(new_qkv, -1).view(ini_shape[0], lin.partition*ini_shape[1])
158
+ else:
159
+ new_qkv = torch.cat(new_qkv, -1)
160
+ new_lin.weight.data = weight + new_qkv.T.to(lin.original.bias.data.dtype)
161
+ return new_lin.cuda() if torch.cuda.is_available() else new_lin
162
+
163
+ class LoraMixin(BaseMixin):
164
+ def __init__(self,
165
+ layer_num,
166
+ r: int = 0,
167
+ lora_alpha: int = 1,
168
+ lora_dropout: float = 0.,
169
+ layer_range = None,
170
+ head_first = False,
171
+ num_attention_heads = None,
172
+ hidden_size_per_attention_head = None,
173
+ qlora = False,
174
+ cross_attention = True):
175
+ super().__init__()
176
+ self.r = r
177
+ self.lora_alpha = lora_alpha
178
+ self.lora_dropout = lora_dropout
179
+
180
+ if layer_range is None:
181
+ layer_range = [i for i in range(layer_num)]
182
+ self.layer_range = layer_range
183
+
184
+ self.scaling = self.lora_alpha / self.r
185
+ self.head_first = head_first
186
+ self.num_attention_heads = num_attention_heads
187
+ self.hidden_size_per_attention_head = hidden_size_per_attention_head
188
+ self.qlora = qlora
189
+ self.cross_attention = cross_attention
190
+
191
+ def reinit(self, parent_model):
192
+ for i in self.layer_range:
193
+ print(f'replacing layer {i} attention with lora')
194
+ parent_model.transformer.layers[i].attention.dense = replace_linear_with_lora(parent_model.transformer.layers[i].attention.dense, 1, self.r, self.lora_alpha, self.lora_dropout, qlora=self.qlora)
195
+ parent_model.transformer.layers[i].attention.query_key_value = replace_linear_with_lora(parent_model.transformer.layers[i].attention.query_key_value, 3, self.r, self.lora_alpha, self.lora_dropout, head_first=self.head_first, num_attention_heads=self.num_attention_heads, hidden_size_per_attention_head=self.hidden_size_per_attention_head, qlora=self.qlora)
196
+ if self.cross_attention and parent_model.transformer.layers[i].is_decoder:
197
+ print(f'replacing layer {i} cross attention with lora')
198
+ parent_model.transformer.layers[i].cross_attention.dense = replace_linear_with_lora(parent_model.transformer.layers[i].cross_attention.dense, 1, self.r, self.lora_alpha, self.lora_dropout, qlora=self.qlora)
199
+ parent_model.transformer.layers[i].cross_attention.query = replace_linear_with_lora(parent_model.transformer.layers[i].cross_attention.query, 1, self.r, self.lora_alpha, self.lora_dropout, qlora=self.qlora)
200
+ parent_model.transformer.layers[i].cross_attention.key_value = replace_linear_with_lora(parent_model.transformer.layers[i].cross_attention.key_value, 2, self.r, self.lora_alpha, self.lora_dropout, qlora=self.qlora)
201
+ if self.qlora:
202
+ print('replacing chatglm linear layer with 4bit')
203
+ def replace_linear_with_nf4(model, name=None, cache={}):
204
+ if type(model) in (nn.Linear, RowParallelLinear, ColumnParallelLinear):
205
+ out_dim, in_dim = model.weight.shape
206
+ return HackLinearNF4(in_dim, out_dim)
207
+ names = set()
208
+ for name, child in model.named_children():
209
+ if name not in names:
210
+ if child in cache:
211
+ new_child = cache[child]
212
+ else:
213
+ new_child = replace_linear_with_nf4(child, name=name, cache=cache)
214
+ cache[child] = new_child
215
+ setattr(model, name, new_child)
216
+ names.add(name)
217
+ flag = True
218
+ while flag:
219
+ flag = False
220
+ for name, child in model.named_children():
221
+ if name not in names:
222
+ setattr(model, name, cache[child])
223
+ names.add(name)
224
+ flag = True
225
+ return model
226
+ replace_linear_with_nf4(parent_model.transformer, None, {})
227
+
228
+ def merge_lora(self):
229
+ for i in self.layer_range:
230
+ print(f'merge layer {i} lora attention back to linear')
231
+ self.transformer.layers[i].attention.dense = merge_linear_lora(self.transformer.layers[i].attention.dense)
232
+ self.transformer.layers[i].attention.query_key_value = merge_linear_lora(self.transformer.layers[i].attention.query_key_value)
233
+ if self.transformer.layers[i].is_decoder:
234
+ print(f'merge layer {i} lora cross attention back to linear')
235
+ self.transformer.layers[i].cross_attention.dense = merge_linear_lora(self.transformer.layers[i].cross_attention.dense)
236
+ self.transformer.layers[i].cross_attention.query = merge_linear_lora(self.transformer.layers[i].cross_attention.query)
237
+ self.transformer.layers[i].cross_attention.key_value = merge_linear_lora(self.transformer.layers[i].cross_attention.key_value)
238
+
239
+ if __name__ == '__main__':
240
+ class Model(nn.Module):
241
+ def __init__(self):
242
+ super().__init__()
243
+ self.child = nn.Linear(100, 200)
244
+
245
+ def forward(self, x):
246
+ return self.child(x)
247
+
248
+ model = Model()
249
+ torch.save(model.state_dict(), "linear.pt")
250
+ x = torch.randn(2, 100)
251
+ out1 = model(x)
252
+ model.child = LoraLinear(100, 200, 10)
253
+ model.load_state_dict(torch.load("linear.pt"), strict=False)
254
+ out2 = model(x)
255
+ torch.save(model.state_dict(), "lora.pt")
256
+ ckpt = torch.load("lora.pt")
257
+ breakpoint()
258
+ model.load_state_dict(ckpt, strict=False)
259
+ out3 = model(x)
260
+ breakpoint()
model/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .chat import chat
2
+ from .infer_util import *
3
+ from .blip2 import BlipImageEvalProcessor
model/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (235 Bytes). View file
 
model/__pycache__/blip2.cpython-310.pyc ADDED
Binary file (4.66 kB). View file
 
model/__pycache__/chat.cpython-310.pyc ADDED
Binary file (4.89 kB). View file
 
model/__pycache__/infer_util.cpython-310.pyc ADDED
Binary file (1.98 kB). View file
 
model/__pycache__/visualglm.cpython-310.pyc ADDED
Binary file (2.25 kB). View file
 
model/blip2.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from sat.model import ViTModel, BaseModel
5
+ from sat.model import BaseMixin
6
+ from sat import AutoModel
7
+ from copy import deepcopy
8
+ from torchvision import transforms
9
+ from torchvision.transforms.functional import InterpolationMode
10
+
11
+ class LNFinalyMixin(BaseMixin):
12
+ def __init__(self, hidden_size):
13
+ super().__init__()
14
+ self.ln_vision = nn.LayerNorm(hidden_size)
15
+
16
+ def final_forward(self, logits, **kw_args):
17
+ return self.ln_vision(logits)
18
+
19
+ class EVAViT(ViTModel):
20
+ def __init__(self, args, transformer=None, parallel_output=True, **kwargs):
21
+ super().__init__(args, transformer=transformer, parallel_output=parallel_output, **kwargs)
22
+ self.del_mixin("cls")
23
+ self.add_mixin("cls", LNFinalyMixin(args.hidden_size))
24
+
25
+ def forward(self, image):
26
+ batch_size = image.size(0)
27
+ input_ids = torch.zeros(batch_size, 1, dtype=torch.long, device=image.device)
28
+ attention_mask = torch.tensor([[1.]], dtype=image.dtype, device=image.device)
29
+ return super().forward(input_ids=input_ids, position_ids=None, attention_mask=attention_mask, image=image)
30
+
31
+ class QFormer(BaseModel):
32
+ def __init__(self, args, transformer=None, parallel_output=True, **kwargs):
33
+ super().__init__(args, transformer=transformer, parallel_output=parallel_output, activation_func=nn.functional.gelu, **kwargs)
34
+ self.transformer.position_embeddings = None
35
+
36
+ def final_forward(self, logits, **kw_args):
37
+ return logits
38
+
39
+ def position_embedding_forward(self, position_ids, **kw_args):
40
+ return None
41
+
42
+ def forward(self, encoder_outputs):
43
+ batch_size = encoder_outputs.size(0)
44
+ input_ids = torch.arange(32, dtype=torch.long, device=encoder_outputs.device).unsqueeze(0).expand(batch_size, -1)
45
+ attention_mask = torch.tensor([[1.]], dtype=encoder_outputs.dtype, device=encoder_outputs.device)
46
+ cross_attention_mask = torch.tensor([[1.]], dtype=encoder_outputs.dtype, device=encoder_outputs.device)
47
+ return super().forward(input_ids=input_ids, position_ids=None, attention_mask=attention_mask, encoder_outputs=encoder_outputs, cross_attention_mask=cross_attention_mask)
48
+
49
+
50
+ class BLIP2(torch.nn.Module):
51
+ def __init__(self, eva_args, qformer_args, vit=None, qformer=None, **kwargs):
52
+ super().__init__()
53
+ if vit is not None:
54
+ self.vit = vit
55
+ else:
56
+ self.vit = EVAViT(EVAViT.get_args(**eva_args))
57
+ if qformer is not None:
58
+ self.qformer = qformer
59
+ else:
60
+ self.qformer = QFormer(QFormer.get_args(**qformer_args))
61
+
62
+ self.glm_proj = nn.Linear(768, 4096).to(self.qformer.parameters().__next__().device).to(self.qformer.parameters().__next__().dtype)
63
+
64
+ def forward(self, image, **kwargs):
65
+ enc = self.vit(image)[0]
66
+ out = self.qformer(enc)[0]
67
+ return self.glm_proj(out)
68
+
69
+ class BlipImageBaseProcessor():
70
+ def __init__(self, mean=None, std=None):
71
+ if mean is None:
72
+ mean = (0.48145466, 0.4578275, 0.40821073)
73
+ if std is None:
74
+ std = (0.26862954, 0.26130258, 0.27577711)
75
+
76
+ self.normalize = transforms.Normalize(mean, std)
77
+
78
+ class BlipImageEvalProcessor(BlipImageBaseProcessor):
79
+ def __init__(self, image_size=384, mean=None, std=None):
80
+ super().__init__(mean=mean, std=std)
81
+
82
+ self.transform = transforms.Compose(
83
+ [
84
+ transforms.Resize(
85
+ (image_size, image_size), interpolation=InterpolationMode.BICUBIC
86
+ ),
87
+ transforms.ToTensor(),
88
+ self.normalize,
89
+ ]
90
+ )
91
+
92
+ def __call__(self, item):
93
+ return self.transform(item)
model/chat.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+ '''
3
+ @File : chat.py
4
+ @Time : 2023/05/08 19:10:08
5
+ @Author : Ming Ding
6
+ @Contact : dm18@mails.tsinghua.edu.cn
7
+ '''
8
+
9
+ import os
10
+ import sys
11
+ import re
12
+ from functools import partial
13
+ from typing import Optional, Tuple, Union, List, Callable, Dict, Any
14
+ import requests
15
+ from PIL import Image
16
+ from io import BytesIO
17
+
18
+ import torch
19
+ from sat.generation.autoregressive_sampling import filling_sequence, BaseStrategy
20
+
21
+ from .blip2 import BlipImageEvalProcessor
22
+
23
+ def get_masks_and_position_ids_glm(seq, mask_position, context_length):
24
+ '''GLM model, different from GPT.
25
+ Args:
26
+ seq: torch.IntTensor, [seq_len]
27
+ mask_position: int, the position of the masked place.
28
+ context_length: int, the length of context.
29
+ Returns:
30
+ tokens: torch.IntTensor, [1, seq_len]
31
+ attention_mask: torch.FloatTensor, [1, seq_len, seq_len]
32
+ position_ids: torch.IntTensor, [2, seq_len]
33
+ '''
34
+ tokens = seq.unsqueeze(0)
35
+
36
+ attention_mask = torch.ones((1, len(seq), len(seq)), device=tokens.device)
37
+ attention_mask.tril_()
38
+ attention_mask[..., :context_length] = 1
39
+ attention_mask.unsqueeze_(1)
40
+
41
+ # 2D position ids
42
+ position_ids = torch.zeros(2, len(seq), device=tokens.device, dtype=torch.long)
43
+ torch.arange(0, context_length, out=position_ids[0, :context_length])
44
+ position_ids[0, context_length:] = mask_position
45
+ torch.arange(1, len(seq) - context_length + 1, out=position_ids[1, context_length:])
46
+
47
+ position_ids = position_ids.unsqueeze(0)
48
+ return tokens, attention_mask, position_ids
49
+
50
+ def process_response(response):
51
+ response = response.strip()
52
+ response = response.replace("[[训练时间]]", "2023年")
53
+ punkts = [
54
+ [",", ","],
55
+ ["!", "!"],
56
+ [":", ":"],
57
+ [";", ";"],
58
+ ["\?", "?"],
59
+ ]
60
+ for item in punkts:
61
+ response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response)
62
+ response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response)
63
+ return response
64
+
65
+ def process_image(text, image=None):
66
+ '''Process image in text.
67
+ Args:
68
+ text: str, text.
69
+ image: Optional, image path / url / PIL image.
70
+ '''
71
+ image_position = text.rfind("<img>") + 5
72
+ # extract path from <img></img> using re
73
+ image_path = re.findall(r"<img>(.*?)</img>", text)
74
+ image_path = image_path[-1] if image_path[-1] else None
75
+ if image_path is not None:
76
+ assert image is None, "image and image_path cannot be both not None."
77
+ text = text.replace(image_path, "")
78
+ image_path = image_path.strip()
79
+ # url
80
+ if image_path.startswith("http"):
81
+ response = requests.get(image_path, timeout=10)
82
+ image = Image.open(BytesIO(response.content))
83
+ # local path
84
+ else:
85
+ image = Image.open(image_path)
86
+ if image is not None and isinstance(image, Image.Image):
87
+ processor = BlipImageEvalProcessor(224)
88
+ image = processor(image.convert('RGB'))
89
+ image = image.unsqueeze(0)
90
+ return text, image_position, image
91
+
92
+
93
+ def chat(image_path, model, tokenizer,
94
+ query: str, history: List[Tuple[str, str]] = None, image: Image = None,
95
+ max_length: int = 1024, top_p=0.7, top_k=30, temperature=0.95, repetition_penalty=1.2,
96
+ invalid_slices=[], english=False
97
+ ):
98
+ if not history:
99
+ history = []
100
+ if image_path:
101
+ prompt = "<img>{}</img>".format(image_path if image_path else "")
102
+ else:
103
+ prompt = "<img></img>"
104
+ if english:
105
+ for i, (old_query, response) in enumerate(history):
106
+ prompt += "Q:{}\nA:{}\n".format(old_query, response)
107
+ prompt += "Q:{}\nA:".format(query)
108
+ else:
109
+ for i, (old_query, response) in enumerate(history):
110
+ prompt += "问:{}\n答:{}\n".format(old_query, response)
111
+ prompt += "问:{}\n答:".format(query)
112
+ # ---------------
113
+ # tokenizer, this is an example of huggingface tokenizer.
114
+ # input str, output['input_ids'] = tensor([[tokenized str, gmask, sop]])
115
+ prompt, image_position, torch_image = process_image(prompt, image=image)
116
+ if torch_image is not None:
117
+ torch_image = torch_image.to(next(model.parameters()).dtype).to(next(model.parameters()).device)
118
+ if image_position < 5: # no image
119
+ inputs = tokenizer([prompt], return_tensors="pt").to(model.parameters().__next__().device)['input_ids'][0]
120
+ pre_image = 0
121
+ else:
122
+ input0 = tokenizer.encode(prompt[:image_position], add_special_tokens=False)
123
+ input1 = [tokenizer.pad_token_id] * model.image_length
124
+ input2 = tokenizer.encode(prompt[image_position:], add_special_tokens=False)
125
+ inputs = sum([input0, input1, input2], [])
126
+ inputs = torch.tensor(tokenizer.build_inputs_with_special_tokens(inputs)).to(model.parameters().__next__().device)
127
+ pre_image = len(input0)
128
+ # ---------------
129
+ # Next, we manually set the format to keep flexibility.
130
+ mask_position = len(inputs) - 2
131
+ context_length = len(inputs) - 1 # all before sop
132
+ get_func = partial(get_masks_and_position_ids_glm, mask_position=mask_position, context_length=context_length)
133
+ seq = torch.cat(
134
+ [inputs, torch.tensor([-1]*(max_length-len(inputs)), device=inputs.device)], dim=0
135
+ )
136
+ # ---------------
137
+ # from sat.generation.sampling_strategies import BeamSearchStrategy
138
+ # strategy = BeamSearchStrategy(num_beams, length_penalty=1., prefer_min_length=5, end_tokens=[tokenizer.eos_token_id], consider_end=True, no_repeat_ngram_size=5, stop_n_iter_unchanged=30, temperature=temperature, top_p=top_p, top_k=60, repetition_penalty=1.1)
139
+ strategy = BaseStrategy(temperature=temperature, top_p=top_p, top_k=top_k, end_tokens=[tokenizer.eos_token_id],
140
+ invalid_slices=invalid_slices, repetition_penalty=repetition_penalty)
141
+ output = filling_sequence(
142
+ model, seq,
143
+ batch_size=1,
144
+ get_masks_and_position_ids=get_func,
145
+ strategy=strategy,
146
+ pre_image=pre_image,
147
+ image=torch_image,
148
+ )[0] # drop memory
149
+
150
+ # ---------------
151
+ # port from inference_glm.py, more general than chat mode
152
+ # clip -1s and fill back generated things into seq
153
+ if type(output) is not list:
154
+ output_list = output.tolist()
155
+ else:
156
+ output_list = output
157
+ for i in range(len(output_list)):
158
+ output = output_list[i]
159
+ if type(output) is not list:
160
+ output = output.tolist()
161
+ try:
162
+ unfinished = output.index(-1)
163
+ except ValueError:
164
+ unfinished = len(output)
165
+ if output[unfinished - 1] == tokenizer.eos_token_id:
166
+ unfinished -= 1
167
+ bog = output.index(tokenizer.bos_token_id)
168
+ output_list[i] = output[:mask_position] + output[bog + 1:unfinished] + output[mask_position + 1:bog]
169
+ # ---------------
170
+
171
+ response = tokenizer.decode(output_list[0])
172
+ sep = 'A:' if english else '答:'
173
+ response = process_response(response).split(sep)[-1].strip()
174
+ history = history + [(query, response)]
175
+ return response, history, torch_image
model/infer_util.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ from io import BytesIO
4
+ import base64
5
+ import re
6
+ import argparse
7
+ import torch
8
+ from transformers import AutoTokenizer
9
+ from sat.model.mixins import CachedAutoregressiveMixin
10
+ from sat.quantization.kernels import quantize
11
+ import hashlib
12
+ from .visualglm import VisualGLMModel
13
+
14
+ def get_infer_setting(gpu_device=0, quant=None):
15
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_device)
16
+ args = argparse.Namespace(
17
+ fp16=True,
18
+ skip_init=True,
19
+ device='cuda' if quant is None else 'cpu',
20
+ )
21
+ model, args = VisualGLMModel.from_pretrained('visualglm-6b', args)
22
+ model.add_mixin('auto-regressive', CachedAutoregressiveMixin())
23
+ assert quant in [None, 4, 8]
24
+ if quant is not None:
25
+ quantize(model.transformer, quant)
26
+ model.eval()
27
+ model = model.cuda()
28
+ tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
29
+ return model, tokenizer
30
+
31
+ def is_chinese(text):
32
+ zh_pattern = re.compile(u'[\u4e00-\u9fa5]+')
33
+ return zh_pattern.search(text)
34
+
35
+ def generate_input(input_text, input_image_prompt, history=[], input_para=None, image_is_encoded=True):
36
+ if not image_is_encoded:
37
+ image = input_image_prompt
38
+ else:
39
+ decoded_image = base64.b64decode(input_image_prompt)
40
+ image = Image.open(BytesIO(decoded_image))
41
+
42
+ input_data = {'input_query': input_text, 'input_image': image, 'history': history, 'gen_kwargs': input_para}
43
+ return input_data
44
+
45
+
46
+ def process_image(image_encoded):
47
+ decoded_image = base64.b64decode(image_encoded)
48
+ image = Image.open(BytesIO(decoded_image))
49
+ image_hash = hashlib.sha256(image.tobytes()).hexdigest()
50
+ image_path = f'./examples/{image_hash}.png'
51
+ if not os.path.isfile(image_path):
52
+ image.save(image_path)
53
+ return os.path.abspath(image_path)
model/visualglm.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from sat.model.official import ChatGLMModel
3
+ from sat.model.base_model import BaseMixin
4
+ from copy import deepcopy
5
+ import json
6
+ from .blip2 import BLIP2
7
+
8
+ from sat.resources.urls import MODEL_URLS
9
+ MODEL_URLS['visualglm-6b'] = 'https://cloud.tsinghua.edu.cn/f/348b98dffcc940b6a09d/?dl=1'
10
+
11
+ class ImageMixin(BaseMixin):
12
+ def __init__(self, args):
13
+ super().__init__()
14
+ self.args = deepcopy(args)
15
+ self.model = BLIP2(args.eva_args, args.qformer_args)
16
+
17
+ def word_embedding_forward(self, input_ids, output_cross_layer, **kw_args):
18
+ if kw_args["pre_image"] > input_ids.shape[1] or kw_args.get("image", None) is None:
19
+ return self.transformer.word_embeddings(input_ids)
20
+ image_emb = self.model(**kw_args)
21
+ # the image is inserted after 问:<img>, override 32 pads
22
+ pre_id, pads, post_id = torch.tensor_split(input_ids, [kw_args["pre_image"], kw_args["pre_image"]+self.args.image_length], dim=1)
23
+ pre_txt_emb = self.transformer.word_embeddings(pre_id)
24
+ post_txt_emb = self.transformer.word_embeddings(post_id)
25
+ return torch.cat([pre_txt_emb, image_emb, post_txt_emb], dim=1)
26
+
27
+ class VisualGLMModel(ChatGLMModel):
28
+ def __init__(self, args, transformer=None, **kwargs):
29
+ super().__init__(args, transformer=transformer, **kwargs)
30
+ self.image_length = args.image_length
31
+ self.add_mixin("eva", ImageMixin(args))
32
+
33
+ @classmethod
34
+ def add_model_specific_args(cls, parser):
35
+ group = parser.add_argument_group('VisualGLM', 'VisualGLM Configurations')
36
+ group.add_argument('--image_length', type=int, default=32)
37
+ group.add_argument('--eva_args', type=json.loads, default={})
38
+ group.add_argument('--qformer_args', type=json.loads, default={})
39
+ return super().add_model_specific_args(parser)
40
+
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ SwissArmyTransformer>=0.3.6
2
+ torch>1.10.0
3
+ torchvision
4
+ transformers>=4.27.1
5
+ mdtex2html
6
+ gradio
requirements_wo_ds.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>1.10.0
2
+ torchvision
3
+ transformers>=4.27.1
4
+ mdtex2html
5
+ gradio
6
+ sentencepiece
7
+ tensorboardX
8
+ datasets
9
+ cpm_kernels
10
+ einops
web_demo.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import gradio as gr
4
+ from PIL import Image
5
+ import os
6
+ import json
7
+ from model import is_chinese, get_infer_setting, generate_input, chat
8
+ import torch
9
+
10
+ def generate_text_with_image(input_text, image, history=[], request_data=dict(), is_zh=True):
11
+ input_para = {
12
+ "max_length": 2048,
13
+ "min_length": 50,
14
+ "temperature": 0.8,
15
+ "top_p": 0.4,
16
+ "top_k": 100,
17
+ "repetition_penalty": 1.2
18
+ }
19
+ input_para.update(request_data)
20
+
21
+ input_data = generate_input(input_text, image, history, input_para, image_is_encoded=False)
22
+ input_image, gen_kwargs = input_data['input_image'], input_data['gen_kwargs']
23
+ with torch.no_grad():
24
+ answer, history, _ = chat(None, model, tokenizer, input_text, history=history, image=input_image, \
25
+ max_length=gen_kwargs['max_length'], top_p=gen_kwargs['top_p'], \
26
+ top_k = gen_kwargs['top_k'], temperature=gen_kwargs['temperature'], english=not is_zh)
27
+ return answer
28
+
29
+
30
+ def request_model(input_text, temperature, top_p, image_prompt, result_previous):
31
+ result_text = [(ele[0], ele[1]) for ele in result_previous]
32
+ for i in range(len(result_text)-1, -1, -1):
33
+ if result_text[i][0] == "" or result_text[i][1] == "":
34
+ del result_text[i]
35
+ print(f"history {result_text}")
36
+
37
+ is_zh = is_chinese(input_text)
38
+ if image_prompt is None:
39
+ if is_zh:
40
+ result_text.append((input_text, '图片为空!请上传图片并重试。'))
41
+ else:
42
+ result_text.append((input_text, 'Image empty! Please upload a image and retry.'))
43
+ return input_text, result_text
44
+ elif input_text == "":
45
+ result_text.append((input_text, 'Text empty! Please enter text and retry.'))
46
+ return "", result_text
47
+
48
+ request_para = {"temperature": temperature, "top_p": top_p}
49
+ image = Image.open(image_prompt)
50
+ try:
51
+ answer = generate_text_with_image(input_text, image, result_text.copy(), request_para, is_zh)
52
+ except Exception as e:
53
+ print(f"error: {e}")
54
+ if is_zh:
55
+ result_text.append((input_text, '超时!请稍等几分钟再重试。'))
56
+ else:
57
+ result_text.append((input_text, 'Timeout! Please wait a few minutes and retry.'))
58
+ return "", result_text
59
+
60
+ result_text.append((input_text, answer))
61
+ print(result_text)
62
+ return "", result_text
63
+
64
+
65
+ DESCRIPTION = '''# <a href="https://github.com/THUDM/VisualGLM-6B">VisualGLM</a>'''
66
+
67
+ MAINTENANCE_NOTICE1 = 'Hint 1: If the app report "Something went wrong, connection error out", please turn off your proxy and retry.\nHint 2: If you upload a large size of image like 10MB, it may take some time to upload and process. Please be patient and wait.'
68
+ MAINTENANCE_NOTICE2 = '提示1: 如果应用报了“Something went wrong, connection error out”的错误,请关闭代理并重试。\n提示2: 如果你上传了很大的图片,比如10MB大小,那将需要一些时间来上传和处理,请耐心等待。'
69
+
70
+ NOTES = 'This app is adapted from <a href="https://github.com/THUDM/VisualGLM-6B">https://github.com/THUDM/VisualGLM-6B</a>. It would be recommended to check out the repo if you want to see the detail of our model and training process.'
71
+
72
+
73
+ def clear_fn(value):
74
+ return "", [("", "Hi, What do you want to know about this image?")], None
75
+
76
+ def clear_fn2(value):
77
+ return [("", "Hi, What do you want to know about this image?")]
78
+
79
+
80
+ def main(args):
81
+ gr.close_all()
82
+ global model, tokenizer
83
+ model, tokenizer = get_infer_setting(gpu_device=0, quant=args.quant)
84
+
85
+ with gr.Blocks(css='style.css') as demo:
86
+ gr.Markdown(DESCRIPTION)
87
+ with gr.Row():
88
+ with gr.Column(scale=4.5):
89
+ with gr.Group():
90
+ input_text = gr.Textbox(label='Input Text', placeholder='Please enter text prompt below and press ENTER.')
91
+ with gr.Row():
92
+ run_button = gr.Button('Generate')
93
+ clear_button = gr.Button('Clear')
94
+
95
+ image_prompt = gr.Image(type="filepath", label="Image Prompt", value=None)
96
+ with gr.Row():
97
+ temperature = gr.Slider(maximum=1, value=0.8, minimum=0, label='Temperature')
98
+ top_p = gr.Slider(maximum=1, value=0.4, minimum=0, label='Top P')
99
+ with gr.Group():
100
+ with gr.Row():
101
+ maintenance_notice = gr.Markdown(MAINTENANCE_NOTICE1)
102
+ with gr.Column(scale=5.5):
103
+ result_text = gr.components.Chatbot(label='Multi-round conversation History', value=[("", "Hi, What do you want to know about this image?")]).style(height=550)
104
+
105
+ gr.Markdown(NOTES)
106
+
107
+ print(gr.__version__)
108
+ run_button.click(fn=request_model,inputs=[input_text, temperature, top_p, image_prompt, result_text],
109
+ outputs=[input_text, result_text])
110
+ input_text.submit(fn=request_model,inputs=[input_text, temperature, top_p, image_prompt, result_text],
111
+ outputs=[input_text, result_text])
112
+ clear_button.click(fn=clear_fn, inputs=clear_button, outputs=[input_text, result_text, image_prompt])
113
+ image_prompt.upload(fn=clear_fn2, inputs=clear_button, outputs=[result_text])
114
+ image_prompt.clear(fn=clear_fn2, inputs=clear_button, outputs=[result_text])
115
+
116
+ print(gr.__version__)
117
+
118
+ demo.queue(concurrency_count=10)
119
+ demo.launch(share=args.share)
120
+
121
+
122
+ if __name__ == '__main__':
123
+ import argparse
124
+ parser = argparse.ArgumentParser()
125
+ parser.add_argument("--quant", choices=[8, 4], type=int, default=None)
126
+ parser.add_argument("--share", action="store_true")
127
+ args = parser.parse_args()
128
+ args.share = "True"
129
+ main(args)
web_demo_hf.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModel, AutoTokenizer
2
+ import gradio as gr
3
+ import mdtex2html
4
+ import torch
5
+
6
+ """Override Chatbot.postprocess"""
7
+
8
+ def postprocess(self, y):
9
+ if y is None:
10
+ return []
11
+ for i, (message, response) in enumerate(y):
12
+ y[i] = (
13
+ None if message is None else mdtex2html.convert((message)),
14
+ None if response is None else mdtex2html.convert(response),
15
+ )
16
+ return y
17
+
18
+ gr.Chatbot.postprocess = postprocess
19
+
20
+
21
+ def parse_text(text):
22
+ """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
23
+ lines = text.split("\n")
24
+ lines = [line for line in lines if line != ""]
25
+ count = 0
26
+ for i, line in enumerate(lines):
27
+ if "```" in line:
28
+ count += 1
29
+ items = line.split('`')
30
+ if count % 2 == 1:
31
+ lines[i] = f'<pre><code class="language-{items[-1]}">'
32
+ else:
33
+ lines[i] = f'<br></code></pre>'
34
+ else:
35
+ if i > 0:
36
+ if count % 2 == 1:
37
+ line = line.replace("`", "\`")
38
+ line = line.replace("<", "&lt;")
39
+ line = line.replace(">", "&gt;")
40
+ line = line.replace(" ", "&nbsp;")
41
+ line = line.replace("*", "&ast;")
42
+ line = line.replace("_", "&lowbar;")
43
+ line = line.replace("-", "&#45;")
44
+ line = line.replace(".", "&#46;")
45
+ line = line.replace("!", "&#33;")
46
+ line = line.replace("(", "&#40;")
47
+ line = line.replace(")", "&#41;")
48
+ line = line.replace("$", "&#36;")
49
+ lines[i] = "<br>"+line
50
+ text = "".join(lines)
51
+ return text
52
+
53
+
54
+ def predict(input, image_path, chatbot, max_length, top_p, temperature, history):
55
+ if image_path is None:
56
+ return [(input, "图片不能为空。请重新上传图片并重试。")], []
57
+ chatbot.append((parse_text(input), ""))
58
+ with torch.no_grad():
59
+ for response, history in model.stream_chat(tokenizer, image_path, input, history, max_length=max_length, top_p=top_p,
60
+ temperature=temperature):
61
+ chatbot[-1] = (parse_text(input), parse_text(response))
62
+
63
+ yield chatbot, history
64
+
65
+
66
+ def predict_new_image(image_path, chatbot, max_length, top_p, temperature):
67
+ input, history = "描述这张图片。", []
68
+ chatbot.append((parse_text(input), ""))
69
+ with torch.no_grad():
70
+ for response, history in model.stream_chat(tokenizer, image_path, input, history, max_length=max_length,
71
+ top_p=top_p,
72
+ temperature=temperature):
73
+ chatbot[-1] = (parse_text(input), parse_text(response))
74
+
75
+ yield chatbot, history
76
+
77
+
78
+ def reset_user_input():
79
+ return gr.update(value='')
80
+
81
+
82
+ def reset_state():
83
+ return None, [], []
84
+
85
+
86
+ DESCRIPTION = '''<h1 align="center"><a href="https://github.com/THUDM/VisualGLM-6B">VisualGLM</a></h1>'''
87
+ MAINTENANCE_NOTICE = 'Hint 1: If the app report "Something went wrong, connection error out", please turn off your proxy and retry.\nHint 2: If you upload a large size of image like 10MB, it may take some time to upload and process. Please be patient and wait.'
88
+ NOTES = 'This app is adapted from <a href="https://github.com/THUDM/VisualGLM-6B">https://github.com/THUDM/VisualGLM-6B</a>. It would be recommended to check out the repo if you want to see the detail of our model and training process.'
89
+
90
+ def main(args):
91
+ global model, tokenizer
92
+ tokenizer = AutoTokenizer.from_pretrained("THUDM/visualglm-6b", trust_remote_code=True)
93
+ if args.quant in [4, 8]:
94
+ model = AutoModel.from_pretrained("THUDM/visualglm-6b", trust_remote_code=True).quantize(args.quant).half().cuda()
95
+ else:
96
+ model = AutoModel.from_pretrained("THUDM/visualglm-6b", trust_remote_code=True).half().cuda()
97
+ model = model.eval()
98
+
99
+ with gr.Blocks(css='style.css') as demo:
100
+ gr.HTML(DESCRIPTION)
101
+
102
+ with gr.Row():
103
+ with gr.Column(scale=2):
104
+ image_path = gr.Image(type="filepath", label="Image Prompt", value=None).style(height=504)
105
+ with gr.Column(scale=4):
106
+ chatbot = gr.Chatbot().style(height=480)
107
+ with gr.Row():
108
+ with gr.Column(scale=2, min_width=100):
109
+ max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
110
+ top_p = gr.Slider(0, 1, value=0.4, step=0.01, label="Top P", interactive=True)
111
+ temperature = gr.Slider(0, 1, value=0.8, step=0.01, label="Temperature", interactive=True)
112
+ with gr.Column(scale=4):
113
+ with gr.Box():
114
+ with gr.Row():
115
+ with gr.Column(scale=2):
116
+ user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=4).style(
117
+ container=False)
118
+ with gr.Column(scale=1, min_width=64):
119
+ submitBtn = gr.Button("Submit", variant="primary")
120
+ emptyBtn = gr.Button("Clear History")
121
+ gr.Markdown(MAINTENANCE_NOTICE + '\n' + NOTES)
122
+ history = gr.State([])
123
+
124
+
125
+ submitBtn.click(predict, [user_input, image_path, chatbot, max_length, top_p, temperature, history], [chatbot, history],
126
+ show_progress=True)
127
+ image_path.upload(predict_new_image, [image_path, chatbot, max_length, top_p, temperature], [chatbot, history],
128
+ show_progress=True)
129
+ image_path.clear(reset_state, outputs=[image_path, chatbot, history], show_progress=True)
130
+ submitBtn.click(reset_user_input, [], [user_input])
131
+ emptyBtn.click(reset_state, outputs=[image_path, chatbot, history], show_progress=True)
132
+
133
+ print(gr.__version__)
134
+
135
+ demo.queue().launch(share=args.share, inbrowser=True, server_name='0.0.0.0', server_port=8080)
136
+
137
+ if __name__ == '__main__':
138
+ import argparse
139
+ parser = argparse.ArgumentParser()
140
+ parser.add_argument("--quant", choices=[8, 4], type=int, default=None)
141
+ parser.add_argument("--share", action="store_true")
142
+ args = parser.parse_args()
143
+ main(args)
your_logfile.log ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ nohup: 忽略输入
2
+ python: can't open file '/root/VisualGLM-6B/your_program.py': [Errno 2] No such file or directory