Upload 116 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- SOWA/.DS_Store +0 -0
- SOWA/.env.example +6 -0
- SOWA/.gitignore +154 -0
- SOWA/.pre-commit-config.yaml +147 -0
- SOWA/.project-root +2 -0
- SOWA/LICENSE +21 -0
- SOWA/Makefile +30 -0
- SOWA/README.md +153 -0
- SOWA/configs/__init__.py +1 -0
- SOWA/configs/callbacks/default.yaml +27 -0
- SOWA/configs/callbacks/early_stopping.yaml +15 -0
- SOWA/configs/callbacks/model_checkpoint.yaml +17 -0
- SOWA/configs/callbacks/model_summary.yaml +5 -0
- SOWA/configs/callbacks/none.yaml +0 -0
- SOWA/configs/callbacks/rich_progress_bar.yaml +4 -0
- SOWA/configs/callbacks/visualization.yaml +4 -0
- SOWA/configs/data/sowa_infer.yaml +56 -0
- SOWA/configs/data/sowa_mvt.yaml +53 -0
- SOWA/configs/data/sowa_overfit.yaml +51 -0
- SOWA/configs/data/sowa_visa.yaml +52 -0
- SOWA/configs/debug/default.yaml +35 -0
- SOWA/configs/debug/fdr.yaml +9 -0
- SOWA/configs/debug/limit.yaml +12 -0
- SOWA/configs/debug/overfit.yaml +13 -0
- SOWA/configs/debug/profiler.yaml +12 -0
- SOWA/configs/eval.yaml +27 -0
- SOWA/configs/experiment/example.yaml +194 -0
- SOWA/configs/extras/default.yaml +8 -0
- SOWA/configs/hparams_search/anomaly_clip_optuna.yaml +61 -0
- SOWA/configs/hydra/default.yaml +19 -0
- SOWA/configs/local/.gitkeep +0 -0
- SOWA/configs/logger/aim.yaml +28 -0
- SOWA/configs/logger/comet.yaml +12 -0
- SOWA/configs/logger/csv.yaml +7 -0
- SOWA/configs/logger/many_loggers.yaml +9 -0
- SOWA/configs/logger/mlflow.yaml +12 -0
- SOWA/configs/logger/neptune.yaml +9 -0
- SOWA/configs/logger/tensorboard.yaml +10 -0
- SOWA/configs/logger/wandb.yaml +16 -0
- SOWA/configs/model/sowa_hfwa.yaml +71 -0
- SOWA/configs/model/sowa_linear.yaml +63 -0
- SOWA/configs/model/sparc_hfwa.yaml +75 -0
- SOWA/configs/model/sparc_linear.yaml +74 -0
- SOWA/configs/model/sparc_prompt.yaml +74 -0
- SOWA/configs/paths/default.yaml +18 -0
- SOWA/configs/prompt/default.yaml +5 -0
- SOWA/configs/prompt/object.yaml +29 -0
- SOWA/configs/prompt/state_template.yaml +91 -0
- SOWA/configs/prompt/template.yaml +51 -0
- SOWA/configs/train.yaml +52 -0
SOWA/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
SOWA/.env.example
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# example of file for storing private and user specific environment variables, like keys or system paths
|
2 |
+
# rename it to ".env" (excluded from version control by default)
|
3 |
+
# .env is loaded by train.py automatically
|
4 |
+
# hydra allows you to reference variables in .yaml configs with special syntax: ${oc.env:MY_VAR}
|
5 |
+
|
6 |
+
MY_VAR="/home/user/my/system/path"
|
SOWA/.gitignore
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
pip-wheel-metadata/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
target/
|
76 |
+
|
77 |
+
# Jupyter Notebook
|
78 |
+
.ipynb_checkpoints
|
79 |
+
|
80 |
+
# IPython
|
81 |
+
profile_default/
|
82 |
+
ipython_config.py
|
83 |
+
|
84 |
+
# pyenv
|
85 |
+
.python-version
|
86 |
+
|
87 |
+
# pipenv
|
88 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
89 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
90 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
91 |
+
# install all needed dependencies.
|
92 |
+
#Pipfile.lock
|
93 |
+
|
94 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
95 |
+
__pypackages__/
|
96 |
+
|
97 |
+
# Celery stuff
|
98 |
+
celerybeat-schedule
|
99 |
+
celerybeat.pid
|
100 |
+
|
101 |
+
# SageMath parsed files
|
102 |
+
*.sage.py
|
103 |
+
|
104 |
+
# Environments
|
105 |
+
.venv
|
106 |
+
env/
|
107 |
+
venv/
|
108 |
+
ENV/
|
109 |
+
env.bak/
|
110 |
+
venv.bak/
|
111 |
+
|
112 |
+
# Spyder project settings
|
113 |
+
.spyderproject
|
114 |
+
.spyproject
|
115 |
+
|
116 |
+
# Rope project settings
|
117 |
+
.ropeproject
|
118 |
+
|
119 |
+
# mkdocs documentation
|
120 |
+
/site
|
121 |
+
|
122 |
+
# mypy
|
123 |
+
.mypy_cache/
|
124 |
+
.dmypy.json
|
125 |
+
dmypy.json
|
126 |
+
|
127 |
+
# Pyre type checker
|
128 |
+
.pyre/
|
129 |
+
|
130 |
+
### VisualStudioCode
|
131 |
+
.vscode/*
|
132 |
+
!.vscode/settings.json
|
133 |
+
!.vscode/tasks.json
|
134 |
+
!.vscode/launch.json
|
135 |
+
!.vscode/extensions.json
|
136 |
+
*.code-workspace
|
137 |
+
**/.vscode
|
138 |
+
|
139 |
+
# JetBrains
|
140 |
+
.idea/
|
141 |
+
|
142 |
+
# Data & Models
|
143 |
+
*.h5
|
144 |
+
*.tar
|
145 |
+
*.tar.gz
|
146 |
+
|
147 |
+
# Lightning-Hydra-Template
|
148 |
+
configs/local/default.yaml
|
149 |
+
/data/
|
150 |
+
/logs/
|
151 |
+
.env
|
152 |
+
|
153 |
+
# Aim logging
|
154 |
+
.aim
|
SOWA/.pre-commit-config.yaml
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
default_language_version:
|
2 |
+
python: python3
|
3 |
+
|
4 |
+
repos:
|
5 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
6 |
+
rev: v4.4.0
|
7 |
+
hooks:
|
8 |
+
# list of supported hooks: https://pre-commit.com/hooks.html
|
9 |
+
- id: trailing-whitespace
|
10 |
+
- id: end-of-file-fixer
|
11 |
+
- id: check-docstring-first
|
12 |
+
- id: check-yaml
|
13 |
+
- id: debug-statements
|
14 |
+
- id: detect-private-key
|
15 |
+
- id: check-executables-have-shebangs
|
16 |
+
- id: check-toml
|
17 |
+
- id: check-case-conflict
|
18 |
+
- id: check-added-large-files
|
19 |
+
|
20 |
+
# python code formatting
|
21 |
+
- repo: https://github.com/psf/black
|
22 |
+
rev: 23.1.0
|
23 |
+
hooks:
|
24 |
+
- id: black
|
25 |
+
args: [--line-length, "99"]
|
26 |
+
|
27 |
+
# python import sorting
|
28 |
+
- repo: https://github.com/PyCQA/isort
|
29 |
+
rev: 5.12.0
|
30 |
+
hooks:
|
31 |
+
- id: isort
|
32 |
+
args: ["--profile", "black", "--filter-files"]
|
33 |
+
|
34 |
+
# python upgrading syntax to newer version
|
35 |
+
- repo: https://github.com/asottile/pyupgrade
|
36 |
+
rev: v3.3.1
|
37 |
+
hooks:
|
38 |
+
- id: pyupgrade
|
39 |
+
args: [--py38-plus]
|
40 |
+
|
41 |
+
# python docstring formatting
|
42 |
+
- repo: https://github.com/myint/docformatter
|
43 |
+
rev: v1.7.4
|
44 |
+
hooks:
|
45 |
+
- id: docformatter
|
46 |
+
args:
|
47 |
+
[
|
48 |
+
--in-place,
|
49 |
+
--wrap-summaries=99,
|
50 |
+
--wrap-descriptions=99,
|
51 |
+
--style=sphinx,
|
52 |
+
--black,
|
53 |
+
]
|
54 |
+
|
55 |
+
# python docstring coverage checking
|
56 |
+
- repo: https://github.com/econchick/interrogate
|
57 |
+
rev: 1.5.0 # or master if you're bold
|
58 |
+
hooks:
|
59 |
+
- id: interrogate
|
60 |
+
args:
|
61 |
+
[
|
62 |
+
--verbose,
|
63 |
+
--fail-under=80,
|
64 |
+
--ignore-init-module,
|
65 |
+
--ignore-init-method,
|
66 |
+
--ignore-module,
|
67 |
+
--ignore-nested-functions,
|
68 |
+
-vv,
|
69 |
+
]
|
70 |
+
|
71 |
+
# python check (PEP8), programming errors and code complexity
|
72 |
+
- repo: https://github.com/PyCQA/flake8
|
73 |
+
rev: 6.0.0
|
74 |
+
hooks:
|
75 |
+
- id: flake8
|
76 |
+
args:
|
77 |
+
[
|
78 |
+
"--extend-ignore",
|
79 |
+
"E203,E402,E501,F401,F841,RST2,RST301",
|
80 |
+
"--exclude",
|
81 |
+
"logs/*,data/*",
|
82 |
+
]
|
83 |
+
additional_dependencies: [flake8-rst-docstrings==0.3.0]
|
84 |
+
|
85 |
+
# python security linter
|
86 |
+
- repo: https://github.com/PyCQA/bandit
|
87 |
+
rev: "1.7.5"
|
88 |
+
hooks:
|
89 |
+
- id: bandit
|
90 |
+
args: ["-s", "B101"]
|
91 |
+
|
92 |
+
# yaml formatting
|
93 |
+
- repo: https://github.com/pre-commit/mirrors-prettier
|
94 |
+
rev: v3.0.0-alpha.6
|
95 |
+
hooks:
|
96 |
+
- id: prettier
|
97 |
+
types: [yaml]
|
98 |
+
exclude: "environment.yaml"
|
99 |
+
|
100 |
+
# shell scripts linter
|
101 |
+
- repo: https://github.com/shellcheck-py/shellcheck-py
|
102 |
+
rev: v0.9.0.2
|
103 |
+
hooks:
|
104 |
+
- id: shellcheck
|
105 |
+
|
106 |
+
# md formatting
|
107 |
+
- repo: https://github.com/executablebooks/mdformat
|
108 |
+
rev: 0.7.16
|
109 |
+
hooks:
|
110 |
+
- id: mdformat
|
111 |
+
args: ["--number"]
|
112 |
+
additional_dependencies:
|
113 |
+
- mdformat-gfm
|
114 |
+
- mdformat-tables
|
115 |
+
- mdformat_frontmatter
|
116 |
+
# - mdformat-toc
|
117 |
+
# - mdformat-black
|
118 |
+
|
119 |
+
# word spelling linter
|
120 |
+
- repo: https://github.com/codespell-project/codespell
|
121 |
+
rev: v2.2.4
|
122 |
+
hooks:
|
123 |
+
- id: codespell
|
124 |
+
args:
|
125 |
+
- --skip=logs/**,data/**,*.ipynb
|
126 |
+
# - --ignore-words-list=abc,def
|
127 |
+
|
128 |
+
# jupyter notebook cell output clearing
|
129 |
+
- repo: https://github.com/kynan/nbstripout
|
130 |
+
rev: 0.6.1
|
131 |
+
hooks:
|
132 |
+
- id: nbstripout
|
133 |
+
|
134 |
+
# jupyter notebook linting
|
135 |
+
- repo: https://github.com/nbQA-dev/nbQA
|
136 |
+
rev: 1.6.3
|
137 |
+
hooks:
|
138 |
+
- id: nbqa-black
|
139 |
+
args: ["--line-length=99"]
|
140 |
+
- id: nbqa-isort
|
141 |
+
args: ["--profile=black"]
|
142 |
+
- id: nbqa-flake8
|
143 |
+
args:
|
144 |
+
[
|
145 |
+
"--extend-ignore=E203,E402,E501,F401,F841",
|
146 |
+
"--exclude=logs/*,data/*",
|
147 |
+
]
|
SOWA/.project-root
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# this file is required for inferring the project root directory
|
2 |
+
# do not delete
|
SOWA/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 huzongxiang
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
SOWA/Makefile
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
help: ## Show help
|
3 |
+
@grep -E '^[.a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
|
4 |
+
|
5 |
+
clean: ## Clean autogenerated files
|
6 |
+
rm -rf dist
|
7 |
+
find . -type f -name "*.DS_Store" -ls -delete
|
8 |
+
find . | grep -E "(__pycache__|\.pyc|\.pyo)" | xargs rm -rf
|
9 |
+
find . | grep -E ".pytest_cache" | xargs rm -rf
|
10 |
+
find . | grep -E ".ipynb_checkpoints" | xargs rm -rf
|
11 |
+
rm -f .coverage
|
12 |
+
|
13 |
+
clean-logs: ## Clean logs
|
14 |
+
rm -rf logs/**
|
15 |
+
|
16 |
+
format: ## Run pre-commit hooks
|
17 |
+
pre-commit run -a
|
18 |
+
|
19 |
+
sync: ## Merge changes from main branch to your current branch
|
20 |
+
git pull
|
21 |
+
git pull origin main
|
22 |
+
|
23 |
+
test: ## Run not slow tests
|
24 |
+
pytest -k "not slow"
|
25 |
+
|
26 |
+
test-full: ## Run all tests
|
27 |
+
pytest
|
28 |
+
|
29 |
+
train: ## Train the model
|
30 |
+
python src/train.py
|
SOWA/README.md
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<div align="center">
|
2 |
+
|
3 |
+
# Soldier-Offier Window self-Attention (SOWA)
|
4 |
+
|
5 |
+
<a href="https://pytorch.org/get-started/locally/"><img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-ee4c2c?logo=pytorch&logoColor=white"></a>
|
6 |
+
<a href="https://pytorchlightning.ai/"><img alt="Lightning" src="https://img.shields.io/badge/-Lightning-792ee5?logo=pytorchlightning&logoColor=white"></a>
|
7 |
+
<a href="https://hydra.cc/"><img alt="Config: Hydra" src="https://img.shields.io/badge/Config-Hydra-89b8cd"></a>
|
8 |
+
<a href="https://github.com/ashleve/lightning-hydra-template"><img alt="Template" src="https://img.shields.io/badge/-Lightning--Hydra--Template-017F2F?style=flat&logo=github&labelColor=gray"></a><br>
|
9 |
+
[![Paper](http://img.shields.io/badge/paper-arxiv.2407.03634-B31B1B.svg)](https://arxiv.org/abs/2407.03634)
|
10 |
+
[![Conference](http://img.shields.io/badge/AnyConference-year-4b44ce.svg)](https://papers.nips.cc/paper/2020)
|
11 |
+
|
12 |
+
</div>
|
13 |
+
|
14 |
+
## Description
|
15 |
+
|
16 |
+
<div align="center">
|
17 |
+
<img src="https://github.com/huzongxiang/sowa/blob/resources/fig1.png" alt="concept" style="width: 50%;">
|
18 |
+
</div>
|
19 |
+
|
20 |
+
Visual anomaly detection is critical in industrial manufacturing, but traditional methods often rely on extensive
|
21 |
+
normal datasets and custom models, limiting scalability.
|
22 |
+
Recent advancements in large-scale visual-language models have significantly improved zero/few-shot anomaly detection. However, these approaches may not fully utilize hierarchical features, potentially missing nuanced details. We
|
23 |
+
introduce a window self-attention mechanism based on the
|
24 |
+
CLIP model, combined with learnable prompts to process
|
25 |
+
multi-level features within a Soldier-Offier Window selfAttention (SOWA) framework. Our method has been tested
|
26 |
+
on five benchmark datasets, demonstrating superior performance by leading in 18 out of 20 metrics compared to existing state-of-the-art techniques.
|
27 |
+
|
28 |
+
![architecture](https://github.com/huzongxiang/sowa/blob/resources/fig2.png)
|
29 |
+
|
30 |
+
## Installation
|
31 |
+
|
32 |
+
#### Pip
|
33 |
+
|
34 |
+
```bash
|
35 |
+
# clone project
|
36 |
+
git clone https://github.com/huzongxiang/sowa
|
37 |
+
cd sowa
|
38 |
+
|
39 |
+
# [OPTIONAL] create conda environment
|
40 |
+
conda create -n sowa python=3.9
|
41 |
+
conda activate sowa
|
42 |
+
|
43 |
+
# install pytorch according to instructions
|
44 |
+
# https://pytorch.org/get-started/
|
45 |
+
|
46 |
+
# install requirements
|
47 |
+
pip install -r requirements.txt
|
48 |
+
```
|
49 |
+
|
50 |
+
#### Conda
|
51 |
+
|
52 |
+
```bash
|
53 |
+
# clone project
|
54 |
+
git clone https://github.com/huzongxiang/sowa
|
55 |
+
cd sowa
|
56 |
+
|
57 |
+
# create conda environment and install dependencies
|
58 |
+
conda env create -f environment.yaml -n sowa
|
59 |
+
|
60 |
+
# activate conda environment
|
61 |
+
conda activate sowa
|
62 |
+
```
|
63 |
+
|
64 |
+
## How to run
|
65 |
+
|
66 |
+
Train model with default configuration
|
67 |
+
|
68 |
+
```bash
|
69 |
+
# train on CPU
|
70 |
+
python src/train.py trainer=cpu data=sowa_visa model=sowa_hfwa
|
71 |
+
|
72 |
+
# train on GPU
|
73 |
+
python src/train.py trainer=gpu data=sowa_visa model=sowa_hfwa
|
74 |
+
```
|
75 |
+
|
76 |
+
## Results
|
77 |
+
|
78 |
+
Comparisons with few-shot (K=4) anomaly detection methods on datasets of MVTec-AD, Visa, BTAD, DAGM and DTD Synthetic.
|
79 |
+
| Metric | Dataset | WinCLIP | April-GAN | Ours |
|
80 |
+
|-----------|----------------|-------------|-------------|-------------|
|
81 |
+
| AC AUROC | MVTec-AD | 95.2±1.3 | 92.8±0.2 | 96.8±0.3 |
|
82 |
+
| | Visa | 87.3±1.8 | 92.6±0.4 | 92.9±0.2 |
|
83 |
+
| | BTAD | 87.0±0.2 | 92.1±0.2 | 94.8±0.2 |
|
84 |
+
| | DAGM | 93.8±0.2 | 96.2±1.1 | 98.9±0.3 |
|
85 |
+
| | DTD-Synthetic | 98.1±0.2 | 98.5±0.1 | 99.1±0.0 |
|
86 |
+
| AC AP | MVTec-AD | 97.3±0.6 | 96.3±0.1 | 98.3±0.3 |
|
87 |
+
| | Visa | 88.8±1.8 | 94.5±0.3 | 94.5±0.2 |
|
88 |
+
| | BTAD | 86.8±0.0 | 95.2±0.5 | 95.5±0.7 |
|
89 |
+
| | DAGM | 83.8±1.1 | 86.7±4.5 | 95.2±1.7 |
|
90 |
+
| | DTD-Synthetic | 99.1±0.1 | 99.4±0.0 | 99.6±0.0 |
|
91 |
+
| AS AUROC | MVTec-AD | 96.2±0.3 | 95.9±0.0 | 95.7±0.1 |
|
92 |
+
| | Visa | 97.2±0.2 | 96.2±0.0 | 97.1±0.0 |
|
93 |
+
| | BTAD | 95.8±0.0 | 94.4±0.1 | 97.1±0.0 |
|
94 |
+
| | DAGM | 93.8±0.1 | 88.9±0.4 | 96.9±0.0 |
|
95 |
+
| | DTD-Synthetic | 96.8±0.2 | 96.7±0.0 | 98.7±0.0 |
|
96 |
+
| AS AUPRO | MVTec-AD | 89.0±0.8 | 91.8±0.1 | 92.4±0.2 |
|
97 |
+
| | Visa | 87.6±0.9 | 90.2±0.1 | 91.4±0.0 |
|
98 |
+
| | BTAD | 66.6±0.2 | 78.2±0.1 | 81.2±0.2 |
|
99 |
+
| | DAGM | 82.4±0.3 | 77.8±0.9 | 94.4±0.1 |
|
100 |
+
| | DTD-Synthetic | 90.1±0.5 | 92.2±0.0 | 96.6±0.1 |
|
101 |
+
|
102 |
+
<!-- 零宽空格 -->
|
103 |
+
|
104 |
+
Performance Comparison on MVTec-AD and Visa Datasets.
|
105 |
+
| Method | Source | MVTec-AD AC AUROC | MVTec-AD AS AUROC | MVTec-AD AS PRO | Visa AC AUROC | Visa AS AUROC | Visa AS PRO |
|
106 |
+
|---------------|-------------------------|-------------------|-------------------|-----------------|---------------|---------------|-------------|
|
107 |
+
| SPADE | arXiv 2020 | 84.8±2.5 | 92.7±0.3 | 87.0±0.5 | 81.7±3.4 | 96.6±0.3 | 87.3±0.8 |
|
108 |
+
| PaDiM | ICPR 2021 | 80.4±2.4 | 92.6±0.7 | 81.3±1.9 | 72.8±2.9 | 93.2±0.5 | 72.6±1.9 |
|
109 |
+
| PatchCore | CVPR 2022 | 88.8±2.6 | 94.3±0.5 | 84.3±1.6 | 85.3±2.1 | 96.8±0.3 | 84.9±1.4 |
|
110 |
+
| WinCLIP | CVPR 2023 | 95.2±1.3 | 96.2±0.3 | 89.0±0.8 | 87.3±1.8 | 97.2±0.2 | 87.6±0.9 |
|
111 |
+
| April-GAN | CVPR 2023 VAND workshop | 92.8±0.2 | 95.9±0.0 | 91.8±0.1 | 92.6±0.4 | 96.2±0.0 | 90.2±0.1 |
|
112 |
+
| PromptAD | CVPR 2024 | 96.6±0.9 | 96.5±0.2 | - | 89.1±1.7 | 97.4±0.3 | - |
|
113 |
+
| InCTRL | CVPR 2024 | 94.5±1.8 | - | - | 87.7±1.9 | - | - |
|
114 |
+
| SOWA | Ours | 96.8±0.3 | 95.7±0.1 | 92.4±0.2 | 92.9±0.2 | 97.1±0.0 | 91.4±0.0 |
|
115 |
+
|
116 |
+
|
117 |
+
<!-- 零宽空格 -->
|
118 |
+
|
119 |
+
Comparisons with few-shot anomaly detection methods on datasets of MVTec-AD, Visa, BTAD, DAGM and DTD Synthetic.
|
120 |
+
<div align="center">
|
121 |
+
<img src="https://github.com/huzongxiang/sowa/blob/resources/fig5.png" alt="few-shot" style="width: 70%;">
|
122 |
+
</div>
|
123 |
+
|
124 |
+
|
125 |
+
## Visualization
|
126 |
+
Visualization results under the few-shot setting (K=4).
|
127 |
+
<div align="center">
|
128 |
+
<img src="https://github.com/huzongxiang/sowa/blob/resources/fig6.png" alt="concept" style="width: 70%;">
|
129 |
+
</div>
|
130 |
+
|
131 |
+
|
132 |
+
## Mechanism
|
133 |
+
Hierarchical Results on MVTec-AD Dataset. A set of images showing the real outputs of the model, illustrating how different layers (H1 to H4) process various feature modes. Each row represents a different sample, with columns showing the original image, segmentation mask, heatmap, and feature outputs from H1 to H4, and fusion.
|
134 |
+
![mechanism](https://github.com/huzongxiang/sowa/blob/resources/fig7.png)
|
135 |
+
|
136 |
+
|
137 |
+
## Inference Speed
|
138 |
+
Inference performance comparison of different methods on a single NVIDIA RTX3070 8GB GPU.
|
139 |
+
<div align="center">
|
140 |
+
<img src="https://github.com/huzongxiang/sowa/blob/resources/fig9.png" alt="speed" style="width: 80%;">
|
141 |
+
</div>
|
142 |
+
|
143 |
+
|
144 |
+
## Citation
|
145 |
+
Please cite the following paper if this work helps your project:
|
146 |
+
```
|
147 |
+
@article{hu2024sowa,
|
148 |
+
title={SOWA: Adapting Hierarchical Frozen Window Self-Attention to Visual-Language Models for Better Anomaly Detection},
|
149 |
+
author={Hu, Zongxiang and Zhang, zhaosheng},
|
150 |
+
journal={arXiv preprint arXiv:2407.03634},
|
151 |
+
year={2024}
|
152 |
+
}
|
153 |
+
```
|
SOWA/configs/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# this file is needed here to include configs when building project as a package
|
SOWA/configs/callbacks/default.yaml
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- model_checkpoint
|
3 |
+
- early_stopping
|
4 |
+
- model_summary
|
5 |
+
- rich_progress_bar
|
6 |
+
- visualization
|
7 |
+
- _self_
|
8 |
+
|
9 |
+
model_checkpoint:
|
10 |
+
dirpath: ${paths.output_dir}/checkpoints
|
11 |
+
filename: "epoch_{epoch:03d}"
|
12 |
+
monitor: "train/loss"
|
13 |
+
mode: "min"
|
14 |
+
save_last: True
|
15 |
+
auto_insert_metric_name: False
|
16 |
+
|
17 |
+
early_stopping:
|
18 |
+
monitor: "train/loss"
|
19 |
+
patience: 10
|
20 |
+
mode: "min"
|
21 |
+
|
22 |
+
model_summary:
|
23 |
+
max_depth: -1
|
24 |
+
|
25 |
+
visualization:
|
26 |
+
dirpath: ${paths.output_dir}/visualizations
|
27 |
+
visualize: True
|
SOWA/configs/callbacks/early_stopping.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html
|
2 |
+
|
3 |
+
early_stopping:
|
4 |
+
_target_: lightning.pytorch.callbacks.EarlyStopping
|
5 |
+
monitor: ??? # quantity to be monitored, must be specified !!!
|
6 |
+
min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement
|
7 |
+
patience: 3 # number of checks with no improvement after which training will be stopped
|
8 |
+
verbose: False # verbosity mode
|
9 |
+
mode: "min" # "max" means higher metric value is better, can be also "min"
|
10 |
+
strict: True # whether to crash the training if monitor is not found in the validation metrics
|
11 |
+
check_finite: True # when set True, stops training when the monitor becomes NaN or infinite
|
12 |
+
stopping_threshold: null # stop training immediately once the monitored quantity reaches this threshold
|
13 |
+
divergence_threshold: null # stop training as soon as the monitored quantity becomes worse than this threshold
|
14 |
+
check_on_train_epoch_end: null # whether to run early stopping at the end of the training epoch
|
15 |
+
# log_rank_zero_only: False # this keyword argument isn't available in stable version
|
SOWA/configs/callbacks/model_checkpoint.yaml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html
|
2 |
+
|
3 |
+
model_checkpoint:
|
4 |
+
_target_: lightning.pytorch.callbacks.ModelCheckpoint
|
5 |
+
dirpath: null # directory to save the model file
|
6 |
+
filename: null # checkpoint filename
|
7 |
+
monitor: null # name of the logged metric which determines when model is improving
|
8 |
+
verbose: False # verbosity mode
|
9 |
+
save_last: null # additionally always save an exact copy of the last checkpoint to a file last.ckpt
|
10 |
+
save_top_k: 1 # save k best models (determined by above metric)
|
11 |
+
mode: "min" # "max" means higher metric value is better, can be also "min"
|
12 |
+
auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name
|
13 |
+
save_weights_only: False # if True, then only the model’s weights will be saved
|
14 |
+
every_n_train_steps: null # number of training steps between checkpoints
|
15 |
+
train_time_interval: null # checkpoints are monitored at the specified time interval
|
16 |
+
every_n_epochs: null # number of epochs between checkpoints
|
17 |
+
save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation
|
SOWA/configs/callbacks/model_summary.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html
|
2 |
+
|
3 |
+
model_summary:
|
4 |
+
_target_: lightning.pytorch.callbacks.RichModelSummary
|
5 |
+
max_depth: 1 # the maximum depth of layer nesting that the summary will include
|
SOWA/configs/callbacks/none.yaml
ADDED
File without changes
|
SOWA/configs/callbacks/rich_progress_bar.yaml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.RichProgressBar.html
|
2 |
+
|
3 |
+
rich_progress_bar:
|
4 |
+
_target_: lightning.pytorch.callbacks.RichProgressBar
|
SOWA/configs/callbacks/visualization.yaml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
visualization:
|
2 |
+
_target_: src.models.components.callback.AnomalyVisualizationCallback
|
3 |
+
dirpath: ${paths.output_dir}/visualizations
|
4 |
+
visualize: True
|
SOWA/configs/data/sowa_infer.yaml
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: src.data.anomaly_clip_datamodule.AnomalyCLIPDataModule
|
2 |
+
data_dir:
|
3 |
+
train: /home/hzx/Projects/Data/MVTec-AD
|
4 |
+
valid: /home/hzx/Projects/Data/MVTec-AD
|
5 |
+
# test: /home/hzx/Projects/Data/BTAD
|
6 |
+
# test: /home/hzx/Projects/Data/DAGM
|
7 |
+
test: /home/hzx/Projects/Data/DTD-Synthetic
|
8 |
+
# test: /home/hzx/Projects/Data/MPDD
|
9 |
+
# test: /home/hzx/Projects/Data/SDD
|
10 |
+
dataset:
|
11 |
+
train:
|
12 |
+
_target_: src.data.components.anomal_dataset.MVTecDataset
|
13 |
+
_partial_: true
|
14 |
+
transform:
|
15 |
+
_target_: src.data.components.transform.ImageTransform
|
16 |
+
image_size: 336
|
17 |
+
mask_transform:
|
18 |
+
_target_: src.data.components.transform.MaskTransform
|
19 |
+
image_size: ${data.image_size}
|
20 |
+
preload: false
|
21 |
+
aug_rate: 0.2
|
22 |
+
valid:
|
23 |
+
_target_: src.data.components.anomal_dataset.VisaDataset
|
24 |
+
_partial_: true
|
25 |
+
transform:
|
26 |
+
_target_: src.data.components.transform.ImageTransform
|
27 |
+
image_size: 336
|
28 |
+
mask_transform:
|
29 |
+
_target_: src.data.components.transform.MaskTransform
|
30 |
+
image_size: ${data.image_size}
|
31 |
+
preload: false
|
32 |
+
test:
|
33 |
+
_target_: src.data.components.anomal_dataset.VisaDataset
|
34 |
+
_partial_: true
|
35 |
+
transform:
|
36 |
+
_target_: src.data.components.transform.ImageTransform
|
37 |
+
image_size: 336
|
38 |
+
mask_transform:
|
39 |
+
_target_: src.data.components.transform.MaskTransform
|
40 |
+
image_size: ${data.image_size}
|
41 |
+
preload: false
|
42 |
+
kshot:
|
43 |
+
_target_: src.data.components.kshot_dataset.VisaKShotDataset
|
44 |
+
_partial_: true
|
45 |
+
k_shot: 4
|
46 |
+
transform:
|
47 |
+
_target_: src.data.components.transform.ImageTransform
|
48 |
+
image_size: 336
|
49 |
+
mask_transform:
|
50 |
+
_target_: src.data.components.transform.MaskTransform
|
51 |
+
image_size: ${data.image_size}
|
52 |
+
preload: false
|
53 |
+
image_size: 336
|
54 |
+
num_workers: 6
|
55 |
+
pin_memory: False
|
56 |
+
batch_size: 8
|
SOWA/configs/data/sowa_mvt.yaml
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: src.data.anomaly_clip_datamodule.AnomalyCLIPDataModule
|
2 |
+
data_dir:
|
3 |
+
train: /home/hzx/Projects/Data/Visa
|
4 |
+
valid: /home/hzx/Projects/Data/MVTec-AD
|
5 |
+
test: /home/hzx/Projects/Data/MVTec-AD
|
6 |
+
dataset:
|
7 |
+
train:
|
8 |
+
_target_: src.data.components.anomal_dataset.VisaDataset
|
9 |
+
_partial_: true
|
10 |
+
transform:
|
11 |
+
_target_: src.data.components.transform.ImageTransform
|
12 |
+
image_size: 336
|
13 |
+
mask_transform:
|
14 |
+
_target_: src.data.components.transform.MaskTransform
|
15 |
+
image_size: ${data.image_size}
|
16 |
+
preload: false
|
17 |
+
valid:
|
18 |
+
_target_: src.data.components.anomal_dataset.MVTecDataset
|
19 |
+
_partial_: true
|
20 |
+
transform:
|
21 |
+
_target_: src.data.components.transform.ImageTransform
|
22 |
+
image_size: 336
|
23 |
+
mask_transform:
|
24 |
+
_target_: src.data.components.transform.MaskTransform
|
25 |
+
image_size: ${data.image_size}
|
26 |
+
preload: false
|
27 |
+
aug_rate: 0.0
|
28 |
+
test:
|
29 |
+
_target_: src.data.components.anomal_dataset.MVTecDataset
|
30 |
+
_partial_: true
|
31 |
+
transform:
|
32 |
+
_target_: src.data.components.transform.ImageTransform
|
33 |
+
image_size: 336
|
34 |
+
mask_transform:
|
35 |
+
_target_: src.data.components.transform.MaskTransform
|
36 |
+
image_size: ${data.image_size}
|
37 |
+
preload: false
|
38 |
+
aug_rate: 0.0
|
39 |
+
kshot:
|
40 |
+
_target_: src.data.components.kshot_dataset.MVTecKShotDataset
|
41 |
+
_partial_: true
|
42 |
+
k_shot: 1
|
43 |
+
transform:
|
44 |
+
_target_: src.data.components.transform.ImageTransform
|
45 |
+
image_size: 336
|
46 |
+
mask_transform:
|
47 |
+
_target_: src.data.components.transform.MaskTransform
|
48 |
+
image_size: ${data.image_size}
|
49 |
+
preload: false
|
50 |
+
image_size: 336
|
51 |
+
num_workers: 6
|
52 |
+
pin_memory: False
|
53 |
+
batch_size: 8
|
SOWA/configs/data/sowa_overfit.yaml
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: src.data.anomaly_clip_datamodule.AnomalyCLIPDataModule
|
2 |
+
data_dir:
|
3 |
+
train: /home/hzx/Projects/Data/MVTec-AD
|
4 |
+
valid: /home/hzx/Projects/Data/MVTec-AD
|
5 |
+
test: /home/hzx/Projects/Data/MVTec-AD
|
6 |
+
dataset:
|
7 |
+
train:
|
8 |
+
_target_: src.data.components.anomal_dataset.MVTecDataset
|
9 |
+
_partial_: true
|
10 |
+
transform:
|
11 |
+
_target_: src.data.components.transform.ImageTransform
|
12 |
+
image_size: 336
|
13 |
+
mask_transform:
|
14 |
+
_target_: src.data.components.transform.MaskTransform
|
15 |
+
image_size: ${data.image_size}
|
16 |
+
preload: false
|
17 |
+
valid:
|
18 |
+
_target_: src.data.components.anomal_dataset.MVTecDataset
|
19 |
+
_partial_: true
|
20 |
+
transform:
|
21 |
+
_target_: src.data.components.transform.ImageTransform
|
22 |
+
image_size: 336
|
23 |
+
mask_transform:
|
24 |
+
_target_: src.data.components.transform.MaskTransform
|
25 |
+
image_size: ${data.image_size}
|
26 |
+
preload: false
|
27 |
+
test:
|
28 |
+
_target_: src.data.components.anomal_dataset.MVTecDataset
|
29 |
+
_partial_: true
|
30 |
+
transform:
|
31 |
+
_target_: src.data.components.transform.ImageTransform
|
32 |
+
image_size: 336
|
33 |
+
mask_transform:
|
34 |
+
_target_: src.data.components.transform.MaskTransform
|
35 |
+
image_size: ${data.image_size}
|
36 |
+
preload: false
|
37 |
+
kshot:
|
38 |
+
_target_: src.data.components.kshot_dataset.MVTecKShotDataset
|
39 |
+
_partial_: true
|
40 |
+
k_shot: 1
|
41 |
+
transform:
|
42 |
+
_target_: src.data.components.transform.ImageTransform
|
43 |
+
image_size: 336
|
44 |
+
mask_transform:
|
45 |
+
_target_: src.data.components.transform.MaskTransform
|
46 |
+
image_size: ${data.image_size}
|
47 |
+
preload: false
|
48 |
+
image_size: 336
|
49 |
+
num_workers: 6
|
50 |
+
pin_memory: False
|
51 |
+
batch_size: 8
|
SOWA/configs/data/sowa_visa.yaml
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: src.data.anomaly_clip_datamodule.AnomalyCLIPDataModule
|
2 |
+
data_dir:
|
3 |
+
train: /home/hzx/Projects/Data/MVTec-AD
|
4 |
+
valid: /home/hzx/Projects/Data/Visa
|
5 |
+
test: /home/hzx/Projects/Data/Visa
|
6 |
+
dataset:
|
7 |
+
train:
|
8 |
+
_target_: src.data.components.anomal_dataset.MVTecDataset
|
9 |
+
_partial_: true
|
10 |
+
transform:
|
11 |
+
_target_: src.data.components.transform.ImageTransform
|
12 |
+
image_size: 336
|
13 |
+
mask_transform:
|
14 |
+
_target_: src.data.components.transform.MaskTransform
|
15 |
+
image_size: ${data.image_size}
|
16 |
+
preload: false
|
17 |
+
aug_rate: 0.2
|
18 |
+
valid:
|
19 |
+
_target_: src.data.components.anomal_dataset.VisaDataset
|
20 |
+
_partial_: true
|
21 |
+
transform:
|
22 |
+
_target_: src.data.components.transform.ImageTransform
|
23 |
+
image_size: 336
|
24 |
+
mask_transform:
|
25 |
+
_target_: src.data.components.transform.MaskTransform
|
26 |
+
image_size: ${data.image_size}
|
27 |
+
preload: false
|
28 |
+
test:
|
29 |
+
_target_: src.data.components.anomal_dataset.VisaDataset
|
30 |
+
_partial_: true
|
31 |
+
transform:
|
32 |
+
_target_: src.data.components.transform.ImageTransform
|
33 |
+
image_size: 336
|
34 |
+
mask_transform:
|
35 |
+
_target_: src.data.components.transform.MaskTransform
|
36 |
+
image_size: ${data.image_size}
|
37 |
+
preload: false
|
38 |
+
kshot:
|
39 |
+
_target_: src.data.components.kshot_dataset.VisaKShotDataset
|
40 |
+
_partial_: true
|
41 |
+
k_shot: 1
|
42 |
+
transform:
|
43 |
+
_target_: src.data.components.transform.ImageTransform
|
44 |
+
image_size: 336
|
45 |
+
mask_transform:
|
46 |
+
_target_: src.data.components.transform.MaskTransform
|
47 |
+
image_size: ${data.image_size}
|
48 |
+
preload: false
|
49 |
+
image_size: 336
|
50 |
+
num_workers: 6
|
51 |
+
pin_memory: False
|
52 |
+
batch_size: 8
|
SOWA/configs/debug/default.yaml
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# default debugging setup, runs 1 full epoch
|
4 |
+
# other debugging configs can inherit from this one
|
5 |
+
|
6 |
+
# overwrite task name so debugging logs are stored in separate folder
|
7 |
+
task_name: "debug"
|
8 |
+
|
9 |
+
# disable callbacks and loggers during debugging
|
10 |
+
callbacks: null
|
11 |
+
logger: null
|
12 |
+
|
13 |
+
extras:
|
14 |
+
ignore_warnings: False
|
15 |
+
enforce_tags: False
|
16 |
+
|
17 |
+
# sets level of all command line loggers to 'DEBUG'
|
18 |
+
# https://hydra.cc/docs/tutorials/basic/running_your_app/logging/
|
19 |
+
hydra:
|
20 |
+
job_logging:
|
21 |
+
root:
|
22 |
+
level: DEBUG
|
23 |
+
|
24 |
+
# use this to also set hydra loggers to 'DEBUG'
|
25 |
+
# verbose: True
|
26 |
+
|
27 |
+
trainer:
|
28 |
+
max_epochs: 1
|
29 |
+
accelerator: cpu # debuggers don't like gpus
|
30 |
+
devices: 1 # debuggers don't like multiprocessing
|
31 |
+
detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor
|
32 |
+
|
33 |
+
data:
|
34 |
+
num_workers: 0 # debuggers don't like multiprocessing
|
35 |
+
pin_memory: False # disable gpu memory pin
|
SOWA/configs/debug/fdr.yaml
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# runs 1 train, 1 validation and 1 test step
|
4 |
+
|
5 |
+
defaults:
|
6 |
+
- default
|
7 |
+
|
8 |
+
trainer:
|
9 |
+
fast_dev_run: true
|
SOWA/configs/debug/limit.yaml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# uses only 1% of the training data and 5% of validation/test data
|
4 |
+
|
5 |
+
defaults:
|
6 |
+
- default
|
7 |
+
|
8 |
+
trainer:
|
9 |
+
max_epochs: 3
|
10 |
+
limit_train_batches: 0.01
|
11 |
+
limit_val_batches: 0.05
|
12 |
+
limit_test_batches: 0.05
|
SOWA/configs/debug/overfit.yaml
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# overfits to 3 batches
|
4 |
+
|
5 |
+
defaults:
|
6 |
+
- default
|
7 |
+
|
8 |
+
trainer:
|
9 |
+
max_epochs: 20
|
10 |
+
overfit_batches: 3
|
11 |
+
|
12 |
+
# model ckpt and early stopping need to be disabled during overfitting
|
13 |
+
callbacks: null
|
SOWA/configs/debug/profiler.yaml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# runs with execution time profiling
|
4 |
+
|
5 |
+
defaults:
|
6 |
+
- default
|
7 |
+
|
8 |
+
trainer:
|
9 |
+
max_epochs: 1
|
10 |
+
profiler: "simple"
|
11 |
+
# profiler: "advanced"
|
12 |
+
# profiler: "pytorch"
|
SOWA/configs/eval.yaml
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
defaults:
|
4 |
+
- _self_
|
5 |
+
- data: anomaly_clip # choose datamodule with `test_dataloader()` for evaluation
|
6 |
+
- model: anomaly_clip
|
7 |
+
- callbacks: default
|
8 |
+
- logger: many_loggers
|
9 |
+
- trainer: default
|
10 |
+
- paths: default
|
11 |
+
- extras: default
|
12 |
+
- hydra: default
|
13 |
+
|
14 |
+
# information of object and prompt template
|
15 |
+
- prompt: default
|
16 |
+
|
17 |
+
task_name: "eval"
|
18 |
+
|
19 |
+
tags: ["dev"]
|
20 |
+
|
21 |
+
# seed for random number generators in pytorch, numpy and python.random
|
22 |
+
seed: 42
|
23 |
+
|
24 |
+
# passing checkpoint path is necessary for evaluation
|
25 |
+
# ckpt_path: /home/hzx/Projects/SPARC/logs/train/runs/2024-06-14_16-32-14/checkpoints/epoch_000.ckpt
|
26 |
+
# ckpt_path: /home/hzx/Projects/Weight/mvtech-kshot/learnable/0-shot/2024-05-27_13-08-55/checkpoints/epoch_000.ckpt
|
27 |
+
ckpt_path: /home/hzx/Projects/Weight/visa-kshot/learnable/0-shot/2024-05-27_11-07-16/checkpoints/epoch_000.ckpt
|
SOWA/configs/experiment/example.yaml
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
task_name: train
|
2 |
+
tags:
|
3 |
+
- dev
|
4 |
+
train: true
|
5 |
+
test: true
|
6 |
+
ckpt_path: null
|
7 |
+
seed: 2025
|
8 |
+
data:
|
9 |
+
_target_: src.data.anomaly_clip_datamodule.AnomalyCLIPDataModule
|
10 |
+
data_dir:
|
11 |
+
train: /home/hzx/Projects/Data/Visa
|
12 |
+
valid: /home/hzx/Projects/Data/MVTec-AD
|
13 |
+
test: /home/hzx/Projects/Data/MVTec-AD
|
14 |
+
dataset:
|
15 |
+
train:
|
16 |
+
_target_: src.data.components.anomal_dataset.VisaDataset
|
17 |
+
_partial_: true
|
18 |
+
transform:
|
19 |
+
_target_: src.data.components.transform.ImageTransform
|
20 |
+
image_size: 336
|
21 |
+
mask_transform:
|
22 |
+
_target_: src.data.components.transform.MaskTransform
|
23 |
+
image_size: ${data.image_size}
|
24 |
+
preload: false
|
25 |
+
valid:
|
26 |
+
_target_: src.data.components.anomal_dataset.MVTecDataset
|
27 |
+
_partial_: true
|
28 |
+
transform:
|
29 |
+
_target_: src.data.components.transform.ImageTransform
|
30 |
+
image_size: 336
|
31 |
+
mask_transform:
|
32 |
+
_target_: src.data.components.transform.MaskTransform
|
33 |
+
image_size: ${data.image_size}
|
34 |
+
preload: false
|
35 |
+
aug_rate: 0.0
|
36 |
+
test:
|
37 |
+
_target_: src.data.components.anomal_dataset.MVTecDataset
|
38 |
+
_partial_: true
|
39 |
+
transform:
|
40 |
+
_target_: src.data.components.transform.ImageTransform
|
41 |
+
image_size: 336
|
42 |
+
mask_transform:
|
43 |
+
_target_: src.data.components.transform.MaskTransform
|
44 |
+
image_size: ${data.image_size}
|
45 |
+
preload: false
|
46 |
+
aug_rate: 0.0
|
47 |
+
kshot:
|
48 |
+
_target_: src.data.components.kshot_dataset.MVTecKShotDataset
|
49 |
+
_partial_: true
|
50 |
+
k_shot: 1
|
51 |
+
transform:
|
52 |
+
_target_: src.data.components.transform.ImageTransform
|
53 |
+
image_size: 336
|
54 |
+
mask_transform:
|
55 |
+
_target_: src.data.components.transform.MaskTransform
|
56 |
+
image_size: ${data.image_size}
|
57 |
+
preload: false
|
58 |
+
image_size: 336
|
59 |
+
num_workers: 4
|
60 |
+
pin_memory: false
|
61 |
+
batch_size: 8
|
62 |
+
model:
|
63 |
+
_target_: src.models.anomaly_clip_module.AnomalyCLIPModule
|
64 |
+
optimizer:
|
65 |
+
_target_: torch.optim.AdamW
|
66 |
+
_partial_: true
|
67 |
+
lr: 0.001
|
68 |
+
weight_decay: 0.2
|
69 |
+
scheduler:
|
70 |
+
_target_: torch.optim.lr_scheduler.ReduceLROnPlateau
|
71 |
+
_partial_: true
|
72 |
+
mode: min
|
73 |
+
factor: 0.1
|
74 |
+
patience: 5
|
75 |
+
net:
|
76 |
+
_target_: src.models.components.anomaly_clip.AnomalyCLIP
|
77 |
+
arch: ViT-L/14@336px
|
78 |
+
image_size: 336
|
79 |
+
class_names:
|
80 |
+
- object
|
81 |
+
temperature: 0.05
|
82 |
+
prompt_length: 24
|
83 |
+
context_length: 77
|
84 |
+
truncate: false
|
85 |
+
feature_map_idx:
|
86 |
+
- 5
|
87 |
+
- 11
|
88 |
+
- 17
|
89 |
+
- 23
|
90 |
+
share_weight: false
|
91 |
+
state_template:
|
92 |
+
normal:
|
93 |
+
- '{}'
|
94 |
+
anomaly:
|
95 |
+
- damaged {}
|
96 |
+
tokenizer:
|
97 |
+
_target_: src.models.components.clip.simple_tokenizer.SimpleTokenizer
|
98 |
+
adapter:
|
99 |
+
_target_: src.models.components.adapter.BasicLayer
|
100 |
+
_partial_: true
|
101 |
+
input_resolution:
|
102 |
+
- 24
|
103 |
+
- 24
|
104 |
+
window_size: 6
|
105 |
+
depth: 1
|
106 |
+
num_heads: 8
|
107 |
+
hidden_features: null
|
108 |
+
cpb_dim: 64
|
109 |
+
value_only: true
|
110 |
+
drop: 0.0
|
111 |
+
attn_drop: 0.2
|
112 |
+
loss:
|
113 |
+
cross_entropy:
|
114 |
+
_target_: torch.nn.CrossEntropyLoss
|
115 |
+
focal:
|
116 |
+
_target_: src.models.components.loss.FocalLoss
|
117 |
+
dice:
|
118 |
+
_target_: src.models.components.loss.BinaryDiceLoss
|
119 |
+
k_shot: false
|
120 |
+
enable_validation: false
|
121 |
+
compile: false
|
122 |
+
callbacks:
|
123 |
+
model_checkpoint:
|
124 |
+
_target_: lightning.pytorch.callbacks.ModelCheckpoint
|
125 |
+
dirpath: ${paths.output_dir}/checkpoints
|
126 |
+
filename: epoch_{epoch:03d}
|
127 |
+
monitor: train/loss
|
128 |
+
verbose: false
|
129 |
+
save_last: true
|
130 |
+
save_top_k: 1
|
131 |
+
mode: min
|
132 |
+
auto_insert_metric_name: false
|
133 |
+
save_weights_only: false
|
134 |
+
every_n_train_steps: null
|
135 |
+
train_time_interval: null
|
136 |
+
every_n_epochs: null
|
137 |
+
save_on_train_epoch_end: null
|
138 |
+
early_stopping:
|
139 |
+
_target_: lightning.pytorch.callbacks.EarlyStopping
|
140 |
+
monitor: train/loss
|
141 |
+
min_delta: 0.0
|
142 |
+
patience: 10
|
143 |
+
verbose: false
|
144 |
+
mode: min
|
145 |
+
strict: true
|
146 |
+
check_finite: true
|
147 |
+
stopping_threshold: null
|
148 |
+
divergence_threshold: null
|
149 |
+
check_on_train_epoch_end: null
|
150 |
+
model_summary:
|
151 |
+
_target_: lightning.pytorch.callbacks.RichModelSummary
|
152 |
+
max_depth: -1
|
153 |
+
rich_progress_bar:
|
154 |
+
_target_: lightning.pytorch.callbacks.RichProgressBar
|
155 |
+
visualization:
|
156 |
+
_target_: src.models.components.callback.AnomalyVisualizationCallback
|
157 |
+
dirpath: ${paths.output_dir}/visualizations
|
158 |
+
visualize: true
|
159 |
+
visulization:
|
160 |
+
dirpath: ${paths.output_dir}/visualizations
|
161 |
+
visualize: true
|
162 |
+
logger:
|
163 |
+
wandb:
|
164 |
+
_target_: lightning.pytorch.loggers.wandb.WandbLogger
|
165 |
+
save_dir: ${paths.output_dir}
|
166 |
+
offline: false
|
167 |
+
id: null
|
168 |
+
anonymous: null
|
169 |
+
project: mvt_optuna
|
170 |
+
log_model: false
|
171 |
+
prefix: ''
|
172 |
+
group: ''
|
173 |
+
tags: []
|
174 |
+
job_type: ''
|
175 |
+
trainer:
|
176 |
+
_target_: lightning.pytorch.trainer.Trainer
|
177 |
+
default_root_dir: ${paths.output_dir}
|
178 |
+
min_epochs: 1
|
179 |
+
max_epochs: 2
|
180 |
+
accelerator: gpu
|
181 |
+
devices: 1
|
182 |
+
check_val_every_n_epoch: 1
|
183 |
+
deterministic: false
|
184 |
+
paths:
|
185 |
+
root_dir: ${oc.env:PROJECT_ROOT}
|
186 |
+
data_dir: ${paths.root_dir}/data/
|
187 |
+
log_dir: ${paths.root_dir}/logs/
|
188 |
+
output_dir: ${hydra:runtime.output_dir}
|
189 |
+
work_dir: ${hydra:runtime.cwd}
|
190 |
+
extras:
|
191 |
+
ignore_warnings: false
|
192 |
+
enforce_tags: true
|
193 |
+
print_config: true
|
194 |
+
optimized_metric: test/objective
|
SOWA/configs/extras/default.yaml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# disable python warnings if they annoy you
|
2 |
+
ignore_warnings: False
|
3 |
+
|
4 |
+
# ask user for tags if none are provided in the config
|
5 |
+
enforce_tags: True
|
6 |
+
|
7 |
+
# pretty print config tree at the start of the run using Rich library
|
8 |
+
print_config: True
|
SOWA/configs/hparams_search/anomaly_clip_optuna.yaml
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# example hyperparameter optimization of some experiment with Optuna:
|
4 |
+
# python train.py -m hparams_search=mnist_optuna experiment=example
|
5 |
+
|
6 |
+
defaults:
|
7 |
+
- override /hydra/sweeper: optuna
|
8 |
+
|
9 |
+
# choose metric which will be optimized by Optuna
|
10 |
+
# make sure this is the correct name of some metric logged in lightning module!
|
11 |
+
optimized_metric: test/objective
|
12 |
+
|
13 |
+
# here we define Optuna hyperparameter search
|
14 |
+
# it optimizes for value returned from function with @hydra.main decorator
|
15 |
+
# docs: https://hydra.cc/docs/next/plugins/optuna_sweeper
|
16 |
+
hydra:
|
17 |
+
mode: "MULTIRUN" # set hydra to multirun by default if this config is attached
|
18 |
+
|
19 |
+
sweeper:
|
20 |
+
_target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper
|
21 |
+
|
22 |
+
# storage URL to persist optimization results
|
23 |
+
# for example, you can use SQLite if you set 'sqlite:///example.db'
|
24 |
+
storage: null
|
25 |
+
|
26 |
+
# name of the study to persist optimization results
|
27 |
+
study_name: null
|
28 |
+
|
29 |
+
# number of parallel workers
|
30 |
+
n_jobs: 1
|
31 |
+
|
32 |
+
# 'minimize' or 'maximize' the objective
|
33 |
+
direction: maximize
|
34 |
+
|
35 |
+
# total number of runs that will be executed
|
36 |
+
n_trials: 50
|
37 |
+
|
38 |
+
# choose Optuna hyperparameter sampler
|
39 |
+
# you can choose bayesian sampler (tpe), random search (without optimization), grid sampler, and others
|
40 |
+
# docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html
|
41 |
+
sampler:
|
42 |
+
_target_: optuna.samplers.TPESampler
|
43 |
+
seed: 1234
|
44 |
+
n_startup_trials: 50 # number of random sampling runs before optimization starts
|
45 |
+
|
46 |
+
# define hyperparameter search space
|
47 |
+
params:
|
48 |
+
trainer.max_epochs: choice(1, 5)
|
49 |
+
# model.optimizer.lr: choice(0.0001, 0.001)
|
50 |
+
# model.net.temperature: choice(0.1, 0.05)
|
51 |
+
model.net.prompt_length: choice(8, 12, 16, 24)
|
52 |
+
model.net.share_weight: choice(true, false)
|
53 |
+
model.net.feature_map_idx : choice([5, 11, 17, 23], [0, 11, 23])
|
54 |
+
# model.net.adapter.hidden_features: choice([1024])
|
55 |
+
model.net.adapter.window_size: choice(6, 12, 24)
|
56 |
+
model.net.adapter.depth: choice(1, 2)
|
57 |
+
model.net.adapter.num_heads: choice(8)
|
58 |
+
# model.net.adapter.cpb_dim: choice(64, 128, 512)
|
59 |
+
model.net.adapter.value_only: choice(true, false)
|
60 |
+
model.net.adapter.drop: choice(0.0, 0.1, 0.2)
|
61 |
+
model.net.adapter.attn_drop: choice(0.0, 0.1, 0.2)
|
SOWA/configs/hydra/default.yaml
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://hydra.cc/docs/configure_hydra/intro/
|
2 |
+
|
3 |
+
# enable color logging
|
4 |
+
defaults:
|
5 |
+
- override hydra_logging: default
|
6 |
+
- override job_logging: default
|
7 |
+
|
8 |
+
# output directory, generated dynamically on each run
|
9 |
+
run:
|
10 |
+
dir: ${paths.log_dir}/${task_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S}
|
11 |
+
sweep:
|
12 |
+
dir: ${paths.log_dir}/${task_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S}
|
13 |
+
subdir: ${hydra.job.num}
|
14 |
+
|
15 |
+
job_logging:
|
16 |
+
handlers:
|
17 |
+
file:
|
18 |
+
# Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242
|
19 |
+
filename: ${hydra.runtime.output_dir}/${task_name}.log
|
SOWA/configs/local/.gitkeep
ADDED
File without changes
|
SOWA/configs/logger/aim.yaml
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://aimstack.io/
|
2 |
+
|
3 |
+
# example usage in lightning module:
|
4 |
+
# https://github.com/aimhubio/aim/blob/main/examples/pytorch_lightning_track.py
|
5 |
+
|
6 |
+
# open the Aim UI with the following command (run in the folder containing the `.aim` folder):
|
7 |
+
# `aim up`
|
8 |
+
|
9 |
+
aim:
|
10 |
+
_target_: aim.pytorch_lightning.AimLogger
|
11 |
+
repo: ${paths.root_dir} # .aim folder will be created here
|
12 |
+
# repo: "aim://ip_address:port" # can instead provide IP address pointing to Aim remote tracking server which manages the repo, see https://aimstack.readthedocs.io/en/latest/using/remote_tracking.html#
|
13 |
+
|
14 |
+
# aim allows to group runs under experiment name
|
15 |
+
experiment: null # any string, set to "default" if not specified
|
16 |
+
|
17 |
+
train_metric_prefix: "train/"
|
18 |
+
val_metric_prefix: "val/"
|
19 |
+
test_metric_prefix: "test/"
|
20 |
+
|
21 |
+
# sets the tracking interval in seconds for system usage metrics (CPU, GPU, memory, etc.)
|
22 |
+
system_tracking_interval: 10 # set to null to disable system metrics tracking
|
23 |
+
|
24 |
+
# enable/disable logging of system params such as installed packages, git info, env vars, etc.
|
25 |
+
log_system_params: true
|
26 |
+
|
27 |
+
# enable/disable tracking console logs (default value is true)
|
28 |
+
capture_terminal_logs: false # set to false to avoid infinite console log loop issue https://github.com/aimhubio/aim/issues/2550
|
SOWA/configs/logger/comet.yaml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://www.comet.ml
|
2 |
+
|
3 |
+
comet:
|
4 |
+
_target_: lightning.pytorch.loggers.comet.CometLogger
|
5 |
+
api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable
|
6 |
+
save_dir: "${paths.output_dir}"
|
7 |
+
project_name: "lightning-hydra-template"
|
8 |
+
rest_api_key: null
|
9 |
+
# experiment_name: ""
|
10 |
+
experiment_key: null # set to resume experiment
|
11 |
+
offline: False
|
12 |
+
prefix: ""
|
SOWA/configs/logger/csv.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# csv logger built in lightning
|
2 |
+
|
3 |
+
csv:
|
4 |
+
_target_: lightning.pytorch.loggers.csv_logs.CSVLogger
|
5 |
+
save_dir: "${paths.output_dir}"
|
6 |
+
name: "csv/"
|
7 |
+
prefix: ""
|
SOWA/configs/logger/many_loggers.yaml
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# train with many loggers at once
|
2 |
+
|
3 |
+
defaults:
|
4 |
+
# - comet
|
5 |
+
- csv
|
6 |
+
# - mlflow
|
7 |
+
# - neptune
|
8 |
+
# - tensorboard
|
9 |
+
- wandb
|
SOWA/configs/logger/mlflow.yaml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://mlflow.org
|
2 |
+
|
3 |
+
mlflow:
|
4 |
+
_target_: lightning.pytorch.loggers.mlflow.MLFlowLogger
|
5 |
+
# experiment_name: ""
|
6 |
+
# run_name: ""
|
7 |
+
tracking_uri: ${paths.log_dir}/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI
|
8 |
+
tags: null
|
9 |
+
# save_dir: "./mlruns"
|
10 |
+
prefix: ""
|
11 |
+
artifact_location: null
|
12 |
+
# run_id: ""
|
SOWA/configs/logger/neptune.yaml
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://neptune.ai
|
2 |
+
|
3 |
+
neptune:
|
4 |
+
_target_: lightning.pytorch.loggers.neptune.NeptuneLogger
|
5 |
+
api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable
|
6 |
+
project: username/lightning-hydra-template
|
7 |
+
# name: ""
|
8 |
+
log_model_checkpoints: True
|
9 |
+
prefix: ""
|
SOWA/configs/logger/tensorboard.yaml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://www.tensorflow.org/tensorboard/
|
2 |
+
|
3 |
+
tensorboard:
|
4 |
+
_target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
|
5 |
+
save_dir: "${paths.output_dir}/tensorboard/"
|
6 |
+
name: null
|
7 |
+
log_graph: False
|
8 |
+
default_hp_metric: True
|
9 |
+
prefix: ""
|
10 |
+
# version: ""
|
SOWA/configs/logger/wandb.yaml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://wandb.ai
|
2 |
+
|
3 |
+
wandb:
|
4 |
+
_target_: lightning.pytorch.loggers.wandb.WandbLogger
|
5 |
+
# name: "" # name of the run (normally generated by wandb)
|
6 |
+
save_dir: "${paths.output_dir}"
|
7 |
+
offline: False
|
8 |
+
id: null # pass correct id to resume experiment!
|
9 |
+
anonymous: null # enable anonymous logging
|
10 |
+
project: mvt_optuna
|
11 |
+
log_model: False # upload lightning ckpts
|
12 |
+
prefix: "" # a string to put at the beginning of metric keys
|
13 |
+
# entity: "" # set to name of your wandb team
|
14 |
+
group: ""
|
15 |
+
tags: []
|
16 |
+
job_type: ""
|
SOWA/configs/model/sowa_hfwa.yaml
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: src.models.anomaly_clip_module.AnomalyCLIPModule
|
2 |
+
|
3 |
+
optimizer:
|
4 |
+
_target_: torch.optim.AdamW
|
5 |
+
_partial_: true
|
6 |
+
lr: 0.001
|
7 |
+
weight_decay: 0.2
|
8 |
+
|
9 |
+
scheduler:
|
10 |
+
_target_: torch.optim.lr_scheduler.ReduceLROnPlateau
|
11 |
+
_partial_: true
|
12 |
+
mode: min
|
13 |
+
factor: 0.1
|
14 |
+
patience: 5
|
15 |
+
|
16 |
+
# scheduler:
|
17 |
+
# _target_: src.models.components.scheduler.WarmupCosineAnnealingLR
|
18 |
+
# _partial_: true
|
19 |
+
# warmup_epochs: 10
|
20 |
+
# total_epoch: 50
|
21 |
+
|
22 |
+
net:
|
23 |
+
_target_: src.models.components.anomaly_clip.AnomalyCLIP
|
24 |
+
arch: ViT-L/14@336px
|
25 |
+
image_size: 336
|
26 |
+
class_names: ["object"]
|
27 |
+
# class_names: ${prompt.class_names}
|
28 |
+
temperature: 0.07 # softmax
|
29 |
+
prompt_length: 12 # length of learnable prompts
|
30 |
+
context_length: 77 # defaut 77 for openai clip
|
31 |
+
truncate: false
|
32 |
+
feature_map_idx: [5, 11, 17, 23] # [0, 12, 23] [6, 12, 18] [5, 11, 17, 23] index of resnetblock in ViT
|
33 |
+
share_weight: true # whether the adapter shares weights for different feature maps
|
34 |
+
# state_template: ${prompt.state_template}
|
35 |
+
state_template:
|
36 |
+
normal: ["{}"]
|
37 |
+
anomaly: ["damaged {}"]
|
38 |
+
tokenizer:
|
39 |
+
_target_: src.models.components.clip.simple_tokenizer.SimpleTokenizer
|
40 |
+
adapter:
|
41 |
+
_target_: src.models.components.adapter.BasicLayer
|
42 |
+
_partial_: true
|
43 |
+
input_resolution: [24, 24] # (image_size - kerner_size) / stride + 1. eg. 24 = (224 - 14) / 14 + 1
|
44 |
+
window_size: 12
|
45 |
+
depth: 1 # if depth < 2, thers is no window shift
|
46 |
+
num_heads: 8
|
47 |
+
hidden_features: null # set null, same as nn.Linear
|
48 |
+
cpb_dim: 64
|
49 |
+
value_only: true
|
50 |
+
drop: 0.0
|
51 |
+
attn_drop: 0.1
|
52 |
+
# shift_size: 1
|
53 |
+
fusion:
|
54 |
+
_target_: src.models.components.cross_modal.DotProductFusion
|
55 |
+
embedding_dim: 768 # clip fusion featrue dim, default 768, only effective for non null
|
56 |
+
|
57 |
+
loss:
|
58 |
+
cross_entropy:
|
59 |
+
_target_: torch.nn.CrossEntropyLoss
|
60 |
+
focal:
|
61 |
+
_target_: src.models.components.loss.FocalLoss
|
62 |
+
dice:
|
63 |
+
_target_: src.models.components.loss.BinaryDiceLoss
|
64 |
+
|
65 |
+
k_shot: false
|
66 |
+
|
67 |
+
filter: true
|
68 |
+
|
69 |
+
enable_validation: false
|
70 |
+
|
71 |
+
compile: false
|
SOWA/configs/model/sowa_linear.yaml
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: src.models.anomaly_clip_module.AnomalyCLIPModule
|
2 |
+
|
3 |
+
optimizer:
|
4 |
+
_target_: torch.optim.AdamW
|
5 |
+
_partial_: true
|
6 |
+
lr: 0.001
|
7 |
+
weight_decay: 0.2
|
8 |
+
|
9 |
+
scheduler:
|
10 |
+
_target_: torch.optim.lr_scheduler.ReduceLROnPlateau
|
11 |
+
_partial_: true
|
12 |
+
mode: min
|
13 |
+
factor: 0.1
|
14 |
+
patience: 5
|
15 |
+
|
16 |
+
net:
|
17 |
+
_target_: src.models.components.anomaly_clip.AnomalyCLIP
|
18 |
+
arch: ViT-L/14@336px
|
19 |
+
image_size: 336
|
20 |
+
class_names: ["object"]
|
21 |
+
# class_names: ${prompt.class_names}
|
22 |
+
temperature: 0.07 # softmax
|
23 |
+
prompt_length: 12 # length of learnable prompts
|
24 |
+
context_length: 77 # defaut 77 for openai clip
|
25 |
+
truncate: false
|
26 |
+
feature_map_idx: [5, 11, 17, 23] # [0, 12, 23] [6, 12, 18] [5, 11, 17, 23] index of resnetblock in ViT
|
27 |
+
share_weight: false # whether the adapter shares weights for different feature maps
|
28 |
+
# state_template: ${prompt.state_template}
|
29 |
+
state_template:
|
30 |
+
normal: ["{}"]
|
31 |
+
anomaly: ["damaged {}"]
|
32 |
+
tokenizer:
|
33 |
+
_target_: src.models.components.clip.simple_tokenizer.SimpleTokenizer
|
34 |
+
adapter:
|
35 |
+
# _target_: torch.nn.Linear
|
36 |
+
# in_features: 1024 # clip vit feature dim, defaut 1024 for openai clip
|
37 |
+
# out_features: 1024
|
38 |
+
# bias: false
|
39 |
+
_target_: src.models.components.adapter.Linear
|
40 |
+
in_features: 1024 # clip vit feature dim, defaut 1024 for openai clip
|
41 |
+
out_features: 1024
|
42 |
+
hidden_features: null # set null, same as nn.Linear
|
43 |
+
dropout_prob: 0.0
|
44 |
+
bias: false
|
45 |
+
fusion:
|
46 |
+
_target_: src.models.components.cross_modal.DotProductFusion
|
47 |
+
embedding_dim: null # clip fusion featrue dim, only effective for learnable
|
48 |
+
|
49 |
+
loss:
|
50 |
+
cross_entropy:
|
51 |
+
_target_: torch.nn.CrossEntropyLoss
|
52 |
+
focal:
|
53 |
+
_target_: src.models.components.loss.FocalLoss
|
54 |
+
dice:
|
55 |
+
_target_: src.models.components.loss.BinaryDiceLoss
|
56 |
+
|
57 |
+
k_shot: false
|
58 |
+
|
59 |
+
filter: true
|
60 |
+
|
61 |
+
enable_validation: false
|
62 |
+
|
63 |
+
compile: false
|
SOWA/configs/model/sparc_hfwa.yaml
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: src.models.anomaly_clip_module.AnomalyCLIPModule
|
2 |
+
|
3 |
+
optimizer:
|
4 |
+
_target_: torch.optim.AdamW
|
5 |
+
_partial_: true
|
6 |
+
lr: 0.001
|
7 |
+
weight_decay: 0.2
|
8 |
+
|
9 |
+
scheduler:
|
10 |
+
_target_: torch.optim.lr_scheduler.ReduceLROnPlateau
|
11 |
+
_partial_: true
|
12 |
+
mode: min
|
13 |
+
factor: 0.1
|
14 |
+
patience: 5
|
15 |
+
|
16 |
+
# scheduler:
|
17 |
+
# _target_: src.models.components.scheduler.WarmupCosineAnnealingLR
|
18 |
+
# _partial_: true
|
19 |
+
# warmup_epochs: 10
|
20 |
+
# total_epoch: 50
|
21 |
+
|
22 |
+
net:
|
23 |
+
_target_: src.models.components.sparc.SPARC
|
24 |
+
arch: ViT-L/14@336px
|
25 |
+
image_size: 336
|
26 |
+
temperature: 0.07 # softmax
|
27 |
+
feature_map_idx: [5, 11, 17, 23] # [0, 12, 23] [6, 12, 18] [5, 11, 17, 23] index of resnetblock in ViT
|
28 |
+
prompt_learner:
|
29 |
+
_target_: src.models.components.coop.AnomalyPromptLearner
|
30 |
+
_partial_: true
|
31 |
+
tokenizer:
|
32 |
+
_target_: src.models.components.clip.simple_tokenizer.SimpleTokenizer
|
33 |
+
prompt_length: 12 # length of learnable prompts
|
34 |
+
context_length: 77 # defaut 77 for openai clip
|
35 |
+
truncate: false
|
36 |
+
class_names: ["object"]
|
37 |
+
# class_names: ${prompt.class_names}
|
38 |
+
# state_template: ${prompt.state_template}
|
39 |
+
state_template:
|
40 |
+
normal: ["{}"]
|
41 |
+
anomaly: ["damaged {}"]
|
42 |
+
text_encoder:
|
43 |
+
_target_: src.models.components.text_encoder.TextMapEncoder
|
44 |
+
_partial_: true
|
45 |
+
adapter:
|
46 |
+
_target_: src.models.components.adapter.BasicLayer
|
47 |
+
_partial_: true
|
48 |
+
input_resolution: [24, 24] # (image_size - kerner_size) / stride + 1. eg. 24 = (224 - 14) / 14 + 1
|
49 |
+
window_size: 12
|
50 |
+
depth: 1 # if depth < 2, thers is no window shift
|
51 |
+
num_heads: 8
|
52 |
+
hidden_features: null # set null, same as nn.Linear
|
53 |
+
cpb_dim: 64
|
54 |
+
value_only: true
|
55 |
+
drop: 0.0
|
56 |
+
attn_drop: 0.1
|
57 |
+
fusion:
|
58 |
+
_target_: src.models.components.cross_modal.DotProductFusion
|
59 |
+
embedding_dim: 768 # clip fusion featrue dim, default 768, only effective for non null
|
60 |
+
|
61 |
+
loss:
|
62 |
+
cross_entropy:
|
63 |
+
_target_: torch.nn.CrossEntropyLoss
|
64 |
+
focal:
|
65 |
+
_target_: src.models.components.loss.FocalLoss
|
66 |
+
dice:
|
67 |
+
_target_: src.models.components.loss.BinaryDiceLoss
|
68 |
+
|
69 |
+
k_shot: false
|
70 |
+
|
71 |
+
filter: true
|
72 |
+
|
73 |
+
enable_validation: false
|
74 |
+
|
75 |
+
compile: false
|
SOWA/configs/model/sparc_linear.yaml
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: src.models.anomaly_clip_module.AnomalyCLIPModule
|
2 |
+
|
3 |
+
optimizer:
|
4 |
+
_target_: torch.optim.AdamW
|
5 |
+
_partial_: true
|
6 |
+
lr: 0.001
|
7 |
+
weight_decay: 0.2
|
8 |
+
|
9 |
+
scheduler:
|
10 |
+
_target_: torch.optim.lr_scheduler.ReduceLROnPlateau
|
11 |
+
_partial_: true
|
12 |
+
mode: min
|
13 |
+
factor: 0.1
|
14 |
+
patience: 5
|
15 |
+
|
16 |
+
# scheduler:
|
17 |
+
# _target_: src.models.components.scheduler.WarmupCosineAnnealingLR
|
18 |
+
# _partial_: true
|
19 |
+
# warmup_epochs: 10
|
20 |
+
# total_epoch: 50
|
21 |
+
|
22 |
+
net:
|
23 |
+
_target_: src.models.components.sparc.SPARC
|
24 |
+
arch: ViT-L/14@336px
|
25 |
+
image_size: 336
|
26 |
+
temperature: 0.07 # softmax
|
27 |
+
feature_map_idx: [5, 11, 17, 23] # [0, 12, 23] [6, 12, 18] [5, 11, 17, 23] index of resnetblock in ViT
|
28 |
+
prompt_learner:
|
29 |
+
_target_: src.models.components.coop.AnomalyPromptLearner
|
30 |
+
_partial_: true
|
31 |
+
tokenizer:
|
32 |
+
_target_: src.models.components.clip.simple_tokenizer.SimpleTokenizer
|
33 |
+
prompt_length: 12 # length of learnable prompts
|
34 |
+
context_length: 77 # defaut 77 for openai clip
|
35 |
+
truncate: false
|
36 |
+
class_names: ["object"]
|
37 |
+
# class_names: ${prompt.class_names}
|
38 |
+
# state_template: ${prompt.state_template}
|
39 |
+
state_template:
|
40 |
+
normal: ["{}"]
|
41 |
+
anomaly: ["damaged {}"]
|
42 |
+
text_encoder:
|
43 |
+
_target_: src.models.components.text_encoder.TextMapEncoder
|
44 |
+
_partial_: true
|
45 |
+
adapter:
|
46 |
+
# _target_: torch.nn.Linear
|
47 |
+
# in_features: 1024 # clip vit feature dim, defaut 1024 for openai clip
|
48 |
+
# out_features: 1024
|
49 |
+
# bias: false
|
50 |
+
_target_: src.models.components.adapter.Linear
|
51 |
+
in_features: 1024 # clip vit feature dim, defaut 1024 for openai clip
|
52 |
+
out_features: 1024
|
53 |
+
hidden_features: null # set null, same as nn.Linear
|
54 |
+
dropout_prob: 0.0
|
55 |
+
bias: false
|
56 |
+
fusion:
|
57 |
+
_target_: src.models.components.cross_modal.DotProductFusion
|
58 |
+
embedding_dim: 768 # clip fusion featrue dim, default 768, only effective for non null
|
59 |
+
|
60 |
+
loss:
|
61 |
+
cross_entropy:
|
62 |
+
_target_: torch.nn.CrossEntropyLoss
|
63 |
+
focal:
|
64 |
+
_target_: src.models.components.loss.FocalLoss
|
65 |
+
dice:
|
66 |
+
_target_: src.models.components.loss.BinaryDiceLoss
|
67 |
+
|
68 |
+
k_shot: false
|
69 |
+
|
70 |
+
filter: true
|
71 |
+
|
72 |
+
enable_validation: false
|
73 |
+
|
74 |
+
compile: false
|
SOWA/configs/model/sparc_prompt.yaml
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: src.models.anomaly_clip_module.AnomalyCLIPModule
|
2 |
+
|
3 |
+
optimizer:
|
4 |
+
_target_: torch.optim.AdamW
|
5 |
+
_partial_: true
|
6 |
+
lr: 0.001
|
7 |
+
weight_decay: 0.2
|
8 |
+
|
9 |
+
scheduler:
|
10 |
+
_target_: torch.optim.lr_scheduler.ReduceLROnPlateau
|
11 |
+
_partial_: true
|
12 |
+
mode: min
|
13 |
+
factor: 0.1
|
14 |
+
patience: 5
|
15 |
+
|
16 |
+
# scheduler:
|
17 |
+
# _target_: src.models.components.scheduler.WarmupCosineAnnealingLR
|
18 |
+
# _partial_: true
|
19 |
+
# warmup_epochs: 10
|
20 |
+
# total_epoch: 50
|
21 |
+
|
22 |
+
net:
|
23 |
+
_target_: src.models.components..sparc.SPARC
|
24 |
+
arch: ViT-L/14@336px
|
25 |
+
image_size: 336
|
26 |
+
temperature: 0.07 # softmax
|
27 |
+
feature_map_idx: [5, 11, 17, 23] # [0, 12, 23] [6, 12, 18] [5, 11, 17, 23] index of resnetblock in ViT
|
28 |
+
share_weight: true # whether the adapter shares weights for different feature maps
|
29 |
+
prompt_learner:
|
30 |
+
_target_: src.models.components.coop.PromptEncoder
|
31 |
+
_partial_: true
|
32 |
+
tokenizer:
|
33 |
+
_target_: src.models.components.clip.simple_tokenizer.SimpleTokenizer
|
34 |
+
context_length: 77 # defaut 77 for openai clip
|
35 |
+
truncate: false
|
36 |
+
class_names: ${prompt.class_names}
|
37 |
+
prompt_normal: ${prompt.template.normal}
|
38 |
+
prompt_abnormal: ${prompt.template.abnormal}
|
39 |
+
prompt_templates: ${prompt.template.templates}
|
40 |
+
adapter:
|
41 |
+
_target_: torch.nn.Linear
|
42 |
+
in_features: 1024 # clip vit feature dim, defaut 1024 for openai clip
|
43 |
+
out_features: 1024
|
44 |
+
bias: false
|
45 |
+
# _target_: src.models.components.adapter.BasicLayer
|
46 |
+
# _partial_: true
|
47 |
+
# input_resolution: [24, 24] # (image_size - kerner_size) / stride + 1. eg. 24 = (224 - 14) / 14 + 1
|
48 |
+
# window_size: 12
|
49 |
+
# depth: 1 # if depth < 2, thers is no window shift
|
50 |
+
# num_heads: 8
|
51 |
+
# hidden_features: null # set null, same as nn.Linear
|
52 |
+
# cpb_dim: 64
|
53 |
+
# value_only: true
|
54 |
+
# drop: 0.0
|
55 |
+
# attn_drop: 0.1
|
56 |
+
fusion:
|
57 |
+
_target_: src.models.components.cross_modal.DotProductFusion
|
58 |
+
embedding_dim: 768 # clip fusion featrue dim, default 768, only effective for non null
|
59 |
+
|
60 |
+
loss:
|
61 |
+
cross_entropy:
|
62 |
+
_target_: torch.nn.CrossEntropyLoss
|
63 |
+
focal:
|
64 |
+
_target_: src.models.components.loss.FocalLoss
|
65 |
+
dice:
|
66 |
+
_target_: src.models.components.loss.BinaryDiceLoss
|
67 |
+
|
68 |
+
k_shot: false
|
69 |
+
|
70 |
+
filter: true
|
71 |
+
|
72 |
+
enable_validation: false
|
73 |
+
|
74 |
+
compile: false
|
SOWA/configs/paths/default.yaml
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# path to root directory
|
2 |
+
# this requires PROJECT_ROOT environment variable to exist
|
3 |
+
# you can replace it with "." if you want the root to be the current working directory
|
4 |
+
root_dir: ${oc.env:PROJECT_ROOT}
|
5 |
+
|
6 |
+
# path to data directory
|
7 |
+
data_dir: ${paths.root_dir}/data/
|
8 |
+
|
9 |
+
# path to logging directory
|
10 |
+
log_dir: ${paths.root_dir}/logs/
|
11 |
+
|
12 |
+
# path to output directory, created dynamically by hydra
|
13 |
+
# path generation pattern is specified in `configs/hydra/default.yaml`
|
14 |
+
# use it to store all files generated during the run, like ckpts and metrics
|
15 |
+
output_dir: ${hydra:runtime.output_dir}
|
16 |
+
|
17 |
+
# path to working directory
|
18 |
+
work_dir: ${hydra:runtime.cwd}
|
SOWA/configs/prompt/default.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# configs/prompt/default.yaml
|
2 |
+
defaults:
|
3 |
+
- object
|
4 |
+
- state_template
|
5 |
+
- template
|
SOWA/configs/prompt/object.yaml
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class_names:
|
2 |
+
- object
|
3 |
+
- pipe_fryum
|
4 |
+
- metal_nut
|
5 |
+
- pcb1
|
6 |
+
- tile
|
7 |
+
- screw
|
8 |
+
- pcb2
|
9 |
+
- wood
|
10 |
+
- zipper
|
11 |
+
- cable
|
12 |
+
- fryum
|
13 |
+
- pill
|
14 |
+
- capsule
|
15 |
+
- hazelnut
|
16 |
+
- pcb4
|
17 |
+
- leather
|
18 |
+
- bottle
|
19 |
+
- cashew
|
20 |
+
- macaroni2
|
21 |
+
- grid
|
22 |
+
- chewinggum
|
23 |
+
- transistor
|
24 |
+
- macaroni1
|
25 |
+
- candle
|
26 |
+
- capsules
|
27 |
+
- pcb3
|
28 |
+
- carpet
|
29 |
+
- toothbrush
|
SOWA/configs/prompt/state_template.yaml
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
state_template:
|
2 |
+
anomaly:
|
3 |
+
- a photo of the damaged {}.
|
4 |
+
- a bright photo of the damaged {}.
|
5 |
+
- a dark photo of the damaged {}.
|
6 |
+
- a close-up photo of the damaged {}.
|
7 |
+
- a black and white photo of the damaged {}.
|
8 |
+
- a blurry photo of the damaged {}.
|
9 |
+
- a blurry photo of a damaged {}.
|
10 |
+
- a photo of the small damaged {}.
|
11 |
+
- a photo of the large damaged {}.
|
12 |
+
- there is a damaged {} in the scene.
|
13 |
+
- this is one damaged {} in the scene.
|
14 |
+
- a photo of the broken {}.
|
15 |
+
- a bright photo of the broken {}.
|
16 |
+
- a dark photo of the broken {}.
|
17 |
+
- a close-up photo of the broken {}.
|
18 |
+
- a black and white photo of the broken {}.
|
19 |
+
- a blurry photo of the broken {}.
|
20 |
+
- a blurry photo of a broken {}.
|
21 |
+
- a photo of the small broken {}.
|
22 |
+
- a photo of the large broken {}.
|
23 |
+
- there is a broken {} in the scene.
|
24 |
+
- this is one broken {} in the scene.
|
25 |
+
- a photo of the {} with flaw.
|
26 |
+
- a bright photo of the {} with flaw.
|
27 |
+
- a dark photo of the {} with flaw.
|
28 |
+
- a close-up photo of the {} with flaw.
|
29 |
+
- a black and white photo of the {} with flaw.
|
30 |
+
- a blurry photo of the {} with flaw.
|
31 |
+
- a blurry photo of a {} with flaw.
|
32 |
+
- a photo of the small {} with flaw.
|
33 |
+
- a photo of the large {} with flaw.
|
34 |
+
- there is a {} with flaw in the scene.
|
35 |
+
- this is one {} with flaw in the scene.
|
36 |
+
- a photo of the {} with defect.
|
37 |
+
- a bright photo of the {} with defect.
|
38 |
+
- a dark photo of the {} with defect.
|
39 |
+
- a close-up photo of the {} with defect.
|
40 |
+
- a black and white photo of the {} with defect.
|
41 |
+
- a blurry photo of the {} with defect.
|
42 |
+
- a blurry photo of a {} with defect.
|
43 |
+
- a photo of the small {} with defect.
|
44 |
+
- a photo of the large {} with defect.
|
45 |
+
- there is a {} with defect in the scene.
|
46 |
+
- this is one {} with defect in the scene.
|
47 |
+
normal:
|
48 |
+
- a photo of the {}.
|
49 |
+
- a bright photo of the {}.
|
50 |
+
- a dark photo of the {}.
|
51 |
+
- a close-up photo of the {}.
|
52 |
+
- a black and white photo of the {}.
|
53 |
+
- a blurry photo of the {}.
|
54 |
+
- a blurry photo of a {}.
|
55 |
+
- a photo of the small {}.
|
56 |
+
- a photo of the large {}.
|
57 |
+
- there is a {} in the scene.
|
58 |
+
- this is one {} in the scene.
|
59 |
+
- a photo of the flawless {}.
|
60 |
+
- a bright photo of the flawless {}.
|
61 |
+
- a dark photo of the flawless {}.
|
62 |
+
- a close-up photo of the flawless {}.
|
63 |
+
- a black and white photo of the flawless {}.
|
64 |
+
- a blurry photo of the flawless {}.
|
65 |
+
- a blurry photo of a flawless {}.
|
66 |
+
- a photo of the small flawless {}.
|
67 |
+
- a photo of the large flawless {}.
|
68 |
+
- there is a flawless {} in the scene.
|
69 |
+
- this is one flawless {} in the scene.
|
70 |
+
- a photo of the {} without flaw.
|
71 |
+
- a bright photo of the {} without flaw.
|
72 |
+
- a dark photo of the {} without flaw.
|
73 |
+
- a close-up photo of the {} without flaw.
|
74 |
+
- a black and white photo of the {} without flaw.
|
75 |
+
- a blurry photo of the {} without flaw.
|
76 |
+
- a blurry photo of a {} without flaw.
|
77 |
+
- a photo of the small {} without flaw.
|
78 |
+
- a photo of the large {} without flaw.
|
79 |
+
- there is a {} without flaw in the scene.
|
80 |
+
- this is one {} without flaw in the scene.
|
81 |
+
- a photo of the {} without defect.
|
82 |
+
- a bright photo of the {} without defect.
|
83 |
+
- a dark photo of the {} without defect.
|
84 |
+
- a close-up photo of the {} without defect.
|
85 |
+
- a black and white photo of the {} without defect.
|
86 |
+
- a blurry photo of the {} without defect.
|
87 |
+
- a blurry photo of a {} without defect.
|
88 |
+
- a photo of the small {} without defect.
|
89 |
+
- a photo of the large {} without defect.
|
90 |
+
- there is a {} without defect in the scene.
|
91 |
+
- this is one {} without defect in the scene.
|
SOWA/configs/prompt/template.yaml
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
template:
|
2 |
+
normal:
|
3 |
+
- '{}'
|
4 |
+
- 'flawless {}'
|
5 |
+
- 'perfect {}'
|
6 |
+
- 'unblemished {}'
|
7 |
+
- '{} without flaw'
|
8 |
+
- '{} without defect'
|
9 |
+
- '{} without damage'
|
10 |
+
abnormal:
|
11 |
+
- 'damaged {}'
|
12 |
+
- 'broken {}'
|
13 |
+
- '{} with flaw'
|
14 |
+
- '{} with defect'
|
15 |
+
- '{} with damage'
|
16 |
+
templates:
|
17 |
+
- 'a bad photo of a {}.'
|
18 |
+
- 'a low resolution photo of the {}.'
|
19 |
+
- 'a bad photo of the {}.'
|
20 |
+
- 'a cropped photo of the {}.'
|
21 |
+
- 'a bright photo of a {}.'
|
22 |
+
- 'a dark photo of the {}.'
|
23 |
+
- 'a photo of my {}.'
|
24 |
+
- 'a photo of the cool {}.'
|
25 |
+
- 'a close-up photo of a {}.'
|
26 |
+
- 'a black and white photo of the {}.'
|
27 |
+
- 'a bright photo of the {}.'
|
28 |
+
- 'a cropped photo of a {}.'
|
29 |
+
- 'a jpeg corrupted photo of a {}.'
|
30 |
+
- 'a blurry photo of the {}.'
|
31 |
+
- 'a photo of the {}.'
|
32 |
+
- 'a good photo of the {}.'
|
33 |
+
- 'a photo of one {}.'
|
34 |
+
- 'a close-up photo of the {}.'
|
35 |
+
- 'a photo of a {}.'
|
36 |
+
- 'a low resolution photo of a {}.'
|
37 |
+
- 'a photo of a large {}.'
|
38 |
+
- 'a blurry photo of a {}.'
|
39 |
+
- 'a jpeg corrupted photo of the {}.'
|
40 |
+
- 'a good photo of a {}.'
|
41 |
+
- 'a photo of the small {}.'
|
42 |
+
- 'a photo of the large {}.'
|
43 |
+
- 'a black and white photo of a {}.'
|
44 |
+
- 'a dark photo of a {}.'
|
45 |
+
- 'a photo of a cool {}.'
|
46 |
+
- 'a photo of a small {}.'
|
47 |
+
- 'there is a {} in the scene.'
|
48 |
+
- 'there is the {} in the scene.'
|
49 |
+
- 'this is a {} in the scene.'
|
50 |
+
- 'this is the {} in the scene.'
|
51 |
+
- 'this is one {} in the scene.'
|
SOWA/configs/train.yaml
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# specify here default configuration
|
4 |
+
# order of defaults determines the order in which configs override each other
|
5 |
+
defaults:
|
6 |
+
- _self_
|
7 |
+
- data: anomaly_clip
|
8 |
+
- model: anomaly_clip
|
9 |
+
- callbacks: default
|
10 |
+
- logger: many_loggers # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
|
11 |
+
- trainer: default
|
12 |
+
- paths: default
|
13 |
+
- extras: default
|
14 |
+
- hydra: default
|
15 |
+
|
16 |
+
# information of object and prompt template
|
17 |
+
- prompt: default
|
18 |
+
|
19 |
+
# experiment configs allow for version control of specific hyperparameters
|
20 |
+
# e.g. best hyperparameters for given model and datamodule
|
21 |
+
- experiment: null
|
22 |
+
|
23 |
+
# config for hyperparameter optimization
|
24 |
+
- hparams_search: null
|
25 |
+
|
26 |
+
# optional local config for machine/user specific settings
|
27 |
+
# it's optional since it doesn't need to exist and is excluded from version control
|
28 |
+
- optional local: default
|
29 |
+
|
30 |
+
# debugging config (enable through command line, e.g. `python train.py debug=default)
|
31 |
+
- debug: null
|
32 |
+
|
33 |
+
# task name, determines output directory path
|
34 |
+
task_name: "train"
|
35 |
+
|
36 |
+
# tags to help you identify your experiments
|
37 |
+
# you can overwrite this in experiment configs
|
38 |
+
# overwrite from command line with `python train.py tags="[first_tag, second_tag]"`
|
39 |
+
tags: ["dev"]
|
40 |
+
|
41 |
+
# set False to skip model training
|
42 |
+
train: True
|
43 |
+
|
44 |
+
# evaluate on test set, using best model weights achieved during training
|
45 |
+
# lightning chooses best weights based on the metric specified in checkpoint callback
|
46 |
+
test: True
|
47 |
+
|
48 |
+
# simply provide checkpoint path to resume training
|
49 |
+
ckpt_path: null
|
50 |
+
|
51 |
+
# seed for random number generators in pytorch, numpy and python.random
|
52 |
+
seed: 2025
|