Spaces:
Runtime error
Runtime error
hysts
commited on
Commit
•
223c6f1
1
Parent(s):
91dedcc
Copy files from https://github.com/hysts/CogView2_demo
Browse files- .gitmodules +3 -0
- .pre-commit-config.yaml +46 -0
- .style.yapf +5 -0
- CogView2 +1 -0
- LICENSE +21 -0
- LICENSE.CogView2 +201 -0
- app.py +100 -24
- model.py +377 -0
- patch +51 -0
- requirements.txt +7 -3
- samples.txt +13 -0
- style.css +7 -0
.gitmodules
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[submodule "CogView2"]
|
2 |
+
path = CogView2
|
3 |
+
url = https://github.com/THUDM/CogView2
|
.pre-commit-config.yaml
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
exclude: ^patch
|
2 |
+
repos:
|
3 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
4 |
+
rev: v4.2.0
|
5 |
+
hooks:
|
6 |
+
- id: check-executables-have-shebangs
|
7 |
+
- id: check-json
|
8 |
+
- id: check-merge-conflict
|
9 |
+
- id: check-shebang-scripts-are-executable
|
10 |
+
- id: check-toml
|
11 |
+
- id: check-yaml
|
12 |
+
- id: double-quote-string-fixer
|
13 |
+
- id: end-of-file-fixer
|
14 |
+
- id: mixed-line-ending
|
15 |
+
args: ['--fix=lf']
|
16 |
+
- id: requirements-txt-fixer
|
17 |
+
- id: trailing-whitespace
|
18 |
+
- repo: https://github.com/myint/docformatter
|
19 |
+
rev: v1.4
|
20 |
+
hooks:
|
21 |
+
- id: docformatter
|
22 |
+
args: ['--in-place']
|
23 |
+
- repo: https://github.com/pycqa/isort
|
24 |
+
rev: 5.10.1
|
25 |
+
hooks:
|
26 |
+
- id: isort
|
27 |
+
- repo: https://github.com/pre-commit/mirrors-mypy
|
28 |
+
rev: v0.812
|
29 |
+
hooks:
|
30 |
+
- id: mypy
|
31 |
+
args: ['--ignore-missing-imports']
|
32 |
+
- repo: https://github.com/google/yapf
|
33 |
+
rev: v0.32.0
|
34 |
+
hooks:
|
35 |
+
- id: yapf
|
36 |
+
args: ['--parallel', '--in-place']
|
37 |
+
- repo: https://github.com/kynan/nbstripout
|
38 |
+
rev: 0.5.0
|
39 |
+
hooks:
|
40 |
+
- id: nbstripout
|
41 |
+
args: ['--extra-keys', 'metadata.interpreter metadata.kernelspec cell.metadata.pycharm']
|
42 |
+
- repo: https://github.com/nbQA-dev/nbQA
|
43 |
+
rev: 1.3.1
|
44 |
+
hooks:
|
45 |
+
- id: nbqa-isort
|
46 |
+
- id: nbqa-yapf
|
.style.yapf
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[style]
|
2 |
+
based_on_style = pep8
|
3 |
+
blank_line_before_nested_class_or_def = false
|
4 |
+
spaces_before_comment = 2
|
5 |
+
split_before_logical_operator = true
|
CogView2
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Subproject commit 4e55cce981eb94b9c8c1f19ba9f632fd3ee42ba8
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2022 hysts
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
LICENSE.CogView2
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
app.py
CHANGED
@@ -1,32 +1,108 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
2 |
import gradio as gr
|
3 |
-
os.system("pip install torch==1.10.1+cu111 torchvision==0.11.2+cu111 torchaudio==0.10.1 -f https://download.pytorch.org/whl/torch_stable.html")
|
4 |
-
os.system("git clone https://github.com/Sleepychord/Image-Local-Attention")
|
5 |
-
os.chdir("Image-Local-Attention")
|
6 |
-
os.system("python setup.py install")
|
7 |
-
os.chdir("..")
|
8 |
-
os.system("git clone https://github.com/NVIDIA/apex")
|
9 |
-
os.chdir("apex")
|
10 |
-
os.system('pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./')
|
11 |
-
os.chdir("..")
|
12 |
-
os.system("git clone https://github.com/THUDM/CogView2")
|
13 |
-
os.chdir("CogView2")
|
14 |
-
os.system("gdown https://drive.google.com/uc?id=1-2nI2TTUOdiQ2WpydGafk_bZIZggQBK4")
|
15 |
-
os.system("7za x coglm.zip")
|
16 |
-
os.system("gdown https://drive.google.com/uc?id=1ulfXJFstYZUestvWcQIadKkNNDVbpIdM")
|
17 |
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
-
gr.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
|
|
|
|
|
|
|
|
31 |
|
32 |
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
import argparse
|
6 |
+
|
7 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
+
from model import AppModel
|
10 |
+
|
11 |
+
DESCRIPTION = '''# CogView2 (text2image)
|
12 |
+
|
13 |
+
This is an unofficial demo for <a href="https://github.com/THUDM/CogView2">https://github.com/THUDM/CogView2</a>.
|
14 |
+
|
15 |
+
[This Space](https://huggingface.co/spaces/chinhon/translation_eng2ch) is used for translation from English to Chinese.
|
16 |
+
'''
|
17 |
+
|
18 |
+
|
19 |
+
def parse_args() -> argparse.Namespace:
|
20 |
+
parser = argparse.ArgumentParser()
|
21 |
+
parser.add_argument('--only-first-stage', action='store_true')
|
22 |
+
parser.add_argument('--share', action='store_true')
|
23 |
+
return parser.parse_args()
|
24 |
+
|
25 |
+
|
26 |
+
def set_example_text(example: list) -> dict:
|
27 |
+
return gr.Textbox.update(value=example[0])
|
28 |
+
|
29 |
+
|
30 |
+
def main():
|
31 |
+
args = parse_args()
|
32 |
+
model = AppModel(args.only_first_stage)
|
33 |
+
|
34 |
+
with gr.Blocks(css='style.css') as demo:
|
35 |
+
gr.Markdown(DESCRIPTION)
|
36 |
|
37 |
+
with gr.Row():
|
38 |
+
with gr.Column():
|
39 |
+
with gr.Group():
|
40 |
+
text = gr.Textbox(label='Input Text')
|
41 |
+
translate = gr.Checkbox(label='Translate to Chinese',
|
42 |
+
value=False)
|
43 |
+
style = gr.Dropdown(choices=[
|
44 |
+
'mainbody',
|
45 |
+
'photo',
|
46 |
+
'flat',
|
47 |
+
'comics',
|
48 |
+
'oil',
|
49 |
+
'sketch',
|
50 |
+
'isometric',
|
51 |
+
'chinese',
|
52 |
+
'watercolor',
|
53 |
+
],
|
54 |
+
label='Style')
|
55 |
+
seed = gr.Slider(0,
|
56 |
+
100000,
|
57 |
+
step=1,
|
58 |
+
value=1234,
|
59 |
+
label='Seed')
|
60 |
+
only_first_stage = gr.Checkbox(
|
61 |
+
label='Only First Stage',
|
62 |
+
value=args.only_first_stage,
|
63 |
+
visible=not args.only_first_stage)
|
64 |
+
num_images = gr.Slider(1,
|
65 |
+
16,
|
66 |
+
step=1,
|
67 |
+
value=8,
|
68 |
+
label='Number of Images')
|
69 |
+
with open('samples.txt') as f:
|
70 |
+
samples = [[line.strip()] for line in f.readlines()]
|
71 |
+
examples = gr.Dataset(components=[text], samples=samples)
|
72 |
+
run_button = gr.Button('Run')
|
73 |
|
74 |
+
with gr.Column():
|
75 |
+
with gr.Group():
|
76 |
+
translated_text = gr.Textbox(label='Translated Text')
|
77 |
+
with gr.Tabs():
|
78 |
+
with gr.TabItem('Output (Grid View)'):
|
79 |
+
result_grid = gr.Image(show_label=False)
|
80 |
+
with gr.TabItem('Output (Gallery)'):
|
81 |
+
result_gallery = gr.Gallery(show_label=False)
|
82 |
|
83 |
+
run_button.click(fn=model.run_with_translation,
|
84 |
+
inputs=[
|
85 |
+
text,
|
86 |
+
translate,
|
87 |
+
style,
|
88 |
+
seed,
|
89 |
+
only_first_stage,
|
90 |
+
num_images,
|
91 |
+
],
|
92 |
+
outputs=[
|
93 |
+
translated_text,
|
94 |
+
result_grid,
|
95 |
+
result_gallery,
|
96 |
+
])
|
97 |
+
examples.click(fn=set_example_text,
|
98 |
+
inputs=examples,
|
99 |
+
outputs=examples.components)
|
100 |
|
101 |
+
demo.launch(
|
102 |
+
enable_queue=True,
|
103 |
+
share=args.share,
|
104 |
+
)
|
105 |
|
106 |
|
107 |
+
if __name__ == '__main__':
|
108 |
+
main()
|
model.py
ADDED
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#This code is adapted from https://github.com/THUDM/CogView2/blob/4e55cce981eb94b9c8c1f19ba9f632fd3ee42ba8/cogview2_text2image.py
|
2 |
+
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
import argparse
|
6 |
+
import functools
|
7 |
+
import logging
|
8 |
+
import pathlib
|
9 |
+
import sys
|
10 |
+
import time
|
11 |
+
from typing import Any
|
12 |
+
|
13 |
+
import gradio as gr
|
14 |
+
import numpy as np
|
15 |
+
import torch
|
16 |
+
from icetk import IceTokenizer
|
17 |
+
from SwissArmyTransformer import get_args
|
18 |
+
from SwissArmyTransformer.arguments import set_random_seed
|
19 |
+
from SwissArmyTransformer.generation.autoregressive_sampling import \
|
20 |
+
filling_sequence
|
21 |
+
from SwissArmyTransformer.model import CachedAutoregressiveModel
|
22 |
+
|
23 |
+
app_dir = pathlib.Path(__file__).parent
|
24 |
+
submodule_dir = app_dir / 'CogView2'
|
25 |
+
sys.path.insert(0, submodule_dir.as_posix())
|
26 |
+
|
27 |
+
from coglm_strategy import CoglmStrategy
|
28 |
+
from sr_pipeline import SRGroup
|
29 |
+
|
30 |
+
formatter = logging.Formatter(
|
31 |
+
'[%(asctime)s] %(name)s %(levelname)s: %(message)s',
|
32 |
+
datefmt='%Y-%m-%d %H:%M:%S')
|
33 |
+
stream_handler = logging.StreamHandler(stream=sys.stdout)
|
34 |
+
stream_handler.setLevel(logging.DEBUG)
|
35 |
+
stream_handler.setFormatter(formatter)
|
36 |
+
logger = logging.getLogger(__name__)
|
37 |
+
logger.setLevel(logging.DEBUG)
|
38 |
+
logger.propagate = False
|
39 |
+
logger.addHandler(stream_handler)
|
40 |
+
|
41 |
+
ICETK_MODEL_DIR = app_dir / 'icetk_models'
|
42 |
+
|
43 |
+
|
44 |
+
def get_masks_and_position_ids_coglm(
|
45 |
+
seq: torch.Tensor, context_length: int
|
46 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
47 |
+
tokens = seq.unsqueeze(0)
|
48 |
+
|
49 |
+
attention_mask = torch.ones((1, len(seq), len(seq)), device=tokens.device)
|
50 |
+
attention_mask.tril_()
|
51 |
+
attention_mask[..., :context_length] = 1
|
52 |
+
attention_mask.unsqueeze_(1)
|
53 |
+
|
54 |
+
position_ids = torch.zeros(len(seq),
|
55 |
+
device=tokens.device,
|
56 |
+
dtype=torch.long)
|
57 |
+
torch.arange(0, context_length, out=position_ids[:context_length])
|
58 |
+
torch.arange(512,
|
59 |
+
512 + len(seq) - context_length,
|
60 |
+
out=position_ids[context_length:])
|
61 |
+
|
62 |
+
position_ids = position_ids.unsqueeze(0)
|
63 |
+
return tokens, attention_mask, position_ids
|
64 |
+
|
65 |
+
|
66 |
+
class InferenceModel(CachedAutoregressiveModel):
|
67 |
+
def final_forward(self, logits, **kwargs):
|
68 |
+
logits_parallel = logits
|
69 |
+
logits_parallel = torch.nn.functional.linear(
|
70 |
+
logits_parallel.float(),
|
71 |
+
self.transformer.word_embeddings.weight[:20000].float())
|
72 |
+
return logits_parallel
|
73 |
+
|
74 |
+
|
75 |
+
def get_recipe(name: str) -> dict[str, Any]:
|
76 |
+
r = {
|
77 |
+
'attn_plus': 1.4,
|
78 |
+
'temp_all_gen': 1.15,
|
79 |
+
'topk_gen': 16,
|
80 |
+
'temp_cluster_gen': 1.,
|
81 |
+
'temp_all_dsr': 1.5,
|
82 |
+
'topk_dsr': 100,
|
83 |
+
'temp_cluster_dsr': 0.89,
|
84 |
+
'temp_all_itersr': 1.3,
|
85 |
+
'topk_itersr': 16,
|
86 |
+
'query_template': '{}<start_of_image>',
|
87 |
+
}
|
88 |
+
if name == 'none':
|
89 |
+
pass
|
90 |
+
elif name == 'mainbody':
|
91 |
+
r['query_template'] = '{} 高清摄影 隔绝<start_of_image>'
|
92 |
+
|
93 |
+
elif name == 'photo':
|
94 |
+
r['query_template'] = '{} 高清摄影<start_of_image>'
|
95 |
+
|
96 |
+
elif name == 'flat':
|
97 |
+
r['query_template'] = '{} 平面风格<start_of_image>'
|
98 |
+
# r['attn_plus'] = 1.8
|
99 |
+
# r['temp_cluster_gen'] = 0.75
|
100 |
+
r['temp_all_gen'] = 1.1
|
101 |
+
r['topk_dsr'] = 5
|
102 |
+
r['temp_cluster_dsr'] = 0.4
|
103 |
+
|
104 |
+
r['temp_all_itersr'] = 1
|
105 |
+
r['topk_itersr'] = 5
|
106 |
+
elif name == 'comics':
|
107 |
+
r['query_template'] = '{} 漫画 隔绝<start_of_image>'
|
108 |
+
r['topk_dsr'] = 5
|
109 |
+
r['temp_cluster_dsr'] = 0.4
|
110 |
+
r['temp_all_gen'] = 1.1
|
111 |
+
r['temp_all_itersr'] = 1
|
112 |
+
r['topk_itersr'] = 5
|
113 |
+
elif name == 'oil':
|
114 |
+
r['query_template'] = '{} 油画风格<start_of_image>'
|
115 |
+
pass
|
116 |
+
elif name == 'sketch':
|
117 |
+
r['query_template'] = '{} 素描风格<start_of_image>'
|
118 |
+
r['temp_all_gen'] = 1.1
|
119 |
+
elif name == 'isometric':
|
120 |
+
r['query_template'] = '{} 等距矢量图<start_of_image>'
|
121 |
+
r['temp_all_gen'] = 1.1
|
122 |
+
elif name == 'chinese':
|
123 |
+
r['query_template'] = '{} 水墨国画<start_of_image>'
|
124 |
+
r['temp_all_gen'] = 1.12
|
125 |
+
elif name == 'watercolor':
|
126 |
+
r['query_template'] = '{} 水彩画风格<start_of_image>'
|
127 |
+
return r
|
128 |
+
|
129 |
+
|
130 |
+
def get_default_args() -> argparse.Namespace:
|
131 |
+
arg_list = ['--mode', 'inference', '--fp16']
|
132 |
+
args = get_args(arg_list)
|
133 |
+
known = argparse.Namespace(img_size=160,
|
134 |
+
only_first_stage=False,
|
135 |
+
inverse_prompt=False,
|
136 |
+
style='mainbody')
|
137 |
+
args = argparse.Namespace(**vars(args), **vars(known),
|
138 |
+
**get_recipe(known.style))
|
139 |
+
return args
|
140 |
+
|
141 |
+
|
142 |
+
class Model:
|
143 |
+
def __init__(self, only_first_stage: bool = False):
|
144 |
+
self.args = get_default_args()
|
145 |
+
self.args.only_first_stage = only_first_stage
|
146 |
+
|
147 |
+
self.tokenizer = self.load_tokenizer()
|
148 |
+
|
149 |
+
self.model, self.args = self.load_model()
|
150 |
+
self.strategy = self.load_strategy()
|
151 |
+
self.srg = self.load_srg()
|
152 |
+
|
153 |
+
self.query_template = self.args.query_template
|
154 |
+
self.style = self.args.style
|
155 |
+
self.device = torch.device(self.args.device)
|
156 |
+
self.fp16 = self.args.fp16
|
157 |
+
self.max_batch_size = self.args.max_inference_batch_size
|
158 |
+
self.only_first_stage = self.args.only_first_stage
|
159 |
+
|
160 |
+
def load_tokenizer(self) -> IceTokenizer:
|
161 |
+
logger.info('--- load_tokenizer ---')
|
162 |
+
start = time.perf_counter()
|
163 |
+
|
164 |
+
tokenizer = IceTokenizer(ICETK_MODEL_DIR.as_posix())
|
165 |
+
tokenizer.add_special_tokens(
|
166 |
+
['<start_of_image>', '<start_of_english>', '<start_of_chinese>'])
|
167 |
+
|
168 |
+
elapsed = time.perf_counter() - start
|
169 |
+
logger.info(f'Elapsed: {elapsed}')
|
170 |
+
logger.info('--- done ---')
|
171 |
+
return tokenizer
|
172 |
+
|
173 |
+
def load_model(self) -> tuple[InferenceModel, argparse.Namespace]:
|
174 |
+
logger.info('--- load_model ---')
|
175 |
+
start = time.perf_counter()
|
176 |
+
|
177 |
+
model, args = InferenceModel.from_pretrained(self.args, 'coglm')
|
178 |
+
|
179 |
+
elapsed = time.perf_counter() - start
|
180 |
+
logger.info(f'Elapsed: {elapsed}')
|
181 |
+
logger.info('--- done ---')
|
182 |
+
return model, args
|
183 |
+
|
184 |
+
def load_strategy(self) -> CoglmStrategy:
|
185 |
+
logger.info('--- load_strategy ---')
|
186 |
+
start = time.perf_counter()
|
187 |
+
|
188 |
+
invalid_slices = [slice(self.tokenizer.num_image_tokens, None)]
|
189 |
+
strategy = CoglmStrategy(invalid_slices,
|
190 |
+
temperature=self.args.temp_all_gen,
|
191 |
+
top_k=self.args.topk_gen,
|
192 |
+
top_k_cluster=self.args.temp_cluster_gen)
|
193 |
+
|
194 |
+
elapsed = time.perf_counter() - start
|
195 |
+
logger.info(f'Elapsed: {elapsed}')
|
196 |
+
logger.info('--- done ---')
|
197 |
+
return strategy
|
198 |
+
|
199 |
+
def load_srg(self) -> SRGroup:
|
200 |
+
logger.info('--- load_srg ---')
|
201 |
+
start = time.perf_counter()
|
202 |
+
|
203 |
+
srg = None if self.args.only_first_stage else SRGroup(self.args)
|
204 |
+
|
205 |
+
elapsed = time.perf_counter() - start
|
206 |
+
logger.info(f'Elapsed: {elapsed}')
|
207 |
+
logger.info('--- done ---')
|
208 |
+
return srg
|
209 |
+
|
210 |
+
def update_style(self, style: str) -> None:
|
211 |
+
if style == self.style:
|
212 |
+
return
|
213 |
+
logger.info('--- update_style ---')
|
214 |
+
start = time.perf_counter()
|
215 |
+
|
216 |
+
self.args = argparse.Namespace(**(vars(self.args) | get_recipe(style)))
|
217 |
+
self.query_template = self.args.query_template
|
218 |
+
logger.info(f'{self.query_template=}')
|
219 |
+
|
220 |
+
self.strategy.temperature = self.args.temp_all_gen
|
221 |
+
|
222 |
+
if self.srg is not None:
|
223 |
+
self.srg.dsr.strategy.temperature = self.args.temp_all_dsr
|
224 |
+
self.srg.dsr.strategy.topk = self.args.topk_dsr
|
225 |
+
self.srg.dsr.strategy.temperature2 = self.args.temp_cluster_dsr
|
226 |
+
|
227 |
+
self.srg.itersr.strategy.temperature = self.args.temp_all_itersr
|
228 |
+
self.srg.itersr.strategy.topk = self.args.topk_itersr
|
229 |
+
|
230 |
+
elapsed = time.perf_counter() - start
|
231 |
+
logger.info(f'Elapsed: {elapsed}')
|
232 |
+
logger.info('--- done ---')
|
233 |
+
|
234 |
+
def run(self, text: str, style: str, seed: int, only_first_stage: bool,
|
235 |
+
num: int) -> list[np.ndarray] | None:
|
236 |
+
set_random_seed(seed)
|
237 |
+
seq, txt_len = self.preprocess_text(text)
|
238 |
+
if seq is None:
|
239 |
+
return None
|
240 |
+
self.update_style(style)
|
241 |
+
self.only_first_stage = only_first_stage
|
242 |
+
tokens = self.generate_tokens(seq, txt_len, num)
|
243 |
+
res = self.generate_images(seq, txt_len, tokens)
|
244 |
+
return res
|
245 |
+
|
246 |
+
@torch.inference_mode()
|
247 |
+
def preprocess_text(
|
248 |
+
self, text: str) -> tuple[torch.Tensor, int] | tuple[None, None]:
|
249 |
+
logger.info('--- preprocess_text ---')
|
250 |
+
start = time.perf_counter()
|
251 |
+
|
252 |
+
text = self.query_template.format(text)
|
253 |
+
logger.info(f'{text=}')
|
254 |
+
seq = self.tokenizer.encode(text)
|
255 |
+
logger.info(f'{len(seq)=}')
|
256 |
+
if len(seq) > 110:
|
257 |
+
logger.info('The input text is too long.')
|
258 |
+
return None, None
|
259 |
+
txt_len = len(seq) - 1
|
260 |
+
seq = torch.tensor(seq + [-1] * 400, device=self.device)
|
261 |
+
|
262 |
+
elapsed = time.perf_counter() - start
|
263 |
+
logger.info(f'Elapsed: {elapsed}')
|
264 |
+
logger.info('--- done ---')
|
265 |
+
return seq, txt_len
|
266 |
+
|
267 |
+
@torch.inference_mode()
|
268 |
+
def generate_tokens(self,
|
269 |
+
seq: torch.Tensor,
|
270 |
+
txt_len: int,
|
271 |
+
num: int = 8) -> torch.Tensor:
|
272 |
+
logger.info('--- generate_tokens ---')
|
273 |
+
start = time.perf_counter()
|
274 |
+
|
275 |
+
# calibrate text length
|
276 |
+
log_attention_weights = torch.zeros(
|
277 |
+
len(seq),
|
278 |
+
len(seq),
|
279 |
+
device=self.device,
|
280 |
+
dtype=torch.half if self.fp16 else torch.float32)
|
281 |
+
log_attention_weights[:, :txt_len] = self.args.attn_plus
|
282 |
+
get_func = functools.partial(get_masks_and_position_ids_coglm,
|
283 |
+
context_length=txt_len)
|
284 |
+
|
285 |
+
output_list = []
|
286 |
+
remaining = num
|
287 |
+
for _ in range((num + self.max_batch_size - 1) // self.max_batch_size):
|
288 |
+
self.strategy.start_pos = txt_len + 1
|
289 |
+
coarse_samples = filling_sequence(
|
290 |
+
self.model,
|
291 |
+
seq.clone(),
|
292 |
+
batch_size=min(remaining, self.max_batch_size),
|
293 |
+
strategy=self.strategy,
|
294 |
+
log_attention_weights=log_attention_weights,
|
295 |
+
get_masks_and_position_ids=get_func)[0]
|
296 |
+
output_list.append(coarse_samples)
|
297 |
+
remaining -= self.max_batch_size
|
298 |
+
output_tokens = torch.cat(output_list, dim=0)
|
299 |
+
logger.info(f'{output_tokens.shape=}')
|
300 |
+
|
301 |
+
elapsed = time.perf_counter() - start
|
302 |
+
logger.info(f'Elapsed: {elapsed}')
|
303 |
+
logger.info('--- done ---')
|
304 |
+
return output_tokens
|
305 |
+
|
306 |
+
@staticmethod
|
307 |
+
def postprocess(tensor: torch.Tensor) -> np.ndarray:
|
308 |
+
return tensor.cpu().mul(255).add_(0.5).clamp_(0, 255).permute(
|
309 |
+
1, 2, 0).to(torch.uint8).numpy()
|
310 |
+
|
311 |
+
@torch.inference_mode()
|
312 |
+
def generate_images(self, seq: torch.Tensor, txt_len: int,
|
313 |
+
tokens: torch.Tensor) -> list[np.ndarray]:
|
314 |
+
logger.info('--- generate_images ---')
|
315 |
+
start = time.perf_counter()
|
316 |
+
|
317 |
+
logger.info(f'{self.only_first_stage=}')
|
318 |
+
res = []
|
319 |
+
if self.only_first_stage:
|
320 |
+
for i in range(len(tokens)):
|
321 |
+
seq = tokens[i]
|
322 |
+
decoded_img = self.tokenizer.decode(image_ids=seq[-400:])
|
323 |
+
decoded_img = torch.nn.functional.interpolate(decoded_img,
|
324 |
+
size=(480, 480))
|
325 |
+
decoded_img = self.postprocess(decoded_img[0])
|
326 |
+
res.append(decoded_img) # only the last image (target)
|
327 |
+
else: # sr
|
328 |
+
iter_tokens = self.srg.sr_base(tokens[:, -400:], seq[:txt_len])
|
329 |
+
for seq in iter_tokens:
|
330 |
+
decoded_img = self.tokenizer.decode(image_ids=seq[-3600:])
|
331 |
+
decoded_img = torch.nn.functional.interpolate(decoded_img,
|
332 |
+
size=(480, 480))
|
333 |
+
decoded_img = self.postprocess(decoded_img[0])
|
334 |
+
res.append(decoded_img) # only the last image (target)
|
335 |
+
|
336 |
+
elapsed = time.perf_counter() - start
|
337 |
+
logger.info(f'Elapsed: {elapsed}')
|
338 |
+
logger.info('--- done ---')
|
339 |
+
return res
|
340 |
+
|
341 |
+
|
342 |
+
class AppModel(Model):
|
343 |
+
def __init__(self, only_first_stage: bool):
|
344 |
+
super().__init__(only_first_stage)
|
345 |
+
self.translator = gr.Interface.load(
|
346 |
+
'spaces/chinhon/translation_eng2ch')
|
347 |
+
|
348 |
+
def make_grid(self, images: list[np.ndarray] | None) -> np.ndarray | None:
|
349 |
+
if images is None or len(images) == 0:
|
350 |
+
return None
|
351 |
+
ncols = 1
|
352 |
+
while True:
|
353 |
+
if ncols**2 >= len(images):
|
354 |
+
break
|
355 |
+
ncols += 1
|
356 |
+
nrows = (len(images) + ncols - 1) // ncols
|
357 |
+
h, w = images[0].shape[:2]
|
358 |
+
grid = np.zeros((h * nrows, w * ncols, 3), dtype=np.uint8)
|
359 |
+
for i in range(nrows):
|
360 |
+
for j in range(ncols):
|
361 |
+
index = ncols * i + j
|
362 |
+
if index >= len(images):
|
363 |
+
break
|
364 |
+
grid[h * i:h * (i + 1), w * j:w * (j + 1)] = images[index]
|
365 |
+
return grid
|
366 |
+
|
367 |
+
def run_with_translation(
|
368 |
+
self, text: str, translate: bool, style: str, seed: int,
|
369 |
+
only_first_stage: bool, num: int
|
370 |
+
) -> tuple[str | None, np.ndarray | None, list[np.ndarray] | None]:
|
371 |
+
if translate:
|
372 |
+
text = translated_text = self.translator(text)
|
373 |
+
else:
|
374 |
+
translated_text = None
|
375 |
+
results = self.run(text, style, seed, only_first_stage, num)
|
376 |
+
grid_image = self.make_grid(results)
|
377 |
+
return translated_text, grid_image, results
|
patch
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
diff --git a/coglm_strategy.py b/coglm_strategy.py
|
2 |
+
index cba87ce..40e4ece 100755
|
3 |
+
--- a/coglm_strategy.py
|
4 |
+
+++ b/coglm_strategy.py
|
5 |
+
@@ -8,6 +8,7 @@
|
6 |
+
|
7 |
+
# here put the import lib
|
8 |
+
import os
|
9 |
+
+import pathlib
|
10 |
+
import sys
|
11 |
+
import math
|
12 |
+
import random
|
13 |
+
@@ -57,7 +58,8 @@ class CoglmStrategy:
|
14 |
+
self._is_done = False
|
15 |
+
self.outlier_count_down = 5
|
16 |
+
self.vis_list = [[]for i in range(16)]
|
17 |
+
- self.cluster_labels = torch.tensor(np.load('cluster_label.npy'), device='cuda', dtype=torch.long)
|
18 |
+
+ cluster_label_path = pathlib.Path(__file__).parent / 'cluster_label.npy'
|
19 |
+
+ self.cluster_labels = torch.tensor(np.load(cluster_label_path), device='cuda', dtype=torch.long)
|
20 |
+
self.top_k_cluster = top_k_cluster
|
21 |
+
|
22 |
+
@property
|
23 |
+
@@ -91,4 +93,4 @@ class CoglmStrategy:
|
24 |
+
|
25 |
+
def finalize(self, tokens, mems):
|
26 |
+
self._is_done = False
|
27 |
+
- return tokens, mems
|
28 |
+
|
29 |
+
+ return tokens, mems
|
30 |
+
diff --git a/sr_pipeline/dsr_sampling.py b/sr_pipeline/dsr_sampling.py
|
31 |
+
index a0d0298..f721573 100755
|
32 |
+
--- a/sr_pipeline/dsr_sampling.py
|
33 |
+
+++ b/sr_pipeline/dsr_sampling.py
|
34 |
+
@@ -8,6 +8,7 @@
|
35 |
+
|
36 |
+
# here put the import lib
|
37 |
+
import os
|
38 |
+
+import pathlib
|
39 |
+
import sys
|
40 |
+
import math
|
41 |
+
import random
|
42 |
+
@@ -27,7 +28,8 @@ class IterativeEntfilterStrategy:
|
43 |
+
self.invalid_slices = invalid_slices
|
44 |
+
self.temperature = temperature
|
45 |
+
self.topk = topk
|
46 |
+
- self.cluster_labels = torch.tensor(np.load('cluster_label.npy'), device='cuda', dtype=torch.long)
|
47 |
+
+ cluster_label_path = pathlib.Path(__file__).parents[1] / 'cluster_label.npy'
|
48 |
+
+ self.cluster_labels = torch.tensor(np.load(cluster_label_path), device='cuda', dtype=torch.long)
|
49 |
+
self.temperature2 = temperature2
|
50 |
+
|
51 |
+
|
requirements.txt
CHANGED
@@ -1,3 +1,7 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
1 |
+
git+https://github.com/Sleepychord/Image-Local-Attention@43fee31
|
2 |
+
gradio==3.0.17
|
3 |
+
icetk==0.0.3
|
4 |
+
numpy==1.22.4
|
5 |
+
SwissArmyTransformer==0.2.4
|
6 |
+
torch==1.11.0
|
7 |
+
torchvision==0.12.0
|
samples.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
A beautiful girl is hugging a husky.
|
2 |
+
A lion teacher wearing a suit is in front of a blackboard.
|
3 |
+
A robot is riding under the blue and cloudy sky.
|
4 |
+
Several youths are talking in a bar.
|
5 |
+
A lion man is typing in the office.
|
6 |
+
A young woman is taking photos.
|
7 |
+
A pirate captain with a skull.
|
8 |
+
A girl holding an oil-paper umbrella in a rainy lane.
|
9 |
+
Earth in the eye.
|
10 |
+
A magnificent church. Sketch.
|
11 |
+
Mount Fuji, cherry blossom and Akita dog. Oil painting.
|
12 |
+
A tiger with angel's wings.
|
13 |
+
A fox is sitting on the books.
|
style.css
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
h1 {
|
2 |
+
text-align: center;
|
3 |
+
}
|
4 |
+
img#visitor-badge {
|
5 |
+
display: block;
|
6 |
+
margin: auto;
|
7 |
+
}
|