neggles commited on
Commit
bb1671a
1 Parent(s): 2cb658e

make space happen

Browse files
.editorconfig ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # http://editorconfig.org
2
+
3
+ root = true
4
+
5
+ [*]
6
+ indent_style = space
7
+ indent_size = 4
8
+ trim_trailing_whitespace = true
9
+ insert_final_newline = true
10
+ charset = utf-8
11
+ end_of_line = lf
12
+
13
+ [*.bat]
14
+ indent_style = tab
15
+ end_of_line = crlf
16
+
17
+ [*.{json,jsonc}]
18
+ indent_style = space
19
+ indent_size = 2
20
+
21
+ [.vscode/*.{json,jsonc}]
22
+ indent_style = space
23
+ indent_size = 4
24
+
25
+ [*.{yml,yaml,toml}]
26
+ indent_style = space
27
+ indent_size = 2
28
+
29
+ [*.md]
30
+ trim_trailing_whitespace = false
31
+
32
+ [Makefile]
33
+ indent_style = tab
34
+ indent_size = 8
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Created by https://www.toptal.com/developers/gitignore/api/linux,windows,macos,visualstudiocode,python
2
+ # Edit at https://www.toptal.com/developers/gitignore?templates=linux,windows,macos,visualstudiocode,python
3
+
4
+ ### Linux ###
5
+ *~
6
+
7
+ # temporary files which can be created if a process still has a handle open of a deleted file
8
+ .fuse_hidden*
9
+
10
+ # KDE directory preferences
11
+ .directory
12
+
13
+ # Linux trash folder which might appear on any partition or disk
14
+ .Trash-*
15
+
16
+ # .nfs files are created when an open file is removed but is still being accessed
17
+ .nfs*
18
+
19
+ ### macOS ###
20
+ # General
21
+ .DS_Store
22
+ .AppleDouble
23
+ .LSOverride
24
+
25
+ # Icon must end with two \r
26
+ Icon
27
+
28
+
29
+ # Thumbnails
30
+ ._*
31
+
32
+ # Files that might appear in the root of a volume
33
+ .DocumentRevisions-V100
34
+ .fseventsd
35
+ .Spotlight-V100
36
+ .TemporaryItems
37
+ .Trashes
38
+ .VolumeIcon.icns
39
+ .com.apple.timemachine.donotpresent
40
+
41
+ # Directories potentially created on remote AFP share
42
+ .AppleDB
43
+ .AppleDesktop
44
+ Network Trash Folder
45
+ Temporary Items
46
+ .apdisk
47
+
48
+ ### Python ###
49
+ # Byte-compiled / optimized / DLL files
50
+ __pycache__/
51
+ *.py[cod]
52
+ *$py.class
53
+
54
+ # C extensions
55
+ *.so
56
+
57
+ # Distribution / packaging
58
+ .Python
59
+ build/
60
+ develop-eggs/
61
+ dist/
62
+ downloads/
63
+ eggs/
64
+ .eggs/
65
+ lib/
66
+ lib64/
67
+ parts/
68
+ sdist/
69
+ var/
70
+ wheels/
71
+ share/python-wheels/
72
+ *.egg-info/
73
+ .installed.cfg
74
+ *.egg
75
+ MANIFEST
76
+
77
+ # PyInstaller
78
+ # Usually these files are written by a python script from a template
79
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
80
+ *.manifest
81
+ *.spec
82
+
83
+ # Installer logs
84
+ pip-log.txt
85
+ pip-delete-this-directory.txt
86
+
87
+ # Unit test / coverage reports
88
+ htmlcov/
89
+ .tox/
90
+ .nox/
91
+ .coverage
92
+ .coverage.*
93
+ .cache
94
+ nosetests.xml
95
+ coverage.xml
96
+ *.cover
97
+ *.py,cover
98
+ .hypothesis/
99
+ .pytest_cache/
100
+ cover/
101
+
102
+ # Translations
103
+ *.mo
104
+ *.pot
105
+
106
+ # Django stuff:
107
+ *.log
108
+ local_settings.py
109
+ db.sqlite3
110
+ db.sqlite3-journal
111
+
112
+ # Flask stuff:
113
+ instance/
114
+ .webassets-cache
115
+
116
+ # Scrapy stuff:
117
+ .scrapy
118
+
119
+ # Sphinx documentation
120
+ docs/_build/
121
+
122
+ # PyBuilder
123
+ .pybuilder/
124
+ target/
125
+
126
+ # Jupyter Notebook
127
+ .ipynb_checkpoints
128
+
129
+ # IPython
130
+ profile_default/
131
+ ipython_config.py
132
+
133
+ # pyenv
134
+ # For a library or package, you might want to ignore these files since the code is
135
+ # intended to run in multiple environments; otherwise, check them in:
136
+ # .python-version
137
+
138
+ # pipenv
139
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
140
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
141
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
142
+ # install all needed dependencies.
143
+ #Pipfile.lock
144
+
145
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
146
+ __pypackages__/
147
+
148
+ # Celery stuff
149
+ celerybeat-schedule
150
+ celerybeat.pid
151
+
152
+ # SageMath parsed files
153
+ *.sage.py
154
+
155
+ # Environments
156
+ .env
157
+ .venv
158
+ env/
159
+ venv/
160
+ ENV/
161
+ env.bak/
162
+ venv.bak/
163
+
164
+ # Spyder project settings
165
+ .spyderproject
166
+ .spyproject
167
+
168
+ # Rope project settings
169
+ .ropeproject
170
+
171
+ # mkdocs documentation
172
+ /site
173
+
174
+ # mypy
175
+ .mypy_cache/
176
+ .dmypy.json
177
+ dmypy.json
178
+
179
+ # Pyre type checker
180
+ .pyre/
181
+
182
+ # pytype static type analyzer
183
+ .pytype/
184
+
185
+ # Cython debug symbols
186
+ cython_debug/
187
+
188
+ ### VisualStudioCode ###
189
+ .vscode/*
190
+ !.vscode/settings.json
191
+ !.vscode/tasks.json
192
+ !.vscode/launch.json
193
+ !.vscode/extensions.json
194
+ *.code-workspace
195
+
196
+ # Local History for Visual Studio Code
197
+ .history/
198
+
199
+ ### VisualStudioCode Patch ###
200
+ # Ignore all local history of files
201
+ .history
202
+ .ionide
203
+
204
+ ### Windows ###
205
+ # Windows thumbnail cache files
206
+ Thumbs.db
207
+ Thumbs.db:encryptable
208
+ ehthumbs.db
209
+ ehthumbs_vista.db
210
+
211
+ # Dump file
212
+ *.stackdump
213
+
214
+ # Folder config file
215
+ [Dd]esktop.ini
216
+
217
+ # Recycle Bin used on file shares
218
+ $RECYCLE.BIN/
219
+
220
+ # Windows Installer files
221
+ *.cab
222
+ *.msi
223
+ *.msix
224
+ *.msm
225
+ *.msp
226
+
227
+ # Windows shortcuts
228
+ *.lnk
229
+
230
+ # End of https://www.toptal.com/developers/gitignore/api/linux,windows,macos,visualstudiocode,python
231
+
232
+ # setuptools-scm _version file
233
+ src/neurosis/_version.py
234
+
235
+ # temp and misc
236
+ /misc/
237
+ /temp/
238
+
239
+ # external repos
240
+ /repos/
241
+
242
+ # wandb
243
+ /wandb/
244
+
245
+ # outputs and such
246
+ /logs/
247
+ /cache/
248
+ /outputs/
249
+ /projects/
250
+
251
+ # direnv
252
+ .envrc
253
+ .envrc.*
254
+
255
+ # dotenv
256
+ .env
257
+ .env.*
258
+
259
+ # temp files
260
+ **/tmp_*.*
261
+ **/*.tmp.*
262
+
263
+ # but keep examples
264
+ !*.example
.pre-commit-config.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # See https://pre-commit.com for more information
2
+ ci:
3
+ autofix_prs: true
4
+ autoupdate_branch: "main"
5
+ autoupdate_commit_msg: "[pre-commit.ci] pre-commit autoupdate"
6
+ autoupdate_schedule: weekly
7
+
8
+ repos:
9
+ - repo: https://github.com/astral-sh/ruff-pre-commit
10
+ rev: v0.2.0
11
+ hooks:
12
+ # Run the linter.
13
+ - id: ruff
14
+ types_or: [python, pyi, jupyter]
15
+ args: [--fix, --exit-non-zero-on-fix]
16
+ # Run the formatter.
17
+ - id: ruff-format
18
+ types_or: [python, pyi, jupyter]
19
+
20
+ - repo: https://github.com/pre-commit/pre-commit-hooks
21
+ rev: v4.5.0
22
+ hooks:
23
+ - id: trailing-whitespace
24
+ exclude_types:
25
+ - "markdown"
26
+ - id: end-of-file-fixer
27
+ - id: check-yaml
.vscode/settings.json ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "editor.insertSpaces": true,
3
+ "editor.tabSize": 4,
4
+ "files.trimTrailingWhitespace": true,
5
+ "editor.rulers": [100, 120],
6
+
7
+ "files.associations": {
8
+ "*.yaml": "yaml"
9
+ },
10
+ "files.exclude": {
11
+ "**/.git": true,
12
+ "**/.svn": true,
13
+ "**/.hg": true,
14
+ "**/CVS": true,
15
+ "**/.DS_Store": true,
16
+ "**/Thumbs.db": true,
17
+ "**/.ruff_cache": true,
18
+ "**/__pycache__": true,
19
+ "**/*.egg-info": true
20
+ },
21
+
22
+ "[shellscript]": {
23
+ "files.eol": "\n",
24
+ "editor.tabSize": 4,
25
+ "editor.detectIndentation": false
26
+ },
27
+
28
+ "[python]": {
29
+ "editor.wordBasedSuggestions": "off",
30
+ "editor.formatOnSave": true,
31
+ "editor.defaultFormatter": "charliermarsh.ruff",
32
+ "editor.codeActionsOnSave": {
33
+ "source.organizeImports": "always"
34
+ }
35
+ },
36
+ "python.analysis.include": ["./src", "./scripts", "./tests"],
37
+
38
+ "[json]": {
39
+ "editor.defaultFormatter": "esbenp.prettier-vscode",
40
+ "editor.detectIndentation": false,
41
+ "editor.formatOnSaveMode": "file",
42
+ "editor.formatOnSave": true,
43
+ "editor.tabSize": 2
44
+ },
45
+ "[jsonc]": {
46
+ "editor.defaultFormatter": "esbenp.prettier-vscode",
47
+ "editor.detectIndentation": false,
48
+ "editor.formatOnSaveMode": "file",
49
+ "editor.formatOnSave": true,
50
+ "editor.tabSize": 2
51
+ },
52
+
53
+ "[toml]": {
54
+ "editor.tabSize": 2,
55
+ "editor.detectIndentation": false,
56
+ "editor.formatOnSave": true,
57
+ "editor.formatOnSaveMode": "file",
58
+ "editor.defaultFormatter": "tamasfe.even-better-toml",
59
+ "editor.rulers": [80, 100]
60
+ },
61
+ "evenBetterToml.formatter.columnWidth": 88,
62
+
63
+ "[yaml]": {
64
+ "editor.detectIndentation": false,
65
+ "editor.tabSize": 2,
66
+ "editor.formatOnSave": true,
67
+ "editor.formatOnSaveMode": "file",
68
+ "diffEditor.ignoreTrimWhitespace": false,
69
+ "editor.defaultFormatter": "redhat.vscode-yaml"
70
+ },
71
+ "yaml.format.bracketSpacing": true,
72
+ "yaml.format.proseWrap": "preserve",
73
+ "yaml.format.singleQuote": false,
74
+ "yaml.format.printWidth": 110,
75
+
76
+ "[hcl]": {
77
+ "editor.detectIndentation": false,
78
+ "editor.formatOnSave": true,
79
+ "editor.formatOnSaveMode": "file",
80
+ "editor.defaultFormatter": "fredwangwang.vscode-hcl-format"
81
+ },
82
+
83
+ "[markdown]": {
84
+ "files.trimTrailingWhitespace": false
85
+ },
86
+
87
+ "css.lint.validProperties": ["dock", "content-align", "content-justify"],
88
+ "[css]": {
89
+ "editor.formatOnSave": true
90
+ },
91
+
92
+ "remote.autoForwardPorts": false,
93
+ "remote.autoForwardPortsSource": "process"
94
+ }
LICENSE.md ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ The MIT License (MIT)
2
+ =====================
3
+
4
+ Copyright © 2024 Andi Powers-Holmes <aholmes@omnom.net>
5
+
6
+ Permission is hereby granted, free of charge, to any person
7
+ obtaining a copy of this software and associated documentation
8
+ files (the “Software”), to deal in the Software without
9
+ restriction, including without limitation the rights to use,
10
+ copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ copies of the Software, and to permit persons to whom the
12
+ Software is furnished to do so, subject to the following
13
+ conditions:
14
+
15
+ The above copyright notice and this permission notice shall be
16
+ included in all copies or substantial portions of the Software.
17
+
18
+ THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND,
19
+ EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
20
+ OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
21
+ NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
22
+ HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
23
+ WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
24
+ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
25
+ OTHER DEALINGS IN THE SOFTWARE.
app.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os import getenv
2
+ from typing import Optional
3
+
4
+ import gradio as gr
5
+ import torch
6
+ from PIL import Image
7
+ from torchvision.transforms import v2 as T
8
+
9
+ from dreamsim import DreamsimBackbone, DreamsimEnsemble, DreamsimModel
10
+
11
+ _ = torch.set_grad_enabled(False)
12
+ torchdev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ torch.set_float32_matmul_precision("high")
14
+
15
+ HF_TOKEN = getenv("HF_TOKEN", None)
16
+ MODEL_REPO = "neggles/dreamsim"
17
+ MODEL_VARIANTS: dict[str, str] = {
18
+ "Ensemble": "ensemble_vitb16",
19
+ "CLIP ViT-B/32": "clip_vitb32",
20
+ "OpenCLIP ViT-B/32": "open_clip_vitb32",
21
+ "DINO ViT-B/16": "dino_vitb16",
22
+ }
23
+
24
+ loaded_models: dict[str, Optional[DreamsimBackbone]] = {
25
+ "ensemble_vitb16": None,
26
+ "clip_vitb32": None,
27
+ "open_clip_vitb32": None,
28
+ "dino_vitb16": None,
29
+ }
30
+
31
+
32
+ def pil_ensure_rgb(image: Image.Image) -> Image.Image:
33
+ # convert to RGB/RGBA if not already (deals with palette images etc.)
34
+ if image.mode not in ["RGB", "RGBA"]:
35
+ image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB")
36
+ # convert RGBA to RGB with white background
37
+ if image.mode == "RGBA":
38
+ canvas = Image.new("RGBA", image.size, (255, 255, 255))
39
+ canvas.alpha_composite(image)
40
+ image = canvas.convert("RGB")
41
+ return image
42
+
43
+
44
+ def pil_pad_square(
45
+ image: Image.Image,
46
+ fill: tuple[int, int, int] = (255, 255, 255),
47
+ ) -> Image.Image:
48
+ w, h = image.size
49
+ # get the largest dimension so we can pad to a square
50
+ px = max(image.size)
51
+ # pad to square with white background
52
+ canvas = Image.new("RGB", (px, px), fill)
53
+ canvas.paste(image, ((px - w) // 2, (px - h) // 2))
54
+ return canvas
55
+
56
+
57
+ def load_model(variant: str) -> DreamsimBackbone:
58
+ global loaded_models
59
+
60
+ if variant in MODEL_VARIANTS:
61
+ # resolve the repo branch for the model variant
62
+ variant = MODEL_VARIANTS[variant]
63
+
64
+ match variant:
65
+ case "ensemble_vitb16":
66
+ if loaded_models[variant] is None:
67
+ model: DreamsimEnsemble = DreamsimEnsemble.from_pretrained(
68
+ MODEL_REPO,
69
+ token=HF_TOKEN,
70
+ revision=variant,
71
+ )
72
+ model.do_resize = False
73
+ loaded_models[variant] = model
74
+
75
+ case "clip_vitb32" | "open_clip_vitb32" | "dino_vitb16":
76
+ if loaded_models[variant] is None:
77
+ model: DreamsimModel = DreamsimModel.from_pretrained(
78
+ MODEL_REPO,
79
+ token=HF_TOKEN,
80
+ revision=variant,
81
+ )
82
+ model.do_resize = False
83
+ loaded_models[variant] = model
84
+
85
+ case _:
86
+ raise ValueError(f"Unknown model variant: {variant}")
87
+
88
+ return loaded_models[variant]
89
+
90
+
91
+ def predict(
92
+ variant: str,
93
+ resize_to: Optional[int],
94
+ image_a: Image.Image,
95
+ image_b: Image.Image,
96
+ ):
97
+ # Load model
98
+ model: DreamsimModel | DreamsimEnsemble = load_model(variant)
99
+ model = model.eval().to(torchdev)
100
+
101
+ # yeet alpha, make white background
102
+ image_a, image_b = pil_ensure_rgb(image_a), pil_ensure_rgb(image_b)
103
+ # pad to square
104
+ image_a, image_b = pil_pad_square(image_a), pil_pad_square(image_b)
105
+
106
+ # Resize images, if necessary
107
+ if resize_to is not None:
108
+ image_a.thumbnail((resize_to, resize_to), resample=Image.Resampling.BICUBIC)
109
+ image_b.thumbnail((resize_to, resize_to), resample=Image.Resampling.BICUBIC)
110
+
111
+ # Preprocess images
112
+ transforms = T.Compose([T.ToImage(), T.ToDtype(torch.float32, scale=True)])
113
+ batch = torch.stack([transforms(image_a).unsqueeze(0), transforms(image_b).unsqueeze(0)], dim=0)
114
+
115
+ loss = model(batch.to(model.device, model.dtype)).cpu().item()
116
+ score = 1.0 - loss
117
+ return score, variant
118
+
119
+
120
+ def main():
121
+ with gr.Blocks(title="DreamSIM Perceptual Similarity") as demo:
122
+ with gr.Row():
123
+ with gr.Column():
124
+ img_input = gr.Image(label="Input", type="pil", image_mode="RGB", scale=1)
125
+ with gr.Column():
126
+ img_target = gr.Image(label="Target", type="pil", image_mode="RGB", scale=1)
127
+ with gr.Row(equal_height=True):
128
+ with gr.Column():
129
+ variant = gr.Radio(
130
+ choices=list(MODEL_VARIANTS.keys()), label="Model Variant", value="Ensemble"
131
+ )
132
+ resize_to = gr.Dropdown(label="Resize To", choices=[224, 384, 512, None], value=224)
133
+ with gr.Column():
134
+ score = gr.Number(label="Similarity Score", precision=8, minimum=0, maximum=1)
135
+ variant_out = gr.Textbox(label="Variant", interactive=False)
136
+ with gr.Row():
137
+ clear = gr.ClearButton(
138
+ components=[img_input, img_target, score], variant="secondary", size="lg"
139
+ )
140
+ submit = gr.Button(value="Submit", variant="primary", size="lg")
141
+
142
+ submit.click(
143
+ predict,
144
+ inputs=[variant, resize_to, img_input, img_target],
145
+ outputs=[score, variant_out],
146
+ api_name=False,
147
+ )
148
+ examples = gr.Examples(
149
+ [
150
+ ["examples/img_a_1.png", "examples/ref_1.png", "Ensemble", 224],
151
+ ["examples/img_b_1.png", "examples/ref_1.png", "Ensemble", 224],
152
+ ],
153
+ inputs=[img_input, img_target, variant, resize_to],
154
+ )
155
+
156
+ demo.queue(max_size=10)
157
+ demo.launch()
158
+
159
+
160
+ if __name__ == "__main__":
161
+ main()
dreamsim/.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
dreamsim/README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ datasets:
4
+ - PerceptionEval/DreamSim
5
+ library_name: transformers
6
+ ---
7
+
8
+ dreamsim! now in quasi-transformers quasi-diffusers form.
9
+
10
+ this probably won't work for you! but if it works for what i'm experimenting with, i'll try to get it upstreamed.
dreamsim/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from .model import DreamsimBackbone, DreamsimEnsemble, DreamsimModel
2
+ from .vit import VisionTransformer, vit_base_dreamsim
3
+
4
+ __all__ = [
5
+ "DreamsimBackbone",
6
+ "DreamsimEnsemble",
7
+ "DreamsimModel",
8
+ "VisionTransformer",
9
+ "vit_base_dreamsim",
10
+ ]
dreamsim/common.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable
2
+
3
+ import torch
4
+ from torch import Tensor, nn
5
+ from torch.nn import functional as F
6
+
7
+
8
+ def ensure_tuple(val: int | tuple[int, ...], n: int = 2) -> tuple[int, ...]:
9
+ if isinstance(val, int):
10
+ return (val,) * n
11
+ elif len(val) != n:
12
+ raise ValueError(f"Expected a tuple of {n} values, but got {len(val)}: {val}")
13
+ return val
14
+
15
+
16
+ def use_fused_attn():
17
+ if hasattr(F, "scaled_dot_product_attention"):
18
+ return True
19
+ return False
20
+
21
+
22
+ class QuickGELU(nn.Module):
23
+ """
24
+ Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs
25
+ """
26
+
27
+ def forward(self, input: Tensor) -> Tensor:
28
+ return input * torch.sigmoid(1.702 * input)
29
+
30
+
31
+ def get_act_layer(name: str) -> Callable[[], nn.Module]:
32
+ match name:
33
+ case "gelu":
34
+ return nn.GELU
35
+ case "quick_gelu":
36
+ return QuickGELU
37
+ case _:
38
+ raise ValueError(f"Activation layer {name} not supported.")
dreamsim/model.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+
3
+ import torch
4
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
5
+ from diffusers.models.modeling_utils import ModelMixin
6
+ from torch import Tensor
7
+ from torch.nn import functional as F
8
+ from torchvision.transforms import v2 as T
9
+
10
+ from .common import ensure_tuple
11
+ from .vit import VisionTransformer, vit_base_dreamsim
12
+
13
+
14
+ class DreamsimBackbone(ModelMixin, ConfigMixin):
15
+ @abstractmethod
16
+ def forward_features(self, x: Tensor) -> Tensor:
17
+ raise NotImplementedError("abstract base class was called ;_;")
18
+
19
+ def forward(self, x: Tensor) -> Tensor:
20
+ """Dreamsim forward pass for similarity computation.
21
+ Args:
22
+ x (Tensor): Input tensor of shape [2, B, 3, H, W].
23
+
24
+ Returns:
25
+ sim (torch.Tensor): dreamsim similarity score of shape [B].
26
+ """
27
+ inputs = x.view(-1, 3, *x.shape[-2:])
28
+
29
+ x = self.forward_features(inputs).view(*x.shape[:2], -1)
30
+
31
+ return 1 - F.cosine_similarity(x[0], x[1], dim=1)
32
+
33
+ def compile(self, *args, **kwargs):
34
+ """Compile the model with Inductor. This is a no-op unless overridden by a subclass."""
35
+ return self
36
+
37
+
38
+ class DreamsimModel(DreamsimBackbone):
39
+ @register_to_config
40
+ def __init__(
41
+ self,
42
+ image_size: int = 224,
43
+ patch_size: int = 16,
44
+ layer_norm_eps: float = 1e-6,
45
+ pre_norm: bool = False,
46
+ act_layer: str = "gelu",
47
+ img_mean: tuple[float, float, float] = (0.485, 0.456, 0.406),
48
+ img_std: tuple[float, float, float] = (0.229, 0.224, 0.225),
49
+ do_resize: bool = False,
50
+ ) -> None:
51
+ super().__init__()
52
+
53
+ self.image_size = ensure_tuple(image_size, 2)
54
+ self.patch_size = ensure_tuple(patch_size, 2)
55
+ self.layer_norm_eps = layer_norm_eps
56
+ self.pre_norm = pre_norm
57
+ self.do_resize = do_resize
58
+ self.img_mean = img_mean
59
+ self.img_std = img_std
60
+
61
+ num_classes = 512 if self.pre_norm else 0
62
+ self.extractor: VisionTransformer = vit_base_dreamsim(
63
+ image_size=image_size,
64
+ patch_size=patch_size,
65
+ layer_norm_eps=layer_norm_eps,
66
+ num_classes=num_classes,
67
+ pre_norm=pre_norm,
68
+ act_layer=act_layer,
69
+ )
70
+
71
+ self.resize = T.Resize(
72
+ self.image_size,
73
+ interpolation=T.InterpolationMode.BICUBIC,
74
+ antialias=True,
75
+ )
76
+ self.img_norm = T.Normalize(mean=self.img_mean, std=self.img_std)
77
+
78
+ def compile(self, *, mode: str = "reduce-overhead", force: bool = False, **kwargs):
79
+ if (not self._compiled) or force:
80
+ self.extractor = torch.compile(self.extractor, mode=mode, **kwargs)
81
+ self._compiled = True
82
+ return self
83
+
84
+ def transforms(self, x: Tensor) -> Tensor:
85
+ if self.do_resize:
86
+ x = self.resize(x)
87
+ return self.img_norm(x)
88
+
89
+ def forward_features(self, x: Tensor) -> Tensor:
90
+ if x.ndim == 3:
91
+ x = x.unsqueeze(0)
92
+ x = self.transforms(x)
93
+ x = self.extractor.forward(x, norm=self.pre_norm)
94
+
95
+ x = x.div(x.norm(dim=1, keepdim=True))
96
+ x = x.sub(x.mean(dim=1, keepdim=True))
97
+ return x
98
+
99
+
100
+ class DreamsimEnsemble(DreamsimBackbone):
101
+ @register_to_config
102
+ def __init__(
103
+ self,
104
+ image_size: int = 224,
105
+ patch_size: int = 16,
106
+ layer_norm_eps: float | tuple[float, ...] = (1e-6, 1e-5, 1e-5),
107
+ num_classes: int | tuple[int, ...] = (0, 512, 512),
108
+ do_resize: bool = False,
109
+ ) -> None:
110
+ super().__init__()
111
+ if isinstance(layer_norm_eps, float):
112
+ layer_norm_eps = (layer_norm_eps,) * 3
113
+ if isinstance(num_classes, int):
114
+ num_classes = (num_classes,) * 3
115
+
116
+ self.image_size = ensure_tuple(image_size, 2)
117
+ self.patch_size = ensure_tuple(patch_size, 2)
118
+ self.do_resize = do_resize
119
+
120
+ self.dino: VisionTransformer = vit_base_dreamsim(
121
+ image_size=self.image_size,
122
+ patch_size=self.patch_size,
123
+ layer_norm_eps=layer_norm_eps[0],
124
+ num_classes=num_classes[0],
125
+ pre_norm=False,
126
+ act_layer="gelu",
127
+ )
128
+ self.clip1: VisionTransformer = vit_base_dreamsim(
129
+ image_size=self.image_size,
130
+ patch_size=self.patch_size,
131
+ layer_norm_eps=layer_norm_eps[1],
132
+ num_classes=num_classes[1],
133
+ pre_norm=True,
134
+ act_layer="quick_gelu",
135
+ )
136
+ self.clip2: VisionTransformer = vit_base_dreamsim(
137
+ image_size=self.image_size,
138
+ patch_size=self.patch_size,
139
+ layer_norm_eps=layer_norm_eps[2],
140
+ num_classes=num_classes[2],
141
+ pre_norm=True,
142
+ act_layer="gelu",
143
+ )
144
+
145
+ self.resize = T.Resize(
146
+ self.image_size,
147
+ interpolation=T.InterpolationMode.BICUBIC,
148
+ antialias=True,
149
+ )
150
+ self.dino_norm = T.Normalize(
151
+ mean=(0.485, 0.456, 0.406),
152
+ std=(0.229, 0.224, 0.225),
153
+ )
154
+ self.clip_norm = T.Normalize(
155
+ mean=(0.48145466, 0.4578275, 0.40821073),
156
+ std=(0.26862954, 0.26130258, 0.27577711),
157
+ )
158
+
159
+ self._compiled = False
160
+
161
+ def compile(self, *, mode: str = "reduce-overhead", force: bool = False, **kwargs):
162
+ if (not self._compiled) or force:
163
+ self.dino = torch.compile(self.dino, mode=mode, **kwargs)
164
+ self.clip1 = torch.compile(self.clip1, mode=mode, **kwargs)
165
+ self.clip2 = torch.compile(self.clip2, mode=mode, **kwargs)
166
+ self._compiled = True
167
+ return self
168
+
169
+ def transforms(self, x: Tensor, resize: bool = False) -> tuple[Tensor, Tensor, Tensor]:
170
+ if resize:
171
+ x = self.resize(x)
172
+ x = self.dino_norm(x), self.clip_norm(x), self.clip_norm(x)
173
+ return x
174
+
175
+ def forward_features(self, x: Tensor) -> Tensor:
176
+ if x.ndim == 3:
177
+ x = x.unsqueeze(0)
178
+ x_dino, x_clip1, x_clip2 = self.transforms(x, self.do_resize)
179
+
180
+ # these expect to always receive a batch, and will return a batch
181
+ x_dino = self.dino.forward(x_dino, norm=False)
182
+ x_clip1 = self.clip1.forward(x_clip1, norm=True)
183
+ x_clip2 = self.clip2.forward(x_clip2, norm=True)
184
+
185
+ z: Tensor = torch.cat([x_dino, x_clip1, x_clip2], dim=1)
186
+ z = z.div(z.norm(dim=1, keepdim=True))
187
+ z = z.sub(z.mean(dim=1, keepdim=True))
188
+ return z
dreamsim/utils.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Functions in this file are courtesty of @ashen-sensored on GitHub - thankyou so much! <3
3
+
4
+ Used to merge DreamSim LoRA weights into the base ViT models manually, so we don't need
5
+ to use an ancient version of PeFT that is no longer supported (and kind of broken)
6
+ """
7
+ import logging
8
+ from os import PathLike
9
+ from pathlib import Path
10
+
11
+ import torch
12
+ from safetensors.torch import load_file
13
+ from torch import Tensor, nn
14
+
15
+ from .model import DreamsimModel
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ @torch.no_grad()
21
+ def calculate_merged_weight(
22
+ lora_a: Tensor,
23
+ lora_b: Tensor,
24
+ base: Tensor,
25
+ scale: float,
26
+ qkv_switches: list[bool],
27
+ ) -> Tensor:
28
+ n_switches = len(qkv_switches)
29
+ n_groups = sum(qkv_switches)
30
+
31
+ qkv_mask = torch.tensor(qkv_switches, dtype=torch.bool).reshape(len(qkv_switches), -1)
32
+ qkv_mask = qkv_mask.broadcast_to((-1, base.shape[0] // n_switches)).reshape(-1)
33
+
34
+ lora_b = lora_b.squeeze()
35
+ delta_w = base.new_zeros(lora_b.shape[0], base.shape[1])
36
+
37
+ grp_in_ch = lora_a.shape[0] // n_groups
38
+ grp_out_ch = lora_b.shape[0] // n_groups
39
+ for i in range(n_groups):
40
+ islice = slice(i * grp_in_ch, (i + 1) * grp_in_ch)
41
+ oslice = slice(i * grp_out_ch, (i + 1) * grp_out_ch)
42
+ delta_w[oslice, :] = lora_b[oslice, :] @ lora_a[islice, :]
43
+
44
+ delta_w_full = base.new_zeros(base.shape)
45
+ delta_w_full[qkv_mask, :] = delta_w
46
+
47
+ merged = base + scale * delta_w_full
48
+ return merged.to(base)
49
+
50
+
51
+ @torch.no_grad()
52
+ def merge_dreamsim_lora(
53
+ base_model: nn.Module,
54
+ lora_path: PathLike,
55
+ torch_device: torch.device | str = torch.device("cpu"),
56
+ ):
57
+ lora_path = Path(lora_path)
58
+ # make sure model is on device
59
+ base_model = base_model.eval().requires_grad_(False).to(torch_device)
60
+
61
+ # load the lora
62
+ if lora_path.suffix.lower() in [".pt", ".pth", ".bin"]:
63
+ lora_sd = torch.load(lora_path, map_location=torch_device, weights_only=True)
64
+ elif lora_path.suffix.lower() == ".safetensors":
65
+ lora_sd = load_file(lora_path)
66
+ else:
67
+ raise ValueError(f"Unsupported file extension '{lora_path.suffix}'")
68
+
69
+ # these loras were created by a cursed PEFT version, okay? so we have to do some crimes.
70
+ group_prefix = "base_model.model.base_model.model.model."
71
+ # get all lora weights for qkv layers, stripping the insane prefix
72
+ group_weights = {k.replace(group_prefix, ""): v for k, v in lora_sd.items() if k.startswith(group_prefix)}
73
+ # strip ".lora_X.weight" from keys to match against base model keys
74
+ group_layers = set([k.rsplit(".", 2)[0] for k in group_weights.keys()])
75
+
76
+ base_weights = base_model.state_dict()
77
+ for key in [x for x in base_weights.keys() if "attn.qkv.weight" in x]:
78
+ param_name = key.rsplit(".", 1)[0]
79
+ if param_name not in group_layers:
80
+ logger.warning(f"QKV param '{param_name}' not found in lora weights")
81
+ continue
82
+ new_weight = calculate_merged_weight(
83
+ group_weights[f"{param_name}.lora_A.weight"],
84
+ group_weights[f"{param_name}.lora_B.weight"],
85
+ base_weights[key],
86
+ 0.5 / 16,
87
+ [True, False, True],
88
+ )
89
+ base_weights[key] = new_weight
90
+
91
+ base_model.load_state_dict(base_weights)
92
+ return base_model.requires_grad_(False)
93
+
94
+
95
+ def remap_clip(state_dict: dict[str, Tensor], variant: str) -> dict[str, Tensor]:
96
+ """Remap keys from the original DreamSim checkpoint to match new model structure."""
97
+
98
+ def prepend_extractor(state_dict: dict[str, Tensor]) -> dict[str, Tensor]:
99
+ if variant.endswith("single"):
100
+ return {f"extractor.{k}": v for k, v in state_dict.items()}
101
+ return state_dict
102
+
103
+ if "clip" not in variant:
104
+ return prepend_extractor(state_dict)
105
+
106
+ if "patch_embed.proj.bias" in state_dict:
107
+ _ = state_dict.pop("patch_embed.proj.bias", None)
108
+ if "pos_drop.weight" in state_dict:
109
+ state_dict["norm_pre.weight"] = state_dict.pop("pos_drop.weight")
110
+ state_dict["norm_pre.bias"] = state_dict.pop("pos_drop.bias")
111
+ if "head.weight" in state_dict and "head.bias" not in state_dict:
112
+ state_dict["head.bias"] = torch.zeros(state_dict["head.weight"].shape[0])
113
+
114
+ return prepend_extractor(state_dict)
115
+
116
+
117
+ def convert_dreamsim_single(
118
+ ckpt_path: PathLike,
119
+ variant: str,
120
+ ensemble: bool = False,
121
+ ) -> DreamsimModel:
122
+ ckpt_path = Path(ckpt_path)
123
+ if ckpt_path.exists():
124
+ if ckpt_path.is_dir():
125
+ ckpt_path = ckpt_path.joinpath("ensemble" if ensemble else variant)
126
+ ckpt_path = ckpt_path.joinpath(f"{variant}_merged.safetensors")
127
+
128
+ # defaults are for dino, overridden as needed below
129
+ patch_size = 16
130
+ layer_norm_eps = 1e-6
131
+ pre_norm = False
132
+ act_layer = "gelu"
133
+
134
+ match variant:
135
+ case "open_clip_vitb16" | "open_clip_vitb32" | "clip_vitb16" | "clip_vitb32":
136
+ patch_size = 32 if "b32" in variant else 16
137
+ layer_norm_eps = 1e-5
138
+ pre_norm = True
139
+ img_mean = (0.48145466, 0.4578275, 0.40821073)
140
+ img_std = (0.26862954, 0.26130258, 0.27577711)
141
+ act_layer = "quick_gelu" if variant.startswith("clip_") else "gelu"
142
+ case "dino_vitb16":
143
+ img_mean = (0.485, 0.456, 0.406)
144
+ img_std = (0.229, 0.224, 0.225)
145
+ case _:
146
+ raise NotImplementedError(f"Unsupported model variant '{variant}'")
147
+
148
+ model: DreamsimModel = DreamsimModel(
149
+ image_size=224,
150
+ patch_size=patch_size,
151
+ layer_norm_eps=layer_norm_eps,
152
+ pre_norm=pre_norm,
153
+ act_layer=act_layer,
154
+ img_mean=img_mean,
155
+ img_std=img_std,
156
+ )
157
+ state_dict = load_file(ckpt_path, device="cpu")
158
+ state_dict = remap_clip(state_dict)
159
+ model.extractor.load_state_dict(state_dict)
160
+ return model
dreamsim/vit.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Mostly copy-paste from timm library.
16
+ https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
17
+ """
18
+ import math
19
+ from functools import partial
20
+ from typing import Callable, Final, Optional, Sequence
21
+
22
+ import torch
23
+ from torch import Tensor, nn
24
+ from torch.nn import functional as F
25
+
26
+ from .common import ensure_tuple, get_act_layer, use_fused_attn
27
+
28
+
29
+ def vit_weights_init(module: nn.Module) -> None:
30
+ if isinstance(module, nn.Linear):
31
+ nn.init.trunc_normal_(module.weight, std=0.02)
32
+ if module.bias is not None:
33
+ nn.init.zeros_(module.bias)
34
+ elif isinstance(module, nn.LayerNorm):
35
+ nn.init.ones_(module.weight)
36
+ nn.init.zeros_(module.bias)
37
+
38
+
39
+ class DropPath(nn.Module):
40
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
41
+
42
+ def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
43
+ super(DropPath, self).__init__()
44
+ self.drop_prob = drop_prob
45
+ self.scale_by_keep = scale_by_keep
46
+
47
+ def forward(self, x: Tensor) -> Tensor:
48
+ if self.drop_prob == 0 or not self.training:
49
+ return x
50
+ keep_prob = 1 - self.drop_prob
51
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
52
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
53
+ if keep_prob > 0.0 and self.scale_by_keep:
54
+ random_tensor.div_(keep_prob)
55
+ return x * random_tensor
56
+
57
+ def extra_repr(self):
58
+ return f"drop_prob={self.drop_prob:0.3f}"
59
+
60
+
61
+ class Mlp(nn.Module):
62
+ def __init__(
63
+ self,
64
+ in_features: int,
65
+ hidden_features: Optional[int] = None,
66
+ out_features: Optional[int] = None,
67
+ act_layer: Callable[[], nn.Module] = nn.GELU,
68
+ drop: float = 0.0,
69
+ ):
70
+ super().__init__()
71
+ out_features = out_features or in_features
72
+ hidden_features = hidden_features or in_features
73
+ self.fc1 = nn.Linear(in_features, hidden_features)
74
+ self.act = act_layer()
75
+ self.fc2 = nn.Linear(hidden_features, out_features)
76
+ self.drop = nn.Dropout(drop) if drop > 0.0 else nn.Identity()
77
+
78
+ def forward(self, x: Tensor) -> Tensor:
79
+ x = self.fc1(x)
80
+ x = self.act(x)
81
+ x = self.drop(x)
82
+ x = self.fc2(x)
83
+ x = self.drop(x)
84
+ return x
85
+
86
+
87
+ class Attention(nn.Module):
88
+ fused_attn: Final[bool]
89
+
90
+ def __init__(
91
+ self,
92
+ dim: int,
93
+ num_heads: int = 8,
94
+ qkv_bias: bool = False,
95
+ qk_scale: Optional[float] = None,
96
+ attn_drop: float = 0.0,
97
+ proj_drop: float = 0.0,
98
+ ):
99
+ super().__init__()
100
+ self.num_heads = num_heads
101
+ self.head_dim = dim // num_heads
102
+ self.scale = qk_scale or self.head_dim**-0.5
103
+ self.fused_attn = use_fused_attn()
104
+
105
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
106
+ self.attn_drop = nn.Dropout(attn_drop) if attn_drop > 0.0 else nn.Identity()
107
+ self.proj = nn.Linear(dim, dim)
108
+ self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0.0 else nn.Identity()
109
+
110
+ def forward(self, x: Tensor) -> Tensor:
111
+ B, N, C = x.shape
112
+ qkv: Tensor = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
113
+ q, k, v = qkv.unbind(0)
114
+
115
+ if self.fused_attn:
116
+ dropout_p = getattr(self.attn_drop, "p", 0.0) if self.training else 0.0
117
+ x = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
118
+ else:
119
+ q = q * self.scale
120
+ attn = q @ k.transpose(-2, -1)
121
+ attn = attn.softmax(dim=-1)
122
+ attn = self.attn_drop(attn)
123
+ x = attn @ v
124
+
125
+ x = x.transpose(1, 2).reshape(B, N, C)
126
+ x = self.proj(x)
127
+ x = self.proj_drop(x)
128
+ return x
129
+
130
+
131
+ class Block(nn.Module):
132
+ def __init__(
133
+ self,
134
+ dim: int,
135
+ num_heads: int,
136
+ mlp_ratio: float = 4.0,
137
+ qkv_bias: bool = False,
138
+ drop: float = 0.0,
139
+ attn_drop: float = 0.0,
140
+ drop_path: float = 0.0,
141
+ act_layer: Callable[[], nn.Module] = nn.GELU,
142
+ norm_layer: Callable[[], nn.Module] = nn.LayerNorm,
143
+ ):
144
+ super().__init__()
145
+ self.norm1 = norm_layer(dim)
146
+ self.attn = Attention(
147
+ dim,
148
+ num_heads=num_heads,
149
+ qkv_bias=qkv_bias,
150
+ attn_drop=attn_drop,
151
+ proj_drop=drop,
152
+ )
153
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
154
+ self.norm2 = norm_layer(dim)
155
+ mlp_hidden_dim = int(dim * mlp_ratio)
156
+ self.mlp = Mlp(
157
+ in_features=dim,
158
+ hidden_features=mlp_hidden_dim,
159
+ act_layer=act_layer,
160
+ drop=drop,
161
+ )
162
+
163
+ def forward(self, x: Tensor) -> Tensor:
164
+ x = x + self.drop_path(self.attn(self.norm1(x)))
165
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
166
+ return x
167
+
168
+
169
+ class PatchEmbed(nn.Module):
170
+ """Image to Patch Embedding"""
171
+
172
+ def __init__(
173
+ self,
174
+ img_size: int | tuple[int, int] = 224,
175
+ patch_size: int | tuple[int, int] = 16,
176
+ in_chans: int = 3,
177
+ embed_dim: int = 768,
178
+ bias: bool = True,
179
+ dynamic_pad: bool = False,
180
+ ):
181
+ super().__init__()
182
+ self.img_size = ensure_tuple(img_size, 2)
183
+ self.patch_size = ensure_tuple(patch_size, 2)
184
+ self.num_patches = (self.img_size[0] // self.patch_size[0]) * (self.img_size[1] // self.patch_size[1])
185
+
186
+ self.dynamic_pad = dynamic_pad
187
+
188
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
189
+
190
+ def forward(self, x: Tensor) -> Tensor:
191
+ _, _, H, W = x.shape
192
+ if self.dynamic_pad:
193
+ pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
194
+ pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
195
+ x = F.pad(x, (0, pad_w, 0, pad_h))
196
+ x = self.proj(x)
197
+ x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
198
+ return x
199
+
200
+
201
+ class VisionTransformer(nn.Module):
202
+ """Vision Transformer"""
203
+
204
+ def __init__(
205
+ self,
206
+ img_size: int | tuple[int, int] = 224,
207
+ patch_size: int | tuple[int, int] = 16,
208
+ in_chans: int = 3,
209
+ num_classes: int = 0,
210
+ embed_dim: int = 768,
211
+ depth: int = 12,
212
+ num_heads: int = 12,
213
+ mlp_ratio: float = 4.0,
214
+ qkv_bias: bool = False,
215
+ pre_norm: bool = False,
216
+ drop_rate: float = 0.0,
217
+ attn_drop_rate: float = 0.0,
218
+ drop_path_rate: float = 0.0,
219
+ norm_layer: Callable[[], nn.Module] = nn.LayerNorm,
220
+ act_layer: Callable[[], nn.Module] = nn.GELU,
221
+ skip_init: bool = False,
222
+ dynamic_pad: bool = False,
223
+ **kwargs,
224
+ ):
225
+ super().__init__()
226
+ self.img_size = img_size
227
+ self.patch_size = patch_size
228
+ self.num_classes = num_classes
229
+ self.num_features = self.embed_dim = embed_dim
230
+ self.depth = depth
231
+
232
+ self.patch_embed = PatchEmbed(
233
+ img_size=img_size,
234
+ patch_size=patch_size,
235
+ in_chans=in_chans,
236
+ embed_dim=embed_dim,
237
+ bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
238
+ dynamic_pad=dynamic_pad,
239
+ )
240
+ num_patches = self.patch_embed.num_patches
241
+ embed_len = num_patches + 1 # num_patches + 1 for the [CLS] token
242
+
243
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
244
+ self.pos_embed = nn.Parameter(torch.zeros(1, embed_len, embed_dim))
245
+ self.pos_drop = nn.Dropout(p=drop_rate) if drop_rate > 0.0 else nn.Identity()
246
+ self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
247
+
248
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.depth)] # stochastic depth decay rule
249
+ self.blocks: list[Block] = nn.ModuleList(
250
+ [
251
+ Block(
252
+ dim=embed_dim,
253
+ num_heads=num_heads,
254
+ mlp_ratio=mlp_ratio,
255
+ qkv_bias=qkv_bias,
256
+ drop=drop_rate,
257
+ attn_drop=attn_drop_rate,
258
+ drop_path=dpr[i],
259
+ act_layer=act_layer,
260
+ norm_layer=norm_layer,
261
+ )
262
+ for i in range(self.depth)
263
+ ]
264
+ )
265
+ self.norm = norm_layer(embed_dim)
266
+
267
+ # Classifier head
268
+ self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
269
+
270
+ if not skip_init:
271
+ self.reset_parameters()
272
+
273
+ def reset_parameters(self):
274
+ nn.init.trunc_normal_(self.cls_token, std=0.02)
275
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
276
+ self.apply(vit_weights_init)
277
+
278
+ def interpolate_pos_encoding(self, x: Tensor, w: Tensor, h: Tensor) -> Tensor:
279
+ npatch = x.shape[1] - 1
280
+ N = self.pos_embed.shape[1] - 1
281
+ if npatch == N and w == h:
282
+ return self.pos_embed
283
+ class_pos_embed = self.pos_embed[:, 0]
284
+ patch_pos_embed = self.pos_embed[:, 1:]
285
+ dim = x.shape[-1]
286
+ w0 = w // self.patch_embed.patch_size[0]
287
+ h0 = h // self.patch_embed.patch_size[0]
288
+ # we add a small number to avoid floating point error in the interpolation
289
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
290
+ w0, h0 = w0 + 0.1, h0 + 0.1
291
+ patch_pos_embed = nn.functional.interpolate(
292
+ patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
293
+ scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
294
+ mode="bicubic",
295
+ )
296
+ if int(w0) != patch_pos_embed.shape[-2] or int(h0) != patch_pos_embed.shape[-1]:
297
+ raise ValueError("Error in positional encoding interpolation.")
298
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
299
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
300
+
301
+ def prepare_tokens(self, x: Tensor) -> Tensor:
302
+ B, _, W, H = x.shape
303
+ x = self.patch_embed(x) # patch linear embedding
304
+
305
+ # add the [CLS] token to the embed patch tokens
306
+ cls_tokens = self.cls_token.expand(B, -1, -1)
307
+ x = torch.cat((cls_tokens, x), dim=1)
308
+
309
+ # add positional encoding to each token
310
+ x = x + self.interpolate_pos_encoding(x, W, H)
311
+
312
+ return self.pos_drop(x)
313
+
314
+ def forward(self, x: Tensor, norm: bool = True) -> Tensor:
315
+ x = self.forward_features(x, norm=norm)
316
+ x = self.forward_head(x)
317
+ return x
318
+
319
+ def forward_features(self, x: Tensor, norm: bool = True) -> Tensor:
320
+ x = self.prepare_tokens(x)
321
+ x = self.norm_pre(x)
322
+ for blk in self.blocks:
323
+ x = blk(x)
324
+ if norm:
325
+ x = self.norm(x)
326
+ return x[:, 0]
327
+
328
+ def forward_head(self, x: Tensor) -> Tensor:
329
+ x = self.head(x)
330
+ return x
331
+
332
+ def get_intermediate_layers(
333
+ self,
334
+ x: Tensor,
335
+ n: int | Sequence[int] = 1,
336
+ norm: bool = True,
337
+ ) -> list[Tensor]:
338
+ # we return the output tokens from the `n` last blocks
339
+ outputs = []
340
+ layer_indices = set(range(self.depth - n, self.depth) if isinstance(n, int) else n)
341
+ x = self.prepare_tokens(x)
342
+ x = self.norm_pre(x)
343
+
344
+ for idx, blk in enumerate(self.blocks):
345
+ x = blk(x)
346
+ if idx in layer_indices:
347
+ outputs.append(x)
348
+ if norm:
349
+ outputs = [self.norm(x) for x in outputs]
350
+ return outputs
351
+
352
+
353
+ def vit_base_dreamsim(
354
+ patch_size: int = 16,
355
+ layer_norm_eps: float = 1e-6,
356
+ num_classes: int = 512,
357
+ act_layer: str | Callable[[], nn.Module] = "gelu",
358
+ **kwargs,
359
+ ):
360
+ if isinstance(act_layer, str):
361
+ act_layer = get_act_layer(act_layer)
362
+
363
+ model = VisionTransformer(
364
+ patch_size=patch_size,
365
+ num_classes=num_classes,
366
+ embed_dim=768,
367
+ depth=12,
368
+ num_heads=12,
369
+ mlp_ratio=4,
370
+ qkv_bias=True,
371
+ norm_layer=partial(nn.LayerNorm, eps=layer_norm_eps),
372
+ act_layer=act_layer,
373
+ **kwargs,
374
+ )
375
+ return model
examples/img_a_1.png ADDED

Git LFS Details

  • SHA256: 1f2ec9cb3cc239c8b37ac8f47508b09a043664ca311559f03295c6ff76bdbadd
  • Pointer size: 132 Bytes
  • Size of remote file: 1.04 MB
examples/img_b_1.png ADDED

Git LFS Details

  • SHA256: 963392f8698a2defc04cf7d4aaacbce41a63ebaea03c69f0979ff1f2ed8982b0
  • Pointer size: 131 Bytes
  • Size of remote file: 898 kB
examples/ref_1.png ADDED

Git LFS Details

  • SHA256: b694282ab12110455ccf23650aa745048ffdaf3f80c15ede95cf11528b7741d1
  • Pointer size: 132 Bytes
  • Size of remote file: 1.02 MB
pyproject.toml ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "dreamsim-space"
3
+ version = "0.1.0"
4
+ authors = [
5
+ { name = "Stephanie Fu" },
6
+ { name = "Netanel Tamir" },
7
+ { name = "Shobhita Sundaram" },
8
+ { name = "Lucy Chai" },
9
+ { name = "Richard Zhang" },
10
+ { name = "Tali Dekel" },
11
+ { name = "Phillip Isola" },
12
+ ]
13
+ maintainers = [
14
+ { name = "Andi Powers-Holmes", email = "aholmes@omnom.net" },
15
+ ]
16
+ description = "DreamSim Gradio Space"
17
+ readme = "README.md"
18
+ requires-python = ">=3.9, <3.11"
19
+ keywords = [
20
+ "deep-learning",
21
+ "machine-learning",
22
+ "pytorch",
23
+ ]
24
+ license = { file = "LICENSE.md" }
25
+ classifiers = [
26
+ "Programming Language :: Python :: 3",
27
+ "License :: OSI Approved :: MIT License",
28
+ ]
29
+ dependencies = [
30
+ "accelerate",
31
+ "diffusers",
32
+ "gradio >=4.19.1, < 5.0.0",
33
+ "numpy",
34
+ "pandas",
35
+ "Pillow",
36
+ "PyYAML",
37
+ "safetensors",
38
+ "simple-parsing >= 0.1.0",
39
+ "torch",
40
+ "torchvision",
41
+ "transformers",
42
+ 'xformers; sys_platform != "win32"',
43
+ ]
44
+
45
+ [project.urls]
46
+ Repository = "https://huggingface.co/spaces/neggles/dreamsim"
47
+
48
+ [project.optional-dependencies]
49
+ dev = [
50
+ "ruff >=0.0.289",
51
+ "setuptools-scm >= 8.0.0",
52
+ "pre-commit >= 3.0.0", # remember to run `pre-commit install` after installing
53
+ "tabulate >= 0.8.9", # for inductor log prettyprinting
54
+ ]
55
+ all = [
56
+ "dreamsim-space[dev]",
57
+ ]
58
+
59
+ [build-system]
60
+ build-backend = "setuptools.build_meta"
61
+ requires = ["setuptools>=64", "wheel"]
62
+
63
+ [tool.setuptools.packages.find]
64
+ namespaces = true
65
+ where = ["."]
66
+ include = ["dreamsim"]
67
+
68
+
69
+ [tool.ruff]
70
+ line-length = 110
71
+ target-version = "py310"
72
+ extend-exclude = ["/usr/lib/*"]
73
+
74
+ [tool.ruff.lint]
75
+ ignore = [
76
+ "F841", # local variable assigned but never used
77
+ "F842", # local variable annotated but never used
78
+ "E501", # line too long - will be fixed in format
79
+ ]
80
+
81
+ [tool.ruff.format]
82
+ quote-style = "double"
83
+ indent-style = "space"
84
+ line-ending = "auto"
85
+ skip-magic-trailing-comma = false
86
+ docstring-code-format = true
87
+
88
+ [tool.ruff.lint.isort]
89
+ combine-as-imports = true
90
+ force-wrap-aliases = true
91
+ known-local-folder = ["dreamsim"]
92
+ known-first-party = ["dreamsim"]
93
+
94
+
95
+ [tool.pyright]
96
+ include = ["src/**"]
97
+ exclude = ["/usr/lib/**"]
98
+ stubPath = "./typings"
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ -e .[all]