zongxiang commited on
Commit
7fe0374
1 Parent(s): ab9ae2e

Upload 116 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. SOWA/.DS_Store +0 -0
  2. SOWA/.env.example +6 -0
  3. SOWA/.gitignore +154 -0
  4. SOWA/.pre-commit-config.yaml +147 -0
  5. SOWA/.project-root +2 -0
  6. SOWA/LICENSE +21 -0
  7. SOWA/Makefile +30 -0
  8. SOWA/README.md +153 -0
  9. SOWA/configs/__init__.py +1 -0
  10. SOWA/configs/callbacks/default.yaml +27 -0
  11. SOWA/configs/callbacks/early_stopping.yaml +15 -0
  12. SOWA/configs/callbacks/model_checkpoint.yaml +17 -0
  13. SOWA/configs/callbacks/model_summary.yaml +5 -0
  14. SOWA/configs/callbacks/none.yaml +0 -0
  15. SOWA/configs/callbacks/rich_progress_bar.yaml +4 -0
  16. SOWA/configs/callbacks/visualization.yaml +4 -0
  17. SOWA/configs/data/sowa_infer.yaml +56 -0
  18. SOWA/configs/data/sowa_mvt.yaml +53 -0
  19. SOWA/configs/data/sowa_overfit.yaml +51 -0
  20. SOWA/configs/data/sowa_visa.yaml +52 -0
  21. SOWA/configs/debug/default.yaml +35 -0
  22. SOWA/configs/debug/fdr.yaml +9 -0
  23. SOWA/configs/debug/limit.yaml +12 -0
  24. SOWA/configs/debug/overfit.yaml +13 -0
  25. SOWA/configs/debug/profiler.yaml +12 -0
  26. SOWA/configs/eval.yaml +27 -0
  27. SOWA/configs/experiment/example.yaml +194 -0
  28. SOWA/configs/extras/default.yaml +8 -0
  29. SOWA/configs/hparams_search/anomaly_clip_optuna.yaml +61 -0
  30. SOWA/configs/hydra/default.yaml +19 -0
  31. SOWA/configs/local/.gitkeep +0 -0
  32. SOWA/configs/logger/aim.yaml +28 -0
  33. SOWA/configs/logger/comet.yaml +12 -0
  34. SOWA/configs/logger/csv.yaml +7 -0
  35. SOWA/configs/logger/many_loggers.yaml +9 -0
  36. SOWA/configs/logger/mlflow.yaml +12 -0
  37. SOWA/configs/logger/neptune.yaml +9 -0
  38. SOWA/configs/logger/tensorboard.yaml +10 -0
  39. SOWA/configs/logger/wandb.yaml +16 -0
  40. SOWA/configs/model/sowa_hfwa.yaml +71 -0
  41. SOWA/configs/model/sowa_linear.yaml +63 -0
  42. SOWA/configs/model/sparc_hfwa.yaml +75 -0
  43. SOWA/configs/model/sparc_linear.yaml +74 -0
  44. SOWA/configs/model/sparc_prompt.yaml +74 -0
  45. SOWA/configs/paths/default.yaml +18 -0
  46. SOWA/configs/prompt/default.yaml +5 -0
  47. SOWA/configs/prompt/object.yaml +29 -0
  48. SOWA/configs/prompt/state_template.yaml +91 -0
  49. SOWA/configs/prompt/template.yaml +51 -0
  50. SOWA/configs/train.yaml +52 -0
SOWA/.DS_Store ADDED
Binary file (6.15 kB). View file
 
SOWA/.env.example ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # example of file for storing private and user specific environment variables, like keys or system paths
2
+ # rename it to ".env" (excluded from version control by default)
3
+ # .env is loaded by train.py automatically
4
+ # hydra allows you to reference variables in .yaml configs with special syntax: ${oc.env:MY_VAR}
5
+
6
+ MY_VAR="/home/user/my/system/path"
SOWA/.gitignore ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .venv
106
+ env/
107
+ venv/
108
+ ENV/
109
+ env.bak/
110
+ venv.bak/
111
+
112
+ # Spyder project settings
113
+ .spyderproject
114
+ .spyproject
115
+
116
+ # Rope project settings
117
+ .ropeproject
118
+
119
+ # mkdocs documentation
120
+ /site
121
+
122
+ # mypy
123
+ .mypy_cache/
124
+ .dmypy.json
125
+ dmypy.json
126
+
127
+ # Pyre type checker
128
+ .pyre/
129
+
130
+ ### VisualStudioCode
131
+ .vscode/*
132
+ !.vscode/settings.json
133
+ !.vscode/tasks.json
134
+ !.vscode/launch.json
135
+ !.vscode/extensions.json
136
+ *.code-workspace
137
+ **/.vscode
138
+
139
+ # JetBrains
140
+ .idea/
141
+
142
+ # Data & Models
143
+ *.h5
144
+ *.tar
145
+ *.tar.gz
146
+
147
+ # Lightning-Hydra-Template
148
+ configs/local/default.yaml
149
+ /data/
150
+ /logs/
151
+ .env
152
+
153
+ # Aim logging
154
+ .aim
SOWA/.pre-commit-config.yaml ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ default_language_version:
2
+ python: python3
3
+
4
+ repos:
5
+ - repo: https://github.com/pre-commit/pre-commit-hooks
6
+ rev: v4.4.0
7
+ hooks:
8
+ # list of supported hooks: https://pre-commit.com/hooks.html
9
+ - id: trailing-whitespace
10
+ - id: end-of-file-fixer
11
+ - id: check-docstring-first
12
+ - id: check-yaml
13
+ - id: debug-statements
14
+ - id: detect-private-key
15
+ - id: check-executables-have-shebangs
16
+ - id: check-toml
17
+ - id: check-case-conflict
18
+ - id: check-added-large-files
19
+
20
+ # python code formatting
21
+ - repo: https://github.com/psf/black
22
+ rev: 23.1.0
23
+ hooks:
24
+ - id: black
25
+ args: [--line-length, "99"]
26
+
27
+ # python import sorting
28
+ - repo: https://github.com/PyCQA/isort
29
+ rev: 5.12.0
30
+ hooks:
31
+ - id: isort
32
+ args: ["--profile", "black", "--filter-files"]
33
+
34
+ # python upgrading syntax to newer version
35
+ - repo: https://github.com/asottile/pyupgrade
36
+ rev: v3.3.1
37
+ hooks:
38
+ - id: pyupgrade
39
+ args: [--py38-plus]
40
+
41
+ # python docstring formatting
42
+ - repo: https://github.com/myint/docformatter
43
+ rev: v1.7.4
44
+ hooks:
45
+ - id: docformatter
46
+ args:
47
+ [
48
+ --in-place,
49
+ --wrap-summaries=99,
50
+ --wrap-descriptions=99,
51
+ --style=sphinx,
52
+ --black,
53
+ ]
54
+
55
+ # python docstring coverage checking
56
+ - repo: https://github.com/econchick/interrogate
57
+ rev: 1.5.0 # or master if you're bold
58
+ hooks:
59
+ - id: interrogate
60
+ args:
61
+ [
62
+ --verbose,
63
+ --fail-under=80,
64
+ --ignore-init-module,
65
+ --ignore-init-method,
66
+ --ignore-module,
67
+ --ignore-nested-functions,
68
+ -vv,
69
+ ]
70
+
71
+ # python check (PEP8), programming errors and code complexity
72
+ - repo: https://github.com/PyCQA/flake8
73
+ rev: 6.0.0
74
+ hooks:
75
+ - id: flake8
76
+ args:
77
+ [
78
+ "--extend-ignore",
79
+ "E203,E402,E501,F401,F841,RST2,RST301",
80
+ "--exclude",
81
+ "logs/*,data/*",
82
+ ]
83
+ additional_dependencies: [flake8-rst-docstrings==0.3.0]
84
+
85
+ # python security linter
86
+ - repo: https://github.com/PyCQA/bandit
87
+ rev: "1.7.5"
88
+ hooks:
89
+ - id: bandit
90
+ args: ["-s", "B101"]
91
+
92
+ # yaml formatting
93
+ - repo: https://github.com/pre-commit/mirrors-prettier
94
+ rev: v3.0.0-alpha.6
95
+ hooks:
96
+ - id: prettier
97
+ types: [yaml]
98
+ exclude: "environment.yaml"
99
+
100
+ # shell scripts linter
101
+ - repo: https://github.com/shellcheck-py/shellcheck-py
102
+ rev: v0.9.0.2
103
+ hooks:
104
+ - id: shellcheck
105
+
106
+ # md formatting
107
+ - repo: https://github.com/executablebooks/mdformat
108
+ rev: 0.7.16
109
+ hooks:
110
+ - id: mdformat
111
+ args: ["--number"]
112
+ additional_dependencies:
113
+ - mdformat-gfm
114
+ - mdformat-tables
115
+ - mdformat_frontmatter
116
+ # - mdformat-toc
117
+ # - mdformat-black
118
+
119
+ # word spelling linter
120
+ - repo: https://github.com/codespell-project/codespell
121
+ rev: v2.2.4
122
+ hooks:
123
+ - id: codespell
124
+ args:
125
+ - --skip=logs/**,data/**,*.ipynb
126
+ # - --ignore-words-list=abc,def
127
+
128
+ # jupyter notebook cell output clearing
129
+ - repo: https://github.com/kynan/nbstripout
130
+ rev: 0.6.1
131
+ hooks:
132
+ - id: nbstripout
133
+
134
+ # jupyter notebook linting
135
+ - repo: https://github.com/nbQA-dev/nbQA
136
+ rev: 1.6.3
137
+ hooks:
138
+ - id: nbqa-black
139
+ args: ["--line-length=99"]
140
+ - id: nbqa-isort
141
+ args: ["--profile=black"]
142
+ - id: nbqa-flake8
143
+ args:
144
+ [
145
+ "--extend-ignore=E203,E402,E501,F401,F841",
146
+ "--exclude=logs/*,data/*",
147
+ ]
SOWA/.project-root ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # this file is required for inferring the project root directory
2
+ # do not delete
SOWA/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 huzongxiang
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
SOWA/Makefile ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ help: ## Show help
3
+ @grep -E '^[.a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
4
+
5
+ clean: ## Clean autogenerated files
6
+ rm -rf dist
7
+ find . -type f -name "*.DS_Store" -ls -delete
8
+ find . | grep -E "(__pycache__|\.pyc|\.pyo)" | xargs rm -rf
9
+ find . | grep -E ".pytest_cache" | xargs rm -rf
10
+ find . | grep -E ".ipynb_checkpoints" | xargs rm -rf
11
+ rm -f .coverage
12
+
13
+ clean-logs: ## Clean logs
14
+ rm -rf logs/**
15
+
16
+ format: ## Run pre-commit hooks
17
+ pre-commit run -a
18
+
19
+ sync: ## Merge changes from main branch to your current branch
20
+ git pull
21
+ git pull origin main
22
+
23
+ test: ## Run not slow tests
24
+ pytest -k "not slow"
25
+
26
+ test-full: ## Run all tests
27
+ pytest
28
+
29
+ train: ## Train the model
30
+ python src/train.py
SOWA/README.md ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ # Soldier-Offier Window self-Attention (SOWA)
4
+
5
+ <a href="https://pytorch.org/get-started/locally/"><img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-ee4c2c?logo=pytorch&logoColor=white"></a>
6
+ <a href="https://pytorchlightning.ai/"><img alt="Lightning" src="https://img.shields.io/badge/-Lightning-792ee5?logo=pytorchlightning&logoColor=white"></a>
7
+ <a href="https://hydra.cc/"><img alt="Config: Hydra" src="https://img.shields.io/badge/Config-Hydra-89b8cd"></a>
8
+ <a href="https://github.com/ashleve/lightning-hydra-template"><img alt="Template" src="https://img.shields.io/badge/-Lightning--Hydra--Template-017F2F?style=flat&logo=github&labelColor=gray"></a><br>
9
+ [![Paper](http://img.shields.io/badge/paper-arxiv.2407.03634-B31B1B.svg)](https://arxiv.org/abs/2407.03634)
10
+ [![Conference](http://img.shields.io/badge/AnyConference-year-4b44ce.svg)](https://papers.nips.cc/paper/2020)
11
+
12
+ </div>
13
+
14
+ ## Description
15
+
16
+ <div align="center">
17
+ <img src="https://github.com/huzongxiang/sowa/blob/resources/fig1.png" alt="concept" style="width: 50%;">
18
+ </div>
19
+
20
+ Visual anomaly detection is critical in industrial manufacturing, but traditional methods often rely on extensive
21
+ normal datasets and custom models, limiting scalability.
22
+ Recent advancements in large-scale visual-language models have significantly improved zero/few-shot anomaly detection. However, these approaches may not fully utilize hierarchical features, potentially missing nuanced details. We
23
+ introduce a window self-attention mechanism based on the
24
+ CLIP model, combined with learnable prompts to process
25
+ multi-level features within a Soldier-Offier Window selfAttention (SOWA) framework. Our method has been tested
26
+ on five benchmark datasets, demonstrating superior performance by leading in 18 out of 20 metrics compared to existing state-of-the-art techniques.
27
+
28
+ ![architecture](https://github.com/huzongxiang/sowa/blob/resources/fig2.png)
29
+
30
+ ## Installation
31
+
32
+ #### Pip
33
+
34
+ ```bash
35
+ # clone project
36
+ git clone https://github.com/huzongxiang/sowa
37
+ cd sowa
38
+
39
+ # [OPTIONAL] create conda environment
40
+ conda create -n sowa python=3.9
41
+ conda activate sowa
42
+
43
+ # install pytorch according to instructions
44
+ # https://pytorch.org/get-started/
45
+
46
+ # install requirements
47
+ pip install -r requirements.txt
48
+ ```
49
+
50
+ #### Conda
51
+
52
+ ```bash
53
+ # clone project
54
+ git clone https://github.com/huzongxiang/sowa
55
+ cd sowa
56
+
57
+ # create conda environment and install dependencies
58
+ conda env create -f environment.yaml -n sowa
59
+
60
+ # activate conda environment
61
+ conda activate sowa
62
+ ```
63
+
64
+ ## How to run
65
+
66
+ Train model with default configuration
67
+
68
+ ```bash
69
+ # train on CPU
70
+ python src/train.py trainer=cpu data=sowa_visa model=sowa_hfwa
71
+
72
+ # train on GPU
73
+ python src/train.py trainer=gpu data=sowa_visa model=sowa_hfwa
74
+ ```
75
+
76
+ ## Results
77
+
78
+ Comparisons with few-shot (K=4) anomaly detection methods on datasets of MVTec-AD, Visa, BTAD, DAGM and DTD Synthetic.
79
+ | Metric | Dataset | WinCLIP | April-GAN | Ours |
80
+ |-----------|----------------|-------------|-------------|-------------|
81
+ | AC AUROC | MVTec-AD | 95.2±1.3 | 92.8±0.2 | 96.8±0.3 |
82
+ | | Visa | 87.3±1.8 | 92.6±0.4 | 92.9±0.2 |
83
+ | | BTAD | 87.0±0.2 | 92.1±0.2 | 94.8±0.2 |
84
+ | | DAGM | 93.8±0.2 | 96.2±1.1 | 98.9±0.3 |
85
+ | | DTD-Synthetic | 98.1±0.2 | 98.5±0.1 | 99.1±0.0 |
86
+ | AC AP | MVTec-AD | 97.3±0.6 | 96.3±0.1 | 98.3±0.3 |
87
+ | | Visa | 88.8±1.8 | 94.5±0.3 | 94.5±0.2 |
88
+ | | BTAD | 86.8±0.0 | 95.2±0.5 | 95.5±0.7 |
89
+ | | DAGM | 83.8±1.1 | 86.7±4.5 | 95.2±1.7 |
90
+ | | DTD-Synthetic | 99.1±0.1 | 99.4±0.0 | 99.6±0.0 |
91
+ | AS AUROC | MVTec-AD | 96.2±0.3 | 95.9±0.0 | 95.7±0.1 |
92
+ | | Visa | 97.2±0.2 | 96.2±0.0 | 97.1±0.0 |
93
+ | | BTAD | 95.8±0.0 | 94.4±0.1 | 97.1±0.0 |
94
+ | | DAGM | 93.8±0.1 | 88.9±0.4 | 96.9±0.0 |
95
+ | | DTD-Synthetic | 96.8±0.2 | 96.7±0.0 | 98.7±0.0 |
96
+ | AS AUPRO | MVTec-AD | 89.0±0.8 | 91.8±0.1 | 92.4±0.2 |
97
+ | | Visa | 87.6±0.9 | 90.2±0.1 | 91.4±0.0 |
98
+ | | BTAD | 66.6±0.2 | 78.2±0.1 | 81.2±0.2 |
99
+ | | DAGM | 82.4±0.3 | 77.8±0.9 | 94.4±0.1 |
100
+ | | DTD-Synthetic | 90.1±0.5 | 92.2±0.0 | 96.6±0.1 |
101
+
102
+ ​<!-- 零宽空格 -->
103
+
104
+ Performance Comparison on MVTec-AD and Visa Datasets.
105
+ | Method | Source | MVTec-AD AC AUROC | MVTec-AD AS AUROC | MVTec-AD AS PRO | Visa AC AUROC | Visa AS AUROC | Visa AS PRO |
106
+ |---------------|-------------------------|-------------------|-------------------|-----------------|---------------|---------------|-------------|
107
+ | SPADE | arXiv 2020 | 84.8±2.5 | 92.7±0.3 | 87.0±0.5 | 81.7±3.4 | 96.6±0.3 | 87.3±0.8 |
108
+ | PaDiM | ICPR 2021 | 80.4±2.4 | 92.6±0.7 | 81.3±1.9 | 72.8±2.9 | 93.2±0.5 | 72.6±1.9 |
109
+ | PatchCore | CVPR 2022 | 88.8±2.6 | 94.3±0.5 | 84.3±1.6 | 85.3±2.1 | 96.8±0.3 | 84.9±1.4 |
110
+ | WinCLIP | CVPR 2023 | 95.2±1.3 | 96.2±0.3 | 89.0±0.8 | 87.3±1.8 | 97.2±0.2 | 87.6±0.9 |
111
+ | April-GAN | CVPR 2023 VAND workshop | 92.8±0.2 | 95.9±0.0 | 91.8±0.1 | 92.6±0.4 | 96.2±0.0 | 90.2±0.1 |
112
+ | PromptAD | CVPR 2024 | 96.6±0.9 | 96.5±0.2 | - | 89.1±1.7 | 97.4±0.3 | - |
113
+ | InCTRL | CVPR 2024 | 94.5±1.8 | - | - | 87.7±1.9 | - | - |
114
+ | SOWA | Ours | 96.8±0.3 | 95.7±0.1 | 92.4±0.2 | 92.9±0.2 | 97.1±0.0 | 91.4±0.0 |
115
+
116
+
117
+ ​<!-- 零宽空格 -->
118
+
119
+ Comparisons with few-shot anomaly detection methods on datasets of MVTec-AD, Visa, BTAD, DAGM and DTD Synthetic.
120
+ <div align="center">
121
+ <img src="https://github.com/huzongxiang/sowa/blob/resources/fig5.png" alt="few-shot" style="width: 70%;">
122
+ </div>
123
+
124
+
125
+ ## Visualization
126
+ Visualization results under the few-shot setting (K=4).
127
+ <div align="center">
128
+ <img src="https://github.com/huzongxiang/sowa/blob/resources/fig6.png" alt="concept" style="width: 70%;">
129
+ </div>
130
+
131
+
132
+ ## Mechanism
133
+ Hierarchical Results on MVTec-AD Dataset. A set of images showing the real outputs of the model, illustrating how different layers (H1 to H4) process various feature modes. Each row represents a different sample, with columns showing the original image, segmentation mask, heatmap, and feature outputs from H1 to H4, and fusion.
134
+ ![mechanism](https://github.com/huzongxiang/sowa/blob/resources/fig7.png)
135
+
136
+
137
+ ## Inference Speed
138
+ Inference performance comparison of different methods on a single NVIDIA RTX3070 8GB GPU.
139
+ <div align="center">
140
+ <img src="https://github.com/huzongxiang/sowa/blob/resources/fig9.png" alt="speed" style="width: 80%;">
141
+ </div>
142
+
143
+
144
+ ## Citation
145
+ Please cite the following paper if this work helps your project:
146
+ ```
147
+ @article{hu2024sowa,
148
+ title={SOWA: Adapting Hierarchical Frozen Window Self-Attention to Visual-Language Models for Better Anomaly Detection},
149
+ author={Hu, Zongxiang and Zhang, zhaosheng},
150
+ journal={arXiv preprint arXiv:2407.03634},
151
+ year={2024}
152
+ }
153
+ ```
SOWA/configs/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # this file is needed here to include configs when building project as a package
SOWA/configs/callbacks/default.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - model_checkpoint
3
+ - early_stopping
4
+ - model_summary
5
+ - rich_progress_bar
6
+ - visualization
7
+ - _self_
8
+
9
+ model_checkpoint:
10
+ dirpath: ${paths.output_dir}/checkpoints
11
+ filename: "epoch_{epoch:03d}"
12
+ monitor: "train/loss"
13
+ mode: "min"
14
+ save_last: True
15
+ auto_insert_metric_name: False
16
+
17
+ early_stopping:
18
+ monitor: "train/loss"
19
+ patience: 10
20
+ mode: "min"
21
+
22
+ model_summary:
23
+ max_depth: -1
24
+
25
+ visualization:
26
+ dirpath: ${paths.output_dir}/visualizations
27
+ visualize: True
SOWA/configs/callbacks/early_stopping.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html
2
+
3
+ early_stopping:
4
+ _target_: lightning.pytorch.callbacks.EarlyStopping
5
+ monitor: ??? # quantity to be monitored, must be specified !!!
6
+ min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement
7
+ patience: 3 # number of checks with no improvement after which training will be stopped
8
+ verbose: False # verbosity mode
9
+ mode: "min" # "max" means higher metric value is better, can be also "min"
10
+ strict: True # whether to crash the training if monitor is not found in the validation metrics
11
+ check_finite: True # when set True, stops training when the monitor becomes NaN or infinite
12
+ stopping_threshold: null # stop training immediately once the monitored quantity reaches this threshold
13
+ divergence_threshold: null # stop training as soon as the monitored quantity becomes worse than this threshold
14
+ check_on_train_epoch_end: null # whether to run early stopping at the end of the training epoch
15
+ # log_rank_zero_only: False # this keyword argument isn't available in stable version
SOWA/configs/callbacks/model_checkpoint.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html
2
+
3
+ model_checkpoint:
4
+ _target_: lightning.pytorch.callbacks.ModelCheckpoint
5
+ dirpath: null # directory to save the model file
6
+ filename: null # checkpoint filename
7
+ monitor: null # name of the logged metric which determines when model is improving
8
+ verbose: False # verbosity mode
9
+ save_last: null # additionally always save an exact copy of the last checkpoint to a file last.ckpt
10
+ save_top_k: 1 # save k best models (determined by above metric)
11
+ mode: "min" # "max" means higher metric value is better, can be also "min"
12
+ auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name
13
+ save_weights_only: False # if True, then only the model’s weights will be saved
14
+ every_n_train_steps: null # number of training steps between checkpoints
15
+ train_time_interval: null # checkpoints are monitored at the specified time interval
16
+ every_n_epochs: null # number of epochs between checkpoints
17
+ save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation
SOWA/configs/callbacks/model_summary.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html
2
+
3
+ model_summary:
4
+ _target_: lightning.pytorch.callbacks.RichModelSummary
5
+ max_depth: 1 # the maximum depth of layer nesting that the summary will include
SOWA/configs/callbacks/none.yaml ADDED
File without changes
SOWA/configs/callbacks/rich_progress_bar.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.RichProgressBar.html
2
+
3
+ rich_progress_bar:
4
+ _target_: lightning.pytorch.callbacks.RichProgressBar
SOWA/configs/callbacks/visualization.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ visualization:
2
+ _target_: src.models.components.callback.AnomalyVisualizationCallback
3
+ dirpath: ${paths.output_dir}/visualizations
4
+ visualize: True
SOWA/configs/data/sowa_infer.yaml ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: src.data.anomaly_clip_datamodule.AnomalyCLIPDataModule
2
+ data_dir:
3
+ train: /home/hzx/Projects/Data/MVTec-AD
4
+ valid: /home/hzx/Projects/Data/MVTec-AD
5
+ # test: /home/hzx/Projects/Data/BTAD
6
+ # test: /home/hzx/Projects/Data/DAGM
7
+ test: /home/hzx/Projects/Data/DTD-Synthetic
8
+ # test: /home/hzx/Projects/Data/MPDD
9
+ # test: /home/hzx/Projects/Data/SDD
10
+ dataset:
11
+ train:
12
+ _target_: src.data.components.anomal_dataset.MVTecDataset
13
+ _partial_: true
14
+ transform:
15
+ _target_: src.data.components.transform.ImageTransform
16
+ image_size: 336
17
+ mask_transform:
18
+ _target_: src.data.components.transform.MaskTransform
19
+ image_size: ${data.image_size}
20
+ preload: false
21
+ aug_rate: 0.2
22
+ valid:
23
+ _target_: src.data.components.anomal_dataset.VisaDataset
24
+ _partial_: true
25
+ transform:
26
+ _target_: src.data.components.transform.ImageTransform
27
+ image_size: 336
28
+ mask_transform:
29
+ _target_: src.data.components.transform.MaskTransform
30
+ image_size: ${data.image_size}
31
+ preload: false
32
+ test:
33
+ _target_: src.data.components.anomal_dataset.VisaDataset
34
+ _partial_: true
35
+ transform:
36
+ _target_: src.data.components.transform.ImageTransform
37
+ image_size: 336
38
+ mask_transform:
39
+ _target_: src.data.components.transform.MaskTransform
40
+ image_size: ${data.image_size}
41
+ preload: false
42
+ kshot:
43
+ _target_: src.data.components.kshot_dataset.VisaKShotDataset
44
+ _partial_: true
45
+ k_shot: 4
46
+ transform:
47
+ _target_: src.data.components.transform.ImageTransform
48
+ image_size: 336
49
+ mask_transform:
50
+ _target_: src.data.components.transform.MaskTransform
51
+ image_size: ${data.image_size}
52
+ preload: false
53
+ image_size: 336
54
+ num_workers: 6
55
+ pin_memory: False
56
+ batch_size: 8
SOWA/configs/data/sowa_mvt.yaml ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: src.data.anomaly_clip_datamodule.AnomalyCLIPDataModule
2
+ data_dir:
3
+ train: /home/hzx/Projects/Data/Visa
4
+ valid: /home/hzx/Projects/Data/MVTec-AD
5
+ test: /home/hzx/Projects/Data/MVTec-AD
6
+ dataset:
7
+ train:
8
+ _target_: src.data.components.anomal_dataset.VisaDataset
9
+ _partial_: true
10
+ transform:
11
+ _target_: src.data.components.transform.ImageTransform
12
+ image_size: 336
13
+ mask_transform:
14
+ _target_: src.data.components.transform.MaskTransform
15
+ image_size: ${data.image_size}
16
+ preload: false
17
+ valid:
18
+ _target_: src.data.components.anomal_dataset.MVTecDataset
19
+ _partial_: true
20
+ transform:
21
+ _target_: src.data.components.transform.ImageTransform
22
+ image_size: 336
23
+ mask_transform:
24
+ _target_: src.data.components.transform.MaskTransform
25
+ image_size: ${data.image_size}
26
+ preload: false
27
+ aug_rate: 0.0
28
+ test:
29
+ _target_: src.data.components.anomal_dataset.MVTecDataset
30
+ _partial_: true
31
+ transform:
32
+ _target_: src.data.components.transform.ImageTransform
33
+ image_size: 336
34
+ mask_transform:
35
+ _target_: src.data.components.transform.MaskTransform
36
+ image_size: ${data.image_size}
37
+ preload: false
38
+ aug_rate: 0.0
39
+ kshot:
40
+ _target_: src.data.components.kshot_dataset.MVTecKShotDataset
41
+ _partial_: true
42
+ k_shot: 1
43
+ transform:
44
+ _target_: src.data.components.transform.ImageTransform
45
+ image_size: 336
46
+ mask_transform:
47
+ _target_: src.data.components.transform.MaskTransform
48
+ image_size: ${data.image_size}
49
+ preload: false
50
+ image_size: 336
51
+ num_workers: 6
52
+ pin_memory: False
53
+ batch_size: 8
SOWA/configs/data/sowa_overfit.yaml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: src.data.anomaly_clip_datamodule.AnomalyCLIPDataModule
2
+ data_dir:
3
+ train: /home/hzx/Projects/Data/MVTec-AD
4
+ valid: /home/hzx/Projects/Data/MVTec-AD
5
+ test: /home/hzx/Projects/Data/MVTec-AD
6
+ dataset:
7
+ train:
8
+ _target_: src.data.components.anomal_dataset.MVTecDataset
9
+ _partial_: true
10
+ transform:
11
+ _target_: src.data.components.transform.ImageTransform
12
+ image_size: 336
13
+ mask_transform:
14
+ _target_: src.data.components.transform.MaskTransform
15
+ image_size: ${data.image_size}
16
+ preload: false
17
+ valid:
18
+ _target_: src.data.components.anomal_dataset.MVTecDataset
19
+ _partial_: true
20
+ transform:
21
+ _target_: src.data.components.transform.ImageTransform
22
+ image_size: 336
23
+ mask_transform:
24
+ _target_: src.data.components.transform.MaskTransform
25
+ image_size: ${data.image_size}
26
+ preload: false
27
+ test:
28
+ _target_: src.data.components.anomal_dataset.MVTecDataset
29
+ _partial_: true
30
+ transform:
31
+ _target_: src.data.components.transform.ImageTransform
32
+ image_size: 336
33
+ mask_transform:
34
+ _target_: src.data.components.transform.MaskTransform
35
+ image_size: ${data.image_size}
36
+ preload: false
37
+ kshot:
38
+ _target_: src.data.components.kshot_dataset.MVTecKShotDataset
39
+ _partial_: true
40
+ k_shot: 1
41
+ transform:
42
+ _target_: src.data.components.transform.ImageTransform
43
+ image_size: 336
44
+ mask_transform:
45
+ _target_: src.data.components.transform.MaskTransform
46
+ image_size: ${data.image_size}
47
+ preload: false
48
+ image_size: 336
49
+ num_workers: 6
50
+ pin_memory: False
51
+ batch_size: 8
SOWA/configs/data/sowa_visa.yaml ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: src.data.anomaly_clip_datamodule.AnomalyCLIPDataModule
2
+ data_dir:
3
+ train: /home/hzx/Projects/Data/MVTec-AD
4
+ valid: /home/hzx/Projects/Data/Visa
5
+ test: /home/hzx/Projects/Data/Visa
6
+ dataset:
7
+ train:
8
+ _target_: src.data.components.anomal_dataset.MVTecDataset
9
+ _partial_: true
10
+ transform:
11
+ _target_: src.data.components.transform.ImageTransform
12
+ image_size: 336
13
+ mask_transform:
14
+ _target_: src.data.components.transform.MaskTransform
15
+ image_size: ${data.image_size}
16
+ preload: false
17
+ aug_rate: 0.2
18
+ valid:
19
+ _target_: src.data.components.anomal_dataset.VisaDataset
20
+ _partial_: true
21
+ transform:
22
+ _target_: src.data.components.transform.ImageTransform
23
+ image_size: 336
24
+ mask_transform:
25
+ _target_: src.data.components.transform.MaskTransform
26
+ image_size: ${data.image_size}
27
+ preload: false
28
+ test:
29
+ _target_: src.data.components.anomal_dataset.VisaDataset
30
+ _partial_: true
31
+ transform:
32
+ _target_: src.data.components.transform.ImageTransform
33
+ image_size: 336
34
+ mask_transform:
35
+ _target_: src.data.components.transform.MaskTransform
36
+ image_size: ${data.image_size}
37
+ preload: false
38
+ kshot:
39
+ _target_: src.data.components.kshot_dataset.VisaKShotDataset
40
+ _partial_: true
41
+ k_shot: 1
42
+ transform:
43
+ _target_: src.data.components.transform.ImageTransform
44
+ image_size: 336
45
+ mask_transform:
46
+ _target_: src.data.components.transform.MaskTransform
47
+ image_size: ${data.image_size}
48
+ preload: false
49
+ image_size: 336
50
+ num_workers: 6
51
+ pin_memory: False
52
+ batch_size: 8
SOWA/configs/debug/default.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # default debugging setup, runs 1 full epoch
4
+ # other debugging configs can inherit from this one
5
+
6
+ # overwrite task name so debugging logs are stored in separate folder
7
+ task_name: "debug"
8
+
9
+ # disable callbacks and loggers during debugging
10
+ callbacks: null
11
+ logger: null
12
+
13
+ extras:
14
+ ignore_warnings: False
15
+ enforce_tags: False
16
+
17
+ # sets level of all command line loggers to 'DEBUG'
18
+ # https://hydra.cc/docs/tutorials/basic/running_your_app/logging/
19
+ hydra:
20
+ job_logging:
21
+ root:
22
+ level: DEBUG
23
+
24
+ # use this to also set hydra loggers to 'DEBUG'
25
+ # verbose: True
26
+
27
+ trainer:
28
+ max_epochs: 1
29
+ accelerator: cpu # debuggers don't like gpus
30
+ devices: 1 # debuggers don't like multiprocessing
31
+ detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor
32
+
33
+ data:
34
+ num_workers: 0 # debuggers don't like multiprocessing
35
+ pin_memory: False # disable gpu memory pin
SOWA/configs/debug/fdr.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # runs 1 train, 1 validation and 1 test step
4
+
5
+ defaults:
6
+ - default
7
+
8
+ trainer:
9
+ fast_dev_run: true
SOWA/configs/debug/limit.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # uses only 1% of the training data and 5% of validation/test data
4
+
5
+ defaults:
6
+ - default
7
+
8
+ trainer:
9
+ max_epochs: 3
10
+ limit_train_batches: 0.01
11
+ limit_val_batches: 0.05
12
+ limit_test_batches: 0.05
SOWA/configs/debug/overfit.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # overfits to 3 batches
4
+
5
+ defaults:
6
+ - default
7
+
8
+ trainer:
9
+ max_epochs: 20
10
+ overfit_batches: 3
11
+
12
+ # model ckpt and early stopping need to be disabled during overfitting
13
+ callbacks: null
SOWA/configs/debug/profiler.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # runs with execution time profiling
4
+
5
+ defaults:
6
+ - default
7
+
8
+ trainer:
9
+ max_epochs: 1
10
+ profiler: "simple"
11
+ # profiler: "advanced"
12
+ # profiler: "pytorch"
SOWA/configs/eval.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - _self_
5
+ - data: anomaly_clip # choose datamodule with `test_dataloader()` for evaluation
6
+ - model: anomaly_clip
7
+ - callbacks: default
8
+ - logger: many_loggers
9
+ - trainer: default
10
+ - paths: default
11
+ - extras: default
12
+ - hydra: default
13
+
14
+ # information of object and prompt template
15
+ - prompt: default
16
+
17
+ task_name: "eval"
18
+
19
+ tags: ["dev"]
20
+
21
+ # seed for random number generators in pytorch, numpy and python.random
22
+ seed: 42
23
+
24
+ # passing checkpoint path is necessary for evaluation
25
+ # ckpt_path: /home/hzx/Projects/SPARC/logs/train/runs/2024-06-14_16-32-14/checkpoints/epoch_000.ckpt
26
+ # ckpt_path: /home/hzx/Projects/Weight/mvtech-kshot/learnable/0-shot/2024-05-27_13-08-55/checkpoints/epoch_000.ckpt
27
+ ckpt_path: /home/hzx/Projects/Weight/visa-kshot/learnable/0-shot/2024-05-27_11-07-16/checkpoints/epoch_000.ckpt
SOWA/configs/experiment/example.yaml ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ task_name: train
2
+ tags:
3
+ - dev
4
+ train: true
5
+ test: true
6
+ ckpt_path: null
7
+ seed: 2025
8
+ data:
9
+ _target_: src.data.anomaly_clip_datamodule.AnomalyCLIPDataModule
10
+ data_dir:
11
+ train: /home/hzx/Projects/Data/Visa
12
+ valid: /home/hzx/Projects/Data/MVTec-AD
13
+ test: /home/hzx/Projects/Data/MVTec-AD
14
+ dataset:
15
+ train:
16
+ _target_: src.data.components.anomal_dataset.VisaDataset
17
+ _partial_: true
18
+ transform:
19
+ _target_: src.data.components.transform.ImageTransform
20
+ image_size: 336
21
+ mask_transform:
22
+ _target_: src.data.components.transform.MaskTransform
23
+ image_size: ${data.image_size}
24
+ preload: false
25
+ valid:
26
+ _target_: src.data.components.anomal_dataset.MVTecDataset
27
+ _partial_: true
28
+ transform:
29
+ _target_: src.data.components.transform.ImageTransform
30
+ image_size: 336
31
+ mask_transform:
32
+ _target_: src.data.components.transform.MaskTransform
33
+ image_size: ${data.image_size}
34
+ preload: false
35
+ aug_rate: 0.0
36
+ test:
37
+ _target_: src.data.components.anomal_dataset.MVTecDataset
38
+ _partial_: true
39
+ transform:
40
+ _target_: src.data.components.transform.ImageTransform
41
+ image_size: 336
42
+ mask_transform:
43
+ _target_: src.data.components.transform.MaskTransform
44
+ image_size: ${data.image_size}
45
+ preload: false
46
+ aug_rate: 0.0
47
+ kshot:
48
+ _target_: src.data.components.kshot_dataset.MVTecKShotDataset
49
+ _partial_: true
50
+ k_shot: 1
51
+ transform:
52
+ _target_: src.data.components.transform.ImageTransform
53
+ image_size: 336
54
+ mask_transform:
55
+ _target_: src.data.components.transform.MaskTransform
56
+ image_size: ${data.image_size}
57
+ preload: false
58
+ image_size: 336
59
+ num_workers: 4
60
+ pin_memory: false
61
+ batch_size: 8
62
+ model:
63
+ _target_: src.models.anomaly_clip_module.AnomalyCLIPModule
64
+ optimizer:
65
+ _target_: torch.optim.AdamW
66
+ _partial_: true
67
+ lr: 0.001
68
+ weight_decay: 0.2
69
+ scheduler:
70
+ _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
71
+ _partial_: true
72
+ mode: min
73
+ factor: 0.1
74
+ patience: 5
75
+ net:
76
+ _target_: src.models.components.anomaly_clip.AnomalyCLIP
77
+ arch: ViT-L/14@336px
78
+ image_size: 336
79
+ class_names:
80
+ - object
81
+ temperature: 0.05
82
+ prompt_length: 24
83
+ context_length: 77
84
+ truncate: false
85
+ feature_map_idx:
86
+ - 5
87
+ - 11
88
+ - 17
89
+ - 23
90
+ share_weight: false
91
+ state_template:
92
+ normal:
93
+ - '{}'
94
+ anomaly:
95
+ - damaged {}
96
+ tokenizer:
97
+ _target_: src.models.components.clip.simple_tokenizer.SimpleTokenizer
98
+ adapter:
99
+ _target_: src.models.components.adapter.BasicLayer
100
+ _partial_: true
101
+ input_resolution:
102
+ - 24
103
+ - 24
104
+ window_size: 6
105
+ depth: 1
106
+ num_heads: 8
107
+ hidden_features: null
108
+ cpb_dim: 64
109
+ value_only: true
110
+ drop: 0.0
111
+ attn_drop: 0.2
112
+ loss:
113
+ cross_entropy:
114
+ _target_: torch.nn.CrossEntropyLoss
115
+ focal:
116
+ _target_: src.models.components.loss.FocalLoss
117
+ dice:
118
+ _target_: src.models.components.loss.BinaryDiceLoss
119
+ k_shot: false
120
+ enable_validation: false
121
+ compile: false
122
+ callbacks:
123
+ model_checkpoint:
124
+ _target_: lightning.pytorch.callbacks.ModelCheckpoint
125
+ dirpath: ${paths.output_dir}/checkpoints
126
+ filename: epoch_{epoch:03d}
127
+ monitor: train/loss
128
+ verbose: false
129
+ save_last: true
130
+ save_top_k: 1
131
+ mode: min
132
+ auto_insert_metric_name: false
133
+ save_weights_only: false
134
+ every_n_train_steps: null
135
+ train_time_interval: null
136
+ every_n_epochs: null
137
+ save_on_train_epoch_end: null
138
+ early_stopping:
139
+ _target_: lightning.pytorch.callbacks.EarlyStopping
140
+ monitor: train/loss
141
+ min_delta: 0.0
142
+ patience: 10
143
+ verbose: false
144
+ mode: min
145
+ strict: true
146
+ check_finite: true
147
+ stopping_threshold: null
148
+ divergence_threshold: null
149
+ check_on_train_epoch_end: null
150
+ model_summary:
151
+ _target_: lightning.pytorch.callbacks.RichModelSummary
152
+ max_depth: -1
153
+ rich_progress_bar:
154
+ _target_: lightning.pytorch.callbacks.RichProgressBar
155
+ visualization:
156
+ _target_: src.models.components.callback.AnomalyVisualizationCallback
157
+ dirpath: ${paths.output_dir}/visualizations
158
+ visualize: true
159
+ visulization:
160
+ dirpath: ${paths.output_dir}/visualizations
161
+ visualize: true
162
+ logger:
163
+ wandb:
164
+ _target_: lightning.pytorch.loggers.wandb.WandbLogger
165
+ save_dir: ${paths.output_dir}
166
+ offline: false
167
+ id: null
168
+ anonymous: null
169
+ project: mvt_optuna
170
+ log_model: false
171
+ prefix: ''
172
+ group: ''
173
+ tags: []
174
+ job_type: ''
175
+ trainer:
176
+ _target_: lightning.pytorch.trainer.Trainer
177
+ default_root_dir: ${paths.output_dir}
178
+ min_epochs: 1
179
+ max_epochs: 2
180
+ accelerator: gpu
181
+ devices: 1
182
+ check_val_every_n_epoch: 1
183
+ deterministic: false
184
+ paths:
185
+ root_dir: ${oc.env:PROJECT_ROOT}
186
+ data_dir: ${paths.root_dir}/data/
187
+ log_dir: ${paths.root_dir}/logs/
188
+ output_dir: ${hydra:runtime.output_dir}
189
+ work_dir: ${hydra:runtime.cwd}
190
+ extras:
191
+ ignore_warnings: false
192
+ enforce_tags: true
193
+ print_config: true
194
+ optimized_metric: test/objective
SOWA/configs/extras/default.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # disable python warnings if they annoy you
2
+ ignore_warnings: False
3
+
4
+ # ask user for tags if none are provided in the config
5
+ enforce_tags: True
6
+
7
+ # pretty print config tree at the start of the run using Rich library
8
+ print_config: True
SOWA/configs/hparams_search/anomaly_clip_optuna.yaml ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # example hyperparameter optimization of some experiment with Optuna:
4
+ # python train.py -m hparams_search=mnist_optuna experiment=example
5
+
6
+ defaults:
7
+ - override /hydra/sweeper: optuna
8
+
9
+ # choose metric which will be optimized by Optuna
10
+ # make sure this is the correct name of some metric logged in lightning module!
11
+ optimized_metric: test/objective
12
+
13
+ # here we define Optuna hyperparameter search
14
+ # it optimizes for value returned from function with @hydra.main decorator
15
+ # docs: https://hydra.cc/docs/next/plugins/optuna_sweeper
16
+ hydra:
17
+ mode: "MULTIRUN" # set hydra to multirun by default if this config is attached
18
+
19
+ sweeper:
20
+ _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper
21
+
22
+ # storage URL to persist optimization results
23
+ # for example, you can use SQLite if you set 'sqlite:///example.db'
24
+ storage: null
25
+
26
+ # name of the study to persist optimization results
27
+ study_name: null
28
+
29
+ # number of parallel workers
30
+ n_jobs: 1
31
+
32
+ # 'minimize' or 'maximize' the objective
33
+ direction: maximize
34
+
35
+ # total number of runs that will be executed
36
+ n_trials: 50
37
+
38
+ # choose Optuna hyperparameter sampler
39
+ # you can choose bayesian sampler (tpe), random search (without optimization), grid sampler, and others
40
+ # docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html
41
+ sampler:
42
+ _target_: optuna.samplers.TPESampler
43
+ seed: 1234
44
+ n_startup_trials: 50 # number of random sampling runs before optimization starts
45
+
46
+ # define hyperparameter search space
47
+ params:
48
+ trainer.max_epochs: choice(1, 5)
49
+ # model.optimizer.lr: choice(0.0001, 0.001)
50
+ # model.net.temperature: choice(0.1, 0.05)
51
+ model.net.prompt_length: choice(8, 12, 16, 24)
52
+ model.net.share_weight: choice(true, false)
53
+ model.net.feature_map_idx : choice([5, 11, 17, 23], [0, 11, 23])
54
+ # model.net.adapter.hidden_features: choice([1024])
55
+ model.net.adapter.window_size: choice(6, 12, 24)
56
+ model.net.adapter.depth: choice(1, 2)
57
+ model.net.adapter.num_heads: choice(8)
58
+ # model.net.adapter.cpb_dim: choice(64, 128, 512)
59
+ model.net.adapter.value_only: choice(true, false)
60
+ model.net.adapter.drop: choice(0.0, 0.1, 0.2)
61
+ model.net.adapter.attn_drop: choice(0.0, 0.1, 0.2)
SOWA/configs/hydra/default.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://hydra.cc/docs/configure_hydra/intro/
2
+
3
+ # enable color logging
4
+ defaults:
5
+ - override hydra_logging: default
6
+ - override job_logging: default
7
+
8
+ # output directory, generated dynamically on each run
9
+ run:
10
+ dir: ${paths.log_dir}/${task_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S}
11
+ sweep:
12
+ dir: ${paths.log_dir}/${task_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S}
13
+ subdir: ${hydra.job.num}
14
+
15
+ job_logging:
16
+ handlers:
17
+ file:
18
+ # Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242
19
+ filename: ${hydra.runtime.output_dir}/${task_name}.log
SOWA/configs/local/.gitkeep ADDED
File without changes
SOWA/configs/logger/aim.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://aimstack.io/
2
+
3
+ # example usage in lightning module:
4
+ # https://github.com/aimhubio/aim/blob/main/examples/pytorch_lightning_track.py
5
+
6
+ # open the Aim UI with the following command (run in the folder containing the `.aim` folder):
7
+ # `aim up`
8
+
9
+ aim:
10
+ _target_: aim.pytorch_lightning.AimLogger
11
+ repo: ${paths.root_dir} # .aim folder will be created here
12
+ # repo: "aim://ip_address:port" # can instead provide IP address pointing to Aim remote tracking server which manages the repo, see https://aimstack.readthedocs.io/en/latest/using/remote_tracking.html#
13
+
14
+ # aim allows to group runs under experiment name
15
+ experiment: null # any string, set to "default" if not specified
16
+
17
+ train_metric_prefix: "train/"
18
+ val_metric_prefix: "val/"
19
+ test_metric_prefix: "test/"
20
+
21
+ # sets the tracking interval in seconds for system usage metrics (CPU, GPU, memory, etc.)
22
+ system_tracking_interval: 10 # set to null to disable system metrics tracking
23
+
24
+ # enable/disable logging of system params such as installed packages, git info, env vars, etc.
25
+ log_system_params: true
26
+
27
+ # enable/disable tracking console logs (default value is true)
28
+ capture_terminal_logs: false # set to false to avoid infinite console log loop issue https://github.com/aimhubio/aim/issues/2550
SOWA/configs/logger/comet.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://www.comet.ml
2
+
3
+ comet:
4
+ _target_: lightning.pytorch.loggers.comet.CometLogger
5
+ api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable
6
+ save_dir: "${paths.output_dir}"
7
+ project_name: "lightning-hydra-template"
8
+ rest_api_key: null
9
+ # experiment_name: ""
10
+ experiment_key: null # set to resume experiment
11
+ offline: False
12
+ prefix: ""
SOWA/configs/logger/csv.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # csv logger built in lightning
2
+
3
+ csv:
4
+ _target_: lightning.pytorch.loggers.csv_logs.CSVLogger
5
+ save_dir: "${paths.output_dir}"
6
+ name: "csv/"
7
+ prefix: ""
SOWA/configs/logger/many_loggers.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # train with many loggers at once
2
+
3
+ defaults:
4
+ # - comet
5
+ - csv
6
+ # - mlflow
7
+ # - neptune
8
+ # - tensorboard
9
+ - wandb
SOWA/configs/logger/mlflow.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://mlflow.org
2
+
3
+ mlflow:
4
+ _target_: lightning.pytorch.loggers.mlflow.MLFlowLogger
5
+ # experiment_name: ""
6
+ # run_name: ""
7
+ tracking_uri: ${paths.log_dir}/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI
8
+ tags: null
9
+ # save_dir: "./mlruns"
10
+ prefix: ""
11
+ artifact_location: null
12
+ # run_id: ""
SOWA/configs/logger/neptune.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # https://neptune.ai
2
+
3
+ neptune:
4
+ _target_: lightning.pytorch.loggers.neptune.NeptuneLogger
5
+ api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable
6
+ project: username/lightning-hydra-template
7
+ # name: ""
8
+ log_model_checkpoints: True
9
+ prefix: ""
SOWA/configs/logger/tensorboard.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://www.tensorflow.org/tensorboard/
2
+
3
+ tensorboard:
4
+ _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
5
+ save_dir: "${paths.output_dir}/tensorboard/"
6
+ name: null
7
+ log_graph: False
8
+ default_hp_metric: True
9
+ prefix: ""
10
+ # version: ""
SOWA/configs/logger/wandb.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://wandb.ai
2
+
3
+ wandb:
4
+ _target_: lightning.pytorch.loggers.wandb.WandbLogger
5
+ # name: "" # name of the run (normally generated by wandb)
6
+ save_dir: "${paths.output_dir}"
7
+ offline: False
8
+ id: null # pass correct id to resume experiment!
9
+ anonymous: null # enable anonymous logging
10
+ project: mvt_optuna
11
+ log_model: False # upload lightning ckpts
12
+ prefix: "" # a string to put at the beginning of metric keys
13
+ # entity: "" # set to name of your wandb team
14
+ group: ""
15
+ tags: []
16
+ job_type: ""
SOWA/configs/model/sowa_hfwa.yaml ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: src.models.anomaly_clip_module.AnomalyCLIPModule
2
+
3
+ optimizer:
4
+ _target_: torch.optim.AdamW
5
+ _partial_: true
6
+ lr: 0.001
7
+ weight_decay: 0.2
8
+
9
+ scheduler:
10
+ _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
11
+ _partial_: true
12
+ mode: min
13
+ factor: 0.1
14
+ patience: 5
15
+
16
+ # scheduler:
17
+ # _target_: src.models.components.scheduler.WarmupCosineAnnealingLR
18
+ # _partial_: true
19
+ # warmup_epochs: 10
20
+ # total_epoch: 50
21
+
22
+ net:
23
+ _target_: src.models.components.anomaly_clip.AnomalyCLIP
24
+ arch: ViT-L/14@336px
25
+ image_size: 336
26
+ class_names: ["object"]
27
+ # class_names: ${prompt.class_names}
28
+ temperature: 0.07 # softmax
29
+ prompt_length: 12 # length of learnable prompts
30
+ context_length: 77 # defaut 77 for openai clip
31
+ truncate: false
32
+ feature_map_idx: [5, 11, 17, 23] # [0, 12, 23] [6, 12, 18] [5, 11, 17, 23] index of resnetblock in ViT
33
+ share_weight: true # whether the adapter shares weights for different feature maps
34
+ # state_template: ${prompt.state_template}
35
+ state_template:
36
+ normal: ["{}"]
37
+ anomaly: ["damaged {}"]
38
+ tokenizer:
39
+ _target_: src.models.components.clip.simple_tokenizer.SimpleTokenizer
40
+ adapter:
41
+ _target_: src.models.components.adapter.BasicLayer
42
+ _partial_: true
43
+ input_resolution: [24, 24] # (image_size - kerner_size) / stride + 1. eg. 24 = (224 - 14) / 14 + 1
44
+ window_size: 12
45
+ depth: 1 # if depth < 2, thers is no window shift
46
+ num_heads: 8
47
+ hidden_features: null # set null, same as nn.Linear
48
+ cpb_dim: 64
49
+ value_only: true
50
+ drop: 0.0
51
+ attn_drop: 0.1
52
+ # shift_size: 1
53
+ fusion:
54
+ _target_: src.models.components.cross_modal.DotProductFusion
55
+ embedding_dim: 768 # clip fusion featrue dim, default 768, only effective for non null
56
+
57
+ loss:
58
+ cross_entropy:
59
+ _target_: torch.nn.CrossEntropyLoss
60
+ focal:
61
+ _target_: src.models.components.loss.FocalLoss
62
+ dice:
63
+ _target_: src.models.components.loss.BinaryDiceLoss
64
+
65
+ k_shot: false
66
+
67
+ filter: true
68
+
69
+ enable_validation: false
70
+
71
+ compile: false
SOWA/configs/model/sowa_linear.yaml ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: src.models.anomaly_clip_module.AnomalyCLIPModule
2
+
3
+ optimizer:
4
+ _target_: torch.optim.AdamW
5
+ _partial_: true
6
+ lr: 0.001
7
+ weight_decay: 0.2
8
+
9
+ scheduler:
10
+ _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
11
+ _partial_: true
12
+ mode: min
13
+ factor: 0.1
14
+ patience: 5
15
+
16
+ net:
17
+ _target_: src.models.components.anomaly_clip.AnomalyCLIP
18
+ arch: ViT-L/14@336px
19
+ image_size: 336
20
+ class_names: ["object"]
21
+ # class_names: ${prompt.class_names}
22
+ temperature: 0.07 # softmax
23
+ prompt_length: 12 # length of learnable prompts
24
+ context_length: 77 # defaut 77 for openai clip
25
+ truncate: false
26
+ feature_map_idx: [5, 11, 17, 23] # [0, 12, 23] [6, 12, 18] [5, 11, 17, 23] index of resnetblock in ViT
27
+ share_weight: false # whether the adapter shares weights for different feature maps
28
+ # state_template: ${prompt.state_template}
29
+ state_template:
30
+ normal: ["{}"]
31
+ anomaly: ["damaged {}"]
32
+ tokenizer:
33
+ _target_: src.models.components.clip.simple_tokenizer.SimpleTokenizer
34
+ adapter:
35
+ # _target_: torch.nn.Linear
36
+ # in_features: 1024 # clip vit feature dim, defaut 1024 for openai clip
37
+ # out_features: 1024
38
+ # bias: false
39
+ _target_: src.models.components.adapter.Linear
40
+ in_features: 1024 # clip vit feature dim, defaut 1024 for openai clip
41
+ out_features: 1024
42
+ hidden_features: null # set null, same as nn.Linear
43
+ dropout_prob: 0.0
44
+ bias: false
45
+ fusion:
46
+ _target_: src.models.components.cross_modal.DotProductFusion
47
+ embedding_dim: null # clip fusion featrue dim, only effective for learnable
48
+
49
+ loss:
50
+ cross_entropy:
51
+ _target_: torch.nn.CrossEntropyLoss
52
+ focal:
53
+ _target_: src.models.components.loss.FocalLoss
54
+ dice:
55
+ _target_: src.models.components.loss.BinaryDiceLoss
56
+
57
+ k_shot: false
58
+
59
+ filter: true
60
+
61
+ enable_validation: false
62
+
63
+ compile: false
SOWA/configs/model/sparc_hfwa.yaml ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: src.models.anomaly_clip_module.AnomalyCLIPModule
2
+
3
+ optimizer:
4
+ _target_: torch.optim.AdamW
5
+ _partial_: true
6
+ lr: 0.001
7
+ weight_decay: 0.2
8
+
9
+ scheduler:
10
+ _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
11
+ _partial_: true
12
+ mode: min
13
+ factor: 0.1
14
+ patience: 5
15
+
16
+ # scheduler:
17
+ # _target_: src.models.components.scheduler.WarmupCosineAnnealingLR
18
+ # _partial_: true
19
+ # warmup_epochs: 10
20
+ # total_epoch: 50
21
+
22
+ net:
23
+ _target_: src.models.components.sparc.SPARC
24
+ arch: ViT-L/14@336px
25
+ image_size: 336
26
+ temperature: 0.07 # softmax
27
+ feature_map_idx: [5, 11, 17, 23] # [0, 12, 23] [6, 12, 18] [5, 11, 17, 23] index of resnetblock in ViT
28
+ prompt_learner:
29
+ _target_: src.models.components.coop.AnomalyPromptLearner
30
+ _partial_: true
31
+ tokenizer:
32
+ _target_: src.models.components.clip.simple_tokenizer.SimpleTokenizer
33
+ prompt_length: 12 # length of learnable prompts
34
+ context_length: 77 # defaut 77 for openai clip
35
+ truncate: false
36
+ class_names: ["object"]
37
+ # class_names: ${prompt.class_names}
38
+ # state_template: ${prompt.state_template}
39
+ state_template:
40
+ normal: ["{}"]
41
+ anomaly: ["damaged {}"]
42
+ text_encoder:
43
+ _target_: src.models.components.text_encoder.TextMapEncoder
44
+ _partial_: true
45
+ adapter:
46
+ _target_: src.models.components.adapter.BasicLayer
47
+ _partial_: true
48
+ input_resolution: [24, 24] # (image_size - kerner_size) / stride + 1. eg. 24 = (224 - 14) / 14 + 1
49
+ window_size: 12
50
+ depth: 1 # if depth < 2, thers is no window shift
51
+ num_heads: 8
52
+ hidden_features: null # set null, same as nn.Linear
53
+ cpb_dim: 64
54
+ value_only: true
55
+ drop: 0.0
56
+ attn_drop: 0.1
57
+ fusion:
58
+ _target_: src.models.components.cross_modal.DotProductFusion
59
+ embedding_dim: 768 # clip fusion featrue dim, default 768, only effective for non null
60
+
61
+ loss:
62
+ cross_entropy:
63
+ _target_: torch.nn.CrossEntropyLoss
64
+ focal:
65
+ _target_: src.models.components.loss.FocalLoss
66
+ dice:
67
+ _target_: src.models.components.loss.BinaryDiceLoss
68
+
69
+ k_shot: false
70
+
71
+ filter: true
72
+
73
+ enable_validation: false
74
+
75
+ compile: false
SOWA/configs/model/sparc_linear.yaml ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: src.models.anomaly_clip_module.AnomalyCLIPModule
2
+
3
+ optimizer:
4
+ _target_: torch.optim.AdamW
5
+ _partial_: true
6
+ lr: 0.001
7
+ weight_decay: 0.2
8
+
9
+ scheduler:
10
+ _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
11
+ _partial_: true
12
+ mode: min
13
+ factor: 0.1
14
+ patience: 5
15
+
16
+ # scheduler:
17
+ # _target_: src.models.components.scheduler.WarmupCosineAnnealingLR
18
+ # _partial_: true
19
+ # warmup_epochs: 10
20
+ # total_epoch: 50
21
+
22
+ net:
23
+ _target_: src.models.components.sparc.SPARC
24
+ arch: ViT-L/14@336px
25
+ image_size: 336
26
+ temperature: 0.07 # softmax
27
+ feature_map_idx: [5, 11, 17, 23] # [0, 12, 23] [6, 12, 18] [5, 11, 17, 23] index of resnetblock in ViT
28
+ prompt_learner:
29
+ _target_: src.models.components.coop.AnomalyPromptLearner
30
+ _partial_: true
31
+ tokenizer:
32
+ _target_: src.models.components.clip.simple_tokenizer.SimpleTokenizer
33
+ prompt_length: 12 # length of learnable prompts
34
+ context_length: 77 # defaut 77 for openai clip
35
+ truncate: false
36
+ class_names: ["object"]
37
+ # class_names: ${prompt.class_names}
38
+ # state_template: ${prompt.state_template}
39
+ state_template:
40
+ normal: ["{}"]
41
+ anomaly: ["damaged {}"]
42
+ text_encoder:
43
+ _target_: src.models.components.text_encoder.TextMapEncoder
44
+ _partial_: true
45
+ adapter:
46
+ # _target_: torch.nn.Linear
47
+ # in_features: 1024 # clip vit feature dim, defaut 1024 for openai clip
48
+ # out_features: 1024
49
+ # bias: false
50
+ _target_: src.models.components.adapter.Linear
51
+ in_features: 1024 # clip vit feature dim, defaut 1024 for openai clip
52
+ out_features: 1024
53
+ hidden_features: null # set null, same as nn.Linear
54
+ dropout_prob: 0.0
55
+ bias: false
56
+ fusion:
57
+ _target_: src.models.components.cross_modal.DotProductFusion
58
+ embedding_dim: 768 # clip fusion featrue dim, default 768, only effective for non null
59
+
60
+ loss:
61
+ cross_entropy:
62
+ _target_: torch.nn.CrossEntropyLoss
63
+ focal:
64
+ _target_: src.models.components.loss.FocalLoss
65
+ dice:
66
+ _target_: src.models.components.loss.BinaryDiceLoss
67
+
68
+ k_shot: false
69
+
70
+ filter: true
71
+
72
+ enable_validation: false
73
+
74
+ compile: false
SOWA/configs/model/sparc_prompt.yaml ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: src.models.anomaly_clip_module.AnomalyCLIPModule
2
+
3
+ optimizer:
4
+ _target_: torch.optim.AdamW
5
+ _partial_: true
6
+ lr: 0.001
7
+ weight_decay: 0.2
8
+
9
+ scheduler:
10
+ _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
11
+ _partial_: true
12
+ mode: min
13
+ factor: 0.1
14
+ patience: 5
15
+
16
+ # scheduler:
17
+ # _target_: src.models.components.scheduler.WarmupCosineAnnealingLR
18
+ # _partial_: true
19
+ # warmup_epochs: 10
20
+ # total_epoch: 50
21
+
22
+ net:
23
+ _target_: src.models.components..sparc.SPARC
24
+ arch: ViT-L/14@336px
25
+ image_size: 336
26
+ temperature: 0.07 # softmax
27
+ feature_map_idx: [5, 11, 17, 23] # [0, 12, 23] [6, 12, 18] [5, 11, 17, 23] index of resnetblock in ViT
28
+ share_weight: true # whether the adapter shares weights for different feature maps
29
+ prompt_learner:
30
+ _target_: src.models.components.coop.PromptEncoder
31
+ _partial_: true
32
+ tokenizer:
33
+ _target_: src.models.components.clip.simple_tokenizer.SimpleTokenizer
34
+ context_length: 77 # defaut 77 for openai clip
35
+ truncate: false
36
+ class_names: ${prompt.class_names}
37
+ prompt_normal: ${prompt.template.normal}
38
+ prompt_abnormal: ${prompt.template.abnormal}
39
+ prompt_templates: ${prompt.template.templates}
40
+ adapter:
41
+ _target_: torch.nn.Linear
42
+ in_features: 1024 # clip vit feature dim, defaut 1024 for openai clip
43
+ out_features: 1024
44
+ bias: false
45
+ # _target_: src.models.components.adapter.BasicLayer
46
+ # _partial_: true
47
+ # input_resolution: [24, 24] # (image_size - kerner_size) / stride + 1. eg. 24 = (224 - 14) / 14 + 1
48
+ # window_size: 12
49
+ # depth: 1 # if depth < 2, thers is no window shift
50
+ # num_heads: 8
51
+ # hidden_features: null # set null, same as nn.Linear
52
+ # cpb_dim: 64
53
+ # value_only: true
54
+ # drop: 0.0
55
+ # attn_drop: 0.1
56
+ fusion:
57
+ _target_: src.models.components.cross_modal.DotProductFusion
58
+ embedding_dim: 768 # clip fusion featrue dim, default 768, only effective for non null
59
+
60
+ loss:
61
+ cross_entropy:
62
+ _target_: torch.nn.CrossEntropyLoss
63
+ focal:
64
+ _target_: src.models.components.loss.FocalLoss
65
+ dice:
66
+ _target_: src.models.components.loss.BinaryDiceLoss
67
+
68
+ k_shot: false
69
+
70
+ filter: true
71
+
72
+ enable_validation: false
73
+
74
+ compile: false
SOWA/configs/paths/default.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # path to root directory
2
+ # this requires PROJECT_ROOT environment variable to exist
3
+ # you can replace it with "." if you want the root to be the current working directory
4
+ root_dir: ${oc.env:PROJECT_ROOT}
5
+
6
+ # path to data directory
7
+ data_dir: ${paths.root_dir}/data/
8
+
9
+ # path to logging directory
10
+ log_dir: ${paths.root_dir}/logs/
11
+
12
+ # path to output directory, created dynamically by hydra
13
+ # path generation pattern is specified in `configs/hydra/default.yaml`
14
+ # use it to store all files generated during the run, like ckpts and metrics
15
+ output_dir: ${hydra:runtime.output_dir}
16
+
17
+ # path to working directory
18
+ work_dir: ${hydra:runtime.cwd}
SOWA/configs/prompt/default.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # configs/prompt/default.yaml
2
+ defaults:
3
+ - object
4
+ - state_template
5
+ - template
SOWA/configs/prompt/object.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class_names:
2
+ - object
3
+ - pipe_fryum
4
+ - metal_nut
5
+ - pcb1
6
+ - tile
7
+ - screw
8
+ - pcb2
9
+ - wood
10
+ - zipper
11
+ - cable
12
+ - fryum
13
+ - pill
14
+ - capsule
15
+ - hazelnut
16
+ - pcb4
17
+ - leather
18
+ - bottle
19
+ - cashew
20
+ - macaroni2
21
+ - grid
22
+ - chewinggum
23
+ - transistor
24
+ - macaroni1
25
+ - candle
26
+ - capsules
27
+ - pcb3
28
+ - carpet
29
+ - toothbrush
SOWA/configs/prompt/state_template.yaml ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ state_template:
2
+ anomaly:
3
+ - a photo of the damaged {}.
4
+ - a bright photo of the damaged {}.
5
+ - a dark photo of the damaged {}.
6
+ - a close-up photo of the damaged {}.
7
+ - a black and white photo of the damaged {}.
8
+ - a blurry photo of the damaged {}.
9
+ - a blurry photo of a damaged {}.
10
+ - a photo of the small damaged {}.
11
+ - a photo of the large damaged {}.
12
+ - there is a damaged {} in the scene.
13
+ - this is one damaged {} in the scene.
14
+ - a photo of the broken {}.
15
+ - a bright photo of the broken {}.
16
+ - a dark photo of the broken {}.
17
+ - a close-up photo of the broken {}.
18
+ - a black and white photo of the broken {}.
19
+ - a blurry photo of the broken {}.
20
+ - a blurry photo of a broken {}.
21
+ - a photo of the small broken {}.
22
+ - a photo of the large broken {}.
23
+ - there is a broken {} in the scene.
24
+ - this is one broken {} in the scene.
25
+ - a photo of the {} with flaw.
26
+ - a bright photo of the {} with flaw.
27
+ - a dark photo of the {} with flaw.
28
+ - a close-up photo of the {} with flaw.
29
+ - a black and white photo of the {} with flaw.
30
+ - a blurry photo of the {} with flaw.
31
+ - a blurry photo of a {} with flaw.
32
+ - a photo of the small {} with flaw.
33
+ - a photo of the large {} with flaw.
34
+ - there is a {} with flaw in the scene.
35
+ - this is one {} with flaw in the scene.
36
+ - a photo of the {} with defect.
37
+ - a bright photo of the {} with defect.
38
+ - a dark photo of the {} with defect.
39
+ - a close-up photo of the {} with defect.
40
+ - a black and white photo of the {} with defect.
41
+ - a blurry photo of the {} with defect.
42
+ - a blurry photo of a {} with defect.
43
+ - a photo of the small {} with defect.
44
+ - a photo of the large {} with defect.
45
+ - there is a {} with defect in the scene.
46
+ - this is one {} with defect in the scene.
47
+ normal:
48
+ - a photo of the {}.
49
+ - a bright photo of the {}.
50
+ - a dark photo of the {}.
51
+ - a close-up photo of the {}.
52
+ - a black and white photo of the {}.
53
+ - a blurry photo of the {}.
54
+ - a blurry photo of a {}.
55
+ - a photo of the small {}.
56
+ - a photo of the large {}.
57
+ - there is a {} in the scene.
58
+ - this is one {} in the scene.
59
+ - a photo of the flawless {}.
60
+ - a bright photo of the flawless {}.
61
+ - a dark photo of the flawless {}.
62
+ - a close-up photo of the flawless {}.
63
+ - a black and white photo of the flawless {}.
64
+ - a blurry photo of the flawless {}.
65
+ - a blurry photo of a flawless {}.
66
+ - a photo of the small flawless {}.
67
+ - a photo of the large flawless {}.
68
+ - there is a flawless {} in the scene.
69
+ - this is one flawless {} in the scene.
70
+ - a photo of the {} without flaw.
71
+ - a bright photo of the {} without flaw.
72
+ - a dark photo of the {} without flaw.
73
+ - a close-up photo of the {} without flaw.
74
+ - a black and white photo of the {} without flaw.
75
+ - a blurry photo of the {} without flaw.
76
+ - a blurry photo of a {} without flaw.
77
+ - a photo of the small {} without flaw.
78
+ - a photo of the large {} without flaw.
79
+ - there is a {} without flaw in the scene.
80
+ - this is one {} without flaw in the scene.
81
+ - a photo of the {} without defect.
82
+ - a bright photo of the {} without defect.
83
+ - a dark photo of the {} without defect.
84
+ - a close-up photo of the {} without defect.
85
+ - a black and white photo of the {} without defect.
86
+ - a blurry photo of the {} without defect.
87
+ - a blurry photo of a {} without defect.
88
+ - a photo of the small {} without defect.
89
+ - a photo of the large {} without defect.
90
+ - there is a {} without defect in the scene.
91
+ - this is one {} without defect in the scene.
SOWA/configs/prompt/template.yaml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ template:
2
+ normal:
3
+ - '{}'
4
+ - 'flawless {}'
5
+ - 'perfect {}'
6
+ - 'unblemished {}'
7
+ - '{} without flaw'
8
+ - '{} without defect'
9
+ - '{} without damage'
10
+ abnormal:
11
+ - 'damaged {}'
12
+ - 'broken {}'
13
+ - '{} with flaw'
14
+ - '{} with defect'
15
+ - '{} with damage'
16
+ templates:
17
+ - 'a bad photo of a {}.'
18
+ - 'a low resolution photo of the {}.'
19
+ - 'a bad photo of the {}.'
20
+ - 'a cropped photo of the {}.'
21
+ - 'a bright photo of a {}.'
22
+ - 'a dark photo of the {}.'
23
+ - 'a photo of my {}.'
24
+ - 'a photo of the cool {}.'
25
+ - 'a close-up photo of a {}.'
26
+ - 'a black and white photo of the {}.'
27
+ - 'a bright photo of the {}.'
28
+ - 'a cropped photo of a {}.'
29
+ - 'a jpeg corrupted photo of a {}.'
30
+ - 'a blurry photo of the {}.'
31
+ - 'a photo of the {}.'
32
+ - 'a good photo of the {}.'
33
+ - 'a photo of one {}.'
34
+ - 'a close-up photo of the {}.'
35
+ - 'a photo of a {}.'
36
+ - 'a low resolution photo of a {}.'
37
+ - 'a photo of a large {}.'
38
+ - 'a blurry photo of a {}.'
39
+ - 'a jpeg corrupted photo of the {}.'
40
+ - 'a good photo of a {}.'
41
+ - 'a photo of the small {}.'
42
+ - 'a photo of the large {}.'
43
+ - 'a black and white photo of a {}.'
44
+ - 'a dark photo of a {}.'
45
+ - 'a photo of a cool {}.'
46
+ - 'a photo of a small {}.'
47
+ - 'there is a {} in the scene.'
48
+ - 'there is the {} in the scene.'
49
+ - 'this is a {} in the scene.'
50
+ - 'this is the {} in the scene.'
51
+ - 'this is one {} in the scene.'
SOWA/configs/train.yaml ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # specify here default configuration
4
+ # order of defaults determines the order in which configs override each other
5
+ defaults:
6
+ - _self_
7
+ - data: anomaly_clip
8
+ - model: anomaly_clip
9
+ - callbacks: default
10
+ - logger: many_loggers # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
11
+ - trainer: default
12
+ - paths: default
13
+ - extras: default
14
+ - hydra: default
15
+
16
+ # information of object and prompt template
17
+ - prompt: default
18
+
19
+ # experiment configs allow for version control of specific hyperparameters
20
+ # e.g. best hyperparameters for given model and datamodule
21
+ - experiment: null
22
+
23
+ # config for hyperparameter optimization
24
+ - hparams_search: null
25
+
26
+ # optional local config for machine/user specific settings
27
+ # it's optional since it doesn't need to exist and is excluded from version control
28
+ - optional local: default
29
+
30
+ # debugging config (enable through command line, e.g. `python train.py debug=default)
31
+ - debug: null
32
+
33
+ # task name, determines output directory path
34
+ task_name: "train"
35
+
36
+ # tags to help you identify your experiments
37
+ # you can overwrite this in experiment configs
38
+ # overwrite from command line with `python train.py tags="[first_tag, second_tag]"`
39
+ tags: ["dev"]
40
+
41
+ # set False to skip model training
42
+ train: True
43
+
44
+ # evaluate on test set, using best model weights achieved during training
45
+ # lightning chooses best weights based on the metric specified in checkpoint callback
46
+ test: True
47
+
48
+ # simply provide checkpoint path to resume training
49
+ ckpt_path: null
50
+
51
+ # seed for random number generators in pytorch, numpy and python.random
52
+ seed: 2025