Maykeye commited on
Commit
1a030c8
1 Parent(s): 6a2e483

Initial commit

Browse files
Files changed (15) hide show
  1. .gitignore +167 -0
  2. LICENSE +201 -0
  3. README.md +63 -3
  4. cli_imgen3_flip.py +54 -0
  5. image_utils.py +130 -0
  6. imgen3.py +100 -0
  7. imgen3flip.py +132 -0
  8. imgen3test.ipynb +0 -0
  9. imgen3test_flip.ipynb +0 -0
  10. krita-flip.png +0 -0
  11. krita-nonflip.png +0 -0
  12. krita/face1.png +0 -0
  13. krita/face2.png +0 -0
  14. torch_utils.py +13 -0
  15. valid.png +0 -0
.gitignore ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Will go to HF someday
2
+ data/
3
+ data/all_images_8.bin
4
+ data/all_images_64.bin
5
+
6
+ # Byte-compiled / optimized / DLL files
7
+ __pycache__/
8
+ *.py[cod]
9
+ *$py.class
10
+
11
+ # C extensions
12
+ *.so
13
+
14
+ # Distribution / packaging
15
+ .Python
16
+ build/
17
+ develop-eggs/
18
+ dist/
19
+ downloads/
20
+ eggs/
21
+ .eggs/
22
+ lib/
23
+ lib64/
24
+ parts/
25
+ sdist/
26
+ var/
27
+ wheels/
28
+ share/python-wheels/
29
+ *.egg-info/
30
+ .installed.cfg
31
+ *.egg
32
+ MANIFEST
33
+
34
+ # PyInstaller
35
+ # Usually these files are written by a python script from a template
36
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
37
+ *.manifest
38
+ *.spec
39
+
40
+ # Installer logs
41
+ pip-log.txt
42
+ pip-delete-this-directory.txt
43
+
44
+ # Unit test / coverage reports
45
+ htmlcov/
46
+ .tox/
47
+ .nox/
48
+ .coverage
49
+ .coverage.*
50
+ .cache
51
+ nosetests.xml
52
+ coverage.xml
53
+ *.cover
54
+ *.py,cover
55
+ .hypothesis/
56
+ .pytest_cache/
57
+ cover/
58
+
59
+ # Translations
60
+ *.mo
61
+ *.pot
62
+
63
+ # Django stuff:
64
+ *.log
65
+ local_settings.py
66
+ db.sqlite3
67
+ db.sqlite3-journal
68
+
69
+ # Flask stuff:
70
+ instance/
71
+ .webassets-cache
72
+
73
+ # Scrapy stuff:
74
+ .scrapy
75
+
76
+ # Sphinx documentation
77
+ docs/_build/
78
+
79
+ # PyBuilder
80
+ .pybuilder/
81
+ target/
82
+
83
+ # Jupyter Notebook
84
+ .ipynb_checkpoints
85
+
86
+ # IPython
87
+ profile_default/
88
+ ipython_config.py
89
+
90
+ # pyenv
91
+ # For a library or package, you might want to ignore these files since the code is
92
+ # intended to run in multiple environments; otherwise, check them in:
93
+ # .python-version
94
+
95
+ # pipenv
96
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
97
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
98
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
99
+ # install all needed dependencies.
100
+ #Pipfile.lock
101
+
102
+ # poetry
103
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
104
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
105
+ # commonly ignored for libraries.
106
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
107
+ #poetry.lock
108
+
109
+ # pdm
110
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
111
+ #pdm.lock
112
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
113
+ # in version control.
114
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
115
+ .pdm.toml
116
+ .pdm-python
117
+ .pdm-build/
118
+
119
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
120
+ __pypackages__/
121
+
122
+ # Celery stuff
123
+ celerybeat-schedule
124
+ celerybeat.pid
125
+
126
+ # SageMath parsed files
127
+ *.sage.py
128
+
129
+ # Environments
130
+ .env
131
+ .venv
132
+ env/
133
+ venv/
134
+ ENV/
135
+ env.bak/
136
+ venv.bak/
137
+
138
+ # Spyder project settings
139
+ .spyderproject
140
+ .spyproject
141
+
142
+ # Rope project settings
143
+ .ropeproject
144
+
145
+ # mkdocs documentation
146
+ /site
147
+
148
+ # mypy
149
+ .mypy_cache/
150
+ .dmypy.json
151
+ dmypy.json
152
+
153
+ # Pyre type checker
154
+ .pyre/
155
+
156
+ # pytype static type analyzer
157
+ .pytype/
158
+
159
+ # Cython debug symbols
160
+ cython_debug/
161
+
162
+ # PyCharm
163
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
164
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
165
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
166
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
167
+ #.idea/
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,3 +1,63 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ datasets:
4
+ - huggan/anime-faces
5
+ ---
6
+
7
+ # Mamba face kiss
8
+
9
+ ## KISS
10
+
11
+ This repo contains two Keep It Simple Stupid anime face generators that generates 64x64 faces from 8x8 provided images.
12
+
13
+ Basic idea was to take 64x64 anime faces dataset(https://huggingface.co/datasets/huggan/anime-faces), resize it to 8x8, then teach the model to restore original images, intuition is that after that if new unseen images are provided, it will make some face.
14
+
15
+ ![Validation](./valid.png)
16
+
17
+ Mamba is being fed a sequence `[A][A]...[A][SEP][B][B][B]...[B]` where there are 64 `[A]` that came from the 8x8 draft. there are 64x64 `[B]`s that are initially are upscaled draft(nearest neighbor) with addition of PAE. Model run several layers of mamba, and spits last 64x64 into RGB image. (`[SEP]` is not used for anything significant other than BERT has it to separate sentences, so I used it too as placeholder for command "Upscale from here")
18
+
19
+ Two models are used.
20
+
21
+ ### RNN goess brr (one way)
22
+
23
+ One(`imgen3test.ipynb` and `imgen3.py`) always feeds images from top-left pixel to bottom-right pixel row by row
24
+
25
+ ![Non-flip image](./krita-nonflip.png)
26
+
27
+
28
+ ### "Bi-directional"
29
+
30
+ Another take(`imgen3test_flip.ipynb` and `imgen3_flip.py`) feed from top-left pixel to bottom-right pixel in every even layer and every odd layer sees upscaled images in reverse order
31
+
32
+ ![Flip image](./krita-flip.png)
33
+
34
+ This flip version also uses way more parameters and different dtype. I didn't notice that much difference.
35
+
36
+
37
+ #### Command line tool
38
+
39
+ Simple script can be used to call the model on a single image
40
+
41
+ ```console
42
+ $ cli_imgen3_flip ./krita/face1.png face1.out.png
43
+
44
+ python cli_imgen3_flip.py ./krita/face1.png /tmp/face1.png
45
+ Weight path is data/image-flip-weights-1024x4-torch.bfloat16.bin
46
+ Loading the model
47
+ Loading 8x8 input image from ./krita/face1.png
48
+ Writing 64x64 image to /tmp/face1.png
49
+ ```
50
+
51
+ It's not really good way to use, comparing to calling through jupyter it though: mamba2 is implemented using triton and it takes around 30 seconds to initialize the model each time (on Raider GE76).
52
+
53
+
54
+ ## Recreating
55
+
56
+ Training is done in `imgen3(_flip)?.py`. Testing is in notebook. `Image_utils` should provide path to anime faces dataset.
57
+
58
+ ## Naming and configuring
59
+
60
+ Name imgen3 comes from "image generation 3".
61
+ Two other attemts are not that interesting to even backup them.
62
+
63
+ I'm too lazy to pass configuration around so parameters are hardcoded in the beginning of the file.
cli_imgen3_flip.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from imgen3flip import weights_path, Model, ImageBatch, OPTS
2
+ import torch
3
+ import torchvision as TV
4
+ import torchvision.transforms.functional as VF
5
+ import sys
6
+
7
+
8
+ assert weights_path.exists(), "Model weights do not exist"
9
+
10
+ assert len(sys.argv) == 3, f"Usage: {
11
+ sys.argv[0]} <input-filename> <output-filename>"
12
+
13
+ input_filename = sys.argv[1]
14
+ output_filename = sys.argv[2]
15
+
16
+ assert input_filename != output_filename, f"Use different file names"
17
+
18
+ print("Loading the model")
19
+ model = Model()
20
+ model.load_state_dict(torch.load(weights_path))
21
+
22
+ print(f"Loading 8x8 input image from {input_filename}")
23
+ # read image and ditch alpha-channel if it presents
24
+ image = TV.io.read_image(input_filename)[:3]
25
+ # Convert range from 0..255 to 0.0..1.0
26
+ image = image / 255.0
27
+ assert image.shape[0] == 3, "RGB image expected"
28
+ # Convert C H W -> H W C
29
+ image = image.permute(1, 2, 0)
30
+ # Now add batch dimension(B=1): H W C -> 1 H W C
31
+ # We also specify H, W, C explicitly as model expect them to be 8x8x3
32
+ image = image.view(1, 8, 8, 3)
33
+
34
+ # Now construct batch that model uses
35
+ # Target and loss are not used in inference, as model code always calculates loss
36
+ dummy_target = torch.zeros(1, 64, 64, 3, **OPTS)
37
+ dummy_loss = torch.tensor(-1, **OPTS)
38
+ inference_batch = ImageBatch(
39
+ im8=image.to(**OPTS),
40
+ im64=dummy_target,
41
+ loss=dummy_loss)
42
+ result = model(inference_batch)
43
+
44
+ # Now convert image to PIL format so we can save it
45
+ new_image = result.im64.detach().float().cpu()
46
+ # new_image: 1 H W C -> H W C
47
+ new_image = new_image[0]
48
+ # new_image: H W C -> C H W
49
+ new_image = new_image.permute(2, 0, 1)
50
+ assert new_image.shape == (3, 64, 64)
51
+ img = VF.to_pil_image(new_image)
52
+ # Save
53
+ print(f"Writing {img.height}x{img.width} image to {output_filename}")
54
+ img.save(output_filename)
image_utils.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from torch import Tensor
3
+ from pathlib import Path
4
+ import torch
5
+ import random
6
+ import torchvision.io as VIO
7
+ import torchvision.transforms.functional as VF
8
+ from dataclasses import dataclass
9
+ from tqdm.auto import tqdm
10
+
11
+ # https://huggingface.co/datasets/huggan/anime-faces
12
+ RAW_IMAGES_PATH = Path(
13
+ "~/Downloads/datasets/anime/anime-faces/images").expanduser()
14
+ RESOLUTIONS = [64, 8]
15
+
16
+ AS_TENSORS_64 = Path(f"data/all_images_64.bin")
17
+ AS_TENSORS_8 = Path(f"data/all_images_8.bin")
18
+
19
+
20
+ @dataclass
21
+ class ImageBatch:
22
+ im8: Tensor
23
+ im64: Tensor
24
+ loss: Tensor
25
+
26
+ @property
27
+ def n_batch(self):
28
+ return self.im8.shape[0]
29
+
30
+ def as_1d(self):
31
+ return ImageBatch(
32
+ im8=self.im8.view(self.n_batch, 8*8, self.im8.shape[-1]),
33
+ im64=self.im64.view(self.n_batch, 64*64, self.im64.shape[-1]),
34
+ loss=self.loss
35
+ )
36
+
37
+ def as_2d(self):
38
+ return ImageBatch(
39
+ im8=self.im8.view(self.n_batch, 8, 8, self.im8.shape[-1]),
40
+ im64=self.im64.view(self.n_batch, 64, 64, self.im64.shape[-1]),
41
+ loss=self.loss
42
+ )
43
+
44
+
45
+ class ImageDB:
46
+ def __init__(self, val_ratio=0.05, dtype=None) -> None:
47
+ if not AS_TENSORS_64.exists():
48
+ self.make_tensor_version()
49
+ print("Load tensors file")
50
+ self.dtype = dtype or torch.bfloat16
51
+ self.all_images_64 = torch.load(AS_TENSORS_64).to(self.dtype)
52
+ self.all_images_8 = torch.load(AS_TENSORS_8).to(self.dtype)
53
+ self.n_val = int(len(self.all_images_64) * val_ratio)
54
+
55
+ def split(self, s: str):
56
+ if s == "train":
57
+ return {
58
+ 8: self.all_images_8[:-self.n_val],
59
+ 64: self.all_images_64[:-self.n_val]
60
+ }
61
+ if s == "valid":
62
+ return {
63
+ 8: self.all_images_8[-self.n_val:],
64
+ 64: self.all_images_64[-self.n_val:]
65
+ }
66
+ raise ValueError(f"Invalid split {s}")
67
+
68
+ @property
69
+ def train_ds(self):
70
+ return self.split("train")
71
+
72
+ @property
73
+ def valid_ds(self):
74
+ return self.split("valid")
75
+
76
+ @torch.no_grad()
77
+ def make_tensor_version(self, path=RAW_IMAGES_PATH):
78
+ items = list(path.glob("*.png"))
79
+ all_tensors = [load_single_image(item) for item in tqdm(items)]
80
+ t64 = torch.stack([t[64] for t in all_tensors])
81
+ t8 = torch.stack([t[8] for t in all_tensors])
82
+ torch.save(t64, AS_TENSORS_64)
83
+ torch.save(t8, AS_TENSORS_8)
84
+ return {8: t8, 64: t64}
85
+
86
+ def random_batch(self, bs: int, split: str = "train"):
87
+ split_dict = self.split(split)
88
+ im8 = split_dict[8]
89
+ im64 = split_dict[64]
90
+ keys = list(range(len(im8)))
91
+ random.shuffle(keys)
92
+ keys = keys[: bs]
93
+ return ImageBatch(
94
+ im64=im64[keys].cuda(),
95
+ im8=im8[keys].cuda(),
96
+ loss=torch.tensor(-1))
97
+
98
+
99
+ def load_single_image(path: Path):
100
+ im = VIO.read_image(str(path))
101
+ im = im / 255.0
102
+ # resize to 8x8
103
+ im8 = VF.resize(im, [8, 8], VF.InterpolationMode.NEAREST_EXACT)
104
+ # C H W -> H W C
105
+ im = im.permute(1, 2, 0).contiguous()
106
+ im8 = im8.permute(1, 2, 0).contiguous()
107
+
108
+ return {64: im, 8: im8}
109
+
110
+
111
+ class RGBToModel(nn.Module):
112
+ def __init__(self, d_model, device=None, dtype=None):
113
+ super().__init__()
114
+ self.fc = nn.Linear(3, d_model, device=device, dtype=dtype)
115
+
116
+ def forward(self, x):
117
+ return self.fc(x)
118
+
119
+
120
+ class ModelToRGB(nn.Module):
121
+ def __init__(self, d_model, device=None, dtype=None):
122
+ super().__init__()
123
+ self.norm = nn.LayerNorm(d_model, device=device, dtype=dtype)
124
+ self.fc = nn.Linear(d_model, 3, device=device, dtype=dtype)
125
+
126
+ def forward(self, x):
127
+ x = self.norm(x)
128
+ x = self.fc(x)
129
+ x = x.sigmoid()
130
+ return x
imgen3.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mamba_ssm.modules.mamba_simple import Mamba
2
+ from mamba_ssm.modules.mamba2_simple import Mamba2Simple
3
+ from mamba_ssm.modules.mamba2 import Mamba2
4
+ import torch
5
+ from torch import Tensor
6
+ import random
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from tqdm.auto import tqdm
10
+ from pathlib import Path
11
+ from einops import rearrange, repeat
12
+ from typing import Optional
13
+
14
+
15
+ from image_utils import ImageDB, ImageBatch, RGBToModel
16
+ from image_utils import ModelToRGB
17
+ from torch_utils import model_numel
18
+
19
+ epochs = 10_000
20
+ bs = 16
21
+ d_model = 768
22
+ weights_path = Path(f"data/weights-{d_model}.bin")
23
+
24
+ OPTS = {
25
+ 'device': "cuda",
26
+ 'dtype': torch.float32
27
+ }
28
+
29
+
30
+ class MambaWrap(nn.Module):
31
+ def __init__(self) -> None:
32
+ super().__init__()
33
+ self.mamba = Mamba2Simple(d_model, **OPTS, headdim=64)
34
+ self.norm = nn.LayerNorm(d_model, **OPTS)
35
+
36
+ def forward(self, x):
37
+ residual = x
38
+ x = self.norm(x)
39
+ x = self.mamba(x)
40
+ x = residual + x
41
+ return x
42
+
43
+
44
+ class Model(nn.Module):
45
+ def __init__(self) -> None:
46
+ super().__init__()
47
+ self.from_rgb = RGBToModel(d_model, **OPTS)
48
+ self.to_rgb = ModelToRGB(d_model, **OPTS)
49
+ self.s0 = nn.Parameter(torch.randn(1, 1, d_model, **OPTS))
50
+ self.suffix = nn.Parameter(torch.randn(64*64, d_model, **OPTS))
51
+ self.layers = nn.ModuleList([MambaWrap() for _ in range(4)])
52
+ self.norm0 = nn.LayerNorm(d_model, **OPTS)
53
+
54
+ def forward(self, batch: ImageBatch):
55
+ B = batch.n_batch
56
+ batch = batch.as_1d()
57
+ batch.im8 = self.from_rgb(batch.im8)
58
+
59
+ s0 = self.s0.repeat(B, 1, 1)
60
+ s1 = self.zoom(batch.im8)
61
+
62
+ x = torch.cat((s0, batch.im8, s1), 1)
63
+ x = self.norm0(x)
64
+ x = self.mamba(x)
65
+ x = x[:, -64*64:]
66
+ y_hat = self.to_rgb(x)
67
+ y_true = batch.im64
68
+ batch.loss = F.mse_loss(y_hat, y_true)
69
+ batch.im64 = y_hat
70
+ return batch.as_2d()
71
+
72
+ def zoom(self, im8):
73
+ im8 = im8.view(im8.shape[0], 8, 8, im8.shape[-1])
74
+ im8 = repeat(
75
+ im8, "B H W C -> B (H 8) (W 8) C").view(im8.shape[0], 64*64, im8.shape[-1])
76
+ im8 = im8 + self.suffix
77
+ return im8
78
+
79
+ def mamba(self, x):
80
+ for layer in self.layers:
81
+ x = layer(x)
82
+ return x
83
+
84
+ if __name__ == "__main_":
85
+ image_db = ImageDB(dtype=OPTS["dtype"])
86
+ model = Model()
87
+ if weights_path.exists():
88
+ print(f"*** Load {weights_path:s}")
89
+ model.load_state_dict(torch.load(weights_path))
90
+ opt = torch.optim.AdamW(model.parameters(), fused=True)
91
+
92
+ for e in (bar := tqdm(range(epochs))):
93
+ b = model(image_db.random_batch(bs))
94
+ b.loss.backward()
95
+ opt.step()
96
+ opt.zero_grad()
97
+ bar.set_description(f'L:{b.loss.item():.4f}')
98
+ if e and e % 100 == 0:
99
+ torch.save(model.state_dict(), weights_path)
100
+ torch.save(model.state_dict(), weights_path)
imgen3flip.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mamba_ssm.modules.mamba2_simple import Mamba2Simple
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from tqdm.auto import tqdm
6
+ from pathlib import Path
7
+ from einops import repeat
8
+
9
+
10
+ from image_utils import ImageDB, ImageBatch, RGBToModel
11
+ from image_utils import ModelToRGB
12
+
13
+ epochs = 10_000
14
+ bs = 16
15
+ # orig;
16
+ # bs = 16
17
+ # d_model = 768
18
+ # headdim = 64
19
+ # n_layer = 4
20
+
21
+ d_model = 1024
22
+ headdim = 64
23
+ n_layer = 4
24
+
25
+ OPTS = {
26
+ 'device': "cuda",
27
+ 'dtype': torch.bfloat16
28
+ }
29
+ # Since we have KISS flip/flop think that number of mamba layers are actually 2 times higher
30
+ # This is somewhat relatable to LLM model where 1 block had two mamba layers: one replaced ATTN, one replaced MLP
31
+
32
+ weights_path = Path(
33
+ f"data/image-flip-weights-{d_model}x{n_layer}-{str(OPTS['dtype'])}.bin")
34
+ print(f"Weight path is {str(weights_path)}")
35
+
36
+
37
+ class MambaWrap(nn.Module):
38
+ def __init__(self) -> None:
39
+ super().__init__()
40
+ self.mamba = Mamba2Simple(d_model, **OPTS, headdim=headdim)
41
+ self.norm = nn.LayerNorm(d_model, **OPTS)
42
+
43
+ def forward(self, x):
44
+ residual = x
45
+ x = self.norm(x)
46
+ x = self.mamba(x)
47
+ x = residual + x
48
+ return x
49
+
50
+
51
+ class MambaFlipFlop(nn.Module):
52
+ def __init__(self, n_values) -> None:
53
+ super().__init__()
54
+ self.mb_forward = MambaWrap()
55
+ self.mb_backward = MambaWrap()
56
+ self.n_values = n_values
57
+
58
+ def forward(self, x):
59
+ x = self.mb_forward(x)
60
+ x = self.swap_order(x)
61
+ x = self.mb_backward(x)
62
+ x = self.swap_order(x)
63
+ return x
64
+
65
+ def swap_order(self, x):
66
+ T = x.shape[1]
67
+ head = torch.arange(0, T - self.n_values)
68
+ tail = torch.arange(T - 1, T - self.n_values - 1, -1)
69
+ seq = torch.cat((head, tail))
70
+ x = x[:, seq]
71
+ return x
72
+
73
+
74
+ class Model(nn.Module):
75
+ def __init__(self) -> None:
76
+ super().__init__()
77
+ self.from_rgb = RGBToModel(d_model, **OPTS)
78
+ self.to_rgb = ModelToRGB(d_model, **OPTS)
79
+ self.s0 = nn.Parameter(torch.randn(1, 1, d_model, **OPTS))
80
+ self.suffix = nn.Parameter(torch.randn(64*64, d_model, **OPTS))
81
+ self.layers = nn.ModuleList([MambaFlipFlop(64*64)
82
+ for _ in range(n_layer)])
83
+ self.norm0 = nn.LayerNorm(d_model, **OPTS)
84
+
85
+ def forward(self, batch: ImageBatch):
86
+ B = batch.n_batch
87
+ batch = batch.as_1d()
88
+ batch.im8 = self.from_rgb(batch.im8)
89
+
90
+ s0 = self.s0.repeat(B, 1, 1)
91
+ s1 = self.zoom(batch.im8)
92
+
93
+ x = torch.cat((s0, batch.im8, s1), 1)
94
+ x = self.norm0(x)
95
+ x = self.mamba(x)
96
+ x = x[:, -64*64:]
97
+ y_hat = self.to_rgb(x)
98
+ y_true = batch.im64
99
+ batch.loss = F.mse_loss(y_hat, y_true)
100
+ batch.im64 = y_hat
101
+ return batch.as_2d()
102
+
103
+ def zoom(self, im8):
104
+ im8 = im8.view(im8.shape[0], 8, 8, im8.shape[-1])
105
+ im8 = repeat(im8, "B H W C -> B (H 8) (W 8) C")
106
+ im8 = im8.view(im8.shape[0], 64*64, im8.shape[-1])
107
+ im8 = im8 + self.suffix
108
+ return im8
109
+
110
+ def mamba(self, x):
111
+ for layer in self.layers:
112
+ x = layer(x)
113
+ return x
114
+
115
+
116
+ if __name__ == "__main__":
117
+ image_db = ImageDB(dtype=OPTS["dtype"])
118
+ model = Model()
119
+ if weights_path.exists():
120
+ print(f"*** Load {str(weights_path)}")
121
+ model.load_state_dict(torch.load(weights_path))
122
+ opt = torch.optim.AdamW(model.parameters(), fused=True)
123
+
124
+ for e in (bar := tqdm(range(epochs))):
125
+ b = model(image_db.random_batch(bs))
126
+ b.loss.backward()
127
+ opt.step()
128
+ opt.zero_grad()
129
+ bar.set_description(f'L:{b.loss.item():.4f}')
130
+ if e and e % 100 == 0:
131
+ torch.save(model.state_dict(), weights_path)
132
+ torch.save(model.state_dict(), weights_path)
imgen3test.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
imgen3test_flip.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
krita-flip.png ADDED
krita-nonflip.png ADDED
krita/face1.png ADDED
krita/face2.png ADDED
torch_utils.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+
4
+
5
+ def model_device(m: nn.Module):
6
+ return next(iter(m.parameters())).device
7
+
8
+
9
+ def model_numel(m: nn.Module, requires_grad=False):
10
+ if requires_grad:
11
+ return sum(p.numel() for p in m.parameters() if p.requires_grad)
12
+ else:
13
+ return sum(p.numel() for p in m.parameters())
valid.png ADDED