AleksanderObuchowski
commited on
Commit
•
5ceacbc
1
Parent(s):
9ee70d2
Add files using upload-large-folder tool
Browse files- .idea/.gitignore +8 -0
- .idea/inspectionProfiles/profiles_settings.xml +6 -0
- .idea/misc.xml +7 -0
- .idea/modules.xml +8 -0
- .idea/vcs.xml +6 -0
- .idea/workspace.xml +204 -0
- 2024.09.27/config.yaml +211 -0
- 2024.09.27/language_model/clip_tokenizer_4.16.2/merges.txt +0 -0
- 2024.09.27/language_model/clip_tokenizer_4.16.2/special_tokens_map.json +27 -0
- 2024.09.27/language_model/clip_tokenizer_4.16.2/tokenizer_config.json +38 -0
- 2024.09.27/language_model/clip_tokenizer_4.16.2/vocab.json +0 -0
- MedImageInsight/Distributed/Utils.py +344 -0
- MedImageInsight/Distributed/__init__.py +6 -0
- MedImageInsight/ImageDataLoader/__init__.py +8 -0
- MedImageInsight/ImageDataLoader/blob_storage.py +244 -0
- MedImageInsight/ImageDataLoader/build.py +260 -0
- MedImageInsight/ImageDataLoader/constants.py +85 -0
- MedImageInsight/ImageDataLoader/languages/__init__.py +0 -0
- MedImageInsight/ImageDataLoader/languages/prompt_engineering.py +101 -0
- MedImageInsight/ImageDataLoader/transforms/__init__.py +1 -0
- MedImageInsight/ImageDataLoader/transforms/autoaugment.py +447 -0
- MedImageInsight/ImageDataLoader/transforms/build.py +261 -0
- MedImageInsight/ImageDataLoader/transforms/threeaugment.py +54 -0
- MedImageInsight/ImageDataLoader/tsv.py +351 -0
- MedImageInsight/ImageDataLoader/tsv_file.py +290 -0
- MedImageInsight/ImageDataLoader/zipdata.py +98 -0
- MedImageInsight/ImageEncoder/__init__.py +8 -0
- MedImageInsight/ImageEncoder/build.py +13 -0
- MedImageInsight/ImageEncoder/coswin.py +779 -0
- MedImageInsight/ImageEncoder/davit_v1.py +727 -0
- MedImageInsight/ImageEncoder/registry.py +18 -0
- MedImageInsight/LangEncoder/__init__.py +13 -0
- MedImageInsight/LangEncoder/build.py +108 -0
- MedImageInsight/LangEncoder/registry.py +18 -0
- MedImageInsight/LangEncoder/transformer.py +210 -0
- MedImageInsight/UniCLModel.py +293 -0
- MedImageInsight/Utils/Arguments.py +134 -0
- MedImageInsight/Utils/GeneraUtils.py +263 -0
- MedImageInsight/Utils/GlobalExceptHook.py +61 -0
- MedImageInsight/Utils/MPIAdapter.py +147 -0
- MedImageInsight/Utils/Utils.py +141 -0
- MedImageInsight/Utils/__init__.py +7 -0
- MedImageInsight/__init__.py +9 -0
.idea/.gitignore
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Default ignored files
|
2 |
+
/shelf/
|
3 |
+
/workspace.xml
|
4 |
+
# Editor-based HTTP Client requests
|
5 |
+
/httpRequests/
|
6 |
+
# Datasource local storage ignored files
|
7 |
+
/dataSources/
|
8 |
+
/dataSources.local.xml
|
.idea/inspectionProfiles/profiles_settings.xml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<component name="InspectionProjectProfileManager">
|
2 |
+
<settings>
|
3 |
+
<option name="USE_PROJECT_PROFILE" value="false" />
|
4 |
+
<version value="1.0" />
|
5 |
+
</settings>
|
6 |
+
</component>
|
.idea/misc.xml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="Black">
|
4 |
+
<option name="sdkName" value="Python 3.12" />
|
5 |
+
</component>
|
6 |
+
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8 virtualenv at ~/medatlas/.venv" project-jdk-type="Python SDK" />
|
7 |
+
</project>
|
.idea/modules.xml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="ProjectModuleManager">
|
4 |
+
<modules>
|
5 |
+
<module fileurl="file://$PROJECT_DIR$/.idea/medatlas.iml" filepath="$PROJECT_DIR$/.idea/medatlas.iml" />
|
6 |
+
</modules>
|
7 |
+
</component>
|
8 |
+
</project>
|
.idea/vcs.xml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="VcsDirectoryMappings">
|
4 |
+
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
5 |
+
</component>
|
6 |
+
</project>
|
.idea/workspace.xml
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="AutoImportSettings">
|
4 |
+
<option name="autoReloadType" value="SELECTIVE" />
|
5 |
+
</component>
|
6 |
+
<component name="ChangeListManager">
|
7 |
+
<list default="true" id="9ec92c76-0e74-4c49-9687-c62749296b88" name="Changes" comment="" />
|
8 |
+
<option name="SHOW_DIALOG" value="false" />
|
9 |
+
<option name="HIGHLIGHT_CONFLICTS" value="true" />
|
10 |
+
<option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
|
11 |
+
<option name="LAST_RESOLUTION" value="IGNORE" />
|
12 |
+
</component>
|
13 |
+
<component name="FileTemplateManagerImpl">
|
14 |
+
<option name="RECENT_TEMPLATES">
|
15 |
+
<list>
|
16 |
+
<option value="Python Script" />
|
17 |
+
</list>
|
18 |
+
</option>
|
19 |
+
</component>
|
20 |
+
<component name="FlaskConsoleOptions" custom-start-script="import sys; print('Python %s on %s' % (sys.version, sys.platform)); sys.path.extend([WORKING_DIR_AND_PYTHON_PATHS]) from flask.cli import ScriptInfo, NoAppException for module in ["main.py", "wsgi.py", "app.py"]: try: locals().update(ScriptInfo(app_import_path=module, create_app=None).load_app().make_shell_context()); print("\nFlask App: %s" % app.import_name); break except NoAppException: pass">
|
21 |
+
<envs>
|
22 |
+
<env key="FLASK_APP" value="app" />
|
23 |
+
</envs>
|
24 |
+
<option name="myCustomStartScript" value="import sys; print('Python %s on %s' % (sys.version, sys.platform)); sys.path.extend([WORKING_DIR_AND_PYTHON_PATHS]) from flask.cli import ScriptInfo, NoAppException for module in ["main.py", "wsgi.py", "app.py"]: try: locals().update(ScriptInfo(app_import_path=module, create_app=None).load_app().make_shell_context()); print("\nFlask App: %s" % app.import_name); break except NoAppException: pass" />
|
25 |
+
<option name="myEnvs">
|
26 |
+
<map>
|
27 |
+
<entry key="FLASK_APP" value="app" />
|
28 |
+
</map>
|
29 |
+
</option>
|
30 |
+
</component>
|
31 |
+
<component name="Git.Settings">
|
32 |
+
<option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
|
33 |
+
</component>
|
34 |
+
<component name="HighlightingSettingsPerFile">
|
35 |
+
<setting file="file://$PROJECT_DIR$/.venv/lib/python3.8/site-packages/safetensors/torch.py" root0="SKIP_INSPECTION" />
|
36 |
+
</component>
|
37 |
+
<component name="ProjectColorInfo">{
|
38 |
+
"associatedIndex": 3
|
39 |
+
}</component>
|
40 |
+
<component name="ProjectId" id="2nytZGYw1NHCwZYyKjVZHmbmsFp" />
|
41 |
+
<component name="ProjectLevelVcsManager">
|
42 |
+
<ConfirmationsSetting value="1" id="Add" />
|
43 |
+
</component>
|
44 |
+
<component name="ProjectViewState">
|
45 |
+
<option name="hideEmptyMiddlePackages" value="true" />
|
46 |
+
<option name="showLibraryContents" value="true" />
|
47 |
+
</component>
|
48 |
+
<component name="PropertiesComponent"><![CDATA[{
|
49 |
+
"keyToString": {
|
50 |
+
"Python.example.executor": "Debug",
|
51 |
+
"Python.explainability.executor": "Run",
|
52 |
+
"Python.main.executor": "Run",
|
53 |
+
"Python.medimageinsightmodel.executor": "Run",
|
54 |
+
"Python.push_to_hub.executor": "Run",
|
55 |
+
"RunOnceActivity.ShowReadmeOnStart": "true",
|
56 |
+
"RunOnceActivity.git.unshallow": "true",
|
57 |
+
"git-widget-placeholder": "master",
|
58 |
+
"last_opened_file_path": "/home/olek/medatlas/2024.09.27/vision_model",
|
59 |
+
"node.js.detected.package.eslint": "true",
|
60 |
+
"node.js.detected.package.tslint": "true",
|
61 |
+
"node.js.selected.package.eslint": "(autodetect)",
|
62 |
+
"node.js.selected.package.tslint": "(autodetect)",
|
63 |
+
"nodejs_package_manager_path": "npm",
|
64 |
+
"vue.rearranger.settings.migration": "true"
|
65 |
+
}
|
66 |
+
}]]></component>
|
67 |
+
<component name="RdControllerToolWindowsLayoutState" isNewUi="true">
|
68 |
+
<layout>
|
69 |
+
<window_info id="Bookmarks" side_tool="true" />
|
70 |
+
<window_info id="Merge Requests" />
|
71 |
+
<window_info id="Commit_Guest" show_stripe_button="false" />
|
72 |
+
<window_info id="Pull Requests" />
|
73 |
+
<window_info id="Learn" />
|
74 |
+
<window_info active="true" content_ui="combo" id="Project" order="0" visible="true" weight="0.16933593" />
|
75 |
+
<window_info id="Commit" order="1" weight="0.25" />
|
76 |
+
<window_info id="Structure" order="2" side_tool="true" weight="0.25" />
|
77 |
+
<window_info anchor="bottom" id="Database Changes" />
|
78 |
+
<window_info anchor="bottom" id="TypeScript" />
|
79 |
+
<window_info anchor="bottom" id="TODO" />
|
80 |
+
<window_info anchor="bottom" id="File Transfer" />
|
81 |
+
<window_info anchor="bottom" id="Version Control" order="0" />
|
82 |
+
<window_info anchor="bottom" id="Problems" order="1" />
|
83 |
+
<window_info anchor="bottom" id="Problems View" order="2" />
|
84 |
+
<window_info active="true" anchor="bottom" id="Terminal" order="3" visible="true" weight="0.3795139" />
|
85 |
+
<window_info anchor="bottom" id="Services" order="4" />
|
86 |
+
<window_info anchor="bottom" id="Python Packages" order="5" weight="0.1" />
|
87 |
+
<window_info anchor="bottom" id="Debug" order="6" weight="0.29618055" />
|
88 |
+
<window_info anchor="bottom" id="Python Console" order="7" weight="0.1" />
|
89 |
+
<window_info anchor="bottom" id="HfCacheToolWindow" order="8" weight="0.44131944" />
|
90 |
+
<window_info anchor="bottom" id="Run" order="9" weight="0.6490499" />
|
91 |
+
<window_info anchor="bottom" id="Find" order="10" weight="0.33020833" />
|
92 |
+
<window_info anchor="right" id="Endpoints" />
|
93 |
+
<window_info anchor="right" id="Coverage" side_tool="true" />
|
94 |
+
<window_info anchor="right" id="SciView" />
|
95 |
+
<window_info anchor="right" content_ui="combo" id="Notifications" order="0" weight="0.25" />
|
96 |
+
<window_info anchor="right" id="AIAssistant" order="1" weight="0.25" />
|
97 |
+
<window_info anchor="right" id="Database" order="2" weight="0.25" />
|
98 |
+
<window_info anchor="right" id="Gradle" order="3" weight="0.25" />
|
99 |
+
<window_info anchor="right" id="Maven" order="4" weight="0.25" />
|
100 |
+
<window_info anchor="right" id="CodeGPT" order="5" weight="0.30566406" />
|
101 |
+
<window_info anchor="right" id="Plots" order="6" weight="0.1" />
|
102 |
+
</layout>
|
103 |
+
</component>
|
104 |
+
<component name="RecentsManager">
|
105 |
+
<key name="CopyFile.RECENT_KEYS">
|
106 |
+
<recent name="$PROJECT_DIR$/2024.09.27/vision_model" />
|
107 |
+
<recent name="$PROJECT_DIR$" />
|
108 |
+
</key>
|
109 |
+
<key name="MoveFile.RECENT_KEYS">
|
110 |
+
<recent name="$PROJECT_DIR$" />
|
111 |
+
<recent name="$PROJECT_DIR$/MedImageInsight/ImageEncoder" />
|
112 |
+
<recent name="$PROJECT_DIR$/MedImageInsights" />
|
113 |
+
</key>
|
114 |
+
</component>
|
115 |
+
<component name="RunManager">
|
116 |
+
<configuration name="main" type="PythonConfigurationType" factoryName="Python" nameIsGenerated="true">
|
117 |
+
<module name="medatlas" />
|
118 |
+
<option name="ENV_FILES" value="" />
|
119 |
+
<option name="INTERPRETER_OPTIONS" value="" />
|
120 |
+
<option name="PARENT_ENVS" value="true" />
|
121 |
+
<envs>
|
122 |
+
<env name="PYTHONUNBUFFERED" value="1" />
|
123 |
+
</envs>
|
124 |
+
<option name="SDK_HOME" value="" />
|
125 |
+
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
|
126 |
+
<option name="IS_MODULE_SDK" value="true" />
|
127 |
+
<option name="ADD_CONTENT_ROOTS" value="true" />
|
128 |
+
<option name="ADD_SOURCE_ROOTS" value="true" />
|
129 |
+
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
|
130 |
+
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/main.py" />
|
131 |
+
<option name="PARAMETERS" value="" />
|
132 |
+
<option name="SHOW_COMMAND_LINE" value="false" />
|
133 |
+
<option name="EMULATE_TERMINAL" value="false" />
|
134 |
+
<option name="MODULE_MODE" value="false" />
|
135 |
+
<option name="REDIRECT_INPUT" value="false" />
|
136 |
+
<option name="INPUT_FILE" value="" />
|
137 |
+
<method v="2" />
|
138 |
+
</configuration>
|
139 |
+
<configuration name="medimageinsightmodel" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
|
140 |
+
<module name="medatlas" />
|
141 |
+
<option name="ENV_FILES" value="" />
|
142 |
+
<option name="INTERPRETER_OPTIONS" value="" />
|
143 |
+
<option name="PARENT_ENVS" value="true" />
|
144 |
+
<envs>
|
145 |
+
<env name="PYTHONUNBUFFERED" value="1" />
|
146 |
+
</envs>
|
147 |
+
<option name="SDK_HOME" value="" />
|
148 |
+
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
|
149 |
+
<option name="IS_MODULE_SDK" value="true" />
|
150 |
+
<option name="ADD_CONTENT_ROOTS" value="true" />
|
151 |
+
<option name="ADD_SOURCE_ROOTS" value="true" />
|
152 |
+
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
|
153 |
+
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/medimageinsightmodel.py" />
|
154 |
+
<option name="PARAMETERS" value="" />
|
155 |
+
<option name="SHOW_COMMAND_LINE" value="false" />
|
156 |
+
<option name="EMULATE_TERMINAL" value="false" />
|
157 |
+
<option name="MODULE_MODE" value="false" />
|
158 |
+
<option name="REDIRECT_INPUT" value="false" />
|
159 |
+
<option name="INPUT_FILE" value="" />
|
160 |
+
<method v="2" />
|
161 |
+
</configuration>
|
162 |
+
<recent_temporary>
|
163 |
+
<list>
|
164 |
+
<item itemvalue="Python.medimageinsightmodel" />
|
165 |
+
</list>
|
166 |
+
</recent_temporary>
|
167 |
+
</component>
|
168 |
+
<component name="SharedIndexes">
|
169 |
+
<attachedChunks>
|
170 |
+
<set>
|
171 |
+
<option value="bundled-js-predefined-d6986cc7102b-bed05e336f61-JavaScript-PY-243.21155.22" />
|
172 |
+
<option value="bundled-python-sdk-5ff8a29a62a8-ca77fbc60dd9-com.jetbrains.pycharm.pro.sharedIndexes.bundled-PY-243.21155.22" />
|
173 |
+
</set>
|
174 |
+
</attachedChunks>
|
175 |
+
</component>
|
176 |
+
<component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="application-level" UseSingleDictionary="true" transferred="true" />
|
177 |
+
<component name="TaskManager">
|
178 |
+
<task active="true" id="Default" summary="Default task">
|
179 |
+
<changelist id="9ec92c76-0e74-4c49-9687-c62749296b88" name="Changes" comment="" />
|
180 |
+
<created>1729957197525</created>
|
181 |
+
<option name="number" value="Default" />
|
182 |
+
<option name="presentableId" value="Default" />
|
183 |
+
<updated>1729957197525</updated>
|
184 |
+
<workItem from="1729957199944" duration="8141000" />
|
185 |
+
<workItem from="1729970018757" duration="142000" />
|
186 |
+
<workItem from="1729970174785" duration="25000" />
|
187 |
+
<workItem from="1729970270429" duration="53000" />
|
188 |
+
<workItem from="1729970419018" duration="9867000" />
|
189 |
+
<workItem from="1730030408588" duration="2251000" />
|
190 |
+
<workItem from="1730037237796" duration="27583000" />
|
191 |
+
</task>
|
192 |
+
<servers />
|
193 |
+
</component>
|
194 |
+
<component name="TypeScriptGeneratedFilesManager">
|
195 |
+
<option name="version" value="3" />
|
196 |
+
</component>
|
197 |
+
<component name="com.intellij.coverage.CoverageDataManagerImpl">
|
198 |
+
<SUITE FILE_PATH="coverage/medatlas$explainability.coverage" NAME="explainability Coverage Results" MODIFIED="1730155021389" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
|
199 |
+
<SUITE FILE_PATH="coverage/medatlas$push_to_hub.coverage" NAME="push_to_hub Coverage Results" MODIFIED="1730031227719" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
|
200 |
+
<SUITE FILE_PATH="coverage/medatlas$example.coverage" NAME="example Coverage Results" MODIFIED="1730041646094" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
|
201 |
+
<SUITE FILE_PATH="coverage/medatlas$main.coverage" NAME="main Coverage Results" MODIFIED="1730153590829" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
|
202 |
+
<SUITE FILE_PATH="coverage/medatlas$medimageinsightmodel.coverage" NAME="medimageinsightmodel Coverage Results" MODIFIED="1730037368621" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
|
203 |
+
</component>
|
204 |
+
</project>
|
2024.09.27/config.yaml
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
##################
|
2 |
+
# Trainer settings
|
3 |
+
##################
|
4 |
+
|
5 |
+
|
6 |
+
TASK: UniCLTask
|
7 |
+
|
8 |
+
NAME: 'Example Eval Configuration'
|
9 |
+
SAVE_TIMER_LOG: true
|
10 |
+
|
11 |
+
# TUTORIAL STEP 1: CHOOSE SAVE DIR
|
12 |
+
SAVE_DIR: ''
|
13 |
+
LOG_EVERY: 10
|
14 |
+
LOGLEVEL_OVERRIDE: INFO
|
15 |
+
LOG_GPU_MEM: true
|
16 |
+
RESUME: False
|
17 |
+
RESET_DATA_LOADER: false
|
18 |
+
|
19 |
+
FP16: true
|
20 |
+
ZERO_STAGE: 0
|
21 |
+
DEEPSPEED: false
|
22 |
+
# ZERO_STAGE: 1
|
23 |
+
AMP: PYTORCH
|
24 |
+
# USE_APEX_DDP: false
|
25 |
+
# USE_APEX_AMP: false
|
26 |
+
# USE_HIT: false
|
27 |
+
|
28 |
+
FIND_UNUSED_PARAMETERS: false
|
29 |
+
|
30 |
+
SAVE_PER_OPTIM_STEPS: 500
|
31 |
+
EVAL_PER_OPTIM_STEPS: 250
|
32 |
+
EVAL_AT_START: False
|
33 |
+
# SAVE_PER_UPDATE_NUM: -1
|
34 |
+
# EVAL_PER_UPDATE_NUM: 0 # 0: do evaluation when saving checkpoint, -1: don't do evaluation
|
35 |
+
|
36 |
+
NO_AUTO_LR_SCALING: true
|
37 |
+
GRAD_CLIPPING: 1.0 #0.07
|
38 |
+
|
39 |
+
SET_SAMPLER_EPOCH: true
|
40 |
+
|
41 |
+
DONT_LOAD_MODEL: true
|
42 |
+
|
43 |
+
user_dir: "./MainzVision" # lower case due to it is used in mainz as such
|
44 |
+
|
45 |
+
##################
|
46 |
+
# Task settings
|
47 |
+
##################
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
VERBOSE: true
|
52 |
+
WORKERS: 6
|
53 |
+
PIN_MEMORY: true
|
54 |
+
IMAGE_ENCODER:
|
55 |
+
NAME: davit_v1
|
56 |
+
NUM_CLASSES: 0
|
57 |
+
#IMAGE_SIZE: [384, 384]
|
58 |
+
IMAGE_SIZE: [480, 480]
|
59 |
+
LOAD_PRETRAINED: true
|
60 |
+
PRETRAINED: ''
|
61 |
+
PRETRAINED_LAYERS: '*'
|
62 |
+
IMAGE_MEAN: [0.485, 0.456, 0.406]
|
63 |
+
IMAGE_STD: [0.229, 0.224, 0.225]
|
64 |
+
SPEC:
|
65 |
+
DROP_RATE: 0.1
|
66 |
+
DROP_PATH_RATE: 0.2
|
67 |
+
PATCH_SIZE: [7, 3, 3, 3]
|
68 |
+
PATCH_STRIDE: [4, 2, 2, 2]
|
69 |
+
PATCH_PADDING: [3, 1, 1, 1]
|
70 |
+
PATCH_PRENORM: [false, true, true, true]
|
71 |
+
DIM_EMBED: [256, 512, 1024, 2048]
|
72 |
+
NUM_HEADS: [8, 16, 32, 64]
|
73 |
+
NUM_GROUPS: [8, 16, 32, 64]
|
74 |
+
DEPTHS: [1, 1, 9, 1]
|
75 |
+
WINDOW_SIZE: 12
|
76 |
+
ENABLE_CHECKPOINT: true
|
77 |
+
|
78 |
+
LANG_ENCODER:
|
79 |
+
NAME: transformer
|
80 |
+
LOAD_PRETRAINED: false
|
81 |
+
PRETRAINED: ''
|
82 |
+
PRETRAINED_LAYERS: '*'
|
83 |
+
TOKENIZER: clip
|
84 |
+
CONTEXT_LENGTH: 77
|
85 |
+
WIDTH: 1024
|
86 |
+
HEADS: 16
|
87 |
+
LAYERS: 16
|
88 |
+
AUTOGRESSIVE: false
|
89 |
+
|
90 |
+
UNICL_MODEL:
|
91 |
+
DIM_PROJECTION: 1024
|
92 |
+
GATHER_TENSORS: true
|
93 |
+
LOAD_PRETRAINED: true
|
94 |
+
|
95 |
+
# TUTORIAL STEP 2: CHOOSE MODEL PATH
|
96 |
+
PRETRAINED: ''
|
97 |
+
|
98 |
+
PRETRAINED_LAYERS: '*'
|
99 |
+
|
100 |
+
AUG:
|
101 |
+
MIXUP_PROB: 0.0
|
102 |
+
MIXUP: 0.8
|
103 |
+
MIXCUT: 1.0
|
104 |
+
MIXCUT_MINMAX: []
|
105 |
+
MIXUP_SWITCH_PROB: 0.5
|
106 |
+
MIXUP_MODE: 'batch'
|
107 |
+
SCALE: [0.8, 1.0]
|
108 |
+
RATIO: [0.75, 1.3333333]
|
109 |
+
INTERPOLATION: 'bicubic'
|
110 |
+
TORCHVISION_AUG:
|
111 |
+
AUTO_AUGMENT: ta_wide
|
112 |
+
RE_PROB: 0.25
|
113 |
+
HFLIP: 0.0
|
114 |
+
VFLIP: 0.0
|
115 |
+
|
116 |
+
LOSS:
|
117 |
+
LOSS: UniCL
|
118 |
+
DATASET:
|
119 |
+
DATASET: 'image_text_pairs_v2'
|
120 |
+
TEXT_FORMAT: 'json'
|
121 |
+
ROOT: ''
|
122 |
+
TRAIN_SET: 'mimic_cxr_v2-chestxray14-chexpertv4-irma2009_v2-rsnaboneage-mura-bingmedicalfewshot'
|
123 |
+
DATA_FORMAT: 'tsv'
|
124 |
+
SAMPLER: 'default'
|
125 |
+
LOADER: 'default'
|
126 |
+
TOKEN_FILE: ''
|
127 |
+
#PROMPT_ENGINEERING: False
|
128 |
+
#SAMPLER: 'chunk'
|
129 |
+
#LOADER: 'azcopy'
|
130 |
+
#TOKEN_FILE: 'cliptrainingpairs.txt'
|
131 |
+
#TEST_SET: 'MarsAtrain'
|
132 |
+
|
133 |
+
|
134 |
+
# TUTORIAL STEP 3: CHOOSE ALL BELOW EVAL PATHS (THESE ARE ALL OPTIONAL EXTRA EVALS)
|
135 |
+
# Note how one eval is ZIP format and the other is TSV format.
|
136 |
+
|
137 |
+
|
138 |
+
|
139 |
+
|
140 |
+
EVALDATASET_LTCXR_S100_N100_TEXT_CLASSIFIER:
|
141 |
+
TEXT_FORMAT: json
|
142 |
+
FORMAT: 'zip'
|
143 |
+
SPLIT: 'NIH-CXR-LT'
|
144 |
+
ZIP_FILE: ''
|
145 |
+
ZIP_MAP_FILE: ''
|
146 |
+
LABEL_FILE: ''
|
147 |
+
IMAGE_TSV: ''
|
148 |
+
TEXT_TSV: ''
|
149 |
+
CWEIGHT_FILE: ''
|
150 |
+
ZS_MODE: 2
|
151 |
+
ZS_WEIGHT: 1.0
|
152 |
+
KNN: 100
|
153 |
+
# CLASSIFICATION_SETS: ['NIH-CXR-LT']
|
154 |
+
# NUM_CLASSES: [20]
|
155 |
+
|
156 |
+
|
157 |
+
|
158 |
+
|
159 |
+
# TUTORIAL STEP 4: SET THE DEFAULT ZEROSHOT EVAL (THIS IS THE MANDATORY EVAL)
|
160 |
+
|
161 |
+
ZEROSHOT_EVAL_DATASET:
|
162 |
+
FORMAT: 'zip'
|
163 |
+
SPLIT: 'NIH-CXR-LT'
|
164 |
+
ZIP_FILE: ''
|
165 |
+
ZIP_MAP_FILE: ''
|
166 |
+
LABEL_FILE: ''
|
167 |
+
|
168 |
+
|
169 |
+
|
170 |
+
EVALUATION_SPLITS: ['cls-zeroshot-eval']
|
171 |
+
TEST:
|
172 |
+
BATCH_SIZE_PER_GPU: 8
|
173 |
+
MODEL_FILE: ''
|
174 |
+
CENTER_CROP: false
|
175 |
+
TRAIN:
|
176 |
+
BATCH_SIZE_TOTAL: 1024
|
177 |
+
BATCH_SIZE_PER_GPU: 16
|
178 |
+
|
179 |
+
SHUFFLE: true
|
180 |
+
|
181 |
+
WEIGHT_SMOOTHING:
|
182 |
+
decay: 0.999
|
183 |
+
use_cpu: False
|
184 |
+
eval_smoothed_weight: True
|
185 |
+
|
186 |
+
START_LEARNING_RATE: 0.00001
|
187 |
+
# MAX_NUM_EPOCHS: 2
|
188 |
+
MAX_NUM_EPOCHS: 100
|
189 |
+
OPTIMIZER: AdamW # adam
|
190 |
+
OPTIMIZER_PARAMS:
|
191 |
+
weight_decay: 0.2 #0.1
|
192 |
+
CUSTOMIZED_PARAMS_CONF:
|
193 |
+
NO_WEIGHT_DECAY_MODULES: ['dw', 'norm']
|
194 |
+
WEIGHT_DECAY_PATTERNS:
|
195 |
+
"\\.bias$": 0.0
|
196 |
+
"logit_scale": 0.0
|
197 |
+
"positional_embedding": 0.0
|
198 |
+
"token_embedding": 0.0
|
199 |
+
|
200 |
+
|
201 |
+
|
202 |
+
LR_SCHEDULER: TimmScheduler
|
203 |
+
LR_SCHEDULER_PARAMS:
|
204 |
+
sched: cosine
|
205 |
+
warmup_steps: 5
|
206 |
+
warmup_lr: 0.000000001
|
207 |
+
min_lr: 0.000000001
|
208 |
+
|
209 |
+
# GRADIENT_ACCUMULATE_STEP will be updated by:
|
210 |
+
# BATCH_SIZE_TOTAL // (BATCH_SIZE_PER_GPU * world_size)
|
211 |
+
GRADIENT_ACCUMULATE_STEP: -1
|
2024.09.27/language_model/clip_tokenizer_4.16.2/merges.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
2024.09.27/language_model/clip_tokenizer_4.16.2/special_tokens_map.json
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": {
|
3 |
+
"content": "<|startoftext|>",
|
4 |
+
"single_word": false,
|
5 |
+
"lstrip": false,
|
6 |
+
"rstrip": false,
|
7 |
+
"normalized": true,
|
8 |
+
"special": false
|
9 |
+
},
|
10 |
+
"eos_token": {
|
11 |
+
"content": "<|endoftext|>",
|
12 |
+
"single_word": false,
|
13 |
+
"lstrip": false,
|
14 |
+
"rstrip": false,
|
15 |
+
"normalized": true,
|
16 |
+
"special": false
|
17 |
+
},
|
18 |
+
"unk_token": {
|
19 |
+
"content": "<|endoftext|>",
|
20 |
+
"single_word": false,
|
21 |
+
"lstrip": false,
|
22 |
+
"rstrip": false,
|
23 |
+
"normalized": true,
|
24 |
+
"special": false
|
25 |
+
},
|
26 |
+
"pad_token": "<|endoftext|>"
|
27 |
+
}
|
2024.09.27/language_model/clip_tokenizer_4.16.2/tokenizer_config.json
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"errors": "replace",
|
3 |
+
"unk_token": {
|
4 |
+
"content": "<|endoftext|>",
|
5 |
+
"single_word": false,
|
6 |
+
"lstrip": false,
|
7 |
+
"rstrip": false,
|
8 |
+
"normalized": true,
|
9 |
+
"special": false,
|
10 |
+
"__type": "AddedToken"
|
11 |
+
},
|
12 |
+
"bos_token": {
|
13 |
+
"content": "<|startoftext|>",
|
14 |
+
"single_word": false,
|
15 |
+
"lstrip": false,
|
16 |
+
"rstrip": false,
|
17 |
+
"normalized": true,
|
18 |
+
"special": false,
|
19 |
+
"__type": "AddedToken"
|
20 |
+
},
|
21 |
+
"eos_token": {
|
22 |
+
"content": "<|endoftext|>",
|
23 |
+
"single_word": false,
|
24 |
+
"lstrip": false,
|
25 |
+
"rstrip": false,
|
26 |
+
"normalized": true,
|
27 |
+
"special": false,
|
28 |
+
"__type": "AddedToken"
|
29 |
+
},
|
30 |
+
"pad_token": "<|endoftext|>",
|
31 |
+
"add_prefix_space": false,
|
32 |
+
"do_lower_case": true,
|
33 |
+
"name_or_path": "openai/clip-vit-base-patch32",
|
34 |
+
"model_max_length": 77,
|
35 |
+
"special_tokens_map_file": "/home/ncodella/.cache/huggingface/transformers/18a566598f286c9139f88160c99f84eec492a26bd22738fa9cb44d5b7e0a5c76.cce1206abbad28826f000510f22f354e53e66a97f7c23745a7dfe27609cc07f5",
|
36 |
+
"tokenizer_file": "/home/ncodella/.cache/huggingface/transformers/7811def0c53be25ba790cb67ac785669b508a8d1cf8c912b8ac046c5f08aee68.20428ea8b6821af2719b760af844a371643ff49f255c73285f6ea448e15597fe",
|
37 |
+
"tokenizer_class": "CLIPTokenizer"
|
38 |
+
}
|
2024.09.27/language_model/clip_tokenizer_4.16.2/vocab.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
MedImageInsight/Distributed/Utils.py
ADDED
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import pickle
|
4 |
+
import requests
|
5 |
+
import tenacity
|
6 |
+
import time
|
7 |
+
import shutil
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.distributed as dist
|
11 |
+
|
12 |
+
from PIL import Image
|
13 |
+
from torchvision.utils import make_grid
|
14 |
+
|
15 |
+
|
16 |
+
from fvcore.nn import FlopCountAnalysis
|
17 |
+
from fvcore.nn import flop_count_table
|
18 |
+
from fvcore.nn import flop_count_str
|
19 |
+
|
20 |
+
logger = logging.getLogger(__name__)
|
21 |
+
|
22 |
+
NORM_MODULES = [
|
23 |
+
torch.nn.BatchNorm1d,
|
24 |
+
torch.nn.BatchNorm2d,
|
25 |
+
torch.nn.BatchNorm3d,
|
26 |
+
torch.nn.SyncBatchNorm,
|
27 |
+
# NaiveSyncBatchNorm inherits from BatchNorm2d
|
28 |
+
torch.nn.GroupNorm,
|
29 |
+
torch.nn.InstanceNorm1d,
|
30 |
+
torch.nn.InstanceNorm2d,
|
31 |
+
torch.nn.InstanceNorm3d,
|
32 |
+
torch.nn.LayerNorm,
|
33 |
+
torch.nn.LocalResponseNorm,
|
34 |
+
]
|
35 |
+
|
36 |
+
|
37 |
+
def register_norm_module(cls):
|
38 |
+
NORM_MODULES.append(cls)
|
39 |
+
|
40 |
+
return cls
|
41 |
+
|
42 |
+
|
43 |
+
def is_main_process():
|
44 |
+
rank = 0
|
45 |
+
if 'OMPI_COMM_WORLD_SIZE' in os.environ:
|
46 |
+
rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
|
47 |
+
|
48 |
+
return rank == 0
|
49 |
+
|
50 |
+
|
51 |
+
@torch.no_grad()
|
52 |
+
def analysis_model(model, dump_input, verbose=False):
|
53 |
+
model.eval()
|
54 |
+
flops = FlopCountAnalysis(model, dump_input)
|
55 |
+
total = flops.total()
|
56 |
+
model.train()
|
57 |
+
params_total = sum(p.numel() for p in model.parameters())
|
58 |
+
params_learned = sum(
|
59 |
+
p.numel() for p in model.parameters() if p.requires_grad
|
60 |
+
)
|
61 |
+
logger.info(f"flop count table:\n {flop_count_table(flops)}")
|
62 |
+
if verbose:
|
63 |
+
logger.info(f"flop count str:\n {flop_count_str(flops)}")
|
64 |
+
logger.info(f" Total flops: {total / 1000 / 1000:.3f}M,")
|
65 |
+
logger.info(f" Total params: {params_total / 1000 / 1000:.3f}M,")
|
66 |
+
logger.info(f" Learned params: {params_learned / 1000 / 1000:.3f}M")
|
67 |
+
|
68 |
+
return total, flop_count_table(flops), flop_count_str(flops)
|
69 |
+
|
70 |
+
|
71 |
+
def gather_tensors(tensor):
|
72 |
+
"""
|
73 |
+
Performs all_gather operation on the provided tensors.
|
74 |
+
*** Warning ***: torch.distributed.all_gather has no gradient.
|
75 |
+
"""
|
76 |
+
tensors_gather = [
|
77 |
+
torch.ones_like(tensor)
|
78 |
+
for _ in range(int(os.environ['WORLD_SIZE']))
|
79 |
+
]
|
80 |
+
|
81 |
+
dist.all_gather(tensors_gather, tensor, async_op=False)
|
82 |
+
# need to do this to restore propagation of the gradients
|
83 |
+
tensors_gather[int(os.environ['RANK'])] = tensor
|
84 |
+
output = torch.cat(tensors_gather, dim=0)
|
85 |
+
return output
|
86 |
+
|
87 |
+
|
88 |
+
def is_valid_url(url):
|
89 |
+
try:
|
90 |
+
from urllib import parse
|
91 |
+
return parse.urlparse(str(url)).scheme != ''
|
92 |
+
except Exception:
|
93 |
+
return False
|
94 |
+
|
95 |
+
|
96 |
+
@tenacity.retry(stop=tenacity.stop_after_attempt(3))
|
97 |
+
def download_file(url, filepath):
|
98 |
+
logger.info(f'Downloading from {url} to {filepath.absolute()}.')
|
99 |
+
with requests.get(url, stream=True, allow_redirects=True, timeout=60) as r:
|
100 |
+
if r.status_code > 200:
|
101 |
+
raise RuntimeError(f'Failed in downloading from {url}, status code {r.status_code}.')
|
102 |
+
|
103 |
+
with open(filepath, 'wb') as f:
|
104 |
+
shutil.copyfileobj(r.raw, f, length=4194304)
|
105 |
+
|
106 |
+
|
107 |
+
class DistributionGridFactory:
|
108 |
+
"""
|
109 |
+
DistributionGrid Factory for helping create, cache and share the DistributionGrid based on the usage.
|
110 |
+
The DistributionGrid con be shared cross modules only the when this 3 conditions:
|
111 |
+
1. expert parallel group size
|
112 |
+
2. expert parallel replica group size,
|
113 |
+
are the same.
|
114 |
+
"""
|
115 |
+
distribution_grid_cache = {}
|
116 |
+
|
117 |
+
@classmethod
|
118 |
+
def get_distribution_grid(cls,
|
119 |
+
expert_parallel_group_size,
|
120 |
+
expert_parallel_replica_group_size,
|
121 |
+
ddp_type):
|
122 |
+
"""
|
123 |
+
Get the DistributionGrid by the conditions.
|
124 |
+
Args:
|
125 |
+
expert_parallel_group_size: expert parallel group size
|
126 |
+
expert_parallel_replica_group_size: expert parallel replica group size
|
127 |
+
ddp_type: distributed data parallel type. "DDP" of the recipe, only allow ddp_type is "MAINZ", "OSS" or "ShardedDDP".
|
128 |
+
|
129 |
+
Returns: new created DistributionGrid or shared DistributionGrid.
|
130 |
+
|
131 |
+
Notes: Currently get_distribution_grid only support "DDP" is "MAINZ", "OSS" or "ShardedDDP".
|
132 |
+
"""
|
133 |
+
# TODO: Support cases that "DDP" is "FSDP".
|
134 |
+
# For "FSDP", we use the DG of self.opt['fsdp_expert_grid'] which is initialize in DistributedTrainer directly.
|
135 |
+
ddp_type = ddp_type.upper()
|
136 |
+
assert ddp_type in ["MAINZ", "OSS", "SHARDEDDDP"], f'DistributionGrid Factory only support "DDP" is "MAINZ",' \
|
137 |
+
f' "OSS" or "ShardedDDP".' \
|
138 |
+
f' But currently "DDP" is {ddp_type}'
|
139 |
+
|
140 |
+
cached_distributed_grid = cls.distribution_grid_cache.get(
|
141 |
+
(expert_parallel_group_size, expert_parallel_replica_group_size), None)
|
142 |
+
|
143 |
+
if cached_distributed_grid is not None:
|
144 |
+
return cached_distributed_grid
|
145 |
+
else:
|
146 |
+
from ort_moe.grids import DistributionGrid
|
147 |
+
distributed_grid = DistributionGrid(expert_parallel_group_size=expert_parallel_group_size,
|
148 |
+
expert_parallel_replica_group_size=expert_parallel_replica_group_size)
|
149 |
+
|
150 |
+
cls.distribution_grid_cache[expert_parallel_group_size,
|
151 |
+
expert_parallel_replica_group_size] = distributed_grid
|
152 |
+
return distributed_grid
|
153 |
+
|
154 |
+
|
155 |
+
def get_world_size():
|
156 |
+
if not dist.is_available():
|
157 |
+
return 1
|
158 |
+
if not dist.is_initialized():
|
159 |
+
return 1
|
160 |
+
return dist.get_world_size()
|
161 |
+
|
162 |
+
|
163 |
+
def get_rank():
|
164 |
+
if not dist.is_available():
|
165 |
+
return 0
|
166 |
+
if not dist.is_initialized():
|
167 |
+
return 0
|
168 |
+
return dist.get_rank()
|
169 |
+
|
170 |
+
|
171 |
+
def synchronize():
|
172 |
+
"""
|
173 |
+
Helper function to synchronize (barrier) among all processes when
|
174 |
+
using distributed training
|
175 |
+
"""
|
176 |
+
if not dist.is_available():
|
177 |
+
return
|
178 |
+
if not dist.is_initialized():
|
179 |
+
return
|
180 |
+
world_size = dist.get_world_size()
|
181 |
+
rank = dist.get_rank()
|
182 |
+
if world_size == 1:
|
183 |
+
return
|
184 |
+
|
185 |
+
def _send_and_wait(r):
|
186 |
+
if rank == r:
|
187 |
+
tensor = torch.tensor(0, device="cuda")
|
188 |
+
else:
|
189 |
+
tensor = torch.tensor(1, device="cuda")
|
190 |
+
dist.broadcast(tensor, r)
|
191 |
+
while tensor.item() == 1:
|
192 |
+
time.sleep(1)
|
193 |
+
|
194 |
+
_send_and_wait(0)
|
195 |
+
# now sync on the main process
|
196 |
+
_send_and_wait(1)
|
197 |
+
|
198 |
+
|
199 |
+
def all_gather(data):
|
200 |
+
"""
|
201 |
+
Run all_gather on arbitrary picklable data (not necessarily tensors)
|
202 |
+
Args:
|
203 |
+
data: any picklable object
|
204 |
+
Returns:
|
205 |
+
list[data]: list of data gathered from each rank
|
206 |
+
"""
|
207 |
+
world_size = get_world_size()
|
208 |
+
if world_size == 1:
|
209 |
+
return [data]
|
210 |
+
|
211 |
+
# serialized to a Tensor
|
212 |
+
buffer = pickle.dumps(data)
|
213 |
+
storage = torch.ByteStorage.from_buffer(buffer)
|
214 |
+
tensor = torch.ByteTensor(storage).to("cuda")
|
215 |
+
|
216 |
+
# obtain Tensor size of each rank
|
217 |
+
local_size = torch.LongTensor([tensor.numel()]).to("cuda")
|
218 |
+
size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)]
|
219 |
+
dist.all_gather(size_list, local_size)
|
220 |
+
size_list = [int(size.item()) for size in size_list]
|
221 |
+
max_size = max(size_list)
|
222 |
+
|
223 |
+
# receiving Tensor from all ranks
|
224 |
+
# we pad the tensor because torch all_gather does not support
|
225 |
+
# gathering tensors of different shapes
|
226 |
+
tensor_list = []
|
227 |
+
for _ in size_list:
|
228 |
+
tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda"))
|
229 |
+
if local_size != max_size:
|
230 |
+
padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda")
|
231 |
+
tensor = torch.cat((tensor, padding), dim=0)
|
232 |
+
dist.all_gather(tensor_list, tensor)
|
233 |
+
|
234 |
+
data_list = []
|
235 |
+
for size, tensor in zip(size_list, tensor_list):
|
236 |
+
buffer = tensor.cpu().numpy().tobytes()[:size]
|
237 |
+
data_list.append(pickle.loads(buffer))
|
238 |
+
|
239 |
+
return data_list
|
240 |
+
|
241 |
+
|
242 |
+
def all_gather_cpu(data):
|
243 |
+
"""
|
244 |
+
Run all_gather on arbitrary picklable data (not necessarily tensors).
|
245 |
+
Args:
|
246 |
+
data: any picklable object
|
247 |
+
group: a torch process group. By default, will use a group which
|
248 |
+
contains all ranks on gloo backend.
|
249 |
+
Returns:
|
250 |
+
list[data]: list of data gathered from each rank
|
251 |
+
"""
|
252 |
+
|
253 |
+
def _get_global_gloo_group():
|
254 |
+
"""
|
255 |
+
Return a process group based on gloo backend, containing all the ranks
|
256 |
+
The result is cached.
|
257 |
+
"""
|
258 |
+
if dist.get_backend() == "nccl":
|
259 |
+
return dist.new_group(backend="gloo")
|
260 |
+
else:
|
261 |
+
return dist.group.WORLD
|
262 |
+
|
263 |
+
if get_world_size() == 1:
|
264 |
+
return [data]
|
265 |
+
group = _get_global_gloo_group() # use CPU group by default, to reduce GPU RAM usage.
|
266 |
+
world_size = dist.get_world_size(group)
|
267 |
+
if world_size == 1:
|
268 |
+
return [data]
|
269 |
+
|
270 |
+
output = [None for _ in range(world_size)]
|
271 |
+
dist.all_gather_object(output, data, group=group)
|
272 |
+
return output
|
273 |
+
|
274 |
+
|
275 |
+
def reduce_dict(input_dict, average=True):
|
276 |
+
"""
|
277 |
+
Args:
|
278 |
+
input_dict (dict): all the values will be reduced
|
279 |
+
average (bool): whether to do average or sum
|
280 |
+
Reduce the values in the dictionary from all processes so that process with rank
|
281 |
+
0 has the averaged results. Returns a dict with the same fields as
|
282 |
+
input_dict, after reduction.
|
283 |
+
"""
|
284 |
+
world_size = get_world_size()
|
285 |
+
if world_size < 2:
|
286 |
+
return input_dict
|
287 |
+
with torch.no_grad():
|
288 |
+
names = []
|
289 |
+
values = []
|
290 |
+
# sort the keys so that they are consistent across processes
|
291 |
+
for k in sorted(input_dict.keys()):
|
292 |
+
names.append(k)
|
293 |
+
values.append(input_dict[k])
|
294 |
+
values = torch.stack(values, dim=0)
|
295 |
+
dist.reduce(values, dst=0)
|
296 |
+
if dist.get_rank() == 0 and average:
|
297 |
+
# only main process gets accumulated, so only divide by
|
298 |
+
# world_size in this case
|
299 |
+
values /= world_size
|
300 |
+
reduced_dict = {k: v for k, v in zip(names, values)}
|
301 |
+
return reduced_dict
|
302 |
+
|
303 |
+
|
304 |
+
def broadcast_data(data):
|
305 |
+
if not torch.distributed.is_initialized():
|
306 |
+
return data
|
307 |
+
rank = dist.get_rank()
|
308 |
+
if rank == 0:
|
309 |
+
data_tensor = torch.tensor(data + [0], device="cuda")
|
310 |
+
else:
|
311 |
+
data_tensor = torch.tensor(data + [1], device="cuda")
|
312 |
+
torch.distributed.broadcast(data_tensor, 0)
|
313 |
+
while data_tensor.cpu().numpy()[-1] == 1:
|
314 |
+
time.sleep(1)
|
315 |
+
|
316 |
+
return data_tensor.cpu().numpy().tolist()[:-1]
|
317 |
+
|
318 |
+
|
319 |
+
def reduce_sum(tensor):
|
320 |
+
if get_world_size() <= 1:
|
321 |
+
return tensor
|
322 |
+
|
323 |
+
tensor = tensor.clone()
|
324 |
+
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
|
325 |
+
return tensor
|
326 |
+
|
327 |
+
|
328 |
+
def save_result(result, filename):
|
329 |
+
output_folder = os.path.dirname(filename)
|
330 |
+
basename = os.path.splitext(os.path.basename(filename))[0]
|
331 |
+
os.makedirs(output_folder, exist_ok=True)
|
332 |
+
|
333 |
+
if isinstance(result, torch.Tensor) and result.ndim in [3,4]:
|
334 |
+
if result.ndim==3 and result.size(0) not in [1,3]:
|
335 |
+
result = make_grid(result.unsqueeze(1))
|
336 |
+
elif result.ndim==4:
|
337 |
+
result = make_grid(result)
|
338 |
+
else:
|
339 |
+
result = make_grid([result])
|
340 |
+
|
341 |
+
im = Image.fromarray(result.clamp_(0, 255).permute(1, 2, 0).to(torch.uint8).numpy())
|
342 |
+
im.save(os.path.join(output_folder, '{}.png'.format(basename)))
|
343 |
+
else:
|
344 |
+
torch.save(result, os.path.join(output_folder, '{}.pth'.format(basename)))
|
MedImageInsight/Distributed/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .Utils import analysis_model
|
2 |
+
from .Utils import is_main_process
|
3 |
+
from .Utils import gather_tensors
|
4 |
+
from .Utils import register_norm_module
|
5 |
+
from .Utils import NORM_MODULES
|
6 |
+
from .Utils import DistributionGridFactory
|
MedImageInsight/ImageDataLoader/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .build import build_dataloader
|
2 |
+
#from .build import build_multitask_dataloader
|
3 |
+
from .transforms import build_transforms
|
4 |
+
#from .imagenet.real_labels import RealLabelsImagenet
|
5 |
+
from .constants import IMAGENET_CLASSES
|
6 |
+
from .constants import IMAGENET_DEFAULT_TEMPLATES
|
7 |
+
from .zipdata import ZipData
|
8 |
+
#from .vision_dataset import VDImageTextDataset, MultiClassTorchDatasetWrapper
|
MedImageInsight/ImageDataLoader/blob_storage.py
ADDED
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import shutil
|
4 |
+
import logging
|
5 |
+
import subprocess
|
6 |
+
import os.path as op
|
7 |
+
from typing import List
|
8 |
+
from collections import OrderedDict
|
9 |
+
|
10 |
+
import torch.distributed as distributed
|
11 |
+
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
DEFAULT_AZCOPY_PATH = 'azcopy/azcopy'
|
15 |
+
|
16 |
+
|
17 |
+
def disk_usage(path: str) -> float:
|
18 |
+
stat = shutil.disk_usage(path)
|
19 |
+
return stat.used / stat.total
|
20 |
+
|
21 |
+
|
22 |
+
def is_download_successful(stdout: str) -> bool:
|
23 |
+
for line in stdout.split('\n'):
|
24 |
+
if line == "Number of Transfers Failed: 0":
|
25 |
+
return True
|
26 |
+
logger.info("Azcopy message:\n %s" % stdout)
|
27 |
+
return False
|
28 |
+
|
29 |
+
|
30 |
+
def ensure_directory(path):
|
31 |
+
"""Check existence of the given directory path. If not, create a new directory.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
path (str): path of a given directory.
|
35 |
+
"""
|
36 |
+
if path == '' or path == '.':
|
37 |
+
return
|
38 |
+
if path is not None and len(path) > 0:
|
39 |
+
assert not op.isfile(path), '{} is a file'.format(path)
|
40 |
+
if not op.exists(path) and not op.islink(path):
|
41 |
+
os.makedirs(path, exist_ok=True)
|
42 |
+
# we should always check if it succeeds.
|
43 |
+
assert op.isdir(op.abspath(path)), path
|
44 |
+
|
45 |
+
|
46 |
+
class LRU(OrderedDict):
|
47 |
+
def __init__(self, maxsize=3):
|
48 |
+
self.maxsize = maxsize
|
49 |
+
|
50 |
+
def __getitem__(self, key):
|
51 |
+
value = super().__getitem__(key)
|
52 |
+
self.move_to_end(key)
|
53 |
+
return value
|
54 |
+
|
55 |
+
def __setitem__(self, key, value):
|
56 |
+
if key in self:
|
57 |
+
if self[key] is not None:
|
58 |
+
self[key].close()
|
59 |
+
self.move_to_end(key)
|
60 |
+
|
61 |
+
logger.debug('=> Cache {}'.format(key))
|
62 |
+
super().__setitem__(key, value)
|
63 |
+
|
64 |
+
if len(self) > self.maxsize:
|
65 |
+
oldest = next(iter(self))
|
66 |
+
if self[oldest] is not None:
|
67 |
+
self[oldest].close()
|
68 |
+
logger.debug('=> Purged {}'.format(oldest))
|
69 |
+
del self[oldest]
|
70 |
+
|
71 |
+
|
72 |
+
class BlobStorage(OrderedDict):
|
73 |
+
""" Pseudo Blob Storage manager
|
74 |
+
|
75 |
+
The registered blobs are maintained in a LRU cache.
|
76 |
+
Limit size, evicting the least recently looked-up key when full.
|
77 |
+
https://docs.python.org/3/library/collections.html#collections.OrderedDict
|
78 |
+
|
79 |
+
Input argument:
|
80 |
+
sas_token (str): path to SAS token.
|
81 |
+
"""
|
82 |
+
def __init__(self,
|
83 |
+
is_train: bool,
|
84 |
+
sas_token_path: str = None,
|
85 |
+
azcopy_path: str = None,
|
86 |
+
*args, **kwds):
|
87 |
+
super().__init__(*args, **kwds)
|
88 |
+
self.maxsize = 2 if is_train else 10 # Set maxsize to large number such val data never get purged.
|
89 |
+
self.is_train = is_train
|
90 |
+
|
91 |
+
if sas_token_path:
|
92 |
+
self.sas_token = BlobStorage.read_sas_token(sas_token_path)
|
93 |
+
self.base_url = self.sas_token[:self.sas_token.index("?")]
|
94 |
+
self.query_string = self.sas_token[self.sas_token.index("?"):]
|
95 |
+
self.container = BlobStorage.extract_container(self.sas_token)
|
96 |
+
else:
|
97 |
+
self.sas_token = None
|
98 |
+
self.base_url = None
|
99 |
+
self.query_string = None
|
100 |
+
self.container = None
|
101 |
+
|
102 |
+
logger.debug(
|
103 |
+
f"=> [BlobStorage] Base url: {self.base_url}"
|
104 |
+
f"=> [BlobStorage] Query string: {self.query_string}"
|
105 |
+
f"=> [BlobStorage] Container name: {self.container}"
|
106 |
+
)
|
107 |
+
|
108 |
+
self.azcopy_path = azcopy_path if azcopy_path else DEFAULT_AZCOPY_PATH
|
109 |
+
self._cached_files = LRU(3)
|
110 |
+
|
111 |
+
def __getitem__(self, key):
|
112 |
+
value = super().__getitem__(key)
|
113 |
+
self.move_to_end(key)
|
114 |
+
return value
|
115 |
+
|
116 |
+
def __setitem__(self, key, value):
|
117 |
+
if key in self:
|
118 |
+
self.move_to_end(key)
|
119 |
+
super().__setitem__(key, value)
|
120 |
+
# NOTE: purge the least recently used data if the disk usage is high.
|
121 |
+
# ITP restarts GPU clusters when disk usage reaches 80%.
|
122 |
+
if len(self) > self.maxsize:
|
123 |
+
oldest = next(iter(self))
|
124 |
+
del self[oldest]
|
125 |
+
|
126 |
+
@staticmethod
|
127 |
+
def read_sas_token(path: str) -> str:
|
128 |
+
with open(path, 'r') as f:
|
129 |
+
token = f.readline().strip()
|
130 |
+
return token
|
131 |
+
|
132 |
+
@staticmethod
|
133 |
+
def extract_container(token: str) -> str:
|
134 |
+
"""
|
135 |
+
Input argument:
|
136 |
+
token (str): the full URI of Shared Access Signature (SAS) in the following format.
|
137 |
+
https://[storage_account].blob.core.windows.net/[container_name][SAS_token]
|
138 |
+
"""
|
139 |
+
return os.path.basename(token.split('?')[0])
|
140 |
+
|
141 |
+
def _convert_to_blob_url(self, local_path: str):
|
142 |
+
return self.base_url + local_path.split("azcopy")[1] + self.query_string
|
143 |
+
|
144 |
+
def _convert_to_blob_folder_url(self, local_path: str):
|
145 |
+
return self.base_url + local_path.split("azcopy")[1] + "/*" + self.query_string
|
146 |
+
|
147 |
+
def fetch_blob(self, local_path: str) -> None:
|
148 |
+
if op.exists(local_path):
|
149 |
+
logger.info('=> Try to open {}'.format(local_path))
|
150 |
+
fp = open(local_path, 'r')
|
151 |
+
self._cached_files[local_path] = fp
|
152 |
+
logger.debug("=> %s downloaded. Skip." % local_path)
|
153 |
+
return
|
154 |
+
blob_url = self._convert_to_blob_url(local_path)
|
155 |
+
rank = '0' if 'RANK' not in os.environ else os.environ['RANK']
|
156 |
+
cmd = [self.azcopy_path, "copy", blob_url, local_path + rank]
|
157 |
+
curr_usage = disk_usage('/')
|
158 |
+
logger.info(
|
159 |
+
"=> Downloading %s with azcopy ... (disk usage: %.2f%%)"
|
160 |
+
% (local_path, curr_usage * 100)
|
161 |
+
)
|
162 |
+
proc = subprocess.run(cmd, stdout=subprocess.PIPE)
|
163 |
+
while not is_download_successful(proc.stdout.decode()):
|
164 |
+
logger.info("=> Azcopy failed to download {}. Retrying ...".format(blob_url))
|
165 |
+
proc = subprocess.run(cmd, stdout=subprocess.PIPE)
|
166 |
+
if not op.exists(local_path):
|
167 |
+
os.rename(local_path + rank, local_path)
|
168 |
+
else:
|
169 |
+
os.remove(local_path + rank)
|
170 |
+
logger.info(
|
171 |
+
"=> Downloaded %s with azcopy ... (disk usage: %.2f%% => %.2f%%)" %
|
172 |
+
(local_path, curr_usage * 100, disk_usage('/') * 100)
|
173 |
+
)
|
174 |
+
|
175 |
+
def fetch_blob_folder(self, local_path: str, azcopy_args: list=[]) -> None:
|
176 |
+
blob_url = self._convert_to_blob_folder_url(local_path)
|
177 |
+
cmd = [self.azcopy_path, "copy", blob_url, local_path] + azcopy_args
|
178 |
+
curr_usage = disk_usage('/')
|
179 |
+
logger.info(
|
180 |
+
"=> Downloading %s with azcopy args %s ... (disk usage: %.2f%%)"
|
181 |
+
% (local_path, ' '.join(azcopy_args), curr_usage * 100)
|
182 |
+
)
|
183 |
+
proc = subprocess.run(cmd, stdout=subprocess.PIPE)
|
184 |
+
while not is_download_successful(proc.stdout.decode()):
|
185 |
+
logger.info("=> Azcopy failed to download {} with args {}. Retrying ...".format(blob_url, ' '.join(azcopy_args)))
|
186 |
+
proc = subprocess.run(cmd, stdout=subprocess.PIPE)
|
187 |
+
logger.info(
|
188 |
+
"=> Downloaded %s with azcopy args %s ... (disk usage: %.2f%% => %.2f%%)" %
|
189 |
+
(local_path, ' '.join(azcopy_args), curr_usage * 100, disk_usage('/') * 100)
|
190 |
+
)
|
191 |
+
|
192 |
+
def register_local_tsv_paths(self, local_paths: List[str]) -> List[str]:
|
193 |
+
if self.sas_token:
|
194 |
+
tsv_paths_new = []
|
195 |
+
lineidx_paths = set()
|
196 |
+
linelist_paths = set()
|
197 |
+
for path in local_paths:
|
198 |
+
tsv_path_az = path.replace(self.container, 'azcopy')
|
199 |
+
tsv_paths_new.append(tsv_path_az)
|
200 |
+
logger.debug("=> Registering {}".format(tsv_path_az))
|
201 |
+
|
202 |
+
if not self.is_train:
|
203 |
+
logger.info('=> Downloading {}...'.format(tsv_path_az))
|
204 |
+
self.fetch_blob(tsv_path_az)
|
205 |
+
logger.info('=> Downloaded {}'.format(tsv_path_az))
|
206 |
+
|
207 |
+
lineidx = op.splitext(path)[0] + '.lineidx'
|
208 |
+
lineidx_ = lineidx.replace(self.container, 'azcopy')
|
209 |
+
if self.is_train:
|
210 |
+
if not op.isfile(lineidx_) and op.dirname(lineidx_) not in lineidx_paths:
|
211 |
+
lineidx_paths.add(op.dirname(lineidx_))
|
212 |
+
else:
|
213 |
+
if not op.isfile(lineidx_):
|
214 |
+
ensure_directory(op.dirname(lineidx_))
|
215 |
+
self.fetch_blob(lineidx_)
|
216 |
+
|
217 |
+
linelist = op.splitext(path)[0] + '.linelist'
|
218 |
+
linelist_ = linelist.replace(self.container, 'azcopy')
|
219 |
+
# .linelist does not always exist. Check existence before fetch
|
220 |
+
if self.is_train:
|
221 |
+
if op.isfile(linelist) and not op.isfile(linelist_) and op.dirname(linelist_) not in linelist_paths:
|
222 |
+
linelist_paths.add(op.dirname(linelist_))
|
223 |
+
else:
|
224 |
+
if op.isfile(linelist) and not op.isfile(linelist_):
|
225 |
+
ensure_directory(op.dirname(linelist_))
|
226 |
+
self.fetch_blob(linelist_)
|
227 |
+
|
228 |
+
if self.is_train:
|
229 |
+
for path in lineidx_paths:
|
230 |
+
self.fetch_blob_folder(path, azcopy_args=['--include-pattern', '*.lineidx'])
|
231 |
+
|
232 |
+
for path in linelist_paths:
|
233 |
+
self.fetch_blob_folder(path, azcopy_args=['--include-pattern', '*.linelist'])
|
234 |
+
|
235 |
+
return tsv_paths_new
|
236 |
+
else:
|
237 |
+
return local_paths
|
238 |
+
|
239 |
+
def open(self, local_path: str):
|
240 |
+
if self.sas_token and 'azcopy' in local_path:
|
241 |
+
while not op.exists(local_path):
|
242 |
+
time.sleep(1)
|
243 |
+
fid = open(local_path, 'r')
|
244 |
+
return fid
|
MedImageInsight/ImageDataLoader/build.py
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import division
|
3 |
+
from __future__ import print_function
|
4 |
+
|
5 |
+
import logging
|
6 |
+
import os
|
7 |
+
import json
|
8 |
+
import pathlib
|
9 |
+
from os.path import basename
|
10 |
+
|
11 |
+
from timm.data import create_loader
|
12 |
+
import torch
|
13 |
+
import torch.utils.data
|
14 |
+
import torch.distributed as dist
|
15 |
+
import torchvision.datasets as datasets
|
16 |
+
from torchvision.io import read_image
|
17 |
+
import torch.distributed as dist
|
18 |
+
from pathlib import Path
|
19 |
+
from yacs.config import CfgNode as CN
|
20 |
+
|
21 |
+
from ..LangEncoder import build_tokenizer
|
22 |
+
|
23 |
+
from .tsv import TSVImageTextDatasetV2
|
24 |
+
from .tsv import TSVMeta
|
25 |
+
from .transforms import build_transforms
|
26 |
+
|
27 |
+
logger = logging.getLogger(__name__)
|
28 |
+
|
29 |
+
|
30 |
+
def build_dataset(cfg, is_train):
|
31 |
+
if cfg['DATASET']['DATASET'] == 'image_text_pairs_v2':
|
32 |
+
dataset = _build_pairs_dataset_v2(cfg, is_train)
|
33 |
+
else:
|
34 |
+
raise ValueError(f'Unknown dataset: {cfg["DATASET"]["DATASET"]}')
|
35 |
+
return dataset
|
36 |
+
|
37 |
+
|
38 |
+
def _get_tsv_list(cfg, is_train):
|
39 |
+
tmp_list = []
|
40 |
+
if is_train and 'TRAIN_TSV_LIST' in cfg['DATASET']:
|
41 |
+
tmp_list = cfg['DATASET']['TRAIN_TSV_LIST']
|
42 |
+
elif 'TEST_TSV_LIST' in cfg['DATASET']:
|
43 |
+
tmp_list = cfg['DATASET']['TEST_TSV_LIST']
|
44 |
+
|
45 |
+
tsv_list = []
|
46 |
+
for l in tmp_list:
|
47 |
+
if l.endswith('.list'):
|
48 |
+
with open(l, 'r') as f:
|
49 |
+
tsv_list.extend([i.strip() for i in f])
|
50 |
+
else:
|
51 |
+
tsv_list.append(l)
|
52 |
+
|
53 |
+
logger.info(f'tsv list: {tsv_list}')
|
54 |
+
|
55 |
+
return tsv_list
|
56 |
+
|
57 |
+
|
58 |
+
def _get_token_file(cfg):
|
59 |
+
num_nodes = dist.get_world_size() // torch.cuda.device_count()
|
60 |
+
if isinstance(cfg['DATASET']['TOKEN_FILE'], list):
|
61 |
+
if num_nodes == 1:
|
62 |
+
logger.warning('=> Multi token files are provided, but only one node is used for training')
|
63 |
+
sas_token_file = cfg['DATASET']['TOKEN_FILE'][0]
|
64 |
+
else:
|
65 |
+
rank = dist.get_rank()
|
66 |
+
node_idx = rank // torch.cuda.device_count()
|
67 |
+
num_token_files = len(cfg['DATASET']['TOKEN_FILE'])
|
68 |
+
sas_token_file = cfg['DATASET']['TOKEN_FILE'][node_idx % num_token_files]
|
69 |
+
else:
|
70 |
+
sas_token_file = cfg['DATASET']['TOKEN_FILE']
|
71 |
+
|
72 |
+
sas_token_file = os.path.join(cfg['DATASET']['ROOT'], sas_token_file)
|
73 |
+
|
74 |
+
if (
|
75 |
+
cfg['DATASET']['LOADER'] == 'blobfuse'
|
76 |
+
or not os.path.isfile(sas_token_file)
|
77 |
+
):
|
78 |
+
sas_token_file = None
|
79 |
+
|
80 |
+
return sas_token_file
|
81 |
+
|
82 |
+
|
83 |
+
def _build_pairs_dataset_v2(cfg, is_train):
|
84 |
+
transforms = build_transforms(cfg, is_train)
|
85 |
+
logger.info('transforms: {}'.format(transforms))
|
86 |
+
|
87 |
+
dataset_name = cfg['DATASET']['TRAIN_SET'] \
|
88 |
+
if is_train else cfg['DATASET']['TEST_SET']
|
89 |
+
|
90 |
+
tokenobj = build_tokenizer(cfg['LANG_ENCODER'])
|
91 |
+
|
92 |
+
if cfg['DATASET']['DATA_FORMAT'] != 'tsv':
|
93 |
+
raise ValueError('Only support tsv format for pairs dataset v2')
|
94 |
+
|
95 |
+
tsv_list = _get_tsv_list(cfg, is_train)
|
96 |
+
|
97 |
+
if len(tsv_list) > 0:
|
98 |
+
tsv_filenames = sorted(
|
99 |
+
[
|
100 |
+
os.path.join(cfg['DATASET']['ROOT'], dataset_name, f)
|
101 |
+
for f in tsv_list
|
102 |
+
]
|
103 |
+
)
|
104 |
+
else:
|
105 |
+
dataset_path = os.path.join(cfg['DATASET']['ROOT'], dataset_name)
|
106 |
+
tsv_files = Path(dataset_path).glob('**/*.tsv')
|
107 |
+
|
108 |
+
tsv_filenames = sorted(
|
109 |
+
[
|
110 |
+
str(path)
|
111 |
+
for path in tsv_files
|
112 |
+
]
|
113 |
+
)
|
114 |
+
|
115 |
+
image_tsv_files = [
|
116 |
+
filename
|
117 |
+
for filename in tsv_filenames
|
118 |
+
if (
|
119 |
+
'image-' in basename(filename)
|
120 |
+
or 'image_' in basename(filename)
|
121 |
+
or '_image' in basename(filename)
|
122 |
+
or '-image' in basename(filename)
|
123 |
+
or 'images-' in basename(filename)
|
124 |
+
)
|
125 |
+
]
|
126 |
+
text_tsv_files = [
|
127 |
+
filename
|
128 |
+
for filename in tsv_filenames
|
129 |
+
if (
|
130 |
+
'text-' in basename(filename)
|
131 |
+
or 'text_' in basename(filename)
|
132 |
+
or '_text' in basename(filename)
|
133 |
+
or '-text' in basename(filename)
|
134 |
+
or 'texts-' in basename(filename)
|
135 |
+
)
|
136 |
+
]
|
137 |
+
|
138 |
+
logger.info(
|
139 |
+
"=> found %d/%d tsv file(s) to load.",
|
140 |
+
len(image_tsv_files), len(text_tsv_files)
|
141 |
+
)
|
142 |
+
|
143 |
+
num_captions = 1 \
|
144 |
+
if is_train else cfg['DATASET'].get('NUM_CAPTIONS', 1)
|
145 |
+
text_format = cfg['DATASET'].get('TEXT_FORMAT', 'json')
|
146 |
+
|
147 |
+
sas_token_file = _get_token_file(cfg)
|
148 |
+
logger.info("=> SAS token path: %s", sas_token_file)
|
149 |
+
|
150 |
+
metas = []
|
151 |
+
cfg_data = cfg['DATASET']
|
152 |
+
if 'CLASSIFICATION_SETS' in cfg_data and 'NUM_CLASSES' in cfg_data:
|
153 |
+
for source, num_classes in zip(cfg_data['CLASSIFICATION_SETS'], cfg_data['NUM_CLASSES']):
|
154 |
+
metas.append(
|
155 |
+
TSVMeta(
|
156 |
+
source=source,
|
157 |
+
num_classes=num_classes,
|
158 |
+
task='classification'
|
159 |
+
)
|
160 |
+
)
|
161 |
+
logger.info('=> add meta: {}'.format(metas[-1]))
|
162 |
+
|
163 |
+
if 'coco-caption' in dataset_name:
|
164 |
+
logger.info('=> coco caption data is used')
|
165 |
+
logger.info('=> update num_captions: 5, text_format: json')
|
166 |
+
logger.warning('=> set sas token to None for coco evaluation')
|
167 |
+
sas_token_file = None
|
168 |
+
num_captions = 5
|
169 |
+
text_format = 'json'
|
170 |
+
|
171 |
+
dataset = TSVImageTextDatasetV2(
|
172 |
+
image_tsv_files, text_tsv_files,
|
173 |
+
transform=transforms,
|
174 |
+
tokenize=tokenobj,
|
175 |
+
context_length=cfg['LANG_ENCODER']['CONTEXT_LENGTH'],
|
176 |
+
num_captions=num_captions,
|
177 |
+
text_format=text_format,
|
178 |
+
is_train=is_train,
|
179 |
+
sas_token_path=sas_token_file,
|
180 |
+
metas=metas,
|
181 |
+
prompt_engineering=cfg['DATASET'].get('PROMPT_ENGINEERING', True),
|
182 |
+
concat_queries=cfg['DATASET'].get('CONCAT_QUERIES', False)
|
183 |
+
)
|
184 |
+
|
185 |
+
logger.info(
|
186 |
+
"=> %s set size: %d", 'train'
|
187 |
+
if is_train else 'val', len(dataset)
|
188 |
+
)
|
189 |
+
|
190 |
+
return dataset
|
191 |
+
|
192 |
+
|
193 |
+
def build_dataloader(cfg, is_train=True, distributed=False):
|
194 |
+
dataset = build_dataset(cfg, is_train)
|
195 |
+
|
196 |
+
if (
|
197 |
+
is_train
|
198 |
+
and 'TIMM_AUG' in cfg['AUG']
|
199 |
+
and cfg['AUG']['TIMM_AUG']['USE_LOADER']
|
200 |
+
):
|
201 |
+
logger.info('=> use timm loader for training')
|
202 |
+
timm_cfg = CN(init_dict=cfg['AUG']['TIMM_AUG'])
|
203 |
+
data_loader = create_loader(
|
204 |
+
dataset,
|
205 |
+
input_size=cfg['IMAGE_ENCODER']['IMAGE_SIZE'][0],
|
206 |
+
batch_size=cfg['TRAIN']['BATCH_SIZE_PER_GPU'],
|
207 |
+
is_training=True,
|
208 |
+
use_prefetcher=True,
|
209 |
+
no_aug=False,
|
210 |
+
re_prob=timm_cfg.RE_PROB,
|
211 |
+
re_mode=timm_cfg.RE_MODE,
|
212 |
+
re_count=timm_cfg.RE_COUNT,
|
213 |
+
re_split=timm_cfg.RE_SPLIT,
|
214 |
+
scale=cfg['AUG']['SCALE'],
|
215 |
+
ratio=cfg['AUG']['RATIO'],
|
216 |
+
hflip=timm_cfg.HFLIP,
|
217 |
+
vflip=timm_cfg.VFLIP,
|
218 |
+
color_jitter=timm_cfg.COLOR_JITTER,
|
219 |
+
auto_augment=timm_cfg.AUTO_AUGMENT,
|
220 |
+
num_aug_splits=0,
|
221 |
+
interpolation=cfg['AUG']['INTERPOLATION'],
|
222 |
+
mean=cfg['IMAGE_ENCODER']['IMAGE_MEAN'],
|
223 |
+
std=cfg['IMAGE_ENCODER']['IMAGE_STD'],
|
224 |
+
num_workers=cfg['WORKERS'],
|
225 |
+
distributed=distributed,
|
226 |
+
collate_fn=None,
|
227 |
+
pin_memory=cfg['PIN_MEMORY'],
|
228 |
+
use_multi_epochs_loader=True
|
229 |
+
)
|
230 |
+
else:
|
231 |
+
if is_train:
|
232 |
+
batch_size_per_gpu = cfg['TRAIN']['BATCH_SIZE_PER_GPU']
|
233 |
+
shuffle = cfg['TRAIN'].get('SHUFFLE', True)
|
234 |
+
else:
|
235 |
+
batch_size_per_gpu = cfg['TEST']['BATCH_SIZE_PER_GPU']
|
236 |
+
shuffle = cfg['TEST'].get('SHUFFLE', False)
|
237 |
+
|
238 |
+
if distributed or cfg.get('ALWAYS_ENABLE_SAMPLER', False):
|
239 |
+
# sampler = build_sampler(cfg, dataset, is_train, shuffle)
|
240 |
+
sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=shuffle)
|
241 |
+
shuffle = False
|
242 |
+
else:
|
243 |
+
sampler = None
|
244 |
+
|
245 |
+
data_loader = torch.utils.data.DataLoader(
|
246 |
+
dataset,
|
247 |
+
batch_size=batch_size_per_gpu,
|
248 |
+
shuffle=shuffle,
|
249 |
+
num_workers=cfg['WORKERS'],
|
250 |
+
pin_memory=cfg['PIN_MEMORY'],
|
251 |
+
sampler=sampler,
|
252 |
+
drop_last=True if is_train else False,
|
253 |
+
prefetch_factor=cfg.get('PREFETCH_FACTOR', 2)
|
254 |
+
)
|
255 |
+
|
256 |
+
return data_loader
|
257 |
+
|
258 |
+
|
259 |
+
|
260 |
+
|
MedImageInsight/ImageDataLoader/constants.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
IMAGENET_CLASSES = ["tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray", "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper", "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander", "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog", "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin", "box turtle", "banded gecko", "green iguana", "Carolina anole", "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard", "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile", "American alligator", "triceratops", "worm snake", "ring-necked snake", "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake", "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra", "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake", "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider", "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider", "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl", "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet", "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck", "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab", "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron", "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot", "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher", "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion", "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel", "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle", "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound", "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound", "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound", "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier", "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier", "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier", "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier", "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer", "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier", "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier", "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever", "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla", "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel", "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel", "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard", "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie", "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann", "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog", "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff", "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky", "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog", "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon", "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle", "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf", "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox", "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat", "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose", "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle", "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper", "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper", "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly", "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly", "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit", "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse", "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison", "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)", "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat", "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan", "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque", "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin", "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey", "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda", "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish", "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown", "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance", "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle", "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo", "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel", "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel", "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)", "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini", "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet", "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra", "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest", "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe", "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton", "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran", "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw", "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking", "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker", "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard", "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot", "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed", "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table", "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig", "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar", "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder", "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute", "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed", "freight car", "French horn", "frying pan", "fur coat", "garbage truck", "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola", "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine", "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer", "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet", "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar", "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep", "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat", "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library", "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion", "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag", "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask", "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone", "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile", "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor", "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa", "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail", "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina", "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart", "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush", "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench", "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case", "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube", "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball", "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag", "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho", "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug", "printer", "prison", "projectile", "projector", "hockey puck", "punching bag", "purse", "quill", "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel", "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator", "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser", "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal", "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard", "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store", "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap", "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door", "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock", "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater", "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight", "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf", "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa", "submarine", "suit", "sundial", "sunglasses", "dark glasses", "sunscreen", "suspension bridge", "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe", "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball", "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof", "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store", "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod", "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard", "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling", "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball", "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink", "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle", "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing", "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website", "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu", "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette", "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli", "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber", "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange", "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate", "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito", "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef", "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player", "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn", "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom", "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"]
|
2 |
+
|
3 |
+
IMAGENET_DEFAULT_TEMPLATES = [
|
4 |
+
'{}.',
|
5 |
+
'a bad photo of a {}.',
|
6 |
+
'a photo of many {}.',
|
7 |
+
'a sculpture of a {}.',
|
8 |
+
'a photo of the hard to see {}.',
|
9 |
+
'a low resolution photo of the {}.',
|
10 |
+
'a rendering of a {}.',
|
11 |
+
'graffiti of a {}.',
|
12 |
+
'a bad photo of the {}.',
|
13 |
+
'a cropped photo of the {}.',
|
14 |
+
'a tattoo of a {}.',
|
15 |
+
'the embroidered {}.',
|
16 |
+
'a photo of a hard to see {}.',
|
17 |
+
'a bright photo of a {}.',
|
18 |
+
'a photo of a clean {}.',
|
19 |
+
'a photo of a dirty {}.',
|
20 |
+
'a dark photo of the {}.',
|
21 |
+
'a drawing of a {}.',
|
22 |
+
'a photo of my {}.',
|
23 |
+
'the plastic {}.',
|
24 |
+
'a photo of the cool {}.',
|
25 |
+
'a close-up photo of a {}.',
|
26 |
+
'a black and white photo of the {}.',
|
27 |
+
'a painting of the {}.',
|
28 |
+
'a painting of a {}.',
|
29 |
+
'a pixelated photo of the {}.',
|
30 |
+
'a sculpture of the {}.',
|
31 |
+
'a bright photo of the {}.',
|
32 |
+
'a cropped photo of a {}.',
|
33 |
+
'a plastic {}.',
|
34 |
+
'a photo of the dirty {}.',
|
35 |
+
'a jpeg corrupted photo of a {}.',
|
36 |
+
'a blurry photo of the {}.',
|
37 |
+
'a photo of the {}.',
|
38 |
+
'a good photo of the {}.',
|
39 |
+
'a rendering of the {}.',
|
40 |
+
'a {} in a video game.',
|
41 |
+
'a photo of one {}.',
|
42 |
+
'a doodle of a {}.',
|
43 |
+
'a close-up photo of the {}.',
|
44 |
+
'a photo of a {}.',
|
45 |
+
'the origami {}.',
|
46 |
+
'the {} in a video game.',
|
47 |
+
'a sketch of a {}.',
|
48 |
+
'a doodle of the {}.',
|
49 |
+
'a origami {}.',
|
50 |
+
'a low resolution photo of a {}.',
|
51 |
+
'the toy {}.',
|
52 |
+
'a rendition of the {}.',
|
53 |
+
'a photo of the clean {}.',
|
54 |
+
'a photo of a large {}.',
|
55 |
+
'a rendition of a {}.',
|
56 |
+
'a photo of a nice {}.',
|
57 |
+
'a photo of a weird {}.',
|
58 |
+
'a blurry photo of a {}.',
|
59 |
+
'a cartoon {}.',
|
60 |
+
'art of a {}.',
|
61 |
+
'a sketch of the {}.',
|
62 |
+
'a embroidered {}.',
|
63 |
+
'a pixelated photo of a {}.',
|
64 |
+
'itap of the {}.',
|
65 |
+
'a jpeg corrupted photo of the {}.',
|
66 |
+
'a good photo of a {}.',
|
67 |
+
'a plushie {}.',
|
68 |
+
'a photo of the nice {}.',
|
69 |
+
'a photo of the small {}.',
|
70 |
+
'a photo of the weird {}.',
|
71 |
+
'the cartoon {}.',
|
72 |
+
'art of the {}.',
|
73 |
+
'a drawing of the {}.',
|
74 |
+
'a photo of the large {}.',
|
75 |
+
'a black and white photo of a {}.',
|
76 |
+
'the plushie {}.',
|
77 |
+
'a dark photo of a {}.',
|
78 |
+
'itap of a {}.',
|
79 |
+
'graffiti of the {}.',
|
80 |
+
'a toy {}.',
|
81 |
+
'itap of my {}.',
|
82 |
+
'a photo of a cool {}.',
|
83 |
+
'a photo of a small {}.',
|
84 |
+
'a tattoo of the {}.',
|
85 |
+
]
|
MedImageInsight/ImageDataLoader/languages/__init__.py
ADDED
File without changes
|
MedImageInsight/ImageDataLoader/languages/prompt_engineering.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import random
|
3 |
+
|
4 |
+
|
5 |
+
def get_prompt_templates():
|
6 |
+
prompt_templates = [
|
7 |
+
'{}.',
|
8 |
+
'a photo of a {}.',
|
9 |
+
'a bad photo of a {}.',
|
10 |
+
'a photo of many {}.',
|
11 |
+
'a sculpture of a {}.',
|
12 |
+
'a photo of the hard to see {}.',
|
13 |
+
'a low resolution photo of the {}.',
|
14 |
+
'a rendering of a {}.',
|
15 |
+
'graffiti of a {}.',
|
16 |
+
'a bad photo of the {}.',
|
17 |
+
'a cropped photo of the {}.',
|
18 |
+
'a tattoo of a {}.',
|
19 |
+
'the embroidered {}.',
|
20 |
+
'a photo of a hard to see {}.',
|
21 |
+
'a bright photo of a {}.',
|
22 |
+
'a photo of a clean {}.',
|
23 |
+
'a photo of a dirty {}.',
|
24 |
+
'a dark photo of the {}.',
|
25 |
+
'a drawing of a {}.',
|
26 |
+
'a photo of my {}.',
|
27 |
+
'the plastic {}.',
|
28 |
+
'a photo of the cool {}.',
|
29 |
+
'a close-up photo of a {}.',
|
30 |
+
'a black and white photo of the {}.',
|
31 |
+
'a painting of the {}.',
|
32 |
+
'a painting of a {}.',
|
33 |
+
'a pixelated photo of the {}.',
|
34 |
+
'a sculpture of the {}.',
|
35 |
+
'a bright photo of the {}.',
|
36 |
+
'a cropped photo of a {}.',
|
37 |
+
'a plastic {}.',
|
38 |
+
'a photo of the dirty {}.',
|
39 |
+
'a jpeg corrupted photo of a {}.',
|
40 |
+
'a blurry photo of the {}.',
|
41 |
+
'a photo of the {}.',
|
42 |
+
'a good photo of the {}.',
|
43 |
+
'a rendering of the {}.',
|
44 |
+
'a {} in a video game.',
|
45 |
+
'a photo of one {}.',
|
46 |
+
'a doodle of a {}.',
|
47 |
+
'a close-up photo of the {}.',
|
48 |
+
'the origami {}.',
|
49 |
+
'the {} in a video game.',
|
50 |
+
'a sketch of a {}.',
|
51 |
+
'a doodle of the {}.',
|
52 |
+
'a origami {}.',
|
53 |
+
'a low resolution photo of a {}.',
|
54 |
+
'the toy {}.',
|
55 |
+
'a rendition of the {}.',
|
56 |
+
'a photo of the clean {}.',
|
57 |
+
'a photo of a large {}.',
|
58 |
+
'a rendition of a {}.',
|
59 |
+
'a photo of a nice {}.',
|
60 |
+
'a photo of a weird {}.',
|
61 |
+
'a blurry photo of a {}.',
|
62 |
+
'a cartoon {}.',
|
63 |
+
'art of a {}.',
|
64 |
+
'a sketch of the {}.',
|
65 |
+
'a embroidered {}.',
|
66 |
+
'a pixelated photo of a {}.',
|
67 |
+
'itap of the {}.',
|
68 |
+
'a jpeg corrupted photo of the {}.',
|
69 |
+
'a good photo of a {}.',
|
70 |
+
'a plushie {}.',
|
71 |
+
'a photo of the nice {}.',
|
72 |
+
'a photo of the small {}.',
|
73 |
+
'a photo of the weird {}.',
|
74 |
+
'the cartoon {}.',
|
75 |
+
'art of the {}.',
|
76 |
+
'a drawing of the {}.',
|
77 |
+
'a photo of the large {}.',
|
78 |
+
'a black and white photo of a {}.',
|
79 |
+
'the plushie {}.',
|
80 |
+
'a dark photo of a {}.',
|
81 |
+
'itap of a {}.',
|
82 |
+
'graffiti of the {}.',
|
83 |
+
'a toy {}.',
|
84 |
+
'itap of my {}.',
|
85 |
+
'a photo of a cool {}.',
|
86 |
+
'a photo of a small {}.',
|
87 |
+
'a tattoo of the {}.',
|
88 |
+
]
|
89 |
+
return prompt_templates
|
90 |
+
|
91 |
+
|
92 |
+
def prompt_engineering(classnames):
|
93 |
+
prompt_templates = get_prompt_templates()
|
94 |
+
temp_idx = np.random.randint(len(prompt_templates))
|
95 |
+
|
96 |
+
if isinstance(classnames, list):
|
97 |
+
classname = random.choice(classnames)
|
98 |
+
else:
|
99 |
+
classname = classnames
|
100 |
+
|
101 |
+
return prompt_templates[temp_idx].replace('{}', classname.replace(',', '').replace('+', ' '))
|
MedImageInsight/ImageDataLoader/transforms/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .build import build_transforms
|
MedImageInsight/ImageDataLoader/transforms/autoaugment.py
ADDED
@@ -0,0 +1,447 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from enum import Enum
|
3 |
+
from typing import List, Tuple, Optional, Dict
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch import Tensor
|
7 |
+
|
8 |
+
from torchvision.transforms import functional as F
|
9 |
+
from torchvision.transforms.functional import InterpolationMode
|
10 |
+
|
11 |
+
__all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment", "TrivialAugmentWide"]
|
12 |
+
|
13 |
+
|
14 |
+
def _apply_op(
|
15 |
+
img: Tensor, op_name: str, magnitude: float, interpolation: InterpolationMode, fill: Optional[List[float]]
|
16 |
+
):
|
17 |
+
if op_name == "ShearX":
|
18 |
+
img = F.affine(
|
19 |
+
img,
|
20 |
+
angle=0.0,
|
21 |
+
translate=[0, 0],
|
22 |
+
scale=1.0,
|
23 |
+
shear=[math.degrees(magnitude), 0.0],
|
24 |
+
interpolation=interpolation,
|
25 |
+
fill=fill,
|
26 |
+
)
|
27 |
+
elif op_name == "ShearY":
|
28 |
+
img = F.affine(
|
29 |
+
img,
|
30 |
+
angle=0.0,
|
31 |
+
translate=[0, 0],
|
32 |
+
scale=1.0,
|
33 |
+
shear=[0.0, math.degrees(magnitude)],
|
34 |
+
interpolation=interpolation,
|
35 |
+
fill=fill,
|
36 |
+
)
|
37 |
+
elif op_name == "TranslateX":
|
38 |
+
img = F.affine(
|
39 |
+
img,
|
40 |
+
angle=0.0,
|
41 |
+
translate=[int(magnitude), 0],
|
42 |
+
scale=1.0,
|
43 |
+
interpolation=interpolation,
|
44 |
+
shear=[0.0, 0.0],
|
45 |
+
fill=fill,
|
46 |
+
)
|
47 |
+
elif op_name == "TranslateY":
|
48 |
+
img = F.affine(
|
49 |
+
img,
|
50 |
+
angle=0.0,
|
51 |
+
translate=[0, int(magnitude)],
|
52 |
+
scale=1.0,
|
53 |
+
interpolation=interpolation,
|
54 |
+
shear=[0.0, 0.0],
|
55 |
+
fill=fill,
|
56 |
+
)
|
57 |
+
elif op_name == "Rotate":
|
58 |
+
img = F.rotate(img, magnitude, interpolation=interpolation, fill=fill)
|
59 |
+
elif op_name == "Brightness":
|
60 |
+
img = F.adjust_brightness(img, 1.0 + magnitude)
|
61 |
+
elif op_name == "Color":
|
62 |
+
img = F.adjust_saturation(img, 1.0 + magnitude)
|
63 |
+
elif op_name == "Contrast":
|
64 |
+
img = F.adjust_contrast(img, 1.0 + magnitude)
|
65 |
+
elif op_name == "Sharpness":
|
66 |
+
img = F.adjust_sharpness(img, 1.0 + magnitude)
|
67 |
+
elif op_name == "Posterize":
|
68 |
+
img = F.posterize(img, int(magnitude))
|
69 |
+
elif op_name == "Solarize":
|
70 |
+
img = F.solarize(img, magnitude)
|
71 |
+
elif op_name == "AutoContrast":
|
72 |
+
img = F.autocontrast(img)
|
73 |
+
elif op_name == "Equalize":
|
74 |
+
img = F.equalize(img)
|
75 |
+
elif op_name == "Invert":
|
76 |
+
img = F.invert(img)
|
77 |
+
elif op_name == "Identity":
|
78 |
+
pass
|
79 |
+
else:
|
80 |
+
raise ValueError(f"The provided operator {op_name} is not recognized.")
|
81 |
+
return img
|
82 |
+
|
83 |
+
|
84 |
+
class AutoAugmentPolicy(Enum):
|
85 |
+
"""AutoAugment policies learned on different datasets.
|
86 |
+
Available policies are IMAGENET, CIFAR10 and SVHN.
|
87 |
+
"""
|
88 |
+
|
89 |
+
IMAGENET = "imagenet"
|
90 |
+
CIFAR10 = "cifar10"
|
91 |
+
SVHN = "svhn"
|
92 |
+
|
93 |
+
|
94 |
+
# FIXME: Eliminate copy-pasted code for fill standardization and _augmentation_space() by moving stuff on a base class
|
95 |
+
class AutoAugment(torch.nn.Module):
|
96 |
+
r"""AutoAugment data augmentation method based on
|
97 |
+
`"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_.
|
98 |
+
If the image is torch Tensor, it should be of type torch.uint8, and it is expected
|
99 |
+
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
|
100 |
+
If img is PIL Image, it is expected to be in mode "L" or "RGB".
|
101 |
+
|
102 |
+
Args:
|
103 |
+
policy (AutoAugmentPolicy): Desired policy enum defined by
|
104 |
+
:class:`torchvision.transforms.autoaugment.AutoAugmentPolicy`. Default is ``AutoAugmentPolicy.IMAGENET``.
|
105 |
+
interpolation (InterpolationMode): Desired interpolation enum defined by
|
106 |
+
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
|
107 |
+
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
|
108 |
+
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
|
109 |
+
image. If given a number, the value is used for all bands respectively.
|
110 |
+
"""
|
111 |
+
|
112 |
+
def __init__(
|
113 |
+
self,
|
114 |
+
policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
|
115 |
+
interpolation: InterpolationMode = InterpolationMode.NEAREST,
|
116 |
+
fill: Optional[List[float]] = None,
|
117 |
+
) -> None:
|
118 |
+
super().__init__()
|
119 |
+
self.policy = policy
|
120 |
+
self.interpolation = interpolation
|
121 |
+
self.fill = fill
|
122 |
+
self.policies = self._get_policies(policy)
|
123 |
+
|
124 |
+
def _get_policies(
|
125 |
+
self, policy: AutoAugmentPolicy
|
126 |
+
) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]:
|
127 |
+
if policy == AutoAugmentPolicy.IMAGENET:
|
128 |
+
return [
|
129 |
+
(("Posterize", 0.4, 8), ("Rotate", 0.6, 9)),
|
130 |
+
(("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
|
131 |
+
(("Equalize", 0.8, None), ("Equalize", 0.6, None)),
|
132 |
+
(("Posterize", 0.6, 7), ("Posterize", 0.6, 6)),
|
133 |
+
(("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
|
134 |
+
(("Equalize", 0.4, None), ("Rotate", 0.8, 8)),
|
135 |
+
(("Solarize", 0.6, 3), ("Equalize", 0.6, None)),
|
136 |
+
(("Posterize", 0.8, 5), ("Equalize", 1.0, None)),
|
137 |
+
(("Rotate", 0.2, 3), ("Solarize", 0.6, 8)),
|
138 |
+
(("Equalize", 0.6, None), ("Posterize", 0.4, 6)),
|
139 |
+
(("Rotate", 0.8, 8), ("Color", 0.4, 0)),
|
140 |
+
(("Rotate", 0.4, 9), ("Equalize", 0.6, None)),
|
141 |
+
(("Equalize", 0.0, None), ("Equalize", 0.8, None)),
|
142 |
+
(("Invert", 0.6, None), ("Equalize", 1.0, None)),
|
143 |
+
(("Color", 0.6, 4), ("Contrast", 1.0, 8)),
|
144 |
+
(("Rotate", 0.8, 8), ("Color", 1.0, 2)),
|
145 |
+
(("Color", 0.8, 8), ("Solarize", 0.8, 7)),
|
146 |
+
(("Sharpness", 0.4, 7), ("Invert", 0.6, None)),
|
147 |
+
(("ShearX", 0.6, 5), ("Equalize", 1.0, None)),
|
148 |
+
(("Color", 0.4, 0), ("Equalize", 0.6, None)),
|
149 |
+
(("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
|
150 |
+
(("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
|
151 |
+
(("Invert", 0.6, None), ("Equalize", 1.0, None)),
|
152 |
+
(("Color", 0.6, 4), ("Contrast", 1.0, 8)),
|
153 |
+
(("Equalize", 0.8, None), ("Equalize", 0.6, None)),
|
154 |
+
]
|
155 |
+
elif policy == AutoAugmentPolicy.CIFAR10:
|
156 |
+
return [
|
157 |
+
(("Invert", 0.1, None), ("Contrast", 0.2, 6)),
|
158 |
+
(("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)),
|
159 |
+
(("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)),
|
160 |
+
(("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)),
|
161 |
+
(("AutoContrast", 0.5, None), ("Equalize", 0.9, None)),
|
162 |
+
(("ShearY", 0.2, 7), ("Posterize", 0.3, 7)),
|
163 |
+
(("Color", 0.4, 3), ("Brightness", 0.6, 7)),
|
164 |
+
(("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)),
|
165 |
+
(("Equalize", 0.6, None), ("Equalize", 0.5, None)),
|
166 |
+
(("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)),
|
167 |
+
(("Color", 0.7, 7), ("TranslateX", 0.5, 8)),
|
168 |
+
(("Equalize", 0.3, None), ("AutoContrast", 0.4, None)),
|
169 |
+
(("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)),
|
170 |
+
(("Brightness", 0.9, 6), ("Color", 0.2, 8)),
|
171 |
+
(("Solarize", 0.5, 2), ("Invert", 0.0, None)),
|
172 |
+
(("Equalize", 0.2, None), ("AutoContrast", 0.6, None)),
|
173 |
+
(("Equalize", 0.2, None), ("Equalize", 0.6, None)),
|
174 |
+
(("Color", 0.9, 9), ("Equalize", 0.6, None)),
|
175 |
+
(("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)),
|
176 |
+
(("Brightness", 0.1, 3), ("Color", 0.7, 0)),
|
177 |
+
(("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)),
|
178 |
+
(("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)),
|
179 |
+
(("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)),
|
180 |
+
(("Equalize", 0.8, None), ("Invert", 0.1, None)),
|
181 |
+
(("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)),
|
182 |
+
]
|
183 |
+
elif policy == AutoAugmentPolicy.SVHN:
|
184 |
+
return [
|
185 |
+
(("ShearX", 0.9, 4), ("Invert", 0.2, None)),
|
186 |
+
(("ShearY", 0.9, 8), ("Invert", 0.7, None)),
|
187 |
+
(("Equalize", 0.6, None), ("Solarize", 0.6, 6)),
|
188 |
+
(("Invert", 0.9, None), ("Equalize", 0.6, None)),
|
189 |
+
(("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
|
190 |
+
(("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)),
|
191 |
+
(("ShearY", 0.9, 8), ("Invert", 0.4, None)),
|
192 |
+
(("ShearY", 0.9, 5), ("Solarize", 0.2, 6)),
|
193 |
+
(("Invert", 0.9, None), ("AutoContrast", 0.8, None)),
|
194 |
+
(("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
|
195 |
+
(("ShearX", 0.9, 4), ("Solarize", 0.3, 3)),
|
196 |
+
(("ShearY", 0.8, 8), ("Invert", 0.7, None)),
|
197 |
+
(("Equalize", 0.9, None), ("TranslateY", 0.6, 6)),
|
198 |
+
(("Invert", 0.9, None), ("Equalize", 0.6, None)),
|
199 |
+
(("Contrast", 0.3, 3), ("Rotate", 0.8, 4)),
|
200 |
+
(("Invert", 0.8, None), ("TranslateY", 0.0, 2)),
|
201 |
+
(("ShearY", 0.7, 6), ("Solarize", 0.4, 8)),
|
202 |
+
(("Invert", 0.6, None), ("Rotate", 0.8, 4)),
|
203 |
+
(("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)),
|
204 |
+
(("ShearX", 0.1, 6), ("Invert", 0.6, None)),
|
205 |
+
(("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)),
|
206 |
+
(("ShearY", 0.8, 4), ("Invert", 0.8, None)),
|
207 |
+
(("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)),
|
208 |
+
(("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)),
|
209 |
+
(("ShearX", 0.7, 2), ("Invert", 0.1, None)),
|
210 |
+
]
|
211 |
+
else:
|
212 |
+
raise ValueError(f"The provided policy {policy} is not recognized.")
|
213 |
+
|
214 |
+
def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]:
|
215 |
+
return {
|
216 |
+
# op_name: (magnitudes, signed)
|
217 |
+
"ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
|
218 |
+
"ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
|
219 |
+
"TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
|
220 |
+
"TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
|
221 |
+
"Rotate": (torch.linspace(0.0, 30.0, num_bins), True),
|
222 |
+
"Brightness": (torch.linspace(0.0, 0.9, num_bins), True),
|
223 |
+
"Color": (torch.linspace(0.0, 0.9, num_bins), True),
|
224 |
+
"Contrast": (torch.linspace(0.0, 0.9, num_bins), True),
|
225 |
+
"Sharpness": (torch.linspace(0.0, 0.9, num_bins), True),
|
226 |
+
"Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False),
|
227 |
+
"Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
|
228 |
+
"AutoContrast": (torch.tensor(0.0), False),
|
229 |
+
"Equalize": (torch.tensor(0.0), False),
|
230 |
+
"Invert": (torch.tensor(0.0), False),
|
231 |
+
}
|
232 |
+
|
233 |
+
@staticmethod
|
234 |
+
def get_params(transform_num: int) -> Tuple[int, Tensor, Tensor]:
|
235 |
+
"""Get parameters for autoaugment transformation
|
236 |
+
|
237 |
+
Returns:
|
238 |
+
params required by the autoaugment transformation
|
239 |
+
"""
|
240 |
+
policy_id = int(torch.randint(transform_num, (1,)).item())
|
241 |
+
probs = torch.rand((2,))
|
242 |
+
signs = torch.randint(2, (2,))
|
243 |
+
|
244 |
+
return policy_id, probs, signs
|
245 |
+
|
246 |
+
def forward(self, img: Tensor) -> Tensor:
|
247 |
+
"""
|
248 |
+
img (PIL Image or Tensor): Image to be transformed.
|
249 |
+
|
250 |
+
Returns:
|
251 |
+
PIL Image or Tensor: AutoAugmented image.
|
252 |
+
"""
|
253 |
+
fill = self.fill
|
254 |
+
if isinstance(img, Tensor):
|
255 |
+
if isinstance(fill, (int, float)):
|
256 |
+
fill = [float(fill)] * F.get_image_num_channels(img)
|
257 |
+
elif fill is not None:
|
258 |
+
fill = [float(f) for f in fill]
|
259 |
+
|
260 |
+
transform_id, probs, signs = self.get_params(len(self.policies))
|
261 |
+
|
262 |
+
for i, (op_name, p, magnitude_id) in enumerate(self.policies[transform_id]):
|
263 |
+
if probs[i] <= p:
|
264 |
+
op_meta = self._augmentation_space(10, F.get_image_size(img))
|
265 |
+
magnitudes, signed = op_meta[op_name]
|
266 |
+
magnitude = float(magnitudes[magnitude_id].item()) if magnitude_id is not None else 0.0
|
267 |
+
if signed and signs[i] == 0:
|
268 |
+
magnitude *= -1.0
|
269 |
+
img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
|
270 |
+
|
271 |
+
return img
|
272 |
+
|
273 |
+
def __repr__(self) -> str:
|
274 |
+
return self.__class__.__name__ + f"(policy={self.policy}, fill={self.fill})"
|
275 |
+
|
276 |
+
|
277 |
+
class RandAugment(torch.nn.Module):
|
278 |
+
r"""RandAugment data augmentation method based on
|
279 |
+
`"RandAugment: Practical automated data augmentation with a reduced search space"
|
280 |
+
<https://arxiv.org/abs/1909.13719>`_.
|
281 |
+
If the image is torch Tensor, it should be of type torch.uint8, and it is expected
|
282 |
+
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
|
283 |
+
If img is PIL Image, it is expected to be in mode "L" or "RGB".
|
284 |
+
|
285 |
+
Args:
|
286 |
+
num_ops (int): Number of augmentation transformations to apply sequentially.
|
287 |
+
magnitude (int): Magnitude for all the transformations.
|
288 |
+
num_magnitude_bins (int): The number of different magnitude values.
|
289 |
+
interpolation (InterpolationMode): Desired interpolation enum defined by
|
290 |
+
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
|
291 |
+
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
|
292 |
+
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
|
293 |
+
image. If given a number, the value is used for all bands respectively.
|
294 |
+
"""
|
295 |
+
|
296 |
+
def __init__(
|
297 |
+
self,
|
298 |
+
num_ops: int = 2,
|
299 |
+
magnitude: int = 9,
|
300 |
+
num_magnitude_bins: int = 31,
|
301 |
+
interpolation: InterpolationMode = InterpolationMode.NEAREST,
|
302 |
+
fill: Optional[List[float]] = None,
|
303 |
+
) -> None:
|
304 |
+
super().__init__()
|
305 |
+
self.num_ops = num_ops
|
306 |
+
self.magnitude = magnitude
|
307 |
+
self.num_magnitude_bins = num_magnitude_bins
|
308 |
+
self.interpolation = interpolation
|
309 |
+
self.fill = fill
|
310 |
+
|
311 |
+
def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]:
|
312 |
+
return {
|
313 |
+
# op_name: (magnitudes, signed)
|
314 |
+
"Identity": (torch.tensor(0.0), False),
|
315 |
+
"ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
|
316 |
+
"ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
|
317 |
+
"TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
|
318 |
+
"TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
|
319 |
+
"Rotate": (torch.linspace(0.0, 30.0, num_bins), True),
|
320 |
+
"Brightness": (torch.linspace(0.0, 0.9, num_bins), True),
|
321 |
+
"Color": (torch.linspace(0.0, 0.9, num_bins), True),
|
322 |
+
"Contrast": (torch.linspace(0.0, 0.9, num_bins), True),
|
323 |
+
"Sharpness": (torch.linspace(0.0, 0.9, num_bins), True),
|
324 |
+
"Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False),
|
325 |
+
"Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
|
326 |
+
"AutoContrast": (torch.tensor(0.0), False),
|
327 |
+
"Equalize": (torch.tensor(0.0), False),
|
328 |
+
}
|
329 |
+
|
330 |
+
def forward(self, img: Tensor) -> Tensor:
|
331 |
+
"""
|
332 |
+
img (PIL Image or Tensor): Image to be transformed.
|
333 |
+
|
334 |
+
Returns:
|
335 |
+
PIL Image or Tensor: Transformed image.
|
336 |
+
"""
|
337 |
+
fill = self.fill
|
338 |
+
if isinstance(img, Tensor):
|
339 |
+
if isinstance(fill, (int, float)):
|
340 |
+
fill = [float(fill)] * F.get_image_num_channels(img)
|
341 |
+
elif fill is not None:
|
342 |
+
fill = [float(f) for f in fill]
|
343 |
+
|
344 |
+
for _ in range(self.num_ops):
|
345 |
+
op_meta = self._augmentation_space(self.num_magnitude_bins, F.get_image_size(img))
|
346 |
+
op_index = int(torch.randint(len(op_meta), (1,)).item())
|
347 |
+
op_name = list(op_meta.keys())[op_index]
|
348 |
+
magnitudes, signed = op_meta[op_name]
|
349 |
+
magnitude = float(magnitudes[self.magnitude].item()) if magnitudes.ndim > 0 else 0.0
|
350 |
+
if signed and torch.randint(2, (1,)):
|
351 |
+
magnitude *= -1.0
|
352 |
+
img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
|
353 |
+
|
354 |
+
return img
|
355 |
+
|
356 |
+
def __repr__(self) -> str:
|
357 |
+
s = self.__class__.__name__ + "("
|
358 |
+
s += "num_ops={num_ops}"
|
359 |
+
s += ", magnitude={magnitude}"
|
360 |
+
s += ", num_magnitude_bins={num_magnitude_bins}"
|
361 |
+
s += ", interpolation={interpolation}"
|
362 |
+
s += ", fill={fill}"
|
363 |
+
s += ")"
|
364 |
+
return s.format(**self.__dict__)
|
365 |
+
|
366 |
+
|
367 |
+
class TrivialAugmentWide(torch.nn.Module):
|
368 |
+
r"""Dataset-independent data-augmentation with TrivialAugment Wide, as described in
|
369 |
+
`"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" <https://arxiv.org/abs/2103.10158>`_.
|
370 |
+
If the image is torch Tensor, it should be of type torch.uint8, and it is expected
|
371 |
+
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
|
372 |
+
If img is PIL Image, it is expected to be in mode "L" or "RGB".
|
373 |
+
|
374 |
+
Args:
|
375 |
+
num_magnitude_bins (int): The number of different magnitude values.
|
376 |
+
interpolation (InterpolationMode): Desired interpolation enum defined by
|
377 |
+
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
|
378 |
+
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
|
379 |
+
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
|
380 |
+
image. If given a number, the value is used for all bands respectively.
|
381 |
+
"""
|
382 |
+
|
383 |
+
def __init__(
|
384 |
+
self,
|
385 |
+
num_magnitude_bins: int = 31,
|
386 |
+
interpolation: InterpolationMode = InterpolationMode.NEAREST,
|
387 |
+
fill: Optional[List[float]] = None,
|
388 |
+
) -> None:
|
389 |
+
super().__init__()
|
390 |
+
self.num_magnitude_bins = num_magnitude_bins
|
391 |
+
self.interpolation = interpolation
|
392 |
+
self.fill = fill
|
393 |
+
|
394 |
+
def _augmentation_space(self, num_bins: int) -> Dict[str, Tuple[Tensor, bool]]:
|
395 |
+
return {
|
396 |
+
# op_name: (magnitudes, signed)
|
397 |
+
"Identity": (torch.tensor(0.0), False),
|
398 |
+
"ShearX": (torch.linspace(0.0, 0.99, num_bins), True),
|
399 |
+
"ShearY": (torch.linspace(0.0, 0.99, num_bins), True),
|
400 |
+
"TranslateX": (torch.linspace(0.0, 32.0, num_bins), True),
|
401 |
+
"TranslateY": (torch.linspace(0.0, 32.0, num_bins), True),
|
402 |
+
"Rotate": (torch.linspace(0.0, 135.0, num_bins), True),
|
403 |
+
"Brightness": (torch.linspace(0.0, 0.99, num_bins), True),
|
404 |
+
"Color": (torch.linspace(0.0, 0.99, num_bins), True),
|
405 |
+
"Contrast": (torch.linspace(0.0, 0.99, num_bins), True),
|
406 |
+
"Sharpness": (torch.linspace(0.0, 0.99, num_bins), True),
|
407 |
+
"Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6)).round().int(), False),
|
408 |
+
"Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
|
409 |
+
"AutoContrast": (torch.tensor(0.0), False),
|
410 |
+
"Equalize": (torch.tensor(0.0), False),
|
411 |
+
}
|
412 |
+
|
413 |
+
def forward(self, img: Tensor) -> Tensor:
|
414 |
+
"""
|
415 |
+
img (PIL Image or Tensor): Image to be transformed.
|
416 |
+
|
417 |
+
Returns:
|
418 |
+
PIL Image or Tensor: Transformed image.
|
419 |
+
"""
|
420 |
+
fill = self.fill
|
421 |
+
if isinstance(img, Tensor):
|
422 |
+
if isinstance(fill, (int, float)):
|
423 |
+
fill = [float(fill)] * F.get_image_num_channels(img)
|
424 |
+
elif fill is not None:
|
425 |
+
fill = [float(f) for f in fill]
|
426 |
+
|
427 |
+
op_meta = self._augmentation_space(self.num_magnitude_bins)
|
428 |
+
op_index = int(torch.randint(len(op_meta), (1,)).item())
|
429 |
+
op_name = list(op_meta.keys())[op_index]
|
430 |
+
magnitudes, signed = op_meta[op_name]
|
431 |
+
magnitude = (
|
432 |
+
float(magnitudes[torch.randint(len(magnitudes), (1,), dtype=torch.long)].item())
|
433 |
+
if magnitudes.ndim > 0
|
434 |
+
else 0.0
|
435 |
+
)
|
436 |
+
if signed and torch.randint(2, (1,)):
|
437 |
+
magnitude *= -1.0
|
438 |
+
|
439 |
+
return _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
|
440 |
+
|
441 |
+
def __repr__(self) -> str:
|
442 |
+
s = self.__class__.__name__ + "("
|
443 |
+
s += "num_magnitude_bins={num_magnitude_bins}"
|
444 |
+
s += ", interpolation={interpolation}"
|
445 |
+
s += ", fill={fill}"
|
446 |
+
s += ")"
|
447 |
+
return s.format(**self.__dict__)
|
MedImageInsight/ImageDataLoader/transforms/build.py
ADDED
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import division
|
3 |
+
from __future__ import print_function
|
4 |
+
|
5 |
+
import timm
|
6 |
+
from timm.data import create_transform
|
7 |
+
|
8 |
+
from yacs.config import CfgNode as CN
|
9 |
+
from PIL import ImageFilter
|
10 |
+
import logging
|
11 |
+
import random
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torchvision.transforms as T
|
15 |
+
|
16 |
+
|
17 |
+
from .autoaugment import AutoAugmentPolicy
|
18 |
+
from .autoaugment import AutoAugment
|
19 |
+
from .autoaugment import RandAugment
|
20 |
+
from .autoaugment import TrivialAugmentWide
|
21 |
+
from .threeaugment import deitIII_Solarization
|
22 |
+
from .threeaugment import deitIII_gray_scale
|
23 |
+
from .threeaugment import deitIII_GaussianBlur
|
24 |
+
|
25 |
+
from PIL import ImageOps
|
26 |
+
from timm.data.transforms import RandomResizedCropAndInterpolation
|
27 |
+
|
28 |
+
logger = logging.getLogger(__name__)
|
29 |
+
|
30 |
+
|
31 |
+
class GaussianBlur(object):
|
32 |
+
"""Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""
|
33 |
+
|
34 |
+
def __init__(self, sigma=[.1, 2.]):
|
35 |
+
self.sigma = sigma
|
36 |
+
|
37 |
+
def __call__(self, x):
|
38 |
+
sigma = random.uniform(self.sigma[0], self.sigma[1])
|
39 |
+
x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
|
40 |
+
return x
|
41 |
+
|
42 |
+
|
43 |
+
def get_resolution(original_resolution):
|
44 |
+
"""Takes (H,W) and returns (precrop, crop)."""
|
45 |
+
area = original_resolution[0] * original_resolution[1]
|
46 |
+
return (160, 128) if area < 96*96 else (512, 480)
|
47 |
+
|
48 |
+
|
49 |
+
INTERPOLATION_MODES = {
|
50 |
+
'bilinear': T.InterpolationMode.BILINEAR,
|
51 |
+
'bicubic': T.InterpolationMode.BICUBIC,
|
52 |
+
'nearest': T.InterpolationMode.NEAREST,
|
53 |
+
}
|
54 |
+
|
55 |
+
|
56 |
+
def build_transforms(cfg, is_train=True):
|
57 |
+
# assert isinstance(cfg.DATASET.OUTPUT_SIZE, (list, tuple)), 'DATASET.OUTPUT_SIZE should be list or tuple'
|
58 |
+
normalize = T.Normalize(
|
59 |
+
mean=cfg['IMAGE_ENCODER']['IMAGE_MEAN'],
|
60 |
+
std=cfg['IMAGE_ENCODER']['IMAGE_STD']
|
61 |
+
)
|
62 |
+
|
63 |
+
transforms = None
|
64 |
+
if is_train:
|
65 |
+
if 'THREE_AUG' in cfg['AUG']:
|
66 |
+
img_size = cfg['IMAGE_ENCODER']['IMAGE_SIZE']
|
67 |
+
remove_random_resized_crop = cfg['AUG']['THREE_AUG']['SRC']
|
68 |
+
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
|
69 |
+
primary_tfl = []
|
70 |
+
scale=(0.08, 1.0)
|
71 |
+
interpolation='bicubic'
|
72 |
+
if remove_random_resized_crop:
|
73 |
+
primary_tfl = [
|
74 |
+
T.Resize(img_size, interpolation=3),
|
75 |
+
T.RandomCrop(img_size, padding=4,padding_mode='reflect'),
|
76 |
+
T.RandomHorizontalFlip()
|
77 |
+
]
|
78 |
+
else:
|
79 |
+
primary_tfl = [
|
80 |
+
RandomResizedCropAndInterpolation(
|
81 |
+
img_size, scale=scale, interpolation=interpolation),
|
82 |
+
T.RandomHorizontalFlip()
|
83 |
+
]
|
84 |
+
secondary_tfl = [T.RandomChoice([gray_scale(p=1.0),
|
85 |
+
Solarization(p=1.0),
|
86 |
+
GaussianBlurDeiTv3(p=1.0)])]
|
87 |
+
color_jitter = cfg['AUG']['THREE_AUG']['COLOR_JITTER']
|
88 |
+
if color_jitter is not None and not color_jitter==0:
|
89 |
+
secondary_tfl.append(T.ColorJitter(color_jitter, color_jitter, color_jitter))
|
90 |
+
final_tfl = [
|
91 |
+
T.ToTensor(),
|
92 |
+
T.Normalize(
|
93 |
+
mean=torch.tensor(mean),
|
94 |
+
std=torch.tensor(std))
|
95 |
+
]
|
96 |
+
return T.Compose(primary_tfl+secondary_tfl+final_tfl)
|
97 |
+
elif 'TIMM_AUG' in cfg['AUG'] and cfg['AUG']['TIMM_AUG']['USE_TRANSFORM']:
|
98 |
+
logger.info('=> use timm transform for training')
|
99 |
+
timm_cfg = cfg['AUG']['TIMM_AUG']
|
100 |
+
transforms = create_transform(
|
101 |
+
input_size=cfg['IMAGE_ENCODER']['IMAGE_SIZE'][0],
|
102 |
+
is_training=True,
|
103 |
+
use_prefetcher=False,
|
104 |
+
no_aug=False,
|
105 |
+
re_prob=timm_cfg.get('RE_PROB', 0.),
|
106 |
+
re_mode=timm_cfg.get('RE_MODE', 'const'),
|
107 |
+
re_count=timm_cfg.get('RE_COUNT', 1),
|
108 |
+
re_num_splits= 0 if not timm_cfg.get('RE_SPLITS', False) else timm_cfg['RE_SPLITS'], # if false or 0, return 0
|
109 |
+
scale=cfg['AUG'].get('SCALE', None),
|
110 |
+
ratio=cfg['AUG'].get('RATIO', None),
|
111 |
+
hflip=timm_cfg.get('HFLIP', 0.5),
|
112 |
+
vflip=timm_cfg.get('VFLIP', 0.),
|
113 |
+
color_jitter=timm_cfg.get('COLOR_JITTER', 0.4),
|
114 |
+
auto_augment=timm_cfg.get('AUTO_AUGMENT', None),
|
115 |
+
interpolation=cfg['AUG']['INTERPOLATION'],
|
116 |
+
mean=cfg['IMAGE_ENCODER']['IMAGE_MEAN'],
|
117 |
+
std=cfg['IMAGE_ENCODER']['IMAGE_STD'],
|
118 |
+
)
|
119 |
+
elif 'TORCHVISION_AUG' in cfg['AUG']:
|
120 |
+
logger.info('=> use torchvision transform fro training')
|
121 |
+
crop_size = cfg['IMAGE_ENCODER']['IMAGE_SIZE'][0]
|
122 |
+
interpolation = INTERPOLATION_MODES[cfg['AUG']['INTERPOLATION']]
|
123 |
+
trans = [
|
124 |
+
T.RandomResizedCrop(
|
125 |
+
crop_size, scale=cfg['AUG']['SCALE'], ratio=cfg['AUG']['RATIO'],
|
126 |
+
interpolation=interpolation
|
127 |
+
)
|
128 |
+
]
|
129 |
+
hflip_prob = cfg['AUG']['TORCHVISION_AUG']['HFLIP']
|
130 |
+
auto_augment_policy = cfg['AUG']['TORCHVISION_AUG'].get('AUTO_AUGMENT', None)
|
131 |
+
if hflip_prob > 0:
|
132 |
+
trans.append(T.RandomHorizontalFlip(hflip_prob))
|
133 |
+
if auto_augment_policy is not None:
|
134 |
+
if auto_augment_policy == "ra":
|
135 |
+
trans.append(RandAugment(interpolation=interpolation))
|
136 |
+
elif auto_augment_policy == "ta_wide":
|
137 |
+
trans.append(TrivialAugmentWide(interpolation=interpolation))
|
138 |
+
else:
|
139 |
+
aa_policy = AutoAugmentPolicy(auto_augment_policy)
|
140 |
+
trans.append(AutoAugment(policy=aa_policy, interpolation=interpolation))
|
141 |
+
trans.extend(
|
142 |
+
[
|
143 |
+
T.ToTensor(),
|
144 |
+
normalize,
|
145 |
+
]
|
146 |
+
)
|
147 |
+
random_erase_prob = cfg['AUG']['TORCHVISION_AUG']['RE_PROB']
|
148 |
+
random_erase_scale = cfg['AUG']['TORCHVISION_AUG'].get('RE_SCALE', 0.33)
|
149 |
+
if random_erase_prob > 0:
|
150 |
+
# NCFC (4/26/2023): Added scale parameter to random erasing for medical imaging
|
151 |
+
trans.append(T.RandomErasing(p=random_erase_prob, scale = (0.02, random_erase_scale)))
|
152 |
+
|
153 |
+
from torchvision.transforms import InterpolationMode
|
154 |
+
rotation = cfg['AUG']['TORCHVISION_AUG'].get('ROTATION', 0.0)
|
155 |
+
if (rotation > 0.0):
|
156 |
+
trans.append(T.RandomRotation(rotation, interpolation=InterpolationMode.BILINEAR))
|
157 |
+
logger.info(" TORCH AUG: Rotation: " + str(rotation))
|
158 |
+
|
159 |
+
transforms = T.Compose(trans)
|
160 |
+
elif cfg['AUG'].get('RANDOM_CENTER_CROP', False):
|
161 |
+
logger.info('=> use random center crop data augmenation')
|
162 |
+
# precrop, crop = get_resolution(cfg.TRAIN.IMAGE_SIZE)
|
163 |
+
crop = cfg['IMAGE_ENCODER']['IMAGE_SIZE'][0]
|
164 |
+
padding = cfg['AUG'].get('RANDOM_CENTER_CROP_PADDING', 32)
|
165 |
+
precrop = crop + padding
|
166 |
+
mode = INTERPOLATION_MODES[cfg['AUG']['INTERPOLATION']]
|
167 |
+
transforms = T.Compose([
|
168 |
+
T.Resize(
|
169 |
+
(precrop, precrop),
|
170 |
+
interpolation=mode
|
171 |
+
),
|
172 |
+
T.RandomCrop((crop, crop)),
|
173 |
+
T.RandomHorizontalFlip(),
|
174 |
+
T.ToTensor(),
|
175 |
+
normalize,
|
176 |
+
])
|
177 |
+
elif cfg['AUG'].get('MAE_FINETUNE_AUG', False):
|
178 |
+
mean = cfg['IMAGE_ENCODER']['IMAGE_MEAN']
|
179 |
+
std = cfg['IMAGE_ENCODER']['IMAGE_STD']
|
180 |
+
transforms = create_transform(
|
181 |
+
input_size=cfg['IMAGE_ENCODER']['IMAGE_SIZE'][0],
|
182 |
+
is_training=True,
|
183 |
+
color_jitter=cfg['AUG'].get('COLOR_JITTER', None),
|
184 |
+
auto_augment=cfg['AUG'].get('AUTO_AUGMENT', 'rand-m9-mstd0.5-inc1'),
|
185 |
+
interpolation='bicubic',
|
186 |
+
re_prob=cfg['AUG'].get('RE_PROB', 0.25),
|
187 |
+
re_mode=cfg['AUG'].get('RE_MODE', "pixel"),
|
188 |
+
re_count=cfg['AUG'].get('RE_COUNT', 1),
|
189 |
+
mean=mean,
|
190 |
+
std=std,
|
191 |
+
)
|
192 |
+
elif cfg['AUG'].get('MAE_PRETRAIN_AUG', False):
|
193 |
+
mean = cfg['IMAGE_ENCODER']['IMAGE_MEAN']
|
194 |
+
std = cfg['IMAGE_ENCODER']['IMAGE_STD']
|
195 |
+
transforms = T.Compose([
|
196 |
+
T.RandomResizedCrop(cfg['IMAGE_ENCODER']['IMAGE_SIZE'][0], scale=tuple(cfg['AUG']['SCALE']), interpolation=INTERPOLATION_MODES["bicubic"]), # 3 is bicubic
|
197 |
+
T.RandomHorizontalFlip(),
|
198 |
+
T.ToTensor(),
|
199 |
+
T.Normalize(mean=mean, std=std)])
|
200 |
+
elif cfg['AUG'].get('ThreeAugment', False): # from DeiT III
|
201 |
+
mean = cfg['IMAGE_ENCODER']['IMAGE_MEAN']
|
202 |
+
std = cfg['IMAGE_ENCODER']['IMAGE_STD']
|
203 |
+
img_size = cfg['IMAGE_ENCODER']['IMAGE_SIZE'][0]
|
204 |
+
remove_random_resized_crop = cfg['AUG'].get('src', False)
|
205 |
+
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
|
206 |
+
primary_tfl = []
|
207 |
+
scale=(0.08, 1.0)
|
208 |
+
interpolation='bicubic'
|
209 |
+
if remove_random_resized_crop:
|
210 |
+
primary_tfl = [
|
211 |
+
T.Resize(img_size, interpolation=3), # bicubic
|
212 |
+
T.RandomCrop(img_size, padding=4,padding_mode='reflect'),
|
213 |
+
T.RandomHorizontalFlip()
|
214 |
+
]
|
215 |
+
else:
|
216 |
+
primary_tfl = [
|
217 |
+
timm.data.transforms.RandomResizedCropAndInterpolation(
|
218 |
+
img_size, scale=scale, interpolation=interpolation),
|
219 |
+
T.RandomHorizontalFlip()
|
220 |
+
]
|
221 |
+
|
222 |
+
secondary_tfl = [T.RandomChoice([deitIII_gray_scale(p=1.0),
|
223 |
+
deitIII_Solarization(p=1.0),
|
224 |
+
deitIII_GaussianBlur(p=1.0)])]
|
225 |
+
color_jitter = cfg['AUG']['COLOR_JITTER']
|
226 |
+
secondary_tfl.append(T.ColorJitter(color_jitter, color_jitter, color_jitter))
|
227 |
+
final_tfl = [
|
228 |
+
T.ToTensor(),
|
229 |
+
T.Normalize(
|
230 |
+
mean=torch.tensor(mean),
|
231 |
+
std=torch.tensor(std))
|
232 |
+
]
|
233 |
+
transforms = T.Compose(primary_tfl+secondary_tfl+final_tfl)
|
234 |
+
logger.info('=> training transformers: {}'.format(transforms))
|
235 |
+
else:
|
236 |
+
mode = INTERPOLATION_MODES[cfg['AUG']['INTERPOLATION']]
|
237 |
+
if cfg['TEST']['CENTER_CROP']:
|
238 |
+
transforms = T.Compose([
|
239 |
+
T.Resize(
|
240 |
+
int(cfg['IMAGE_ENCODER']['IMAGE_SIZE'][0] / 0.875),
|
241 |
+
# the same behavior as in deit: size = int((256 / 224) * args.input_size)
|
242 |
+
# 224 / 256 = 0.875
|
243 |
+
interpolation=mode
|
244 |
+
),
|
245 |
+
T.CenterCrop(cfg['IMAGE_ENCODER']['IMAGE_SIZE'][0]),
|
246 |
+
T.ToTensor(),
|
247 |
+
normalize,
|
248 |
+
])
|
249 |
+
else:
|
250 |
+
transforms = T.Compose([
|
251 |
+
T.Resize(
|
252 |
+
(cfg['IMAGE_ENCODER']['IMAGE_SIZE'][1], cfg['IMAGE_ENCODER']['IMAGE_SIZE'][0]),
|
253 |
+
interpolation=mode
|
254 |
+
),
|
255 |
+
T.ToTensor(),
|
256 |
+
normalize,
|
257 |
+
])
|
258 |
+
logger.info('=> testing transformers: {}'.format(transforms))
|
259 |
+
|
260 |
+
return transforms
|
261 |
+
|
MedImageInsight/ImageDataLoader/transforms/threeaugment.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from PIL import ImageFilter, ImageOps
|
3 |
+
from torchvision import transforms
|
4 |
+
|
5 |
+
|
6 |
+
class deitIII_GaussianBlur(object):
|
7 |
+
"""
|
8 |
+
Apply Gaussian Blur to the PIL image.
|
9 |
+
"""
|
10 |
+
def __init__(self, p=0.1, radius_min=0.1, radius_max=2.):
|
11 |
+
self.prob = p
|
12 |
+
self.radius_min = radius_min
|
13 |
+
self.radius_max = radius_max
|
14 |
+
|
15 |
+
def __call__(self, img):
|
16 |
+
do_it = random.random() <= self.prob
|
17 |
+
if not do_it:
|
18 |
+
return img
|
19 |
+
|
20 |
+
img = img.filter(
|
21 |
+
ImageFilter.GaussianBlur(
|
22 |
+
radius=random.uniform(self.radius_min, self.radius_max)
|
23 |
+
)
|
24 |
+
)
|
25 |
+
return img
|
26 |
+
|
27 |
+
|
28 |
+
class deitIII_Solarization(object):
|
29 |
+
"""
|
30 |
+
Apply Solarization to the PIL image.
|
31 |
+
"""
|
32 |
+
def __init__(self, p=0.2):
|
33 |
+
self.p = p
|
34 |
+
|
35 |
+
def __call__(self, img):
|
36 |
+
if random.random() < self.p:
|
37 |
+
return ImageOps.solarize(img)
|
38 |
+
else:
|
39 |
+
return img
|
40 |
+
|
41 |
+
|
42 |
+
class deitIII_gray_scale(object):
|
43 |
+
"""
|
44 |
+
Apply Solarization to the PIL image.
|
45 |
+
"""
|
46 |
+
def __init__(self, p=0.2):
|
47 |
+
self.p = p
|
48 |
+
self.transf = transforms.Grayscale(3)
|
49 |
+
|
50 |
+
def __call__(self, img):
|
51 |
+
if random.random() < self.p:
|
52 |
+
return self.transf(img)
|
53 |
+
else:
|
54 |
+
return img
|
MedImageInsight/ImageDataLoader/tsv.py
ADDED
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import division
|
3 |
+
from __future__ import print_function
|
4 |
+
|
5 |
+
import os
|
6 |
+
from io import BytesIO
|
7 |
+
import json
|
8 |
+
import logging
|
9 |
+
import base64
|
10 |
+
import random
|
11 |
+
from typing import Callable, List, Tuple, Union, NamedTuple
|
12 |
+
from PIL import Image
|
13 |
+
from PIL import ImageFile
|
14 |
+
import torch.utils.data as data
|
15 |
+
from .languages.prompt_engineering import prompt_engineering
|
16 |
+
from .tsv_file import TSVFile, CompositeTSVFile
|
17 |
+
|
18 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
19 |
+
|
20 |
+
logger = logging.getLogger(__name__)
|
21 |
+
|
22 |
+
|
23 |
+
class TSVDataset(data.Dataset):
|
24 |
+
|
25 |
+
def __init__(self,
|
26 |
+
tsv_file: Union[str, List[str]],
|
27 |
+
transform: Callable = None,
|
28 |
+
map_file: str = None,
|
29 |
+
token_file: str = None,
|
30 |
+
is_train: bool = True,
|
31 |
+
azcopy_path: str = None):
|
32 |
+
self.transform = transform
|
33 |
+
self._chunk_sizes = None
|
34 |
+
self.label2idx = self._load_map(map_file)
|
35 |
+
self.class_selector = list(self.label2idx.keys()) if self.label2idx else None
|
36 |
+
|
37 |
+
if isinstance(tsv_file, str):
|
38 |
+
if os.path.splitext(tsv_file)[1] == '.tsv':
|
39 |
+
self.tsv_file = TSVFile(
|
40 |
+
tsv_file, class_selector=self.class_selector
|
41 |
+
)
|
42 |
+
else:
|
43 |
+
self.tsv_file = CompositeTSVFile(
|
44 |
+
tsv_file,
|
45 |
+
class_selector=self.class_selector,
|
46 |
+
is_train=is_train,
|
47 |
+
sas_token_path=token_file,
|
48 |
+
azcopy_path=azcopy_path
|
49 |
+
)
|
50 |
+
self._chunk_sizes = self.tsv_file.get_chunk_size()
|
51 |
+
elif isinstance(tsv_file, list):
|
52 |
+
self.tsv_file = CompositeTSVFile(
|
53 |
+
tsv_file,
|
54 |
+
class_selector=self.class_selector,
|
55 |
+
is_train=is_train,
|
56 |
+
sas_token_path=token_file,
|
57 |
+
azcopy_path=azcopy_path
|
58 |
+
)
|
59 |
+
self._chunk_sizes = self.tsv_file.get_chunk_size()
|
60 |
+
else:
|
61 |
+
raise ValueError("Invalid input! Please check the tsv filenames")
|
62 |
+
|
63 |
+
logger.debug('=> {}\titems: {}'.format(tsv_file, len(self.tsv_file)))
|
64 |
+
|
65 |
+
def fetch_blob(self, idx):
|
66 |
+
image_tsv = self.tsv_file.file_list[idx]
|
67 |
+
self.tsv_file.blob_storage.fetch_blob(image_tsv)
|
68 |
+
|
69 |
+
def num_classes(self):
|
70 |
+
return len(self.class_selector)
|
71 |
+
|
72 |
+
def get_chunk_sizes(self):
|
73 |
+
return self._chunk_sizes
|
74 |
+
|
75 |
+
def get_class_boundaries(self):
|
76 |
+
# The samples of each class are organized class-by-class.
|
77 |
+
# _class_boundaries stores the lower- and upper-bound of each class.
|
78 |
+
return self.tsv_file.get_class_boundaries()
|
79 |
+
|
80 |
+
def get_filenames(self):
|
81 |
+
filenames = [
|
82 |
+
self.tsv_file.get_key(i)
|
83 |
+
for i in range(self.tsv_file.num_rows())
|
84 |
+
]
|
85 |
+
|
86 |
+
return filenames
|
87 |
+
|
88 |
+
def _load_map(self, map_file: str):
|
89 |
+
if not map_file:
|
90 |
+
return None
|
91 |
+
|
92 |
+
label2idx = {}
|
93 |
+
with open(map_file) as f:
|
94 |
+
for line in f:
|
95 |
+
items = line.strip().split('\t')
|
96 |
+
label2idx[items[0]] = int(items[1])
|
97 |
+
|
98 |
+
return label2idx
|
99 |
+
|
100 |
+
def __getitem__(self, index: Union[int, Tuple[int, int]]):
|
101 |
+
items = self.tsv_file[index]
|
102 |
+
_, target, img = self._decode_data(items)
|
103 |
+
|
104 |
+
if self.transform:
|
105 |
+
img = self.transform(img)
|
106 |
+
|
107 |
+
return img, target
|
108 |
+
|
109 |
+
def _decode_data(self, items: Tuple[str, str, str]):
|
110 |
+
key = items[0]
|
111 |
+
label = self._get_label(items[1])
|
112 |
+
image = Image.open(BytesIO(base64.b64decode(items[2]))).convert('RGB')
|
113 |
+
|
114 |
+
return key, label, image
|
115 |
+
|
116 |
+
def _get_label(self, item: str):
|
117 |
+
if not self.label2idx:
|
118 |
+
return int(item)
|
119 |
+
|
120 |
+
js = json.loads(item)
|
121 |
+
return self.label2idx[js[0]['class']]
|
122 |
+
|
123 |
+
def __len__(self):
|
124 |
+
return len(self.tsv_file)
|
125 |
+
|
126 |
+
|
127 |
+
class TSVMeta(NamedTuple):
|
128 |
+
source: str
|
129 |
+
num_classes: int
|
130 |
+
task: str
|
131 |
+
|
132 |
+
|
133 |
+
class TSVImageTextDatasetV2(data.Dataset):
|
134 |
+
"""
|
135 |
+
This class is intended for encapsulating Image/Text pair data for contrastive learning described in
|
136 |
+
the following paper,
|
137 |
+
"Learning Transferable Visual Models From Natural Language Supervision" (a.k.a CLIP)
|
138 |
+
V2: support image text pairs and supervised classification data
|
139 |
+
"""
|
140 |
+
|
141 |
+
def __init__(self,
|
142 |
+
image_tsv_file: Union[str, List[str]],
|
143 |
+
text_tsv_file: Union[str, List[str]],
|
144 |
+
transform: Callable = None,
|
145 |
+
tokenize: Callable = None,
|
146 |
+
context_length: int = 77,
|
147 |
+
num_captions: int = 1,
|
148 |
+
text_format: str = 'txt',
|
149 |
+
is_train: bool = True,
|
150 |
+
sas_token_path: str = None,
|
151 |
+
azcopy_path: str = None,
|
152 |
+
metas: List[NamedTuple] = None,
|
153 |
+
prompt_engineering=True,
|
154 |
+
concat_queries=False):
|
155 |
+
self.transform = transform
|
156 |
+
self.tokenize = tokenize
|
157 |
+
self._chunk_sizes = None
|
158 |
+
self.context_length = context_length
|
159 |
+
self.num_captions = num_captions
|
160 |
+
self.text_format = text_format
|
161 |
+
self.tsv_file_list = []
|
162 |
+
self.metas = metas
|
163 |
+
self.label_offsets = self.build_label_offsets()
|
164 |
+
self.prompt_engineering = prompt_engineering
|
165 |
+
self.concat_queries = concat_queries
|
166 |
+
|
167 |
+
if isinstance(image_tsv_file, str) and isinstance(text_tsv_file, str):
|
168 |
+
# single tsv file
|
169 |
+
if (
|
170 |
+
os.path.splitext(image_tsv_file)[1].lower() == '.tsv'
|
171 |
+
and os.path.splitext(text_tsv_file)[1].lower() == '.tsv'
|
172 |
+
):
|
173 |
+
self.tsv_file_list.append((image_tsv_file, text_tsv_file))
|
174 |
+
self.image_tsv_file = TSVFile(
|
175 |
+
image_tsv_file, if_generate_lineidx=True
|
176 |
+
)
|
177 |
+
self.text_tsv_file = TSVFile(
|
178 |
+
text_tsv_file, if_generate_lineidx=True
|
179 |
+
)
|
180 |
+
else:
|
181 |
+
raise ValueError("Invalid input! Please check the tsv filenames.")
|
182 |
+
# multiple tsv files specified in a list
|
183 |
+
elif (
|
184 |
+
isinstance(image_tsv_file, list)
|
185 |
+
and isinstance(text_tsv_file, list)
|
186 |
+
):
|
187 |
+
assert len(image_tsv_file) == len(text_tsv_file), \
|
188 |
+
"Inconsistent number of Image/Text tsv files!"
|
189 |
+
self.tsv_file_list = [
|
190 |
+
(txt, img)
|
191 |
+
for img, txt in zip(image_tsv_file, text_tsv_file)
|
192 |
+
]
|
193 |
+
self.image_tsv_file = CompositeTSVFile(
|
194 |
+
image_tsv_file,
|
195 |
+
is_train=is_train,
|
196 |
+
sas_token_path=sas_token_path,
|
197 |
+
azcopy_path=azcopy_path
|
198 |
+
)
|
199 |
+
self.text_tsv_file = CompositeTSVFile(
|
200 |
+
text_tsv_file,
|
201 |
+
is_train=is_train,
|
202 |
+
sas_token_path=sas_token_path,
|
203 |
+
azcopy_path=azcopy_path
|
204 |
+
)
|
205 |
+
self._chunk_sizes = self.image_tsv_file.get_chunk_size()
|
206 |
+
else:
|
207 |
+
raise ValueError("Invalid input! Please check the tsv filenames.")
|
208 |
+
|
209 |
+
assert len(self.image_tsv_file) == len(self.text_tsv_file), \
|
210 |
+
"Inconsistent size of Image/Text ({}/{}) data!".format(
|
211 |
+
len(self.image_tsv_file), len(self.text_tsv_file)
|
212 |
+
)
|
213 |
+
|
214 |
+
def build_label_offsets(self):
|
215 |
+
if self.metas is None:
|
216 |
+
return None
|
217 |
+
|
218 |
+
label_offsets = {}
|
219 |
+
offset = 1
|
220 |
+
for meta in self.metas:
|
221 |
+
print(meta)
|
222 |
+
print(label_offsets)
|
223 |
+
label_offsets[meta.source] = offset
|
224 |
+
offset += meta.num_classes
|
225 |
+
|
226 |
+
return label_offsets
|
227 |
+
|
228 |
+
def fetch_blob(self, idx):
|
229 |
+
# image_tsv, text_tsv = self.tsv_file_list[idx]
|
230 |
+
image_tsv = self.image_tsv_file.file_list[idx]
|
231 |
+
text_tsv = self.text_tsv_file.file_list[idx]
|
232 |
+
self.image_tsv_file.blob_storage.fetch_blob(image_tsv)
|
233 |
+
self.text_tsv_file.blob_storage.fetch_blob(text_tsv)
|
234 |
+
|
235 |
+
def get_chunk_sizes(self):
|
236 |
+
return self._chunk_sizes
|
237 |
+
|
238 |
+
def __getitem__(self, index: Union[int, Tuple[int, int]]):
|
239 |
+
if index is None:
|
240 |
+
import torch
|
241 |
+
return torch.tensor([], dtype=torch.float32), \
|
242 |
+
torch.tensor([], dtype=torch.int64), \
|
243 |
+
torch.tensor([], dtype=torch.int64)
|
244 |
+
|
245 |
+
items_image = self.image_tsv_file[index]
|
246 |
+
items_text = self.text_tsv_file[index]
|
247 |
+
|
248 |
+
assert items_text[0] == items_image[0], \
|
249 |
+
'keys do not match for image and text {} vs {}'.format(
|
250 |
+
items_text[0], items_image[0]
|
251 |
+
)
|
252 |
+
|
253 |
+
_, img = self._decode_image(items_image)
|
254 |
+
_, txt, label = self._decode_text(items_text)
|
255 |
+
|
256 |
+
if self.transform:
|
257 |
+
img = self.transform(img)
|
258 |
+
|
259 |
+
tokens = self.tokenize(
|
260 |
+
txt, padding='max_length', truncation=True, max_length=self.context_length,
|
261 |
+
return_tensors='pt'
|
262 |
+
) if self.tokenize else txt
|
263 |
+
|
264 |
+
tokens['input_ids'].squeeze_()
|
265 |
+
tokens['attention_mask'].squeeze_()
|
266 |
+
|
267 |
+
return img, tokens, label
|
268 |
+
|
269 |
+
def _decode_image(self, items: Tuple[str, str]):
|
270 |
+
key = items[0]
|
271 |
+
image = Image.open(BytesIO(base64.b64decode(items[1]))).convert('RGB')
|
272 |
+
|
273 |
+
return key, image
|
274 |
+
|
275 |
+
def _decode_text(self, items: Tuple[str, Union[str, dict]]):
|
276 |
+
key = items[0]
|
277 |
+
text = ''
|
278 |
+
|
279 |
+
if self.text_format != 'json':
|
280 |
+
raise ValueError('Only support json format')
|
281 |
+
|
282 |
+
# Do some reasonable handing of occasionally bad data.
|
283 |
+
try:
|
284 |
+
js = json.loads(items[1])
|
285 |
+
except Exception as e:
|
286 |
+
|
287 |
+
# empty dictionary
|
288 |
+
js = {}
|
289 |
+
|
290 |
+
# Record the data error in the log.
|
291 |
+
logger.info("JSON parsing error on: " + items[1])
|
292 |
+
logger.info(str(e))
|
293 |
+
|
294 |
+
# do not raise the exception
|
295 |
+
# raise e
|
296 |
+
|
297 |
+
# put some text in and continue processing data (do not kill job)
|
298 |
+
sstr = items[1].find("\"")
|
299 |
+
if (sstr < 0):
|
300 |
+
sstr = 0
|
301 |
+
|
302 |
+
estr = items[1][sstr:].find("\"")
|
303 |
+
if (estr < 0):
|
304 |
+
estr = len(items[1])
|
305 |
+
|
306 |
+
text = items[1][sstr:estr]
|
307 |
+
if (len(text) < 2):
|
308 |
+
text = "A picture showing some content."
|
309 |
+
|
310 |
+
label = 0
|
311 |
+
|
312 |
+
if 'captions' in js:
|
313 |
+
captions = js['captions']
|
314 |
+
if isinstance(captions, list):
|
315 |
+
if self.num_captions == 1:
|
316 |
+
text = random.choice(captions)
|
317 |
+
else:
|
318 |
+
text = captions
|
319 |
+
if len(captions) > self.num_captions:
|
320 |
+
text = captions[:self.num_captions]
|
321 |
+
elif isinstance(captions, str):
|
322 |
+
text = captions
|
323 |
+
else:
|
324 |
+
raise ValueError('captions should be str or list')
|
325 |
+
label = 0
|
326 |
+
elif 'tags' in js:
|
327 |
+
text = prompt_engineering(js['tags'])
|
328 |
+
label = 0
|
329 |
+
elif 'task' in js and js['task'] == 'classification':
|
330 |
+
if (self.prompt_engineering):
|
331 |
+
text = prompt_engineering(js['class_name'])
|
332 |
+
else:
|
333 |
+
text = js['class_name']
|
334 |
+
label = js['class_id']
|
335 |
+
|
336 |
+
if (self.label_offsets is not None):
|
337 |
+
if (js['source'] in self.label_offsets):
|
338 |
+
label += self.label_offsets[js['source']]
|
339 |
+
|
340 |
+
if (self.concat_queries):
|
341 |
+
if ('queries' in js) and (len(js['queries']) > 0):
|
342 |
+
q = ''
|
343 |
+
for item in js['queries']:
|
344 |
+
q = q + item + ' '
|
345 |
+
|
346 |
+
text = q + ', ' + text
|
347 |
+
|
348 |
+
return key, text, label
|
349 |
+
|
350 |
+
def __len__(self):
|
351 |
+
return len(self.image_tsv_file)
|
MedImageInsight/ImageDataLoader/tsv_file.py
ADDED
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import gc
|
3 |
+
import os
|
4 |
+
import os.path as op
|
5 |
+
import json
|
6 |
+
from typing import List
|
7 |
+
from .blob_storage import BlobStorage, disk_usage
|
8 |
+
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
|
12 |
+
def generate_lineidx(filein: str, idxout: str) -> None:
|
13 |
+
idxout_tmp = idxout + '.tmp'
|
14 |
+
with open(filein, 'r') as tsvin, open(idxout_tmp, 'w') as tsvout:
|
15 |
+
fsize = os.fstat(tsvin.fileno()).st_size
|
16 |
+
fpos = 0
|
17 |
+
while fpos != fsize:
|
18 |
+
tsvout.write(str(fpos) + "\n")
|
19 |
+
tsvin.readline()
|
20 |
+
fpos = tsvin.tell()
|
21 |
+
os.rename(idxout_tmp, idxout)
|
22 |
+
|
23 |
+
|
24 |
+
def read_to_character(fp, c):
|
25 |
+
result = []
|
26 |
+
while True:
|
27 |
+
s = fp.read(32)
|
28 |
+
assert s != ''
|
29 |
+
if c in s:
|
30 |
+
result.append(s[: s.index(c)])
|
31 |
+
break
|
32 |
+
else:
|
33 |
+
result.append(s)
|
34 |
+
return ''.join(result)
|
35 |
+
|
36 |
+
|
37 |
+
class TSVFile(object):
|
38 |
+
def __init__(self,
|
39 |
+
tsv_file: str,
|
40 |
+
if_generate_lineidx: bool = True,
|
41 |
+
lineidx: str = None,
|
42 |
+
class_selector: List[str] = None,
|
43 |
+
blob_storage: BlobStorage = None):
|
44 |
+
self.tsv_file = tsv_file
|
45 |
+
self.lineidx = op.splitext(tsv_file)[0] + '.lineidx' \
|
46 |
+
if not lineidx else lineidx
|
47 |
+
self.linelist = op.splitext(tsv_file)[0] + '.linelist'
|
48 |
+
self.chunks = op.splitext(tsv_file)[0] + '.chunks'
|
49 |
+
self._fp = None
|
50 |
+
self._lineidx = None
|
51 |
+
self._sample_indices = None
|
52 |
+
self._class_boundaries = None
|
53 |
+
self._class_selector = class_selector
|
54 |
+
self._blob_storage = blob_storage
|
55 |
+
self._len = None
|
56 |
+
# the process always keeps the process which opens the file.
|
57 |
+
# If the pid is not equal to the currrent pid, we will re-open the file.
|
58 |
+
self.pid = None
|
59 |
+
# generate lineidx if not exist
|
60 |
+
if not op.isfile(self.lineidx) and if_generate_lineidx:
|
61 |
+
generate_lineidx(self.tsv_file, self.lineidx)
|
62 |
+
|
63 |
+
def __del__(self):
|
64 |
+
self.gcidx()
|
65 |
+
if self._fp:
|
66 |
+
self._fp.close()
|
67 |
+
# physically remove the tsv file if it is retrieved by BlobStorage
|
68 |
+
if self._blob_storage and 'azcopy' in self.tsv_file and os.path.exists(self.tsv_file):
|
69 |
+
try:
|
70 |
+
original_usage = disk_usage('/')
|
71 |
+
os.remove(self.tsv_file)
|
72 |
+
logger.info("Purged %s (disk usage: %.2f%% => %.2f%%)" %
|
73 |
+
(self.tsv_file, original_usage, disk_usage('/') * 100))
|
74 |
+
except:
|
75 |
+
# Known issue: multiple threads attempting to delete the file will raise a FileNotFound error.
|
76 |
+
# TODO: try Threadling.Lock to better handle the race condition
|
77 |
+
pass
|
78 |
+
|
79 |
+
def __str__(self):
|
80 |
+
return "TSVFile(tsv_file='{}')".format(self.tsv_file)
|
81 |
+
|
82 |
+
def __repr__(self):
|
83 |
+
return str(self)
|
84 |
+
|
85 |
+
def gcidx(self):
|
86 |
+
logger.debug('Run gc collect')
|
87 |
+
self._lineidx = None
|
88 |
+
self._sample_indices = None
|
89 |
+
#self._class_boundaries = None
|
90 |
+
return gc.collect()
|
91 |
+
|
92 |
+
def get_class_boundaries(self):
|
93 |
+
return self._class_boundaries
|
94 |
+
|
95 |
+
def num_rows(self, gcf=False):
|
96 |
+
if (self._len is None):
|
97 |
+
self._ensure_lineidx_loaded()
|
98 |
+
retval = len(self._sample_indices)
|
99 |
+
|
100 |
+
if (gcf):
|
101 |
+
self.gcidx()
|
102 |
+
|
103 |
+
self._len = retval
|
104 |
+
|
105 |
+
return self._len
|
106 |
+
|
107 |
+
def seek(self, idx: int):
|
108 |
+
self._ensure_tsv_opened()
|
109 |
+
self._ensure_lineidx_loaded()
|
110 |
+
try:
|
111 |
+
pos = self._lineidx[self._sample_indices[idx]]
|
112 |
+
except:
|
113 |
+
logger.info('=> {}-{}'.format(self.tsv_file, idx))
|
114 |
+
raise
|
115 |
+
self._fp.seek(pos)
|
116 |
+
return [s.strip() for s in self._fp.readline().split('\t')]
|
117 |
+
|
118 |
+
def seek_first_column(self, idx: int):
|
119 |
+
self._ensure_tsv_opened()
|
120 |
+
self._ensure_lineidx_loaded()
|
121 |
+
pos = self._lineidx[idx]
|
122 |
+
self._fp.seek(pos)
|
123 |
+
return read_to_character(self._fp, '\t')
|
124 |
+
|
125 |
+
def get_key(self, idx: int):
|
126 |
+
return self.seek_first_column(idx)
|
127 |
+
|
128 |
+
def __getitem__(self, index: int):
|
129 |
+
return self.seek(index)
|
130 |
+
|
131 |
+
def __len__(self):
|
132 |
+
return self.num_rows()
|
133 |
+
|
134 |
+
def _ensure_lineidx_loaded(self):
|
135 |
+
if self._lineidx is None:
|
136 |
+
logger.debug('=> loading lineidx: {}'.format(self.lineidx))
|
137 |
+
with open(self.lineidx, 'r') as fp:
|
138 |
+
lines = fp.readlines()
|
139 |
+
lines = [line.strip() for line in lines]
|
140 |
+
self._lineidx = [int(line) for line in lines]
|
141 |
+
|
142 |
+
# read the line list if exists
|
143 |
+
linelist = None
|
144 |
+
if op.isfile(self.linelist):
|
145 |
+
with open(self.linelist, 'r') as fp:
|
146 |
+
linelist = sorted(
|
147 |
+
[
|
148 |
+
int(line.strip())
|
149 |
+
for line in fp.readlines()
|
150 |
+
]
|
151 |
+
)
|
152 |
+
|
153 |
+
if op.isfile(self.chunks):
|
154 |
+
self._sample_indices = []
|
155 |
+
self._class_boundaries = []
|
156 |
+
class_boundaries = json.load(open(self.chunks, 'r'))
|
157 |
+
for class_name, boundary in class_boundaries.items():
|
158 |
+
start = len(self._sample_indices)
|
159 |
+
if class_name in self._class_selector:
|
160 |
+
for idx in range(boundary[0], boundary[1] + 1):
|
161 |
+
# NOTE: potentially slow when linelist is long, try to speed it up
|
162 |
+
if linelist and idx not in linelist:
|
163 |
+
continue
|
164 |
+
self._sample_indices.append(idx)
|
165 |
+
end = len(self._sample_indices)
|
166 |
+
self._class_boundaries.append((start, end))
|
167 |
+
else:
|
168 |
+
if linelist:
|
169 |
+
self._sample_indices = linelist
|
170 |
+
else:
|
171 |
+
self._sample_indices = list(range(len(self._lineidx)))
|
172 |
+
|
173 |
+
def _ensure_tsv_opened(self):
|
174 |
+
if self._fp is None:
|
175 |
+
if self._blob_storage:
|
176 |
+
self._fp = self._blob_storage.open(self.tsv_file)
|
177 |
+
else:
|
178 |
+
self._fp = open(self.tsv_file, 'r')
|
179 |
+
self.pid = os.getpid()
|
180 |
+
|
181 |
+
if self.pid != os.getpid():
|
182 |
+
logger.debug('=> re-open {} because the process id changed'.format(self.tsv_file))
|
183 |
+
self._fp = open(self.tsv_file, 'r')
|
184 |
+
self.pid = os.getpid()
|
185 |
+
|
186 |
+
|
187 |
+
class CompositeTSVFile:
|
188 |
+
def __init__(self,
|
189 |
+
file_list: List[str],
|
190 |
+
root: str = '.',
|
191 |
+
class_selector: List[str] = None,
|
192 |
+
is_train: bool = True,
|
193 |
+
sas_token_path: str = None,
|
194 |
+
azcopy_path: str = None):
|
195 |
+
self.root = root
|
196 |
+
self.tsvs = None
|
197 |
+
self.chunk_sizes = None
|
198 |
+
self.accum_chunk_sizes = None
|
199 |
+
self._class_selector = class_selector
|
200 |
+
self._class_boundaries = None
|
201 |
+
self.initialized = False
|
202 |
+
assert isinstance(file_list, list)
|
203 |
+
self.blob_storage = BlobStorage(is_train, sas_token_path, azcopy_path)
|
204 |
+
self.file_list = self.blob_storage.register_local_tsv_paths(file_list)
|
205 |
+
logger.info('=> Init CompositeTSVFile...')
|
206 |
+
self.initialize()
|
207 |
+
logger.info('=> Init CompositeTSVFile Done...')
|
208 |
+
|
209 |
+
def get_key(self, index: int):
|
210 |
+
idx_source, idx_row = self._calc_chunk_idx_row(index)
|
211 |
+
k = self.tsvs[idx_source].get_key(idx_row)
|
212 |
+
return '_'.join([self.file_list[idx_source], k])
|
213 |
+
|
214 |
+
def get_class_boundaries(self):
|
215 |
+
return self._class_boundaries
|
216 |
+
|
217 |
+
def get_chunk_size(self):
|
218 |
+
return self.chunk_sizes
|
219 |
+
|
220 |
+
def num_rows(self):
|
221 |
+
return sum(self.chunk_sizes)
|
222 |
+
|
223 |
+
def _calc_chunk_idx_row(self, index: int):
|
224 |
+
idx_chunk = 0
|
225 |
+
idx_row = index
|
226 |
+
while index >= self.accum_chunk_sizes[idx_chunk]:
|
227 |
+
idx_chunk += 1
|
228 |
+
idx_row = index - self.accum_chunk_sizes[idx_chunk-1]
|
229 |
+
return idx_chunk, idx_row
|
230 |
+
|
231 |
+
def __getitem__(self, index: int):
|
232 |
+
idx_source, idx_row = self._calc_chunk_idx_row(index)
|
233 |
+
if idx_source not in self.blob_storage:
|
234 |
+
self.blob_storage[idx_source] = TSVFile(
|
235 |
+
op.join(self.root, self.file_list[idx_source]),
|
236 |
+
class_selector=self._class_selector,
|
237 |
+
blob_storage=self.blob_storage,
|
238 |
+
if_generate_lineidx=True
|
239 |
+
)
|
240 |
+
return self.blob_storage[idx_source].seek(idx_row)
|
241 |
+
|
242 |
+
def __len__(self):
|
243 |
+
return sum(self.chunk_sizes)
|
244 |
+
|
245 |
+
def initialize(self):
|
246 |
+
"""
|
247 |
+
this function has to be called in init function if cache_policy is
|
248 |
+
enabled. Thus, let's always call it in init funciton to make it simple.
|
249 |
+
"""
|
250 |
+
if self.initialized:
|
251 |
+
return
|
252 |
+
self.tsvs = [
|
253 |
+
TSVFile(
|
254 |
+
op.join(self.root, f),
|
255 |
+
class_selector=self._class_selector
|
256 |
+
) for f in self.file_list
|
257 |
+
]
|
258 |
+
logger.debug("=> Calculating chunk sizes ...")
|
259 |
+
self.chunk_sizes = [tsv.num_rows(gcf=True) for tsv in self.tsvs]
|
260 |
+
|
261 |
+
self.accum_chunk_sizes = [0]
|
262 |
+
for size in self.chunk_sizes:
|
263 |
+
self.accum_chunk_sizes += [self.accum_chunk_sizes[-1] + size]
|
264 |
+
self.accum_chunk_sizes = self.accum_chunk_sizes[1:]
|
265 |
+
|
266 |
+
if (
|
267 |
+
self._class_selector
|
268 |
+
and all([tsv.get_class_boundaries() for tsv in self.tsvs])
|
269 |
+
):
|
270 |
+
"""
|
271 |
+
Note: When using CompositeTSVFile, make sure that the classes contained in each
|
272 |
+
tsv file do not overlap. Otherwise, the class boundaries won't be correct.
|
273 |
+
"""
|
274 |
+
self._class_boundaries = []
|
275 |
+
offset = 0
|
276 |
+
for tsv in self.tsvs:
|
277 |
+
boundaries = tsv.get_class_boundaries()
|
278 |
+
for bound in boundaries:
|
279 |
+
self._class_boundaries.append((bound[0] + offset, bound[1] + offset))
|
280 |
+
offset += len(tsv)
|
281 |
+
self.initialized = True
|
282 |
+
|
283 |
+
|
284 |
+
def load_list_file(fname: str) -> List[str]:
|
285 |
+
with open(fname, 'r') as fp:
|
286 |
+
lines = fp.readlines()
|
287 |
+
result = [line.strip() for line in lines]
|
288 |
+
if len(result) > 0 and result[-1] == '':
|
289 |
+
result = result[:-1]
|
290 |
+
return result
|
MedImageInsight/ImageDataLoader/zipdata.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path as op
|
2 |
+
from zipfile import ZipFile, BadZipFile
|
3 |
+
import torch.utils.data as data
|
4 |
+
from PIL import Image
|
5 |
+
from io import BytesIO
|
6 |
+
import multiprocessing
|
7 |
+
|
8 |
+
_VALID_IMAGE_TYPES = ['.jpg', '.jpeg', '.tiff', '.bmp', '.png']
|
9 |
+
|
10 |
+
|
11 |
+
class ZipData(data.Dataset):
|
12 |
+
_IGNORE_ATTRS = {'_zip_file'}
|
13 |
+
|
14 |
+
def __init__(self, path, map_file,
|
15 |
+
transform=None, target_transform=None,
|
16 |
+
extensions=None):
|
17 |
+
self._path = path
|
18 |
+
if not extensions:
|
19 |
+
extensions = _VALID_IMAGE_TYPES
|
20 |
+
self._zip_file = ZipFile(path)
|
21 |
+
self.zip_dict = {}
|
22 |
+
self.samples = []
|
23 |
+
self.transform = transform
|
24 |
+
self.target_transform = target_transform
|
25 |
+
self.class_to_idx = {}
|
26 |
+
with open(map_file, 'r') as f:
|
27 |
+
for line in iter(f.readline, ""):
|
28 |
+
line = line.strip()
|
29 |
+
if not line:
|
30 |
+
continue
|
31 |
+
cls_idx = [l for l in line.split('\t') if l]
|
32 |
+
if not cls_idx:
|
33 |
+
continue
|
34 |
+
if (len(cls_idx) < 2):
|
35 |
+
cls_idx = [l for l in line.split(' ') if l]
|
36 |
+
if not cls_idx:
|
37 |
+
continue
|
38 |
+
assert len(cls_idx) >= 2, "invalid line: {}".format(line)
|
39 |
+
idx = int(cls_idx[1])
|
40 |
+
cls = cls_idx[0]
|
41 |
+
del cls_idx
|
42 |
+
at_idx = cls.find('@')
|
43 |
+
assert at_idx >= 0, "invalid class: {}".format(cls)
|
44 |
+
cls = cls[at_idx + 1:]
|
45 |
+
if cls.startswith('/'):
|
46 |
+
# Python ZipFile expects no root
|
47 |
+
cls = cls[1:]
|
48 |
+
assert cls, "invalid class in line {}".format(line)
|
49 |
+
prev_idx = self.class_to_idx.get(cls)
|
50 |
+
assert prev_idx is None or prev_idx == idx, "class: {} idx: {} previously had idx: {}".format(
|
51 |
+
cls, idx, prev_idx
|
52 |
+
)
|
53 |
+
self.class_to_idx[cls] = idx
|
54 |
+
|
55 |
+
for fst in self._zip_file.infolist():
|
56 |
+
fname = fst.filename
|
57 |
+
target = self.class_to_idx.get(fname)
|
58 |
+
if target is None:
|
59 |
+
continue
|
60 |
+
if fname.endswith('/') or fname.startswith('.') or fst.file_size == 0:
|
61 |
+
continue
|
62 |
+
ext = op.splitext(fname)[1].lower()
|
63 |
+
if ext in extensions:
|
64 |
+
self.samples.append((fname, target))
|
65 |
+
assert len(self), "No images found in: {} with map: {}".format(self._path, map_file)
|
66 |
+
|
67 |
+
def __repr__(self):
|
68 |
+
return 'ZipData({}, size={})'.format(self._path, len(self))
|
69 |
+
|
70 |
+
def __getstate__(self):
|
71 |
+
return {
|
72 |
+
key: val if key not in self._IGNORE_ATTRS else None
|
73 |
+
for key, val in self.__dict__.iteritems()
|
74 |
+
}
|
75 |
+
|
76 |
+
def __getitem__(self, index):
|
77 |
+
proc = multiprocessing.current_process()
|
78 |
+
pid = proc.pid # get pid of this process.
|
79 |
+
if pid not in self.zip_dict:
|
80 |
+
self.zip_dict[pid] = ZipFile(self._path)
|
81 |
+
zip_file = self.zip_dict[pid]
|
82 |
+
|
83 |
+
if index >= len(self) or index < 0:
|
84 |
+
raise KeyError("{} is invalid".format(index))
|
85 |
+
path, target = self.samples[index]
|
86 |
+
try:
|
87 |
+
sample = Image.open(BytesIO(zip_file.read(path))).convert('RGB')
|
88 |
+
except BadZipFile:
|
89 |
+
print("bad zip file")
|
90 |
+
return None, None
|
91 |
+
if self.transform is not None:
|
92 |
+
sample = self.transform(sample)
|
93 |
+
if self.target_transform is not None:
|
94 |
+
target = self.target_transform(target)
|
95 |
+
return sample, target
|
96 |
+
|
97 |
+
def __len__(self):
|
98 |
+
return len(self.samples)
|
MedImageInsight/ImageEncoder/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import division
|
3 |
+
from __future__ import print_function
|
4 |
+
|
5 |
+
from .build import build_image_encoder
|
6 |
+
|
7 |
+
from .coswin import *
|
8 |
+
from .davit_v1 import *
|
MedImageInsight/ImageEncoder/build.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .registry import image_encoders
|
2 |
+
from .registry import is_image_encoder
|
3 |
+
|
4 |
+
|
5 |
+
def build_image_encoder(config_encoder, verbose, **kwargs):
|
6 |
+
model_name = config_encoder['NAME']
|
7 |
+
if model_name.startswith('cls_'):
|
8 |
+
model_name = model_name[4:]
|
9 |
+
|
10 |
+
if not is_image_encoder(model_name):
|
11 |
+
raise ValueError(f'Unkown model: {model_name}')
|
12 |
+
|
13 |
+
return image_encoders(model_name)(config_encoder, verbose, **kwargs)
|
MedImageInsight/ImageEncoder/coswin.py
ADDED
@@ -0,0 +1,779 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# CoSwin: Convolutional Swin Transformer
|
3 |
+
# Copyright (c) 2021 Microsoft
|
4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
5 |
+
# Written by Ze Liu
|
6 |
+
# Modified by Bin Xiao
|
7 |
+
# --------------------------------------------------------
|
8 |
+
|
9 |
+
import logging
|
10 |
+
import os
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.utils.checkpoint as checkpoint
|
14 |
+
import numpy as np
|
15 |
+
from einops import rearrange, repeat
|
16 |
+
from einops.layers.torch import Rearrange
|
17 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
18 |
+
|
19 |
+
from .registry import register_image_encoder
|
20 |
+
|
21 |
+
logger = logging.getLogger(__name__)
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
class Mlp(nn.Module):
|
26 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
27 |
+
super().__init__()
|
28 |
+
out_features = out_features or in_features
|
29 |
+
hidden_features = hidden_features or in_features
|
30 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
31 |
+
self.act = act_layer()
|
32 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
33 |
+
self.drop = nn.Dropout(drop)
|
34 |
+
|
35 |
+
def forward(self, x):
|
36 |
+
x = self.fc1(x)
|
37 |
+
x = self.act(x)
|
38 |
+
x = self.drop(x)
|
39 |
+
x = self.fc2(x)
|
40 |
+
x = self.drop(x)
|
41 |
+
return x
|
42 |
+
|
43 |
+
|
44 |
+
def window_partition(x, window_size):
|
45 |
+
"""
|
46 |
+
Args:
|
47 |
+
x: (B, H, W, C)
|
48 |
+
window_size (int): window size
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
windows: (num_windows*B, window_size, window_size, C)
|
52 |
+
"""
|
53 |
+
B, H, W, C = x.shape
|
54 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
55 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
56 |
+
return windows
|
57 |
+
|
58 |
+
|
59 |
+
def window_reverse(windows, window_size, H, W):
|
60 |
+
"""
|
61 |
+
Args:
|
62 |
+
windows: (num_windows*B, window_size, window_size, C)
|
63 |
+
window_size (int): Window size
|
64 |
+
H (int): Height of image
|
65 |
+
W (int): Width of image
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
x: (B, H, W, C)
|
69 |
+
"""
|
70 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
71 |
+
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
72 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
73 |
+
return x
|
74 |
+
|
75 |
+
|
76 |
+
class WindowAttention(nn.Module):
|
77 |
+
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
78 |
+
It supports both of shifted and non-shifted window.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
dim (int): Number of input channels.
|
82 |
+
window_size (tuple[int]): The height and width of the window.
|
83 |
+
num_heads (int): Number of attention heads.
|
84 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
85 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
86 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
87 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
88 |
+
"""
|
89 |
+
|
90 |
+
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
|
91 |
+
|
92 |
+
super().__init__()
|
93 |
+
self.dim = dim
|
94 |
+
self.window_size = window_size # Wh, Ww
|
95 |
+
self.num_heads = num_heads
|
96 |
+
head_dim = dim // num_heads
|
97 |
+
self.scale = qk_scale or head_dim ** -0.5
|
98 |
+
|
99 |
+
# define a parameter table of relative position bias
|
100 |
+
self.relative_position_bias_table = nn.Parameter(
|
101 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
102 |
+
|
103 |
+
# get pair-wise relative position index for each token inside the window
|
104 |
+
coords_h = torch.arange(self.window_size[0])
|
105 |
+
coords_w = torch.arange(self.window_size[1])
|
106 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
107 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
108 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
109 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
110 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
111 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
112 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
113 |
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
114 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
115 |
+
|
116 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
117 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
118 |
+
self.proj = nn.Linear(dim, dim)
|
119 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
120 |
+
|
121 |
+
trunc_normal_(self.relative_position_bias_table, std=.02)
|
122 |
+
self.softmax = nn.Softmax(dim=-1)
|
123 |
+
|
124 |
+
def forward(self, x, mask=None):
|
125 |
+
"""
|
126 |
+
Args:
|
127 |
+
x: input features with shape of (num_windows*B, N, C)
|
128 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
129 |
+
"""
|
130 |
+
B_, N, C = x.shape
|
131 |
+
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
132 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
133 |
+
|
134 |
+
q = q * self.scale
|
135 |
+
attn = (q @ k.transpose(-2, -1))
|
136 |
+
|
137 |
+
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
138 |
+
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
139 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
140 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
141 |
+
|
142 |
+
if mask is not None:
|
143 |
+
nW = mask.shape[0]
|
144 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
145 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
146 |
+
attn = self.softmax(attn)
|
147 |
+
else:
|
148 |
+
attn = self.softmax(attn)
|
149 |
+
|
150 |
+
attn = self.attn_drop(attn)
|
151 |
+
|
152 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
153 |
+
x = self.proj(x)
|
154 |
+
x = self.proj_drop(x)
|
155 |
+
return x
|
156 |
+
|
157 |
+
def extra_repr(self) -> str:
|
158 |
+
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
|
159 |
+
|
160 |
+
def flops(self, N):
|
161 |
+
# calculate flops for 1 window with token length of N
|
162 |
+
flops = 0
|
163 |
+
# qkv = self.qkv(x)
|
164 |
+
flops += N * self.dim * 3 * self.dim
|
165 |
+
# attn = (q @ k.transpose(-2, -1))
|
166 |
+
flops += self.num_heads * N * (self.dim // self.num_heads) * N
|
167 |
+
# x = (attn @ v)
|
168 |
+
flops += self.num_heads * N * N * (self.dim // self.num_heads)
|
169 |
+
# x = self.proj(x)
|
170 |
+
flops += N * self.dim * self.dim
|
171 |
+
return flops
|
172 |
+
|
173 |
+
|
174 |
+
class SwinTransformerBlock(nn.Module):
|
175 |
+
r""" Swin Transformer Block.
|
176 |
+
|
177 |
+
Args:
|
178 |
+
dim (int): Number of input channels.
|
179 |
+
input_resolution (tuple[int]): Input resulotion.
|
180 |
+
num_heads (int): Number of attention heads.
|
181 |
+
window_size (int): Window size.
|
182 |
+
shift_size (int): Shift size for SW-MSA.
|
183 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
184 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
185 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
186 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
187 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
188 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
189 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
190 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
191 |
+
"""
|
192 |
+
|
193 |
+
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
|
194 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
|
195 |
+
act_layer=nn.GELU, norm_layer=nn.LayerNorm, layer_scale=False):
|
196 |
+
super().__init__()
|
197 |
+
self.dim = dim
|
198 |
+
self.input_resolution = input_resolution
|
199 |
+
self.num_heads = num_heads
|
200 |
+
self.window_size = window_size
|
201 |
+
self.shift_size = shift_size
|
202 |
+
self.mlp_ratio = mlp_ratio
|
203 |
+
if min(self.input_resolution) <= self.window_size:
|
204 |
+
# if window size is larger than input resolution, we don't partition windows
|
205 |
+
self.shift_size = 0
|
206 |
+
self.window_size = min(self.input_resolution)
|
207 |
+
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
208 |
+
|
209 |
+
self.norm1 = norm_layer(dim)
|
210 |
+
self.attn = WindowAttention(
|
211 |
+
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
|
212 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
213 |
+
|
214 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
215 |
+
self.norm2 = norm_layer(dim)
|
216 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
217 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
218 |
+
|
219 |
+
if self.shift_size > 0:
|
220 |
+
# calculate attention mask for SW-MSA
|
221 |
+
H, W = self.input_resolution
|
222 |
+
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
223 |
+
h_slices = (slice(0, -self.window_size),
|
224 |
+
slice(-self.window_size, -self.shift_size),
|
225 |
+
slice(-self.shift_size, None))
|
226 |
+
w_slices = (slice(0, -self.window_size),
|
227 |
+
slice(-self.window_size, -self.shift_size),
|
228 |
+
slice(-self.shift_size, None))
|
229 |
+
cnt = 0
|
230 |
+
for h in h_slices:
|
231 |
+
for w in w_slices:
|
232 |
+
img_mask[:, h, w, :] = cnt
|
233 |
+
cnt += 1
|
234 |
+
|
235 |
+
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
236 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
237 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
238 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
239 |
+
else:
|
240 |
+
attn_mask = None
|
241 |
+
|
242 |
+
self.gamma = 1.0
|
243 |
+
if layer_scale:
|
244 |
+
logger.info('=> enable layer scale')
|
245 |
+
self.gamma = nn.Parameter(
|
246 |
+
1e-4*torch.ones(dim), requires_grad=True
|
247 |
+
)
|
248 |
+
|
249 |
+
self.register_buffer("attn_mask", attn_mask)
|
250 |
+
|
251 |
+
def forward(self, x):
|
252 |
+
H, W = self.input_resolution
|
253 |
+
B, L, C = x.shape
|
254 |
+
assert L == H * W, "input feature has wrong size"
|
255 |
+
|
256 |
+
shortcut = x
|
257 |
+
x = self.norm1(x)
|
258 |
+
x = x.view(B, H, W, C)
|
259 |
+
|
260 |
+
# cyclic shift
|
261 |
+
if self.shift_size > 0:
|
262 |
+
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
263 |
+
else:
|
264 |
+
shifted_x = x
|
265 |
+
|
266 |
+
# partition windows
|
267 |
+
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
268 |
+
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
|
269 |
+
|
270 |
+
# W-MSA/SW-MSA
|
271 |
+
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
|
272 |
+
|
273 |
+
# merge windows
|
274 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
275 |
+
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
276 |
+
|
277 |
+
# reverse cyclic shift
|
278 |
+
if self.shift_size > 0:
|
279 |
+
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
280 |
+
else:
|
281 |
+
x = shifted_x
|
282 |
+
x = x.view(B, H * W, C)
|
283 |
+
|
284 |
+
# FFN
|
285 |
+
x = shortcut + self.drop_path(self.gamma*x)
|
286 |
+
x = x + self.drop_path(self.gamma*self.mlp(self.norm2(x)))
|
287 |
+
|
288 |
+
return x
|
289 |
+
|
290 |
+
def extra_repr(self) -> str:
|
291 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
|
292 |
+
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
|
293 |
+
|
294 |
+
def flops(self):
|
295 |
+
flops = 0
|
296 |
+
H, W = self.input_resolution
|
297 |
+
# norm1
|
298 |
+
flops += self.dim * H * W
|
299 |
+
# W-MSA/SW-MSA
|
300 |
+
nW = H * W / self.window_size / self.window_size
|
301 |
+
flops += nW * self.attn.flops(self.window_size * self.window_size)
|
302 |
+
# mlp
|
303 |
+
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
|
304 |
+
# norm2
|
305 |
+
flops += self.dim * H * W
|
306 |
+
return flops
|
307 |
+
|
308 |
+
|
309 |
+
class PatchMerging(nn.Module):
|
310 |
+
r""" Patch Merging Layer.
|
311 |
+
|
312 |
+
Args:
|
313 |
+
input_resolution (tuple[int]): Resolution of input feature.
|
314 |
+
dim (int): Number of input channels.
|
315 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
316 |
+
"""
|
317 |
+
|
318 |
+
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
|
319 |
+
super().__init__()
|
320 |
+
self.input_resolution = input_resolution
|
321 |
+
self.dim = dim
|
322 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
323 |
+
self.norm = norm_layer(4 * dim)
|
324 |
+
|
325 |
+
def forward(self, x):
|
326 |
+
"""
|
327 |
+
x: B, H*W, C
|
328 |
+
"""
|
329 |
+
H, W = self.input_resolution
|
330 |
+
B, L, C = x.shape
|
331 |
+
assert L == H * W, "input feature has wrong size"
|
332 |
+
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
|
333 |
+
|
334 |
+
x = x.view(B, H, W, C)
|
335 |
+
|
336 |
+
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
337 |
+
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
338 |
+
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
339 |
+
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
340 |
+
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
341 |
+
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
342 |
+
|
343 |
+
x = self.norm(x)
|
344 |
+
x = self.reduction(x)
|
345 |
+
|
346 |
+
return x
|
347 |
+
|
348 |
+
def extra_repr(self) -> str:
|
349 |
+
return f"input_resolution={self.input_resolution}, dim={self.dim}"
|
350 |
+
|
351 |
+
def flops(self):
|
352 |
+
H, W = self.input_resolution
|
353 |
+
flops = H * W * self.dim
|
354 |
+
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
|
355 |
+
return flops
|
356 |
+
|
357 |
+
|
358 |
+
class BasicLayer(nn.Module):
|
359 |
+
""" A basic Swin Transformer layer for one stage.
|
360 |
+
|
361 |
+
Args:
|
362 |
+
dim (int): Number of input channels.
|
363 |
+
input_resolution (tuple[int]): Input resolution.
|
364 |
+
depth (int): Number of blocks.
|
365 |
+
num_heads (int): Number of attention heads.
|
366 |
+
window_size (int): Local window size.
|
367 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
368 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
369 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
370 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
371 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
372 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
373 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
374 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
375 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
376 |
+
"""
|
377 |
+
|
378 |
+
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
|
379 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
|
380 |
+
drop_path=0., norm_layer=nn.LayerNorm, downsample=None,
|
381 |
+
use_checkpoint=False, layer_scale=False):
|
382 |
+
|
383 |
+
super().__init__()
|
384 |
+
self.dim = dim
|
385 |
+
self.input_resolution = input_resolution
|
386 |
+
self.depth = depth
|
387 |
+
self.use_checkpoint = use_checkpoint
|
388 |
+
|
389 |
+
# build blocks
|
390 |
+
self.blocks = nn.ModuleList([
|
391 |
+
SwinTransformerBlock(
|
392 |
+
dim=dim, input_resolution=input_resolution,
|
393 |
+
num_heads=num_heads, window_size=window_size,
|
394 |
+
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
395 |
+
mlp_ratio=mlp_ratio,
|
396 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
397 |
+
drop=drop, attn_drop=attn_drop,
|
398 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
399 |
+
norm_layer=norm_layer,
|
400 |
+
layer_scale=layer_scale
|
401 |
+
)
|
402 |
+
for i in range(depth)])
|
403 |
+
|
404 |
+
# patch merging layer
|
405 |
+
if downsample is not None:
|
406 |
+
# self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
|
407 |
+
self.downsample = downsample(
|
408 |
+
input_resolution=input_resolution, patch_size=3, in_chans=dim, embed_dim=dim*2,
|
409 |
+
stride=2, padding=1, norm_layer=norm_layer
|
410 |
+
)
|
411 |
+
else:
|
412 |
+
self.downsample = None
|
413 |
+
|
414 |
+
def forward(self, x):
|
415 |
+
for blk in self.blocks:
|
416 |
+
if self.use_checkpoint:
|
417 |
+
x = checkpoint.checkpoint(blk, x)
|
418 |
+
else:
|
419 |
+
x = blk(x)
|
420 |
+
if self.downsample is not None:
|
421 |
+
x = self.downsample(x)
|
422 |
+
return x
|
423 |
+
|
424 |
+
def extra_repr(self) -> str:
|
425 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
426 |
+
|
427 |
+
def flops(self):
|
428 |
+
flops = 0
|
429 |
+
for blk in self.blocks:
|
430 |
+
flops += blk.flops()
|
431 |
+
if self.downsample is not None:
|
432 |
+
flops += self.downsample.flops()
|
433 |
+
return flops
|
434 |
+
|
435 |
+
|
436 |
+
class PatchEmbed(nn.Module):
|
437 |
+
r""" Image to Patch Embedding
|
438 |
+
|
439 |
+
Args:
|
440 |
+
img_size (int): Image size. Default: 224.
|
441 |
+
patch_size (int): Patch token size. Default: 4.
|
442 |
+
in_chans (int): Number of input image channels. Default: 3.
|
443 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
444 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
445 |
+
"""
|
446 |
+
|
447 |
+
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
448 |
+
super().__init__()
|
449 |
+
img_size = to_2tuple(img_size)
|
450 |
+
patch_size = to_2tuple(patch_size)
|
451 |
+
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
|
452 |
+
self.img_size = img_size
|
453 |
+
self.patch_size = patch_size
|
454 |
+
self.patches_resolution = patches_resolution
|
455 |
+
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
456 |
+
|
457 |
+
self.in_chans = in_chans
|
458 |
+
self.embed_dim = embed_dim
|
459 |
+
|
460 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
461 |
+
if norm_layer is not None:
|
462 |
+
self.norm = norm_layer(embed_dim)
|
463 |
+
else:
|
464 |
+
self.norm = None
|
465 |
+
|
466 |
+
def forward(self, x):
|
467 |
+
B, C, H, W = x.shape
|
468 |
+
# FIXME look at relaxing size constraints
|
469 |
+
assert H == self.img_size[0] and W == self.img_size[1], \
|
470 |
+
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
471 |
+
x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
|
472 |
+
if self.norm is not None:
|
473 |
+
x = self.norm(x)
|
474 |
+
return x
|
475 |
+
|
476 |
+
def flops(self):
|
477 |
+
Ho, Wo = self.patches_resolution
|
478 |
+
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
479 |
+
if self.norm is not None:
|
480 |
+
flops += Ho * Wo * self.embed_dim
|
481 |
+
return flops
|
482 |
+
|
483 |
+
|
484 |
+
class ConvEmbed(nn.Module):
|
485 |
+
""" Image to Patch Embedding
|
486 |
+
"""
|
487 |
+
|
488 |
+
def __init__(
|
489 |
+
self,
|
490 |
+
input_resolution=(224,224),
|
491 |
+
patch_size=7,
|
492 |
+
in_chans=3,
|
493 |
+
embed_dim=64,
|
494 |
+
stride=4,
|
495 |
+
padding=2,
|
496 |
+
norm_layer=None
|
497 |
+
):
|
498 |
+
super().__init__()
|
499 |
+
self.patch_size = patch_size
|
500 |
+
self.input_resolution = input_resolution
|
501 |
+
|
502 |
+
self.proj = nn.Conv2d(
|
503 |
+
in_chans, embed_dim,
|
504 |
+
kernel_size=patch_size,
|
505 |
+
stride=stride,
|
506 |
+
padding=padding
|
507 |
+
)
|
508 |
+
self.norm = norm_layer(embed_dim) if norm_layer else None
|
509 |
+
|
510 |
+
def forward(self, x):
|
511 |
+
if len(x.size()) == 3:
|
512 |
+
x = rearrange(
|
513 |
+
x, 'b (h w) c -> b c h w',
|
514 |
+
h=self.input_resolution[0],
|
515 |
+
w=self.input_resolution[1]
|
516 |
+
)
|
517 |
+
|
518 |
+
x = self.proj(x)
|
519 |
+
|
520 |
+
B, C, H, W = x.shape
|
521 |
+
x = rearrange(x, 'b c h w -> b (h w) c')
|
522 |
+
if self.norm:
|
523 |
+
x = self.norm(x)
|
524 |
+
|
525 |
+
return x
|
526 |
+
|
527 |
+
|
528 |
+
class SwinTransformer(nn.Module):
|
529 |
+
r""" Swin Transformer
|
530 |
+
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
|
531 |
+
https://arxiv.org/pdf/2103.14030
|
532 |
+
|
533 |
+
Args:
|
534 |
+
img_size (int | tuple(int)): Input image size. Default 224
|
535 |
+
patch_size (int | tuple(int)): Patch size. Default: 4
|
536 |
+
in_chans (int): Number of input image channels. Default: 3
|
537 |
+
num_classes (int): Number of classes for classification head. Default: 1000
|
538 |
+
embed_dim (int): Patch embedding dimension. Default: 96
|
539 |
+
depths (tuple(int)): Depth of each Swin Transformer layer.
|
540 |
+
num_heads (tuple(int)): Number of attention heads in different layers.
|
541 |
+
window_size (int): Window size. Default: 7
|
542 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
|
543 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
544 |
+
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
|
545 |
+
drop_rate (float): Dropout rate. Default: 0
|
546 |
+
attn_drop_rate (float): Attention dropout rate. Default: 0
|
547 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.1
|
548 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
549 |
+
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
|
550 |
+
patch_norm (bool): If True, add normalization after patch embedding. Default: True
|
551 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
|
552 |
+
"""
|
553 |
+
|
554 |
+
def __init__(self, img_size=224, patch_size=7, patch_padding=2, patch_stride=4, in_chans=3,
|
555 |
+
num_classes=1000, embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
|
556 |
+
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
|
557 |
+
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
558 |
+
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
559 |
+
use_checkpoint=False, layer_scale=False, **kwargs):
|
560 |
+
super().__init__()
|
561 |
+
|
562 |
+
self.num_classes = num_classes
|
563 |
+
self.num_layers = len(depths)
|
564 |
+
self.embed_dim = embed_dim
|
565 |
+
self.ape = ape
|
566 |
+
self.patch_norm = patch_norm
|
567 |
+
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
|
568 |
+
self.mlp_ratio = mlp_ratio
|
569 |
+
|
570 |
+
# split image into non-overlapping patches
|
571 |
+
# self.patch_embed = PatchEmbed(
|
572 |
+
# img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
|
573 |
+
# norm_layer=norm_layer if self.patch_norm else None)
|
574 |
+
|
575 |
+
self.patch_embed = ConvEmbed(
|
576 |
+
patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, padding=patch_padding,
|
577 |
+
norm_layer=norm_layer if self.patch_norm else None
|
578 |
+
)
|
579 |
+
|
580 |
+
img_size = to_2tuple(img_size)
|
581 |
+
patches_resolution = (
|
582 |
+
int(np.floor(float(img_size[0]+2*patch_padding-patch_size)/patch_stride+1)),
|
583 |
+
int(np.floor(float(img_size[0]+2*patch_padding-patch_size)/patch_stride+1))
|
584 |
+
)
|
585 |
+
num_patches = patches_resolution[0] * patches_resolution[1]
|
586 |
+
# num_patches = self.patch_embed.num_patches
|
587 |
+
# patches_resolution = self.patch_embed.patches_resolution
|
588 |
+
self.patches_resolution = patches_resolution
|
589 |
+
|
590 |
+
# absolute position embedding
|
591 |
+
if self.ape:
|
592 |
+
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
|
593 |
+
trunc_normal_(self.absolute_pos_embed, std=.02)
|
594 |
+
|
595 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
596 |
+
|
597 |
+
# stochastic depth
|
598 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
599 |
+
|
600 |
+
# build layers
|
601 |
+
self.layers = nn.ModuleList()
|
602 |
+
for i_layer in range(self.num_layers):
|
603 |
+
layer = BasicLayer(
|
604 |
+
dim=int(embed_dim * 2 ** i_layer),
|
605 |
+
input_resolution=(
|
606 |
+
patches_resolution[0] // (2 ** i_layer),
|
607 |
+
patches_resolution[1] // (2 ** i_layer)
|
608 |
+
),
|
609 |
+
depth=depths[i_layer],
|
610 |
+
num_heads=num_heads[i_layer],
|
611 |
+
window_size=window_size,
|
612 |
+
mlp_ratio=self.mlp_ratio,
|
613 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
614 |
+
drop=drop_rate, attn_drop=attn_drop_rate,
|
615 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
616 |
+
norm_layer=norm_layer,
|
617 |
+
# downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
|
618 |
+
downsample=ConvEmbed if (i_layer < self.num_layers - 1) else None,
|
619 |
+
use_checkpoint=use_checkpoint,
|
620 |
+
layer_scale=layer_scale
|
621 |
+
)
|
622 |
+
self.layers.append(layer)
|
623 |
+
|
624 |
+
self.norm = norm_layer(self.num_features)
|
625 |
+
self.avgpool = nn.AdaptiveAvgPool1d(1)
|
626 |
+
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
627 |
+
|
628 |
+
self.apply(self._init_weights)
|
629 |
+
|
630 |
+
@property
|
631 |
+
def dim_out(self):
|
632 |
+
return self.num_features
|
633 |
+
|
634 |
+
def _init_weights(self, m):
|
635 |
+
if isinstance(m, nn.Linear):
|
636 |
+
trunc_normal_(m.weight, std=.02)
|
637 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
638 |
+
nn.init.constant_(m.bias, 0)
|
639 |
+
elif isinstance(m, nn.LayerNorm):
|
640 |
+
nn.init.constant_(m.bias, 0)
|
641 |
+
nn.init.constant_(m.weight, 1.0)
|
642 |
+
|
643 |
+
def from_pretrained(self, pretrained='', pretrained_layers=[], verbose=True):
|
644 |
+
if os.path.isfile(pretrained):
|
645 |
+
logging.info(f'=> loading pretrained model {pretrained}')
|
646 |
+
pretrained_dict = torch.load(pretrained, map_location='cpu')
|
647 |
+
|
648 |
+
self.from_state_dict(pretrained_dict, pretrained_layers, verbose)
|
649 |
+
|
650 |
+
def from_state_dict(self, pretrained_dict, pretrained_layers=[], verbose=True):
|
651 |
+
model_dict = self.state_dict()
|
652 |
+
stripped_key = lambda x: x[14:] if x.startswith('image_encoder.') else x
|
653 |
+
|
654 |
+
pretrained_dict = {
|
655 |
+
stripped_key(k): v for k, v in pretrained_dict.items()
|
656 |
+
if stripped_key(k) in model_dict.keys()
|
657 |
+
}
|
658 |
+
need_init_state_dict = {}
|
659 |
+
for k, v in pretrained_dict.items():
|
660 |
+
need_init = (
|
661 |
+
(
|
662 |
+
k.split('.')[0] in pretrained_layers
|
663 |
+
or pretrained_layers[0] == '*'
|
664 |
+
)
|
665 |
+
and 'relative_position_index' not in k
|
666 |
+
and 'attn_mask' not in k
|
667 |
+
)
|
668 |
+
|
669 |
+
if need_init:
|
670 |
+
if verbose:
|
671 |
+
logger.info(f'=> init {k} from pretrained state dict')
|
672 |
+
|
673 |
+
if 'relative_position_bias_table' in k and v.size() != model_dict[k].size():
|
674 |
+
relative_position_bias_table_pretrained = v
|
675 |
+
relative_position_bias_table_current = model_dict[k]
|
676 |
+
L1, nH1 = relative_position_bias_table_pretrained.size()
|
677 |
+
L2, nH2 = relative_position_bias_table_current.size()
|
678 |
+
if nH1 != nH2:
|
679 |
+
logger.info(f"Error in loading {k}, passing")
|
680 |
+
else:
|
681 |
+
if L1 != L2:
|
682 |
+
logger.info(
|
683 |
+
'=> load_pretrained: resized variant: {} to {}'
|
684 |
+
.format((L1, nH1), (L2, nH2))
|
685 |
+
)
|
686 |
+
S1 = int(L1 ** 0.5)
|
687 |
+
S2 = int(L2 ** 0.5)
|
688 |
+
relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate(
|
689 |
+
relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1),
|
690 |
+
size=(S2, S2),
|
691 |
+
mode='bicubic')
|
692 |
+
v = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0)
|
693 |
+
|
694 |
+
if 'absolute_pos_embed' in k and v.size() != model_dict[k].size():
|
695 |
+
absolute_pos_embed_pretrained = v
|
696 |
+
absolute_pos_embed_current = model_dict[k]
|
697 |
+
_, L1, C1 = absolute_pos_embed_pretrained.size()
|
698 |
+
_, L2, C2 = absolute_pos_embed_current.size()
|
699 |
+
if C1 != C1:
|
700 |
+
logger.info(f"Error in loading {k}, passing")
|
701 |
+
else:
|
702 |
+
if L1 != L2:
|
703 |
+
logger.info(
|
704 |
+
'=> load_pretrained: resized variant: {} to {}'
|
705 |
+
.format((1, L1, C1), (1, L2, C2))
|
706 |
+
)
|
707 |
+
S1 = int(L1 ** 0.5)
|
708 |
+
S2 = int(L2 ** 0.5)
|
709 |
+
absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.reshape(-1, S1, S1, C1)
|
710 |
+
absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.permute(0, 3, 1, 2)
|
711 |
+
absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate(
|
712 |
+
absolute_pos_embed_pretrained, size=(S2, S2), mode='bicubic')
|
713 |
+
v = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1).flatten(1, 2)
|
714 |
+
|
715 |
+
need_init_state_dict[k] = v
|
716 |
+
self.load_state_dict(need_init_state_dict, strict=False)
|
717 |
+
|
718 |
+
@torch.jit.ignore
|
719 |
+
def no_weight_decay(self):
|
720 |
+
return {'absolute_pos_embed'}
|
721 |
+
|
722 |
+
@torch.jit.ignore
|
723 |
+
def no_weight_decay_keywords(self):
|
724 |
+
return {'relative_position_bias_table'}
|
725 |
+
|
726 |
+
def forward_features(self, x):
|
727 |
+
x = self.patch_embed(x)
|
728 |
+
if self.ape:
|
729 |
+
x = x + self.absolute_pos_embed
|
730 |
+
x = self.pos_drop(x)
|
731 |
+
|
732 |
+
for layer in self.layers:
|
733 |
+
x = layer(x)
|
734 |
+
|
735 |
+
x = self.norm(x) # B L C
|
736 |
+
x = self.avgpool(x.transpose(1, 2)) # B C 1
|
737 |
+
x = torch.flatten(x, 1)
|
738 |
+
return x
|
739 |
+
|
740 |
+
def forward(self, x):
|
741 |
+
x = self.forward_features(x)
|
742 |
+
x = self.head(x)
|
743 |
+
return x
|
744 |
+
|
745 |
+
|
746 |
+
@register_image_encoder
|
747 |
+
def image_encoder(config_encoder, verbose, **kwargs):
|
748 |
+
spec = config_encoder['SPEC']
|
749 |
+
|
750 |
+
coswin = SwinTransformer(
|
751 |
+
img_size=config_encoder['IMAGE_SIZE'],
|
752 |
+
patch_size=spec['PATCH_SIZE'],
|
753 |
+
patch_padding=spec['PATCH_PADDING'],
|
754 |
+
patch_stride=spec['PATCH_STRIDE'],
|
755 |
+
in_chans=spec['IN_CHANS'],
|
756 |
+
num_classes=0,
|
757 |
+
embed_dim=spec['EMBED_DIM'],
|
758 |
+
depths=spec['DEPTHS'],
|
759 |
+
num_heads=spec['NUM_HEADS'],
|
760 |
+
window_size=spec['WINDOW_SIZE'],
|
761 |
+
mlp_ratio=spec['MLP_RATIO'],
|
762 |
+
qkv_bias=spec['QKV_BIAS'],
|
763 |
+
qk_scale=spec.get('QK_SCALE', None),
|
764 |
+
drop_rate=spec['DROP_RATE'],
|
765 |
+
drop_path_rate=spec['DROP_PATH_RATE'],
|
766 |
+
ape=spec['APE'],
|
767 |
+
patch_norm=spec['PATCH_NORM'],
|
768 |
+
layer_scale=spec.get('LAYER_SCALE', False),
|
769 |
+
use_checkpoint=spec.get('ENABLE_CHECKPOINT', False)
|
770 |
+
)
|
771 |
+
|
772 |
+
if config_encoder['LOAD_PRETRAINED']:
|
773 |
+
coswin.from_pretrained(
|
774 |
+
config_encoder['PRETRAINED'],
|
775 |
+
config_encoder['PRETRAINED_LAYERS'],
|
776 |
+
verbose
|
777 |
+
)
|
778 |
+
|
779 |
+
return coswin
|
MedImageInsight/ImageEncoder/davit_v1.py
ADDED
@@ -0,0 +1,727 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import copy
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import torch.utils.checkpoint as checkpoint
|
9 |
+
from collections import OrderedDict
|
10 |
+
|
11 |
+
from einops import rearrange
|
12 |
+
from timm.models.layers import DropPath, trunc_normal_
|
13 |
+
|
14 |
+
# helper methods
|
15 |
+
from .registry import register_image_encoder
|
16 |
+
|
17 |
+
import mup.init
|
18 |
+
from mup import MuReadout, set_base_shapes
|
19 |
+
|
20 |
+
logger = logging.getLogger(__name__)
|
21 |
+
|
22 |
+
|
23 |
+
class MySequential(nn.Sequential):
|
24 |
+
def forward(self, *inputs):
|
25 |
+
for module in self._modules.values():
|
26 |
+
if type(inputs) == tuple:
|
27 |
+
inputs = module(*inputs)
|
28 |
+
else:
|
29 |
+
inputs = module(inputs)
|
30 |
+
return inputs
|
31 |
+
|
32 |
+
|
33 |
+
class PreNorm(nn.Module):
|
34 |
+
def __init__(self, norm, fn, drop_path=None):
|
35 |
+
super().__init__()
|
36 |
+
self.norm = norm
|
37 |
+
self.fn = fn
|
38 |
+
self.drop_path = drop_path
|
39 |
+
|
40 |
+
def forward(self, x, *args, **kwargs):
|
41 |
+
shortcut = x
|
42 |
+
if self.norm != None:
|
43 |
+
x, size = self.fn(self.norm(x), *args, **kwargs)
|
44 |
+
else:
|
45 |
+
x, size = self.fn(x, *args, **kwargs)
|
46 |
+
|
47 |
+
if self.drop_path:
|
48 |
+
x = self.drop_path(x)
|
49 |
+
|
50 |
+
x = shortcut + x
|
51 |
+
|
52 |
+
return x, size
|
53 |
+
|
54 |
+
|
55 |
+
class Mlp(nn.Module):
|
56 |
+
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
|
57 |
+
"""
|
58 |
+
|
59 |
+
def __init__(
|
60 |
+
self,
|
61 |
+
in_features,
|
62 |
+
hidden_features=None,
|
63 |
+
out_features=None,
|
64 |
+
act_layer=nn.GELU,
|
65 |
+
):
|
66 |
+
super().__init__()
|
67 |
+
out_features = out_features or in_features
|
68 |
+
hidden_features = hidden_features or in_features
|
69 |
+
self.net = nn.Sequential(OrderedDict([
|
70 |
+
("fc1", nn.Linear(in_features, hidden_features)),
|
71 |
+
("act", act_layer()),
|
72 |
+
("fc2", nn.Linear(hidden_features, out_features))
|
73 |
+
]))
|
74 |
+
|
75 |
+
def forward(self, x, size):
|
76 |
+
return self.net(x), size
|
77 |
+
|
78 |
+
|
79 |
+
class DepthWiseConv2d(nn.Module):
|
80 |
+
def __init__(
|
81 |
+
self,
|
82 |
+
dim_in,
|
83 |
+
kernel_size,
|
84 |
+
padding,
|
85 |
+
stride,
|
86 |
+
bias=True,
|
87 |
+
):
|
88 |
+
super().__init__()
|
89 |
+
self.dw = nn.Conv2d(
|
90 |
+
dim_in, dim_in,
|
91 |
+
kernel_size=kernel_size,
|
92 |
+
padding=padding,
|
93 |
+
groups=dim_in,
|
94 |
+
stride=stride,
|
95 |
+
bias=bias
|
96 |
+
)
|
97 |
+
|
98 |
+
def forward(self, x, size):
|
99 |
+
B, N, C = x.shape
|
100 |
+
H, W = size
|
101 |
+
assert N == H * W
|
102 |
+
|
103 |
+
x = self.dw(x.transpose(1, 2).view(B, C, H, W))
|
104 |
+
size = (x.size(-2), x.size(-1))
|
105 |
+
x = x.flatten(2).transpose(1, 2)
|
106 |
+
return x, size
|
107 |
+
|
108 |
+
|
109 |
+
class ConvEmbed(nn.Module):
|
110 |
+
""" Image to Patch Embedding
|
111 |
+
"""
|
112 |
+
|
113 |
+
def __init__(
|
114 |
+
self,
|
115 |
+
patch_size=7,
|
116 |
+
in_chans=3,
|
117 |
+
embed_dim=64,
|
118 |
+
stride=4,
|
119 |
+
padding=2,
|
120 |
+
norm_layer=None,
|
121 |
+
pre_norm=True
|
122 |
+
):
|
123 |
+
super().__init__()
|
124 |
+
self.patch_size = patch_size
|
125 |
+
|
126 |
+
self.proj = nn.Conv2d(
|
127 |
+
in_chans, embed_dim,
|
128 |
+
kernel_size=patch_size,
|
129 |
+
stride=stride,
|
130 |
+
padding=padding
|
131 |
+
)
|
132 |
+
|
133 |
+
dim_norm = in_chans if pre_norm else embed_dim
|
134 |
+
self.norm = norm_layer(dim_norm) if norm_layer else None
|
135 |
+
|
136 |
+
self.pre_norm = pre_norm
|
137 |
+
|
138 |
+
def forward(self, x, size):
|
139 |
+
H, W = size
|
140 |
+
if len(x.size()) == 3:
|
141 |
+
if self.norm and self.pre_norm:
|
142 |
+
x = self.norm(x)
|
143 |
+
x = rearrange(
|
144 |
+
x, 'b (h w) c -> b c h w',
|
145 |
+
h=H, w=W
|
146 |
+
)
|
147 |
+
|
148 |
+
x = self.proj(x)
|
149 |
+
|
150 |
+
_, _, H, W = x.shape
|
151 |
+
x = rearrange(x, 'b c h w -> b (h w) c')
|
152 |
+
if self.norm and not self.pre_norm:
|
153 |
+
x = self.norm(x)
|
154 |
+
|
155 |
+
return x, (H, W)
|
156 |
+
|
157 |
+
|
158 |
+
class ChannelAttention(nn.Module):
|
159 |
+
|
160 |
+
def __init__(self, dim, base_dim, groups=8, base_groups=8, qkv_bias=True, dynamic_scale=True, standparam=True):
|
161 |
+
super().__init__()
|
162 |
+
|
163 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
164 |
+
self.proj = nn.Linear(dim, dim)
|
165 |
+
self.dynamic_scale = dynamic_scale
|
166 |
+
|
167 |
+
self.dim = dim
|
168 |
+
self.groups = groups
|
169 |
+
self.group_dim = dim // groups
|
170 |
+
|
171 |
+
self.base_dim = base_dim
|
172 |
+
self.base_groups = base_groups
|
173 |
+
self.base_group_dim = base_dim // base_groups
|
174 |
+
|
175 |
+
self.group_wm = self.group_dim / self.base_group_dim # Width multiplier for each group.
|
176 |
+
self.standparam = standparam
|
177 |
+
|
178 |
+
def forward(self, x, size):
|
179 |
+
B, N, C = x.shape
|
180 |
+
assert C == self.dim
|
181 |
+
|
182 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.groups, C // self.groups).permute(2, 0, 3, 1, 4)
|
183 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # Shape: [B, groups, N, group_dim].
|
184 |
+
|
185 |
+
scale = N ** -0.5 if self.dynamic_scale else self.dim ** -0.5
|
186 |
+
|
187 |
+
# Change the scaling factor.
|
188 |
+
# Ref: examples/Transformer/model.py in muP.
|
189 |
+
# Note: We consider backward compatiblity and follow https://github.com/microsoft/mup/issues/18.
|
190 |
+
if self.standparam:
|
191 |
+
scale = N ** -0.5 if self.dynamic_scale else self.dim ** -0.5
|
192 |
+
else:
|
193 |
+
assert self.dynamic_scale # Currently only support dynamic scale.
|
194 |
+
scale = N ** -0.5
|
195 |
+
|
196 |
+
q = q * scale
|
197 |
+
attention = q.transpose(-1, -2) @ k
|
198 |
+
attention = attention.softmax(dim=-1)
|
199 |
+
|
200 |
+
if not self.standparam:
|
201 |
+
# Follow https://github.com/microsoft/mup/issues/18.
|
202 |
+
attention = attention / self.group_wm
|
203 |
+
|
204 |
+
x = (attention @ v.transpose(-1, -2)).transpose(-1, -2)
|
205 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
206 |
+
x = self.proj(x)
|
207 |
+
return x, size
|
208 |
+
|
209 |
+
|
210 |
+
class ChannelBlock(nn.Module):
|
211 |
+
|
212 |
+
def __init__(self, dim, base_dim, groups, base_groups, mlp_ratio=4., qkv_bias=True,
|
213 |
+
drop_path_rate=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,
|
214 |
+
conv_at_attn=True, conv_at_ffn=True, dynamic_scale=True, standparam=True):
|
215 |
+
super().__init__()
|
216 |
+
|
217 |
+
drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
218 |
+
|
219 |
+
self.conv1 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_attn else None
|
220 |
+
self.channel_attn = PreNorm(
|
221 |
+
norm_layer(dim),
|
222 |
+
ChannelAttention(dim, base_dim, groups=groups, base_groups=base_groups, qkv_bias=qkv_bias,
|
223 |
+
dynamic_scale=dynamic_scale, standparam=standparam),
|
224 |
+
drop_path
|
225 |
+
)
|
226 |
+
self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_ffn else None
|
227 |
+
self.ffn = PreNorm(
|
228 |
+
norm_layer(dim),
|
229 |
+
Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer),
|
230 |
+
drop_path
|
231 |
+
)
|
232 |
+
|
233 |
+
def forward(self, x, size):
|
234 |
+
if self.conv1:
|
235 |
+
x, size = self.conv1(x, size)
|
236 |
+
x, size = self.channel_attn(x, size)
|
237 |
+
|
238 |
+
if self.conv2:
|
239 |
+
x, size = self.conv2(x, size)
|
240 |
+
x, size = self.ffn(x, size)
|
241 |
+
|
242 |
+
return x, size
|
243 |
+
|
244 |
+
|
245 |
+
def window_partition(x, window_size: int):
|
246 |
+
B, H, W, C = x.shape
|
247 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
248 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
249 |
+
return windows
|
250 |
+
|
251 |
+
|
252 |
+
def window_reverse(windows, window_size: int, H: int, W: int):
|
253 |
+
B = windows.shape[0] // (H * W // window_size // window_size)
|
254 |
+
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
255 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
256 |
+
return x
|
257 |
+
|
258 |
+
|
259 |
+
class WindowAttention(nn.Module):
|
260 |
+
|
261 |
+
def __init__(self, dim, base_dim, num_heads, base_num_heads, window_size, qkv_bias=True, standparam=True):
|
262 |
+
|
263 |
+
super().__init__()
|
264 |
+
|
265 |
+
self.window_size = window_size
|
266 |
+
|
267 |
+
self.dim = dim
|
268 |
+
self.num_heads = num_heads
|
269 |
+
head_dim = dim // num_heads
|
270 |
+
|
271 |
+
self.base_dim = base_dim
|
272 |
+
self.base_num_heads = base_num_heads
|
273 |
+
base_head_dim = base_dim // base_num_heads
|
274 |
+
|
275 |
+
# Change the scaling factor.
|
276 |
+
# Ref: examples/Transformer/model.py in muP.
|
277 |
+
# Note: We consider backward compatiblity and follow https://github.com/microsoft/mup/issues/17.
|
278 |
+
if standparam:
|
279 |
+
scale = float(head_dim) ** -0.5
|
280 |
+
else:
|
281 |
+
# TODO: Here we ensure backward compatibility, which may not be optimal.
|
282 |
+
# We may add an argument called backward_comp. If it is set as False, we use
|
283 |
+
# float(head_dim) ** -1 * math.sqrt(attn_mult)
|
284 |
+
# as in the Transformer example in muP.
|
285 |
+
base_scale = float(base_head_dim) ** -0.5 # The same as scaling in standard parametrization.
|
286 |
+
head_wm = head_dim / base_head_dim # Width multiplier for each head.
|
287 |
+
scale = base_scale / head_wm
|
288 |
+
# scale_1 = (float(base_head_dim) ** 0.5) * (float(head_dim) ** -1) # Equivalent implementation as shown in the muP paper.
|
289 |
+
# assert np.isclose(scale, scale_1)
|
290 |
+
self.scale = scale
|
291 |
+
|
292 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
293 |
+
self.proj = nn.Linear(dim, dim)
|
294 |
+
|
295 |
+
self.softmax = nn.Softmax(dim=-1)
|
296 |
+
|
297 |
+
def forward(self, x, size):
|
298 |
+
|
299 |
+
H, W = size
|
300 |
+
B, L, C = x.shape
|
301 |
+
assert L == H * W, "input feature has wrong size"
|
302 |
+
|
303 |
+
x = x.view(B, H, W, C)
|
304 |
+
|
305 |
+
pad_l = pad_t = 0
|
306 |
+
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
307 |
+
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
308 |
+
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
309 |
+
_, Hp, Wp, _ = x.shape
|
310 |
+
|
311 |
+
x = window_partition(x, self.window_size)
|
312 |
+
x = x.view(-1, self.window_size * self.window_size, C)
|
313 |
+
|
314 |
+
B_, N, C = x.shape
|
315 |
+
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
316 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
317 |
+
|
318 |
+
q = q * self.scale
|
319 |
+
attn = (q @ k.transpose(-2, -1))
|
320 |
+
attn = self.softmax(attn)
|
321 |
+
|
322 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
323 |
+
x = self.proj(x)
|
324 |
+
|
325 |
+
# merge windows
|
326 |
+
x = x.view(
|
327 |
+
-1, self.window_size, self.window_size, C
|
328 |
+
)
|
329 |
+
x = window_reverse(x, self.window_size, Hp, Wp)
|
330 |
+
|
331 |
+
if pad_r > 0 or pad_b > 0:
|
332 |
+
x = x[:, :H, :W, :].contiguous()
|
333 |
+
|
334 |
+
x = x.view(B, H * W, C)
|
335 |
+
|
336 |
+
return x, size
|
337 |
+
|
338 |
+
|
339 |
+
class SpatialBlock(nn.Module):
|
340 |
+
|
341 |
+
def __init__(self, dim, base_dim, num_heads, base_num_heads, window_size,
|
342 |
+
mlp_ratio=4., qkv_bias=True, drop_path_rate=0., act_layer=nn.GELU,
|
343 |
+
norm_layer=nn.LayerNorm, conv_at_attn=True, conv_at_ffn=True, standparam=True):
|
344 |
+
super().__init__()
|
345 |
+
|
346 |
+
drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
347 |
+
|
348 |
+
self.conv1 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_attn else None
|
349 |
+
self.window_attn = PreNorm(
|
350 |
+
norm_layer(dim),
|
351 |
+
WindowAttention(dim, base_dim, num_heads, base_num_heads, window_size, qkv_bias=qkv_bias,
|
352 |
+
standparam=standparam),
|
353 |
+
drop_path
|
354 |
+
)
|
355 |
+
self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_ffn else None
|
356 |
+
self.ffn = PreNorm(
|
357 |
+
norm_layer(dim),
|
358 |
+
Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer),
|
359 |
+
drop_path
|
360 |
+
)
|
361 |
+
|
362 |
+
def forward(self, x, size):
|
363 |
+
if self.conv1:
|
364 |
+
x, size = self.conv1(x, size)
|
365 |
+
x, size = self.window_attn(x, size)
|
366 |
+
|
367 |
+
if self.conv2:
|
368 |
+
x, size = self.conv2(x, size)
|
369 |
+
x, size = self.ffn(x, size)
|
370 |
+
return x, size
|
371 |
+
|
372 |
+
|
373 |
+
class DaViT(nn.Module):
|
374 |
+
""" DaViT: Dual-Attention Transformer
|
375 |
+
|
376 |
+
Args:
|
377 |
+
img_size (int | tuple(int)): Input image size. Default: 224
|
378 |
+
patch_size (int | tuple(int)): Patch size. Default: 4
|
379 |
+
in_chans (int): Number of input image channels. Default: 3
|
380 |
+
num_classes (int): Number of classes for classification head. Default: 1000
|
381 |
+
depths (tuple(int)): Number of spatial and channel blocks in different stages. Default: (1, 1, 3, 1)
|
382 |
+
patch_size (tuple(int)): Patch sizes in different stages. Default: (7, 2, 2, 2)
|
383 |
+
patch_stride (tuple(int)): Patch strides in different stages. Default: (4, 2, 2, 2)
|
384 |
+
patch_padding (tuple(int)): Patch padding sizes in different stages. Default: (3, 0, 0, 0)
|
385 |
+
patch_prenorm (tuple(bool)): Use pre-normalization or not in different stages. Default: (False, False, False, False)
|
386 |
+
embed_dims (tuple(int)): Patch embedding dimension. Default: (64, 128, 192, 256)
|
387 |
+
base_embed_dims (tuple(int)): Patch embedding dimension (base case for muP). Default: (64, 128, 192, 256)
|
388 |
+
num_heads (tuple(int)): Number of attention heads in different layers. Default: (4, 8, 12, 16)
|
389 |
+
base_num_heads (tuple(int)): Number of attention heads in different layers (base case for muP). Default: (4, 8, 12, 16)
|
390 |
+
num_groups (tuple(int)): Number of groups in channel attention in different layers. Default: (3, 6, 12, 24)
|
391 |
+
base_num_groups (tuple(int)): Number of groups in channel attention in different layers (base case for muP). Default: (3, 6, 12, 24)
|
392 |
+
window_size (int): Window size. Default: 7
|
393 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
|
394 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
395 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.1
|
396 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
397 |
+
enable_checkpoint (bool): If True, enabling checkpoint. Default: False
|
398 |
+
conv_at_attn (bool): If True, add convolution layer before attention. Default: True
|
399 |
+
conv_at_ffn (bool): If True, add convolution layer before ffn. Default: True
|
400 |
+
dynamic_scale (bool): If True, scale of channel attention is respect to the number of tokens. Default: True
|
401 |
+
standparam (bool): Use standard parametrization or mu-parametrization. Default: True (i.e., use standard paramerization)
|
402 |
+
"""
|
403 |
+
|
404 |
+
def __init__(
|
405 |
+
self,
|
406 |
+
img_size=224,
|
407 |
+
in_chans=3,
|
408 |
+
num_classes=1000,
|
409 |
+
depths=(1, 1, 3, 1),
|
410 |
+
patch_size=(7, 2, 2, 2),
|
411 |
+
patch_stride=(4, 2, 2, 2),
|
412 |
+
patch_padding=(3, 0, 0, 0),
|
413 |
+
patch_prenorm=(False, False, False, False),
|
414 |
+
embed_dims=(64, 128, 192, 256),
|
415 |
+
base_embed_dims=(64, 128, 192, 256),
|
416 |
+
num_heads=(3, 6, 12, 24),
|
417 |
+
base_num_heads=(3, 6, 12, 24),
|
418 |
+
num_groups=(3, 6, 12, 24),
|
419 |
+
base_num_groups=(3, 6, 12, 24),
|
420 |
+
window_size=7,
|
421 |
+
mlp_ratio=4.,
|
422 |
+
qkv_bias=True,
|
423 |
+
drop_path_rate=0.1,
|
424 |
+
norm_layer=nn.LayerNorm,
|
425 |
+
enable_checkpoint=False,
|
426 |
+
conv_at_attn=True,
|
427 |
+
conv_at_ffn=True,
|
428 |
+
dynamic_scale=True,
|
429 |
+
standparam=True
|
430 |
+
):
|
431 |
+
super().__init__()
|
432 |
+
|
433 |
+
self.num_classes = num_classes
|
434 |
+
self.embed_dims = embed_dims
|
435 |
+
self.num_heads = num_heads
|
436 |
+
self.num_groups = num_groups
|
437 |
+
self.num_stages = len(self.embed_dims)
|
438 |
+
self.enable_checkpoint = enable_checkpoint
|
439 |
+
assert self.num_stages == len(self.num_heads) == len(self.num_groups)
|
440 |
+
|
441 |
+
num_stages = len(embed_dims)
|
442 |
+
self.img_size = img_size
|
443 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths) * 2)]
|
444 |
+
|
445 |
+
depth_offset = 0
|
446 |
+
convs = []
|
447 |
+
blocks = []
|
448 |
+
for i in range(num_stages):
|
449 |
+
conv_embed = ConvEmbed(
|
450 |
+
patch_size=patch_size[i],
|
451 |
+
stride=patch_stride[i],
|
452 |
+
padding=patch_padding[i],
|
453 |
+
in_chans=in_chans if i == 0 else self.embed_dims[i - 1],
|
454 |
+
embed_dim=self.embed_dims[i],
|
455 |
+
norm_layer=norm_layer,
|
456 |
+
pre_norm=patch_prenorm[i]
|
457 |
+
)
|
458 |
+
convs.append(conv_embed)
|
459 |
+
|
460 |
+
logger.info(f'=> Depth offset in stage {i}: {depth_offset}')
|
461 |
+
block = MySequential(
|
462 |
+
*[
|
463 |
+
MySequential(OrderedDict([
|
464 |
+
(
|
465 |
+
'spatial_block', SpatialBlock(
|
466 |
+
embed_dims[i],
|
467 |
+
base_embed_dims[i],
|
468 |
+
num_heads[i],
|
469 |
+
base_num_heads[i],
|
470 |
+
window_size,
|
471 |
+
drop_path_rate=dpr[depth_offset + j * 2],
|
472 |
+
qkv_bias=qkv_bias,
|
473 |
+
mlp_ratio=mlp_ratio,
|
474 |
+
conv_at_attn=conv_at_attn,
|
475 |
+
conv_at_ffn=conv_at_ffn,
|
476 |
+
standparam=standparam
|
477 |
+
)
|
478 |
+
),
|
479 |
+
(
|
480 |
+
'channel_block', ChannelBlock(
|
481 |
+
embed_dims[i],
|
482 |
+
base_embed_dims[i],
|
483 |
+
num_groups[i],
|
484 |
+
base_num_groups[i],
|
485 |
+
drop_path_rate=dpr[depth_offset + j * 2 + 1],
|
486 |
+
qkv_bias=qkv_bias,
|
487 |
+
mlp_ratio=mlp_ratio,
|
488 |
+
conv_at_attn=conv_at_attn,
|
489 |
+
conv_at_ffn=conv_at_ffn,
|
490 |
+
dynamic_scale=dynamic_scale,
|
491 |
+
standparam=standparam
|
492 |
+
)
|
493 |
+
)
|
494 |
+
])) for j in range(depths[i])
|
495 |
+
]
|
496 |
+
)
|
497 |
+
blocks.append(block)
|
498 |
+
depth_offset += depths[i] * 2
|
499 |
+
|
500 |
+
self.convs = nn.ModuleList(convs)
|
501 |
+
self.blocks = nn.ModuleList(blocks)
|
502 |
+
|
503 |
+
self.norms = norm_layer(self.embed_dims[-1])
|
504 |
+
self.avgpool = nn.AdaptiveAvgPool1d(1)
|
505 |
+
|
506 |
+
if standparam:
|
507 |
+
self.head = nn.Linear(self.embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
|
508 |
+
else:
|
509 |
+
self.head = MuReadout(self.embed_dims[-1], num_classes,
|
510 |
+
readout_zero_init=True) # Follow examples/ResNet/resnet.py in muP.
|
511 |
+
|
512 |
+
if torch.cuda.is_available():
|
513 |
+
self.device = torch.device(type="cuda", index=0)
|
514 |
+
else:
|
515 |
+
self.device = torch.device(type="cpu")
|
516 |
+
|
517 |
+
def custom_init_weights(self, use_original_init=True):
|
518 |
+
self.use_original_init = use_original_init
|
519 |
+
logger.info('Custom init: {}'.format('original init' if self.use_original_init else 'muP init'))
|
520 |
+
self.apply(self._custom_init_weights)
|
521 |
+
|
522 |
+
@property
|
523 |
+
def dim_out(self):
|
524 |
+
return self.embed_dims[-1]
|
525 |
+
|
526 |
+
def _custom_init_weights(self, m):
|
527 |
+
# Customized initialization for weights.
|
528 |
+
if self.use_original_init:
|
529 |
+
# Original initialization.
|
530 |
+
# Note: This is not SP init. We do not implement SP init here.
|
531 |
+
custom_trunc_normal_ = trunc_normal_
|
532 |
+
custom_normal_ = nn.init.normal_
|
533 |
+
else:
|
534 |
+
# muP.
|
535 |
+
custom_trunc_normal_ = mup.init.trunc_normal_
|
536 |
+
custom_normal_ = mup.init.normal_
|
537 |
+
|
538 |
+
# These initializations will overwrite the existing inializations from the modules and adjusted by set_base_shapes().
|
539 |
+
if isinstance(m, MuReadout):
|
540 |
+
pass # Note: MuReadout is already zero initialized due to readout_zero_init=True.
|
541 |
+
elif isinstance(m, nn.Linear):
|
542 |
+
custom_trunc_normal_(m.weight, std=0.02)
|
543 |
+
if m.bias is not None:
|
544 |
+
nn.init.constant_(m.bias, 0)
|
545 |
+
elif isinstance(m, nn.Conv2d):
|
546 |
+
custom_normal_(m.weight, std=0.02)
|
547 |
+
for name, _ in m.named_parameters():
|
548 |
+
if name in ['bias']:
|
549 |
+
nn.init.constant_(m.bias, 0)
|
550 |
+
elif isinstance(m, nn.LayerNorm): # Follow P24 Layernorm Weights and Biases.
|
551 |
+
nn.init.constant_(m.weight, 1.0)
|
552 |
+
nn.init.constant_(m.bias, 0)
|
553 |
+
elif isinstance(m, nn.BatchNorm2d): # Follow P24 Layernorm Weights and Biases.
|
554 |
+
nn.init.constant_(m.weight, 1.0)
|
555 |
+
nn.init.constant_(m.bias, 0)
|
556 |
+
|
557 |
+
def _try_remap_keys(self, pretrained_dict):
|
558 |
+
remap_keys = {
|
559 |
+
"conv_embeds": "convs",
|
560 |
+
"main_blocks": "blocks",
|
561 |
+
"0.cpe.0.proj": "spatial_block.conv1.fn.dw",
|
562 |
+
"0.attn": "spatial_block.window_attn.fn",
|
563 |
+
"0.cpe.1.proj": "spatial_block.conv2.fn.dw",
|
564 |
+
"0.mlp": "spatial_block.ffn.fn.net",
|
565 |
+
"1.cpe.0.proj": "channel_block.conv1.fn.dw",
|
566 |
+
"1.attn": "channel_block.channel_attn.fn",
|
567 |
+
"1.cpe.1.proj": "channel_block.conv2.fn.dw",
|
568 |
+
"1.mlp": "channel_block.ffn.fn.net",
|
569 |
+
"0.norm1": "spatial_block.window_attn.norm",
|
570 |
+
"0.norm2": "spatial_block.ffn.norm",
|
571 |
+
"1.norm1": "channel_block.channel_attn.norm",
|
572 |
+
"1.norm2": "channel_block.ffn.norm"
|
573 |
+
}
|
574 |
+
|
575 |
+
full_key_mappings = {}
|
576 |
+
for k in pretrained_dict.keys():
|
577 |
+
old_k = k
|
578 |
+
for remap_key in remap_keys.keys():
|
579 |
+
if remap_key in k:
|
580 |
+
logger.info(f'=> Repace {remap_key} with {remap_keys[remap_key]}')
|
581 |
+
k = k.replace(remap_key, remap_keys[remap_key])
|
582 |
+
|
583 |
+
full_key_mappings[old_k] = k
|
584 |
+
|
585 |
+
return full_key_mappings
|
586 |
+
|
587 |
+
def from_state_dict(self, pretrained_dict, pretrained_layers=[], verbose=True):
|
588 |
+
model_dict = self.state_dict()
|
589 |
+
stripped_key = lambda x: x[14:] if x.startswith('image_encoder.') else x
|
590 |
+
full_key_mappings = self._try_remap_keys(pretrained_dict)
|
591 |
+
|
592 |
+
pretrained_dict = {
|
593 |
+
stripped_key(full_key_mappings[k]): v.to(self.device) for k, v in pretrained_dict.items()
|
594 |
+
if stripped_key(full_key_mappings[k]) in model_dict.keys()
|
595 |
+
}
|
596 |
+
need_init_state_dict = {}
|
597 |
+
for k, v in pretrained_dict.items():
|
598 |
+
need_init = (
|
599 |
+
k.split('.')[0] in pretrained_layers
|
600 |
+
or pretrained_layers[0] == '*'
|
601 |
+
)
|
602 |
+
if need_init:
|
603 |
+
if verbose:
|
604 |
+
logger.info(f'=> init {k} from pretrained state dict')
|
605 |
+
|
606 |
+
need_init_state_dict[k] = v.to(self.device)
|
607 |
+
self.load_state_dict(need_init_state_dict, strict=False)
|
608 |
+
|
609 |
+
def from_pretrained(self, pretrained='', pretrained_layers=[], verbose=True):
|
610 |
+
if os.path.isfile(pretrained):
|
611 |
+
logger.info(f'=> loading pretrained model {pretrained}')
|
612 |
+
pretrained_dict = torch.load(pretrained, map_location='cpu')
|
613 |
+
|
614 |
+
self.from_state_dict(pretrained_dict, pretrained_layers, verbose)
|
615 |
+
|
616 |
+
def forward_features(self, x):
|
617 |
+
input_size = (x.size(2), x.size(3))
|
618 |
+
for conv, block in zip(self.convs, self.blocks):
|
619 |
+
x, input_size = conv(x, input_size)
|
620 |
+
if self.enable_checkpoint:
|
621 |
+
x, input_size = checkpoint.checkpoint(block, x, input_size)
|
622 |
+
else:
|
623 |
+
x, input_size = block(x, input_size)
|
624 |
+
|
625 |
+
x = self.avgpool(x.transpose(1, 2))
|
626 |
+
x = torch.flatten(x, 1)
|
627 |
+
x = self.norms(x)
|
628 |
+
|
629 |
+
return x
|
630 |
+
|
631 |
+
def forward(self, x):
|
632 |
+
x = self.forward_features(x)
|
633 |
+
x = self.head(x)
|
634 |
+
return x
|
635 |
+
|
636 |
+
|
637 |
+
def create_encoder(config_encoder):
|
638 |
+
spec = config_encoder['SPEC']
|
639 |
+
standparam = spec.get('STANDPARAM', True)
|
640 |
+
|
641 |
+
if standparam:
|
642 |
+
# Dummy values for muP parameters.
|
643 |
+
base_embed_dims = spec['DIM_EMBED']
|
644 |
+
base_num_heads = spec['NUM_HEADS']
|
645 |
+
base_num_groups = spec['NUM_GROUPS']
|
646 |
+
else:
|
647 |
+
base_embed_dims = spec['BASE_DIM_EMBED']
|
648 |
+
base_num_heads = spec['BASE_NUM_HEADS']
|
649 |
+
base_num_groups = spec['BASE_NUM_GROUPS']
|
650 |
+
|
651 |
+
davit = DaViT(
|
652 |
+
num_classes=config_encoder['NUM_CLASSES'],
|
653 |
+
depths=spec['DEPTHS'],
|
654 |
+
embed_dims=spec['DIM_EMBED'],
|
655 |
+
base_embed_dims=base_embed_dims,
|
656 |
+
num_heads=spec['NUM_HEADS'],
|
657 |
+
base_num_heads=base_num_heads,
|
658 |
+
num_groups=spec['NUM_GROUPS'],
|
659 |
+
base_num_groups=base_num_groups,
|
660 |
+
patch_size=spec['PATCH_SIZE'],
|
661 |
+
patch_stride=spec['PATCH_STRIDE'],
|
662 |
+
patch_padding=spec['PATCH_PADDING'],
|
663 |
+
patch_prenorm=spec['PATCH_PRENORM'],
|
664 |
+
drop_path_rate=spec['DROP_PATH_RATE'],
|
665 |
+
img_size=config_encoder['IMAGE_SIZE'],
|
666 |
+
window_size=spec.get('WINDOW_SIZE', 7),
|
667 |
+
enable_checkpoint=spec.get('ENABLE_CHECKPOINT', False),
|
668 |
+
conv_at_attn=spec.get('CONV_AT_ATTN', True),
|
669 |
+
conv_at_ffn=spec.get('CONV_AT_FFN', True),
|
670 |
+
dynamic_scale=spec.get('DYNAMIC_SCALE', True),
|
671 |
+
standparam=standparam,
|
672 |
+
)
|
673 |
+
return davit
|
674 |
+
|
675 |
+
|
676 |
+
def create_mup_encoder(config_encoder):
|
677 |
+
def gen_config(config, wm):
|
678 |
+
new_config = copy.deepcopy(config)
|
679 |
+
for name in ['DIM_EMBED', 'NUM_HEADS', 'NUM_GROUPS']:
|
680 |
+
base_name = 'BASE_' + name
|
681 |
+
new_values = [round(base_value * wm) for base_value in
|
682 |
+
config['SPEC'][base_name]] # New value = base value * width multiplier.
|
683 |
+
logger.info(f'config["SPEC"]["{name}"]: {new_config["SPEC"][name]} -> {new_values}')
|
684 |
+
new_config['SPEC'][name] = new_values
|
685 |
+
return new_config
|
686 |
+
|
687 |
+
logger.info('muP: Create models and set base shapes')
|
688 |
+
logger.info('=> Create model')
|
689 |
+
model = create_encoder(config_encoder)
|
690 |
+
|
691 |
+
logger.info('=> Create base model')
|
692 |
+
base_config = gen_config(config_encoder, wm=1.0)
|
693 |
+
base_model = create_encoder(base_config)
|
694 |
+
|
695 |
+
logger.info('=> Create delta model')
|
696 |
+
delta_config = gen_config(config_encoder, wm=2.0)
|
697 |
+
delta_model = create_encoder(delta_config)
|
698 |
+
|
699 |
+
logger.info('=> Set base shapes in model for training')
|
700 |
+
set_base_shapes(model, base=base_model, delta=delta_model)
|
701 |
+
|
702 |
+
return model
|
703 |
+
|
704 |
+
|
705 |
+
@register_image_encoder
|
706 |
+
def image_encoder(config_encoder, verbose, **kwargs):
|
707 |
+
spec = config_encoder['SPEC']
|
708 |
+
standparam = spec.get('STANDPARAM', True)
|
709 |
+
|
710 |
+
if standparam:
|
711 |
+
logger.info('Create model with standard parameterization')
|
712 |
+
model = create_encoder(config_encoder)
|
713 |
+
model.custom_init_weights(use_original_init=True)
|
714 |
+
else:
|
715 |
+
logger.info('Create model with mu parameterization')
|
716 |
+
model = create_mup_encoder(config_encoder)
|
717 |
+
model.custom_init_weights(use_original_init=False)
|
718 |
+
|
719 |
+
logger.info('Load model from pretrained checkpoint')
|
720 |
+
if config_encoder['LOAD_PRETRAINED']:
|
721 |
+
model.from_pretrained(
|
722 |
+
config_encoder['PRETRAINED'],
|
723 |
+
config_encoder['PRETRAINED_LAYERS'],
|
724 |
+
verbose
|
725 |
+
)
|
726 |
+
|
727 |
+
return model
|
MedImageInsight/ImageEncoder/registry.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_image_encoders = {}
|
2 |
+
|
3 |
+
|
4 |
+
def register_image_encoder(fn):
|
5 |
+
module_name_split = fn.__module__.split('.')
|
6 |
+
model_name = module_name_split[-1]
|
7 |
+
|
8 |
+
_image_encoders[model_name] = fn
|
9 |
+
|
10 |
+
return fn
|
11 |
+
|
12 |
+
|
13 |
+
def image_encoders(model_name):
|
14 |
+
return _image_encoders[model_name]
|
15 |
+
|
16 |
+
|
17 |
+
def is_image_encoder(model_name):
|
18 |
+
return model_name in _image_encoders
|
MedImageInsight/LangEncoder/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import division
|
3 |
+
from __future__ import print_function
|
4 |
+
|
5 |
+
from .build import build_lang_encoder
|
6 |
+
from .build import build_tokenizer
|
7 |
+
|
8 |
+
from .transformer import *
|
9 |
+
# from .hf_model import *
|
10 |
+
# from .zcode import *
|
11 |
+
# from .pretrain import *
|
12 |
+
# from .tulrv6 import *
|
13 |
+
# from .t5 import *
|
MedImageInsight/LangEncoder/build.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
|
4 |
+
from transformers import CLIPTokenizer, CLIPTokenizerFast
|
5 |
+
from transformers import AutoTokenizer
|
6 |
+
|
7 |
+
from .registry import lang_encoders
|
8 |
+
from .registry import is_lang_encoder
|
9 |
+
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
|
13 |
+
def build_lang_encoder(config_encoder, tokenizer, verbose, **kwargs):
|
14 |
+
model_name = config_encoder['NAME']
|
15 |
+
|
16 |
+
if model_name.endswith('pretrain'):
|
17 |
+
model_name = 'pretrain'
|
18 |
+
|
19 |
+
if not is_lang_encoder(model_name):
|
20 |
+
raise ValueError(f'Unknown model: {model_name}')
|
21 |
+
|
22 |
+
return lang_encoders(model_name)(config_encoder, tokenizer, verbose, **kwargs)
|
23 |
+
|
24 |
+
|
25 |
+
def post_process_clip(text):
|
26 |
+
text['input_ids'].squeeze_() # torch.Size([1, 77])
|
27 |
+
text['attention_mask'].squeeze_() # torch.Size([1, 77])
|
28 |
+
return text
|
29 |
+
|
30 |
+
|
31 |
+
def build_tokenizer(config_encoder):
|
32 |
+
tokenizer = None
|
33 |
+
os.environ['TOKENIZERS_PARALLELISM'] = 'false' # 'true', avoid hanging
|
34 |
+
|
35 |
+
if config_encoder['TOKENIZER'] == 'clip':
|
36 |
+
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
|
37 |
+
pretrained_tokenizer = config_encoder.get(
|
38 |
+
'PRETRAINED_TOKENIZER', 'openai/clip-vit-base-patch32'
|
39 |
+
)
|
40 |
+
# print(pretrained_tokenizer)
|
41 |
+
tokenizer = CLIPTokenizer.from_pretrained(pretrained_tokenizer)
|
42 |
+
tokenizer.add_special_tokens({'cls_token': tokenizer.eos_token})
|
43 |
+
tokenizer.post_process = post_process_clip
|
44 |
+
elif config_encoder['TOKENIZER'] == 'clip-fast':
|
45 |
+
pretrained_tokenizer = config_encoder.get(
|
46 |
+
'PRETRAINED_TOKENIZER', 'openai/clip-vit-base-patch32'
|
47 |
+
)
|
48 |
+
tokenizer = CLIPTokenizerFast.from_pretrained(pretrained_tokenizer, from_slow=True)
|
49 |
+
tokenizer.post_process = post_process_clip
|
50 |
+
elif config_encoder['TOKENIZER'] == 'zcodepp':
|
51 |
+
from .zcodepp import ZCodeppTokenizer
|
52 |
+
tokenizer = ZCodeppTokenizer(config_encoder)
|
53 |
+
tokenizer.post_process = lambda x: x
|
54 |
+
elif config_encoder['TOKENIZER'] == 'zcode':
|
55 |
+
from transformers import XLMRobertaTokenizer
|
56 |
+
tokenizer = XLMRobertaTokenizer.from_pretrained(config_encoder['PRETRAINED_TOKENIZER'])
|
57 |
+
elif config_encoder['TOKENIZER'] == 'tulrv6':
|
58 |
+
from .modeling_tulrv6 import TULRv6Tokenizer
|
59 |
+
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
60 |
+
pretrained_tokenizer = config_encoder.get(
|
61 |
+
'PRETRAINED_TOKENIZER', 'tulrv6-base'
|
62 |
+
)
|
63 |
+
tokenizer = TULRv6Tokenizer.from_pretrained(pretrained_tokenizer)
|
64 |
+
# tokenizer.post_process = post_process_clip
|
65 |
+
else:
|
66 |
+
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
67 |
+
pretrained_tokenizer = config_encoder.get('PRETRAINED_TOKENIZER', '')
|
68 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
69 |
+
pretrained_tokenizer
|
70 |
+
if pretrained_tokenizer else config_encoder['TOKENIZER']
|
71 |
+
)
|
72 |
+
tokenizer.post_process = post_process_clip
|
73 |
+
|
74 |
+
# Extra configurations.
|
75 |
+
if 'TOKENIZER_CONF' in config_encoder:
|
76 |
+
tokenizer_conf = config_encoder['TOKENIZER_CONF']
|
77 |
+
|
78 |
+
num_pretrained_tokens = len(tokenizer)
|
79 |
+
|
80 |
+
addition_special_tokens_config = tokenizer_conf.get('ADDITIONAL_SPECIAL_TOKENS', None)
|
81 |
+
if addition_special_tokens_config == 'od+cap':
|
82 |
+
# Note: We still keep the additional special tokens from original tokenizer when we add new special tokens.
|
83 |
+
# This is to make sure tokenizer.additional_special_tokens afterwards includes original additional special tokens.
|
84 |
+
special_tokens_dict = {
|
85 |
+
'additional_special_tokens': \
|
86 |
+
tokenizer.additional_special_tokens + \
|
87 |
+
['<od>','</od>','<cap>','</cap>'] + \
|
88 |
+
[f'<loc_{x}>' for x in range(tokenizer_conf.get('NUM_LOCATION_TOKENS', 0))]
|
89 |
+
}
|
90 |
+
tokenizer.add_special_tokens(special_tokens_dict)
|
91 |
+
elif isinstance(addition_special_tokens_config, list):
|
92 |
+
special_tokens_dict = {
|
93 |
+
'additional_special_tokens': \
|
94 |
+
tokenizer.additional_special_tokens + \
|
95 |
+
addition_special_tokens_config + \
|
96 |
+
[f'<loc_{x}>' for x in range(tokenizer_conf.get('NUM_LOCATION_TOKENS', 0))]+
|
97 |
+
[f'<time_{x}>' for x in range(
|
98 |
+
tokenizer_conf.get('NUM_TIME_TOKENS', 0))]
|
99 |
+
}
|
100 |
+
tokenizer.add_special_tokens(special_tokens_dict)
|
101 |
+
elif addition_special_tokens_config is not None:
|
102 |
+
raise ValueError('ADDITIONAL_SPECIAL_TOKENS type error')
|
103 |
+
|
104 |
+
num_current_tokens = len(tokenizer)
|
105 |
+
logger.info(f'{num_pretrained_tokens} tokens in pretrained tokenizer => {num_current_tokens} in current tokenizer')
|
106 |
+
logger.info(f'All special tokens in tokenizer: {tokenizer.additional_special_tokens}')
|
107 |
+
|
108 |
+
return tokenizer
|
MedImageInsight/LangEncoder/registry.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_lang_encoders = {}
|
2 |
+
|
3 |
+
|
4 |
+
def register_lang_encoder(fn):
|
5 |
+
module_name_split = fn.__module__.split('.')
|
6 |
+
model_name = module_name_split[-1]
|
7 |
+
|
8 |
+
_lang_encoders[model_name] = fn
|
9 |
+
|
10 |
+
return fn
|
11 |
+
|
12 |
+
|
13 |
+
def lang_encoders(model_name):
|
14 |
+
return _lang_encoders[model_name]
|
15 |
+
|
16 |
+
|
17 |
+
def is_lang_encoder(model_name):
|
18 |
+
return model_name in _lang_encoders
|
MedImageInsight/LangEncoder/transformer.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
from typing import Tuple, Union
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torch import nn
|
10 |
+
|
11 |
+
from timm.models.layers import DropPath, trunc_normal_
|
12 |
+
|
13 |
+
from .registry import register_lang_encoder
|
14 |
+
from ..Utils import is_main_process
|
15 |
+
from ..Utils import register_norm_module
|
16 |
+
|
17 |
+
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
@register_norm_module
|
21 |
+
class LayerNorm(nn.Module):
|
22 |
+
def __init__(self, hidden_size, eps=1e-12):
|
23 |
+
"""Construct a layernorm module in the TF style (epsilon inside the square root).
|
24 |
+
"""
|
25 |
+
super(LayerNorm, self).__init__()
|
26 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
27 |
+
self.bias = nn.Parameter(torch.zeros(hidden_size))
|
28 |
+
self.variance_epsilon = eps
|
29 |
+
|
30 |
+
def forward(self, x):
|
31 |
+
pdtype = x.dtype
|
32 |
+
x = x.float()
|
33 |
+
u = x.mean(-1, keepdim=True)
|
34 |
+
s = (x - u).pow(2).mean(-1, keepdim=True)
|
35 |
+
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
|
36 |
+
return self.weight * x.to(pdtype) + self.bias
|
37 |
+
|
38 |
+
|
39 |
+
class QuickGELU(nn.Module):
|
40 |
+
def forward(self, x: torch.Tensor):
|
41 |
+
return x * torch.sigmoid(1.702 * x)
|
42 |
+
|
43 |
+
|
44 |
+
class ResidualAttentionBlock(nn.Module):
|
45 |
+
def __init__(self,
|
46 |
+
d_model: int,
|
47 |
+
n_head: int,
|
48 |
+
attn_mask: torch.Tensor = None,
|
49 |
+
drop_path: float = 0.0):
|
50 |
+
super().__init__()
|
51 |
+
|
52 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
53 |
+
self.ln_1 = LayerNorm(d_model)
|
54 |
+
self.mlp = nn.Sequential(OrderedDict([
|
55 |
+
("c_fc", nn.Linear(d_model, d_model * 4)),
|
56 |
+
("gelu", QuickGELU()),
|
57 |
+
("c_proj", nn.Linear(d_model * 4, d_model))
|
58 |
+
]))
|
59 |
+
self.ln_2 = LayerNorm(d_model)
|
60 |
+
self.attn_mask = attn_mask
|
61 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
62 |
+
|
63 |
+
def attention(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None):
|
64 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) \
|
65 |
+
if self.attn_mask is not None else None
|
66 |
+
|
67 |
+
|
68 |
+
return self.attn(
|
69 |
+
x, x, x,
|
70 |
+
key_padding_mask=key_padding_mask,
|
71 |
+
need_weights=False,
|
72 |
+
attn_mask=self.attn_mask
|
73 |
+
)[0]
|
74 |
+
|
75 |
+
def forward(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None):
|
76 |
+
x = x + self.drop_path(self.attention(self.ln_1(x), key_padding_mask=key_padding_mask))
|
77 |
+
x = x + self.drop_path(self.mlp(self.ln_2(x)))
|
78 |
+
return x
|
79 |
+
|
80 |
+
|
81 |
+
class Transformer(nn.Module):
|
82 |
+
def __init__(self,
|
83 |
+
context_length: int,
|
84 |
+
vocab_size: int,
|
85 |
+
width: int,
|
86 |
+
layers: int,
|
87 |
+
heads: int,
|
88 |
+
drop_path: float = 0.0,
|
89 |
+
autogressive: bool =True,
|
90 |
+
key_padding_token: int = 0,
|
91 |
+
):
|
92 |
+
super().__init__()
|
93 |
+
|
94 |
+
self.token_embedding = nn.Embedding(vocab_size, width)
|
95 |
+
self.key_padding_token = key_padding_token
|
96 |
+
|
97 |
+
self.context_length = context_length
|
98 |
+
self.positional_embedding = nn.Parameter(
|
99 |
+
torch.empty(self.context_length, width)
|
100 |
+
)
|
101 |
+
|
102 |
+
self.width = width
|
103 |
+
self.layers = layers
|
104 |
+
self.autogressive = autogressive
|
105 |
+
attn_mask = self.build_attention_mask() if autogressive else None
|
106 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path, layers)] # stochastic depth decay rule
|
107 |
+
self.resblocks = nn.ModuleList(
|
108 |
+
[
|
109 |
+
ResidualAttentionBlock(width, heads, attn_mask, dpr[i])
|
110 |
+
for i in range(layers)
|
111 |
+
]
|
112 |
+
)
|
113 |
+
|
114 |
+
self.ln_final = LayerNorm(width)
|
115 |
+
|
116 |
+
trunc_normal_(self.positional_embedding, std=.02)
|
117 |
+
# nn.init.normal_(self.token_embedding, std=.02)
|
118 |
+
trunc_normal_(self.token_embedding.weight, std=.02)
|
119 |
+
self.apply(self._init_weights)
|
120 |
+
|
121 |
+
@property
|
122 |
+
def dim_out(self):
|
123 |
+
return self.width
|
124 |
+
|
125 |
+
def build_attention_mask(self):
|
126 |
+
# lazily create causal attention mask, with full attention between the vision tokens
|
127 |
+
# pytorch uses additive attention mask; fill with -inf
|
128 |
+
mask = torch.empty(self.context_length, self.context_length)
|
129 |
+
mask.fill_(float("-inf"))
|
130 |
+
mask.triu_(1) # zero out the lower diagonal
|
131 |
+
return mask
|
132 |
+
|
133 |
+
def _init_weights(self, m):
|
134 |
+
if isinstance(m, (nn.Linear, nn.Conv2d)):
|
135 |
+
if is_main_process():
|
136 |
+
logger.info('=> init weight of Linear/Conv2d from trunc norm')
|
137 |
+
trunc_normal_(m.weight, std=0.02)
|
138 |
+
if m.bias is not None:
|
139 |
+
if is_main_process():
|
140 |
+
logger.info('=> init bias of Linear/Conv2d to zeros')
|
141 |
+
nn.init.constant_(m.bias, 0)
|
142 |
+
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
|
143 |
+
nn.init.constant_(m.bias, 0)
|
144 |
+
|
145 |
+
def load_pretrained(self, pretrained='', pretrained_layers=[], verbose=True):
|
146 |
+
if os.path.isfile(pretrained):
|
147 |
+
pretrained_dict = torch.load(pretrained, map_location='cpu')
|
148 |
+
logging.info(f'=> loading pretrained model {pretrained}')
|
149 |
+
model_dict = self.state_dict()
|
150 |
+
pretrained_dict = {
|
151 |
+
k: v for k, v in pretrained_dict.items()
|
152 |
+
if k in model_dict.keys()
|
153 |
+
}
|
154 |
+
need_init_state_dict = {}
|
155 |
+
for k, v in pretrained_dict.items():
|
156 |
+
need_init = (
|
157 |
+
k.split('.')[0] in pretrained_layers
|
158 |
+
or pretrained_layers[0] == '*'
|
159 |
+
)
|
160 |
+
if need_init:
|
161 |
+
if verbose:
|
162 |
+
logging.info(f'=> init {k} from {pretrained}')
|
163 |
+
|
164 |
+
need_init_state_dict[k] = v
|
165 |
+
self.load_state_dict(need_init_state_dict, strict=False)
|
166 |
+
|
167 |
+
|
168 |
+
@torch.jit.ignore
|
169 |
+
def no_weight_decay(self):
|
170 |
+
return {
|
171 |
+
'positional_embedding',
|
172 |
+
'token_embedding',
|
173 |
+
}
|
174 |
+
|
175 |
+
def forward(self, input_ids, attention_mask=None):
|
176 |
+
input_ids = input_ids.to(self.positional_embedding.device, non_blocking=True)
|
177 |
+
# Here we generate key_padding_mask using attention_mask instead of using
|
178 |
+
# a predefined key_padding_token (e.g., 0). This is to solve a discrepancy
|
179 |
+
# between Transformer 4.16.2 and 4.25.1, since Transformers 4.16.2 uses token id 0
|
180 |
+
# for padding but 4.25.1 uses EOS token (token id 49407) for padding.
|
181 |
+
key_padding_mask = (attention_mask == 0) if not self.autogressive else None
|
182 |
+
# a True value indicates that the corresponding key value will be ignored for the purpose of attention
|
183 |
+
x = self.token_embedding(input_ids) # [batch_size, n_ctx, d_model]
|
184 |
+
x = x + self.positional_embedding
|
185 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
186 |
+
for block in self.resblocks:
|
187 |
+
x = block(x, key_padding_mask)
|
188 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
189 |
+
|
190 |
+
x = self.ln_final(x)
|
191 |
+
|
192 |
+
return {'last_hidden_state': x}
|
193 |
+
|
194 |
+
|
195 |
+
@register_lang_encoder
|
196 |
+
def lang_encoder(config_encoder, tokenizer, verbose, **kwargs):
|
197 |
+
transformer = Transformer(
|
198 |
+
context_length=config_encoder['CONTEXT_LENGTH'],
|
199 |
+
vocab_size=tokenizer.vocab_size,
|
200 |
+
width=config_encoder['WIDTH'],
|
201 |
+
layers=config_encoder['LAYERS'],
|
202 |
+
heads=config_encoder['HEADS'],
|
203 |
+
autogressive=config_encoder.get('AUTOGRESSIVE', True),
|
204 |
+
key_padding_token=config_encoder.get('KEY_PADDING_TOKEN', 0),
|
205 |
+
)
|
206 |
+
|
207 |
+
if config_encoder['LOAD_PRETRAINED']:
|
208 |
+
transformer.load_pretrained()
|
209 |
+
|
210 |
+
return transformer
|
MedImageInsight/UniCLModel.py
ADDED
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pathlib
|
2 |
+
import tempfile
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
import copy
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
|
10 |
+
from timm.models.layers import trunc_normal_
|
11 |
+
|
12 |
+
from .ImageEncoder import build_image_encoder
|
13 |
+
from .LangEncoder import build_lang_encoder
|
14 |
+
from .LangEncoder import build_tokenizer
|
15 |
+
|
16 |
+
import mup.init
|
17 |
+
from mup import set_base_shapes
|
18 |
+
|
19 |
+
from safetensors.torch import load_file
|
20 |
+
|
21 |
+
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
|
25 |
+
class UniCLModel(nn.Module):
|
26 |
+
def __init__(self, config: dict):
|
27 |
+
super().__init__()
|
28 |
+
|
29 |
+
self.conf_lang_encoder = config['LANG_ENCODER']
|
30 |
+
self.tokenizer = build_tokenizer(self.conf_lang_encoder)
|
31 |
+
|
32 |
+
self.lang_encoder = build_lang_encoder(self.conf_lang_encoder, self.tokenizer, config['VERBOSE'])
|
33 |
+
|
34 |
+
dim_projection = config['UNICL_MODEL']['DIM_PROJECTION']
|
35 |
+
if hasattr(self.lang_encoder, 'dim_out'):
|
36 |
+
dim_out = self.lang_encoder.dim_out
|
37 |
+
else:
|
38 |
+
with torch.no_grad():
|
39 |
+
dim_out = self.lang_encoder(
|
40 |
+
torch.zeros(1,1).type(torch.LongTensor)
|
41 |
+
)['last_hidden_state'].size(2)
|
42 |
+
|
43 |
+
self.lang_projection = nn.Parameter(torch.empty(dim_out, dim_projection))
|
44 |
+
|
45 |
+
self.conf_image_encoder = config['IMAGE_ENCODER']
|
46 |
+
self.image_encoder = build_image_encoder(self.conf_image_encoder, config['VERBOSE'])
|
47 |
+
|
48 |
+
self.image_projection = nn.Parameter(
|
49 |
+
torch.empty(self.image_encoder.dim_out, dim_projection)
|
50 |
+
)
|
51 |
+
|
52 |
+
self.logit_scale = nn.Parameter(torch.ones([]))
|
53 |
+
|
54 |
+
if torch.cuda.is_available():
|
55 |
+
self.device = torch.device(type="cuda", index=0)
|
56 |
+
else:
|
57 |
+
self.device = torch.device(type="cpu")
|
58 |
+
|
59 |
+
def custom_init_weights(self, use_original_init=True):
|
60 |
+
self.use_original_init = use_original_init
|
61 |
+
logger.info('Custom init: {}'.format('original init' if self.use_original_init else 'muP init'))
|
62 |
+
|
63 |
+
if self.use_original_init:
|
64 |
+
# Original initialization.
|
65 |
+
# Note: This is not SP init. We do not implement SP init here.
|
66 |
+
custom_trunc_normal_ = trunc_normal_ # Note: This should be the same as torch.nn.init.trunc_normal_
|
67 |
+
else:
|
68 |
+
# muP.
|
69 |
+
custom_trunc_normal_ = mup.init.trunc_normal_
|
70 |
+
|
71 |
+
custom_trunc_normal_(self.lang_projection, std=.02)
|
72 |
+
custom_trunc_normal_(self.image_projection, std=.02)
|
73 |
+
|
74 |
+
def _convert_old_weights(self, model_dict):
|
75 |
+
model_dict_updated = {}
|
76 |
+
for k, v in model_dict.items():
|
77 |
+
if k.startswith('visual.'):
|
78 |
+
model_dict_updated['image_encoder.'+k[7:]] = v
|
79 |
+
elif k.startswith('text.'):
|
80 |
+
model_dict_updated['lang_encoder.'+k[5:]] = v
|
81 |
+
elif k == 'vision_projection':
|
82 |
+
model_dict_updated['image_projection'] = v
|
83 |
+
elif k == 'text_projection':
|
84 |
+
model_dict_updated['lang_projection'] = v
|
85 |
+
else:
|
86 |
+
model_dict_updated[k] = v
|
87 |
+
|
88 |
+
return model_dict_updated
|
89 |
+
|
90 |
+
def from_pretrained(self, pretrained='', pretrained_layers=[], verbose=True):
|
91 |
+
if not os.path.isfile(pretrained):
|
92 |
+
logger.warning(f'=> Pretrained model ({pretrained}) is not a file, skip init weight')
|
93 |
+
return
|
94 |
+
|
95 |
+
## Load SafeTensors Version of Pretrained Model
|
96 |
+
pretrained_dict = load_file(pretrained)
|
97 |
+
logger.info(f'=> Loading pretrained model {pretrained}')
|
98 |
+
model_dict = self.state_dict()
|
99 |
+
pretrained_dict = self._convert_old_weights(pretrained_dict)
|
100 |
+
## To ensure cuda is mapped to all weights in the SafeTensors version model
|
101 |
+
pretrained_dict = {
|
102 |
+
k: v.to(self.device) for k, v in pretrained_dict.items()
|
103 |
+
}
|
104 |
+
need_init_state_dict = {}
|
105 |
+
image_encoder_state_dict = {}
|
106 |
+
for k, v in pretrained_dict.items():
|
107 |
+
need_init = (
|
108 |
+
k.split('.')[0] in pretrained_layers
|
109 |
+
or pretrained_layers[0] == '*'
|
110 |
+
)
|
111 |
+
|
112 |
+
if need_init:
|
113 |
+
if k.startswith('image_encoder.'):
|
114 |
+
image_encoder_state_dict[k] = v.to(self.device)
|
115 |
+
else:
|
116 |
+
if verbose:
|
117 |
+
logger.info(f'=> init {k} from {pretrained}')
|
118 |
+
|
119 |
+
if 'positional_embedding' in k and v.size() != model_dict[k].size():
|
120 |
+
positional_embedding_pretrained = v
|
121 |
+
positional_embedding_current = model_dict[k]
|
122 |
+
L1, nH1 = positional_embedding_pretrained.size()
|
123 |
+
L2, nH2 = positional_embedding_current.size()
|
124 |
+
if nH1 != nH2:
|
125 |
+
logger.info(f"Error in loading {k}, passing")
|
126 |
+
else:
|
127 |
+
if L1 != L2:
|
128 |
+
logger.info(
|
129 |
+
'=> load_pretrained: resized variant: {} to {}'
|
130 |
+
.format((L1, nH1), (L2, nH2))
|
131 |
+
)
|
132 |
+
|
133 |
+
posemb = positional_embedding_pretrained.float()
|
134 |
+
posemb_grid = posemb.unsqueeze(dim=0).permute(0, 2, 1)
|
135 |
+
posemb_grid = torch.nn.functional.interpolate(posemb_grid, size=L2, mode='linear')
|
136 |
+
posemb_grid = posemb_grid.permute(0, 2, 1).squeeze(dim=0)
|
137 |
+
v = posemb_grid
|
138 |
+
|
139 |
+
need_init_state_dict[k] = v.to(self.device)
|
140 |
+
self.image_encoder.from_state_dict(image_encoder_state_dict, ['*'], verbose)
|
141 |
+
self.load_state_dict(need_init_state_dict, strict=False)
|
142 |
+
|
143 |
+
@torch.jit.ignore
|
144 |
+
def no_weight_decay(self):
|
145 |
+
no_weight_decay = {'logit_scale'}
|
146 |
+
if hasattr(self.lang_encoder, 'no_weight_decay'):
|
147 |
+
for k in self.lang_encoder.no_weight_decay():
|
148 |
+
no_weight_decay.add('lang_encoder.'+k)
|
149 |
+
|
150 |
+
if hasattr(self.image_encoder, 'no_weight_decay'):
|
151 |
+
for k in self.visual.no_weight_decay():
|
152 |
+
no_weight_decay.add('image_encoder.'+k)
|
153 |
+
|
154 |
+
return no_weight_decay
|
155 |
+
|
156 |
+
@property
|
157 |
+
def dtype(self):
|
158 |
+
return self.logit_scale.dtype
|
159 |
+
|
160 |
+
def encode_image(self, image, norm=True):
|
161 |
+
x = self.image_encoder.forward_features(image)
|
162 |
+
x = x @ self.image_projection
|
163 |
+
|
164 |
+
if norm:
|
165 |
+
x = x / x.norm(dim=-1, keepdim=True)
|
166 |
+
|
167 |
+
return x
|
168 |
+
|
169 |
+
def encode_text(self, text, norm=True):
|
170 |
+
x = self.lang_encoder(**text)
|
171 |
+
x = x['last_hidden_state']
|
172 |
+
|
173 |
+
if self.conf_lang_encoder['TOKENIZER'] == 'clip':
|
174 |
+
x = x[torch.arange(x.size(0)), text['input_ids'].argmax(dim=-1)]
|
175 |
+
else:
|
176 |
+
x = x[:, 0]
|
177 |
+
|
178 |
+
x = x @ self.lang_projection
|
179 |
+
|
180 |
+
if norm:
|
181 |
+
x = x / x.norm(dim=-1, keepdim=True)
|
182 |
+
|
183 |
+
return x
|
184 |
+
|
185 |
+
def forward(self, image, text):
|
186 |
+
features_image = self.encode_image(image)
|
187 |
+
features_text = self.encode_text(text)
|
188 |
+
|
189 |
+
# cosine similarity as logits
|
190 |
+
T = self.logit_scale.exp()
|
191 |
+
|
192 |
+
return features_image, features_text, T
|
193 |
+
|
194 |
+
|
195 |
+
def create_model(config):
|
196 |
+
model = UniCLModel(config)
|
197 |
+
return model
|
198 |
+
|
199 |
+
|
200 |
+
def create_mup_model(config):
|
201 |
+
def gen_config(config, wm):
|
202 |
+
# TODO: Currently only support the case that all UniCL, lang encoder, and image encoder use
|
203 |
+
# mu parameterization. This requirement can be relaxed.
|
204 |
+
assert (not config['UNICL_MODEL']['STANDPARAM']) and \
|
205 |
+
(not config['LANG_ENCODER']['STANDPARAM']) and \
|
206 |
+
(not config['IMAGE_ENCODER']['SPEC']['STANDPARAM'])
|
207 |
+
new_config = copy.deepcopy(config)
|
208 |
+
logger.info(f'Generate config with width mult = {wm}:')
|
209 |
+
|
210 |
+
# Generate config for UniCL head.
|
211 |
+
new_config_section = new_config['UNICL_MODEL']
|
212 |
+
new_config_section['STANDPARAM'] = True # Use standard parameterization when determining base shapes.
|
213 |
+
for name in ['DIM_PROJECTION']:
|
214 |
+
base_name = 'BASE_' + name
|
215 |
+
new_values = round(new_config_section[base_name] * wm) # New value = base value * width multiplier.
|
216 |
+
logger.info(f'config["UNICL_MODEL"]["{name}"]: {new_config_section[name]} -> {new_values}')
|
217 |
+
new_config_section[name] = new_values
|
218 |
+
|
219 |
+
# Generate config for lang encoder.
|
220 |
+
new_config_section = new_config['LANG_ENCODER']
|
221 |
+
new_config_section['STANDPARAM'] = True
|
222 |
+
for name in ['WIDTH', 'HEADS']:
|
223 |
+
base_name = 'BASE_' + name
|
224 |
+
new_values = round(new_config_section[base_name] * wm) # New value = base value * width multiplier.
|
225 |
+
logger.info(f'config["LANG_ENCODER"]["{name}"]: {new_config_section[name]} -> {new_values}')
|
226 |
+
new_config_section[name] = new_values
|
227 |
+
|
228 |
+
# Generate config for image encoder.
|
229 |
+
new_config_section = new_config['IMAGE_ENCODER']['SPEC']
|
230 |
+
new_config_section['STANDPARAM'] = True
|
231 |
+
for name in ['DIM_EMBED', 'NUM_HEADS', 'NUM_GROUPS']:
|
232 |
+
base_name = 'BASE_' + name
|
233 |
+
new_values = [round(base_value * wm) for base_value in new_config_section[base_name]] # New value = base value * width multiplier.
|
234 |
+
logger.info(f'config["IMAGE_ENCODER"]["SPEC"]["{name}"]: {new_config_section[name]} -> {new_values}')
|
235 |
+
new_config_section[name] = new_values
|
236 |
+
|
237 |
+
return new_config
|
238 |
+
|
239 |
+
logger.info('muP: Create models and set base shapes')
|
240 |
+
logger.info('=> Create model')
|
241 |
+
model = create_model(config)
|
242 |
+
# Temporarily remove the lang and image encoders from model to prevent from
|
243 |
+
# setting the base shape for these encoders again.
|
244 |
+
lang_encoder, image_encoder = model.lang_encoder, model.image_encoder
|
245 |
+
model.lang_encoder, model.image_encoder = None, None
|
246 |
+
|
247 |
+
logger.info('=> Create base model')
|
248 |
+
base_config = gen_config(config, wm=1.0)
|
249 |
+
base_model = create_model(base_config)
|
250 |
+
del base_model.lang_encoder, base_model.image_encoder
|
251 |
+
|
252 |
+
logger.info('=> Create delta model')
|
253 |
+
delta_config = gen_config(config, wm=2.0)
|
254 |
+
delta_model = create_model(delta_config)
|
255 |
+
del delta_model.lang_encoder, delta_model.image_encoder
|
256 |
+
|
257 |
+
logger.info('=> Set base shapes in model for training')
|
258 |
+
set_base_shapes(model, base=base_model, delta=delta_model)
|
259 |
+
|
260 |
+
# Restore the lang and image encoders in the model.
|
261 |
+
model.lang_encoder, model.image_encoder = lang_encoder, image_encoder
|
262 |
+
|
263 |
+
return model
|
264 |
+
|
265 |
+
|
266 |
+
def build_unicl_model(config, **kwargs):
|
267 |
+
standparam = config['UNICL_MODEL'].get('STANDPARAM', True)
|
268 |
+
|
269 |
+
if standparam:
|
270 |
+
logger.info('Create model with standard parameterization')
|
271 |
+
model = create_model(config)
|
272 |
+
|
273 |
+
use_original_init = True
|
274 |
+
else:
|
275 |
+
logger.info('Create model with mu parameterization')
|
276 |
+
model = create_mup_model(config)
|
277 |
+
use_original_init = False
|
278 |
+
|
279 |
+
# Initialize other parameters.
|
280 |
+
model.custom_init_weights(use_original_init=use_original_init)
|
281 |
+
|
282 |
+
if config['UNICL_MODEL']['LOAD_PRETRAINED']:
|
283 |
+
pretrained_path = config['UNICL_MODEL']['PRETRAINED']
|
284 |
+
from .Distributed.Utils import is_valid_url, download_file
|
285 |
+
if is_valid_url(pretrained_path):
|
286 |
+
with tempfile.TemporaryDirectory() as tmp_path:
|
287 |
+
file_local_path = pathlib.Path(tmp_path) / 'base_model.pt'
|
288 |
+
download_file(pretrained_path, file_local_path)
|
289 |
+
model.from_pretrained(str(file_local_path), config['UNICL_MODEL']['PRETRAINED_LAYERS'], config['VERBOSE'])
|
290 |
+
else:
|
291 |
+
model.from_pretrained(pretrained_path, config['UNICL_MODEL']['PRETRAINED_LAYERS'], config['VERBOSE'])
|
292 |
+
|
293 |
+
return model
|
MedImageInsight/Utils/Arguments.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
import re
|
6 |
+
import yaml
|
7 |
+
|
8 |
+
logger = logging.getLogger(__name__)
|
9 |
+
|
10 |
+
|
11 |
+
def add_env_parser_to_yaml():
|
12 |
+
"""
|
13 |
+
Adding ability of resolving environment variables to the yaml SafeLoader.
|
14 |
+
Environment variables in the form of "${<env_var_name>}" can be resolved as strings.
|
15 |
+
If the <env_var_name> is not in the env, <env_var_name> itself would be used.
|
16 |
+
|
17 |
+
E.g.:
|
18 |
+
config:
|
19 |
+
username: admin
|
20 |
+
password: ${SERVICE_PASSWORD}
|
21 |
+
service: https://${SERVICE_HOST}/service
|
22 |
+
"""
|
23 |
+
loader = yaml.SafeLoader
|
24 |
+
env_pattern = re.compile(r".*?\${(.*?)}.*?")
|
25 |
+
|
26 |
+
def env_constructor(loader, node):
|
27 |
+
value = loader.construct_scalar(node)
|
28 |
+
for group in env_pattern.findall(value):
|
29 |
+
value = value.replace(f"${{{group}}}", os.environ.get(group, group))
|
30 |
+
return value
|
31 |
+
|
32 |
+
yaml.add_implicit_resolver("!ENV", env_pattern, Loader=loader)
|
33 |
+
yaml.add_constructor("!ENV", env_constructor, Loader=loader)
|
34 |
+
|
35 |
+
|
36 |
+
def load_config_dict_to_opt(opt, config_dict, splitter='.', log_new=False):
|
37 |
+
"""
|
38 |
+
Load the key, value pairs from config_dict to opt, overriding existing values in opt
|
39 |
+
if there is any.
|
40 |
+
"""
|
41 |
+
if not isinstance(config_dict, dict):
|
42 |
+
raise TypeError("Config must be a Python dictionary")
|
43 |
+
for k, v in config_dict.items():
|
44 |
+
k_parts = k.split(splitter)
|
45 |
+
pointer = opt
|
46 |
+
for k_part in k_parts[:-1]:
|
47 |
+
if '[' in k_part and ']' in k_part:
|
48 |
+
# for the format "a.b[0][1].c: d"
|
49 |
+
k_part_splits = k_part.split('[')
|
50 |
+
k_part = k_part_splits.pop(0)
|
51 |
+
pointer = pointer[k_part]
|
52 |
+
for i in k_part_splits:
|
53 |
+
assert i[-1] == ']'
|
54 |
+
pointer = pointer[int(i[:-1])]
|
55 |
+
else:
|
56 |
+
if k_part not in pointer:
|
57 |
+
pointer[k_part] = {}
|
58 |
+
pointer = pointer[k_part]
|
59 |
+
assert isinstance(pointer, dict), "Overriding key needs to be inside a Python dict."
|
60 |
+
if '[' in k_parts[-1] and ']' in k_parts[-1]:
|
61 |
+
k_part_splits = k_parts[-1].split('[')
|
62 |
+
k_part = k_part_splits.pop(0)
|
63 |
+
pointer = pointer[k_part]
|
64 |
+
for i in k_part_splits[:-1]:
|
65 |
+
assert i[-1] == ']'
|
66 |
+
pointer = pointer[int(i[:-1])]
|
67 |
+
assert k_part_splits[-1][-1] == ']'
|
68 |
+
ori_value = pointer[int(k_part_splits[-1][:-1])]
|
69 |
+
pointer[int(k_part_splits[-1][:-1])] = v
|
70 |
+
else:
|
71 |
+
ori_value = pointer.get(k_parts[-1])
|
72 |
+
pointer[k_parts[-1]] = v
|
73 |
+
if ori_value:
|
74 |
+
logger.warning(f"Overrided {k} from {ori_value} to {v}")
|
75 |
+
elif log_new:
|
76 |
+
logger.warning(f"Added {k}: {v}")
|
77 |
+
|
78 |
+
|
79 |
+
def load_opt_from_config_files(conf_files):
|
80 |
+
"""
|
81 |
+
Load opt from the config files, settings in later files can override those in previous files.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
conf_files (list): a list of config file paths
|
85 |
+
|
86 |
+
Returns:
|
87 |
+
dict: a dictionary of opt settings
|
88 |
+
"""
|
89 |
+
opt = {}
|
90 |
+
for conf_file in conf_files:
|
91 |
+
with open(conf_file, encoding='utf-8') as f:
|
92 |
+
# config_dict = yaml.safe_load(f)
|
93 |
+
config_dict = yaml.unsafe_load(f)
|
94 |
+
|
95 |
+
load_config_dict_to_opt(opt, config_dict)
|
96 |
+
|
97 |
+
return opt
|
98 |
+
|
99 |
+
|
100 |
+
def load_opt_command(args):
|
101 |
+
parser = argparse.ArgumentParser(description='MainzTrain: Pretrain or fine-tune models for NLP tasks.')
|
102 |
+
parser.add_argument('command', help='Command: train/evaluate/train-and-evaluate')
|
103 |
+
parser.add_argument('--conf_files', nargs='+', required=True, help='Path(s) to the MainzTrain config file(s).')
|
104 |
+
parser.add_argument('--user_dir', help='Path to the user defined module for tasks (models, criteria), optimizers, and lr schedulers.')
|
105 |
+
parser.add_argument('--config_overrides', nargs='*', help='Override parameters on config with a json style string, e.g. {"<PARAM_NAME_1>": <PARAM_VALUE_1>, "<PARAM_GROUP_2>.<PARAM_SUBGROUP_2>.<PARAM_2>": <PARAM_VALUE_2>}. A key with "." updates the object in the corresponding nested dict. Remember to escape " in command line.')
|
106 |
+
|
107 |
+
cmdline_args = parser.parse_args() if not args else parser.parse_args(args)
|
108 |
+
|
109 |
+
add_env_parser_to_yaml()
|
110 |
+
opt = load_opt_from_config_files(cmdline_args.conf_files)
|
111 |
+
|
112 |
+
if cmdline_args.config_overrides:
|
113 |
+
config_overrides_string = ' '.join(cmdline_args.config_overrides)
|
114 |
+
config_overrides_string = os.path.expandvars(config_overrides_string)
|
115 |
+
logger.warning(f"Command line config overrides: {config_overrides_string}")
|
116 |
+
config_dict = yaml.safe_load(config_overrides_string)
|
117 |
+
load_config_dict_to_opt(opt, config_dict)
|
118 |
+
|
119 |
+
# combine cmdline_args into opt dictionary
|
120 |
+
for key, val in cmdline_args.__dict__.items():
|
121 |
+
if val is not None:
|
122 |
+
opt[key] = val
|
123 |
+
|
124 |
+
return opt, cmdline_args
|
125 |
+
|
126 |
+
|
127 |
+
def save_opt_to_json(opt, conf_file):
|
128 |
+
with open(conf_file, 'w', encoding='utf-8') as f:
|
129 |
+
json.dump(opt, f, indent=4)
|
130 |
+
|
131 |
+
|
132 |
+
def save_opt_to_yaml(opt, conf_file):
|
133 |
+
with open(conf_file, 'w', encoding='utf-8') as f:
|
134 |
+
yaml.dump(opt, f)
|
MedImageInsight/Utils/GeneraUtils.py
ADDED
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import logging
|
3 |
+
import copy
|
4 |
+
import itertools
|
5 |
+
import random
|
6 |
+
from collections.abc import Iterable, Iterator
|
7 |
+
import torch
|
8 |
+
from torch._C import default_generator
|
9 |
+
import torch.distributed as dist
|
10 |
+
import time
|
11 |
+
from functools import wraps, partial
|
12 |
+
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
|
16 |
+
class ObjectView(object):
|
17 |
+
def __init__(self, d):
|
18 |
+
self.__dict__ = d
|
19 |
+
|
20 |
+
|
21 |
+
class AverageMeter(object):
|
22 |
+
"""Computes and stores the average and current value."""
|
23 |
+
def __init__(self):
|
24 |
+
self.reset()
|
25 |
+
|
26 |
+
def reset(self):
|
27 |
+
self.val = 0
|
28 |
+
self.avg = 0
|
29 |
+
self.sum = 0
|
30 |
+
self.count = 0
|
31 |
+
|
32 |
+
def update(self, val, n=1, decay=0):
|
33 |
+
self.val = val
|
34 |
+
if decay:
|
35 |
+
alpha = math.exp(-n / decay) # exponential decay over 100 updates
|
36 |
+
self.sum = alpha * self.sum + (1 - alpha) * val * n
|
37 |
+
self.count = alpha * self.count + (1 - alpha) * n
|
38 |
+
else:
|
39 |
+
self.sum += val * n
|
40 |
+
self.count += n
|
41 |
+
self.avg = self.sum / self.count
|
42 |
+
|
43 |
+
def getstate(self):
|
44 |
+
return {'val': self.val,
|
45 |
+
'avg': self.avg,
|
46 |
+
'sum': self.sum,
|
47 |
+
'count': self.count}
|
48 |
+
|
49 |
+
def setstate(self, state):
|
50 |
+
self.val = state['val']
|
51 |
+
self.avg = state['avg']
|
52 |
+
self.sum = state['sum']
|
53 |
+
self.count = state['count']
|
54 |
+
|
55 |
+
|
56 |
+
def move_batch_to_device(batch, device):
|
57 |
+
"""
|
58 |
+
Move the batch to the device.
|
59 |
+
It should be called before feeding the batch to the model.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
batch (torch.tensor or container of torch.tensor): input batch
|
63 |
+
device (torch.device): device to move the batch to
|
64 |
+
Returns:
|
65 |
+
return_batch: same type as the input batch with internal tensors moved to device
|
66 |
+
"""
|
67 |
+
if torch.is_tensor(batch):
|
68 |
+
return_batch = batch.to(device)
|
69 |
+
elif isinstance(batch, list):
|
70 |
+
return_batch = [move_batch_to_device(t, device) for t in batch]
|
71 |
+
elif isinstance(batch, tuple):
|
72 |
+
return_batch = tuple(move_batch_to_device(t, device) for t in batch)
|
73 |
+
elif isinstance(batch, dict):
|
74 |
+
return_batch = {}
|
75 |
+
for k in batch:
|
76 |
+
return_batch[k] = move_batch_to_device(batch[k], device)
|
77 |
+
else:
|
78 |
+
logger.debug(f"Can not move type {type(batch)} to device. Skipping it in the batch.")
|
79 |
+
return_batch = batch
|
80 |
+
|
81 |
+
return return_batch
|
82 |
+
|
83 |
+
|
84 |
+
def cast_batch_to_dtype(batch, dtype):
|
85 |
+
"""
|
86 |
+
Cast the float32 tensors in a batch to a specified torch dtype.
|
87 |
+
It should be called before feeding the batch to the FP16 DeepSpeed model.
|
88 |
+
|
89 |
+
Args:
|
90 |
+
batch (torch.tensor or container of torch.tensor): input batch
|
91 |
+
Returns:
|
92 |
+
return_batch: same type as the input batch with internal float32 tensors casted to the specified dtype.
|
93 |
+
"""
|
94 |
+
if torch.is_tensor(batch):
|
95 |
+
if torch.is_floating_point(batch):
|
96 |
+
return_batch = batch.to(dtype)
|
97 |
+
else:
|
98 |
+
return_batch = batch
|
99 |
+
elif isinstance(batch, list):
|
100 |
+
return_batch = [cast_batch_to_dtype(t, dtype) for t in batch]
|
101 |
+
elif isinstance(batch, tuple):
|
102 |
+
return_batch = tuple(cast_batch_to_dtype(t, dtype) for t in batch)
|
103 |
+
elif isinstance(batch, dict):
|
104 |
+
return_batch = {}
|
105 |
+
for k in batch:
|
106 |
+
return_batch[k] = cast_batch_to_dtype(batch[k], dtype)
|
107 |
+
else:
|
108 |
+
logger.debug(f"Can not cast type {type(batch)} to {dtype}. Skipping it in the batch.")
|
109 |
+
return_batch = batch
|
110 |
+
|
111 |
+
return return_batch
|
112 |
+
|
113 |
+
|
114 |
+
def cast_batch_to_half(batch):
|
115 |
+
"""
|
116 |
+
Cast the float32 tensors in a batch to float16.
|
117 |
+
It should be called before feeding the batch to the FP16 DeepSpeed model.
|
118 |
+
|
119 |
+
Args:
|
120 |
+
batch (torch.tensor or container of torch.tensor): input batch
|
121 |
+
Returns:
|
122 |
+
return_batch: same type as the input batch with internal float32 tensors casted to float16
|
123 |
+
"""
|
124 |
+
return cast_batch_to_dtype(batch, torch.float16)
|
125 |
+
|
126 |
+
|
127 |
+
def cast_batch_to_bf16(batch):
|
128 |
+
"""
|
129 |
+
Cast the float32 tensors in a batch to bfloat16.
|
130 |
+
It should be called before feeding the batch to the FP16 DeepSpeed model.
|
131 |
+
|
132 |
+
Args:
|
133 |
+
batch (torch.tensor or container of torch.tensor): input batch
|
134 |
+
Returns:
|
135 |
+
return_batch: same type as the input batch with internal float32 tensors casted to bfloat16
|
136 |
+
"""
|
137 |
+
return cast_batch_to_dtype(batch, torch.bfloat16)
|
138 |
+
|
139 |
+
|
140 |
+
# copied from MainzSpeech/moe_tools
|
141 |
+
def peek_first_item_from_iterator(it):
|
142 |
+
# extract first item from iterator
|
143 |
+
first_item = next(it)
|
144 |
+
# create iterator with the first item added back in
|
145 |
+
new_it = itertools.chain([copy.deepcopy(first_item)], it)
|
146 |
+
return first_item, new_it
|
147 |
+
|
148 |
+
|
149 |
+
# copied from MainzSpeech/moe_tools
|
150 |
+
def generate_dummy_batch(it):
|
151 |
+
"""
|
152 |
+
Generates a dummy batch by peeking at given iterable or iterator on rank 0,
|
153 |
+
then broadcast dummy_batch to all other ranks.
|
154 |
+
"""
|
155 |
+
from mpi4py import MPI
|
156 |
+
assert isinstance(it, Iterable) or isinstance(it, Iterator)
|
157 |
+
if isinstance(it, Iterable):
|
158 |
+
it = iter(it)
|
159 |
+
if MPI.COMM_WORLD.Get_rank() == 0:
|
160 |
+
dummy_batch, it = peek_first_item_from_iterator(it)
|
161 |
+
else:
|
162 |
+
dummy_batch = None
|
163 |
+
dummy_batch = MPI.COMM_WORLD.bcast(dummy_batch, root=0)
|
164 |
+
assert dummy_batch is not None
|
165 |
+
return dummy_batch, it
|
166 |
+
|
167 |
+
|
168 |
+
def retry_on_failure(func=None, *, max_retries=3, on_error_func=None, on_retry_func=None, raise_err_func=None, sleep_time=30, error_types=(Exception,)):
|
169 |
+
"""
|
170 |
+
Decorator utility to retry a function, this decorator must be used without arguments (@retry_on_failure) or with all named arguments (@retry_on_failure(max_retries=10)).
|
171 |
+
Args:
|
172 |
+
max_retries (int): The number of retries to perform, in addition to the initial retry. Defaults to 3.
|
173 |
+
sleep_time (int): The time in seconds to wait before the next retry. Defaults to 30.
|
174 |
+
error_types (Tuple[type]): a tuple of exception types which are used to except any error being retried, if the exception that is thrown is not an instance of one of these types, the function is not retried. Defaults to (Exception,) which covers all exceptions.
|
175 |
+
on_retry_func (callable(num_retries)): A function with a single argument, the number of retries done so far. This function is called just before any retry. Defaults to a function logging `num_retries`.
|
176 |
+
on_error_func (callable(num_retries)): A function with a single argument, the number of retries done in total. This function is called after `max_retries` has been tried. Defaults to a function logging `num_retries`.
|
177 |
+
raise_err_func (callable(err)): A function with a single argument, the exception that was thrown. This function is called after `max_retries` has been tried. Defaults to raising the error.
|
178 |
+
"""
|
179 |
+
if on_error_func is None:
|
180 |
+
def on_error_func(retried_times):
|
181 |
+
logger.warning(f"Failed after retrying {retried_times} times")
|
182 |
+
|
183 |
+
if on_retry_func is None:
|
184 |
+
def on_retry_func(idx):
|
185 |
+
logger.warning(f"Retrying on failure {idx}")
|
186 |
+
|
187 |
+
if raise_err_func is None:
|
188 |
+
def raise_err_func(err):
|
189 |
+
raise err
|
190 |
+
|
191 |
+
if func is None:
|
192 |
+
return partial(
|
193 |
+
retry_on_failure,
|
194 |
+
max_retries=max_retries,
|
195 |
+
on_error_func=on_error_func,
|
196 |
+
on_retry_func=on_retry_func,
|
197 |
+
raise_err_func=raise_err_func,
|
198 |
+
sleep_time=sleep_time,
|
199 |
+
error_types=error_types,
|
200 |
+
)
|
201 |
+
|
202 |
+
@wraps(func)
|
203 |
+
def decorator(*args, **kwargs):
|
204 |
+
num_retries = 0
|
205 |
+
while True:
|
206 |
+
try:
|
207 |
+
return func(*args, **kwargs)
|
208 |
+
except error_types as err:
|
209 |
+
num_retries += 1
|
210 |
+
on_retry_func(num_retries)
|
211 |
+
if num_retries > max_retries:
|
212 |
+
on_error_func(num_retries)
|
213 |
+
raise_err_func(err)
|
214 |
+
time.sleep(sleep_time)
|
215 |
+
|
216 |
+
return decorator
|
217 |
+
|
218 |
+
|
219 |
+
class TemporaryRngState:
|
220 |
+
'''
|
221 |
+
Context manager for working with a temporary random number generator (RNG) state.
|
222 |
+
The constructor gets a random number from the Python RNG that is used as
|
223 |
+
(part of) the seed for the temporary RNG
|
224 |
+
and then stores the current RNG state to restore the it later on.
|
225 |
+
If add_rank_to_seed=True, the GPU rank is added to the seed.
|
226 |
+
This is useful to initialize MoE models
|
227 |
+
where the experts on different GPUs should be initialized independently.
|
228 |
+
Note that this feature requires torch.distributed to be initialized.
|
229 |
+
On enter, the context managers sets the RNG state to the random seed created in the constructor
|
230 |
+
to establish a temporary RNG state.
|
231 |
+
On exit, the context manager resets the RNG state to the previously remembered state.
|
232 |
+
Thereby, any RNG operations executed with this context manager
|
233 |
+
do not affect the global, non-temporary RNG state.
|
234 |
+
However, the usage of this context manager does advance the Python RNG
|
235 |
+
since it uses that RNG to generate the random seed in the constructor.
|
236 |
+
The context manager resets the Python RNG state and
|
237 |
+
the PyTorch RNG state for CPU and GPU (if cuda is initialized).
|
238 |
+
It does not currently reset the numpy RNG state.
|
239 |
+
'''
|
240 |
+
def __init__(self, add_rank_to_seed=False):
|
241 |
+
self.seed = random.randrange(2**32)
|
242 |
+
if add_rank_to_seed and dist.is_initialized():
|
243 |
+
self.seed += dist.get_rank()
|
244 |
+
self.python_rng_state = random.getstate()
|
245 |
+
self.torch_rng_state = torch.get_rng_state()
|
246 |
+
if torch.cuda.is_initialized():
|
247 |
+
self.torch_rng_state_cuda = torch.cuda.get_rng_state()
|
248 |
+
|
249 |
+
def __enter__(self):
|
250 |
+
# increment seed for different RNGs to avoid correlation
|
251 |
+
# in the (very unlikely) case that the different RNGs
|
252 |
+
# use the exact same algorithm
|
253 |
+
random.seed(self.seed)
|
254 |
+
# do not call torch.maunal_seed here, because that sets the seed of all GPUs
|
255 |
+
default_generator.manual_seed(self.seed + 1)
|
256 |
+
if torch.cuda.is_initialized():
|
257 |
+
torch.cuda.manual_seed(self.seed + 2) # only set seed of default cuda device
|
258 |
+
|
259 |
+
def __exit__(self, exc_type, exc_value, exc_traceback):
|
260 |
+
random.setstate(self.python_rng_state)
|
261 |
+
torch.set_rng_state(self.torch_rng_state)
|
262 |
+
if torch.cuda.is_initialized():
|
263 |
+
torch.cuda.set_rng_state(self.torch_rng_state_cuda)
|
MedImageInsight/Utils/GlobalExceptHook.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import logging
|
3 |
+
|
4 |
+
logger = logging.getLogger(__name__)
|
5 |
+
|
6 |
+
_orig_except_hook = None
|
7 |
+
|
8 |
+
|
9 |
+
def _global_except_hook(exctype, value, traceback):
|
10 |
+
"""Catches an unhandled exception and call MPI_Abort()."""
|
11 |
+
try:
|
12 |
+
if _orig_except_hook:
|
13 |
+
_orig_except_hook(exctype, value, traceback)
|
14 |
+
else:
|
15 |
+
sys.__excepthook__(exctype, value, traceback)
|
16 |
+
|
17 |
+
finally:
|
18 |
+
import mpi4py.MPI
|
19 |
+
rank = mpi4py.MPI.COMM_WORLD.Get_rank()
|
20 |
+
logger.warning("******************************************")
|
21 |
+
logger.warning("MainzTrainer:")
|
22 |
+
logger.warning(f" Uncaught exception on rank {rank}.")
|
23 |
+
logger.warning(" Calling MPI_Abort() to shut down MPI...")
|
24 |
+
logger.warning("******************************************")
|
25 |
+
logging.shutdown()
|
26 |
+
|
27 |
+
try:
|
28 |
+
import mpi4py.MPI
|
29 |
+
mpi4py.MPI.COMM_WORLD.Abort(1)
|
30 |
+
except Exception as e:
|
31 |
+
# Something is completely broken...
|
32 |
+
# There's nothing we can do any more
|
33 |
+
sys.stderr.write("Sorry, failed to stop MPI and the process may hang.\n")
|
34 |
+
sys.stderr.flush()
|
35 |
+
raise e
|
36 |
+
|
37 |
+
|
38 |
+
def add_hook():
|
39 |
+
"""
|
40 |
+
Add a global hook function that captures all unhandled exceptions.
|
41 |
+
The function calls MPI_Abort() to force all processes abort.
|
42 |
+
|
43 |
+
An MPI runtime is expected to kill all of its child processes
|
44 |
+
if one of them exits abnormally or without calling `MPI_Finalize()`.
|
45 |
+
However, when a Python program run on `mpi4py`, the MPI runtime
|
46 |
+
often fails to detect a process failure, and the rest of the processes
|
47 |
+
hang infinitely.
|
48 |
+
|
49 |
+
See https://github.com/chainer/chainermn/issues/236 and
|
50 |
+
https://mpi4py.readthedocs.io/en/stable/mpi4py.run.html for more
|
51 |
+
information.
|
52 |
+
"""
|
53 |
+
global _orig_except_hook
|
54 |
+
|
55 |
+
if _orig_except_hook is not None:
|
56 |
+
logger.warning("GlobalExceptHook.add_hook() seems to be called multiple times. Ignoring.")
|
57 |
+
return
|
58 |
+
|
59 |
+
logger.info("Adding global except hook for the distributed job to shutdown MPI if unhandled exception is raised on some of the ranks.")
|
60 |
+
_orig_except_hook = sys.excepthook
|
61 |
+
sys.excepthook = _global_except_hook
|
MedImageInsight/Utils/MPIAdapter.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from mpi4py import MPI
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
import subprocess
|
6 |
+
import torch
|
7 |
+
|
8 |
+
logger = logging.getLogger(__name__)
|
9 |
+
|
10 |
+
|
11 |
+
class MPIAdapter:
|
12 |
+
"""
|
13 |
+
MPIAdapter automatically detects and analyzes the training environment for distributed training
|
14 |
+
and offers methods to set up distributed training jobs.
|
15 |
+
|
16 |
+
For example, it determines whether training happens on AML, Philly, or locally.
|
17 |
+
It also determines variables such as the world size and the rank of each GPU.
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self, set_env_vars=True, master_address=None, port='55551'):
|
21 |
+
local_address = '127.0.0.1'
|
22 |
+
default_torch_distributed_port = str(port) # chosen arbitrarily
|
23 |
+
|
24 |
+
if 'OMPI_COMM_WORLD_SIZE' not in os.environ:
|
25 |
+
# application was started without MPI
|
26 |
+
# default to single node with single process
|
27 |
+
self.env_info = 'no MPI'
|
28 |
+
self.world_size = 1
|
29 |
+
self.local_size = 1
|
30 |
+
self.rank = 0
|
31 |
+
self.local_rank = 0
|
32 |
+
self.master_address = local_address
|
33 |
+
self.master_port = default_torch_distributed_port
|
34 |
+
else:
|
35 |
+
# application was started with MPI
|
36 |
+
# get MPI parameters
|
37 |
+
self.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
|
38 |
+
self.local_size = int(os.environ['OMPI_COMM_WORLD_LOCAL_SIZE'])
|
39 |
+
self.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
|
40 |
+
self.local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
|
41 |
+
|
42 |
+
if master_address is not None:
|
43 |
+
self.master_address = master_address
|
44 |
+
self.master_port = default_torch_distributed_port
|
45 |
+
self.env_info = 'manually set master ip'
|
46 |
+
elif 'PHILLY_CONTAINER_IP' in os.environ:
|
47 |
+
# application is running on Philly
|
48 |
+
# read environment variables on master node and broadcast via MPI
|
49 |
+
self.env_info = 'philly'
|
50 |
+
if self.rank == 0:
|
51 |
+
self.master_address = os.environ['PHILLY_CONTAINER_IP']
|
52 |
+
self.master_port = os.environ['PHILLY_CONTAINER_PORT_RANGE_START']
|
53 |
+
else:
|
54 |
+
self.master_address = None
|
55 |
+
self.master_port = None
|
56 |
+
self.master_address = MPI.COMM_WORLD.bcast(self.master_address, root=0)
|
57 |
+
self.master_port = MPI.COMM_WORLD.bcast(self.master_port, root=0)
|
58 |
+
elif "AMLK8S_NUM_WORKER" in os.environ or "AZ_CMK8S_JOB_WORK_DIR" in os.environ:
|
59 |
+
# application is running on AMLK8S (ITP)
|
60 |
+
# read master address from a specific file.
|
61 |
+
self.env_info = 'AMLK8S (ITP)'
|
62 |
+
# from: https://k8s-wiki.azureml.com/faq.html
|
63 |
+
regexp = r"[\s\S]*export[\s]*DLTS_SD_worker0_IP=([0-9.]+)[\s|s]*"
|
64 |
+
with open("/dlts-runtime/env/init.env", 'r') as f:
|
65 |
+
line = f.read()
|
66 |
+
match = re.match(regexp, line)
|
67 |
+
if match:
|
68 |
+
self.master_address = str(match.group(1))
|
69 |
+
else:
|
70 |
+
# Did not find master node ip in file. It must be a single-node
|
71 |
+
# debugging job with custom "mpirun" command
|
72 |
+
assert self.world_size == self.local_size, \
|
73 |
+
"It's not a single-node debugging job on AMLK8S (ITP), but no master ip is found in file."
|
74 |
+
self.env_info = 'single-node AMLK8S (ITP) debugging job'
|
75 |
+
self.master_address = local_address
|
76 |
+
self.master_port = default_torch_distributed_port
|
77 |
+
elif 'AZ_BATCH_MASTER_NODE' in os.environ:
|
78 |
+
# application is running on multiple nodes on AML
|
79 |
+
self.env_info = 'multi-node AML'
|
80 |
+
master_node_params = os.environ['AZ_BATCH_MASTER_NODE'].split(':')
|
81 |
+
self.master_address = master_node_params[0]
|
82 |
+
self.master_port = default_torch_distributed_port
|
83 |
+
elif self.world_size == self.local_size:
|
84 |
+
# application is running with MPI on single node
|
85 |
+
self.env_info = 'single-node AML or other MPI environment'
|
86 |
+
self.master_address = local_address
|
87 |
+
self.master_port = default_torch_distributed_port
|
88 |
+
else:
|
89 |
+
# multi-node MPI environment, but not Philly or AML
|
90 |
+
# we use "hostname -I" command on rank 0 to get the master address
|
91 |
+
self.env_info = 'multi-node other MPI environment'
|
92 |
+
if self.rank == 0:
|
93 |
+
hostname_cmd = ["hostname -I"]
|
94 |
+
result = subprocess.check_output(hostname_cmd, shell=True)
|
95 |
+
self.master_address = result.decode('utf-8').split()[0]
|
96 |
+
self.master_port = default_torch_distributed_port
|
97 |
+
else:
|
98 |
+
self.master_address = None
|
99 |
+
self.master_port = None
|
100 |
+
self.master_address = MPI.COMM_WORLD.bcast(self.master_address, root=0)
|
101 |
+
self.master_port = MPI.COMM_WORLD.bcast(self.master_port, root=0)
|
102 |
+
|
103 |
+
self.init_method_url = f'tcp://{self.master_address}:{self.master_port}'
|
104 |
+
if set_env_vars:
|
105 |
+
self._set_env_vars()
|
106 |
+
|
107 |
+
def log_info(self):
|
108 |
+
"""
|
109 |
+
Logs information about distributed training environment.
|
110 |
+
"""
|
111 |
+
# use logger.warning because MainzTrain has a hidden convention
|
112 |
+
# of not printing logger.info messages on processes with rank > 0
|
113 |
+
logger.warning('----------------')
|
114 |
+
logger.warning('MPI Adapter data')
|
115 |
+
logger.warning('----------------')
|
116 |
+
logger.warning(f'environment info: {self.env_info}')
|
117 |
+
logger.warning(f'init method url: {self.init_method_url}')
|
118 |
+
logger.warning(f'world size: {self.world_size}')
|
119 |
+
logger.warning(f'local size: {self.local_size}')
|
120 |
+
logger.warning(f'rank: {self.rank}')
|
121 |
+
logger.warning(f'local rank: {self.local_rank}')
|
122 |
+
logger.warning(f'master address: {self.master_address}')
|
123 |
+
logger.warning(f'master port: {self.master_port}')
|
124 |
+
logger.warning('----------------')
|
125 |
+
|
126 |
+
def init_process_group(self, backend):
|
127 |
+
"""
|
128 |
+
Initializes the default PyTorch distributed process group.
|
129 |
+
"""
|
130 |
+
# use logger.warning because MainzTrain has a hidden convention
|
131 |
+
# of not printing logger.info messages on processes with rank > 0
|
132 |
+
logger.warning('trying to initialize process group ...')
|
133 |
+
torch.distributed.init_process_group(backend=backend,
|
134 |
+
init_method=self.init_method_url,
|
135 |
+
world_size=self.world_size,
|
136 |
+
rank=self.rank)
|
137 |
+
logger.warning('process group initialized')
|
138 |
+
|
139 |
+
def _set_env_vars(self):
|
140 |
+
"""
|
141 |
+
Sets environment variables for world size, rank, local rank, master addr, and master port.
|
142 |
+
"""
|
143 |
+
os.environ['WORLD_SIZE'] = str(self.world_size)
|
144 |
+
os.environ['RANK'] = str(self.rank)
|
145 |
+
os.environ["LOCAL_RANK"] = str(self.local_rank)
|
146 |
+
os.environ['MASTER_ADDR'] = self.master_address
|
147 |
+
os.environ['MASTER_PORT'] = self.master_port
|
MedImageInsight/Utils/Utils.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import torch.distributed as dist
|
5 |
+
import yaml
|
6 |
+
|
7 |
+
from fvcore.nn import FlopCountAnalysis
|
8 |
+
from fvcore.nn import flop_count_table
|
9 |
+
from fvcore.nn import flop_count_str
|
10 |
+
|
11 |
+
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
|
15 |
+
NORM_MODULES = [
|
16 |
+
torch.nn.BatchNorm1d,
|
17 |
+
torch.nn.BatchNorm2d,
|
18 |
+
torch.nn.BatchNorm3d,
|
19 |
+
torch.nn.SyncBatchNorm,
|
20 |
+
# NaiveSyncBatchNorm inherits from BatchNorm2d
|
21 |
+
torch.nn.GroupNorm,
|
22 |
+
torch.nn.InstanceNorm1d,
|
23 |
+
torch.nn.InstanceNorm2d,
|
24 |
+
torch.nn.InstanceNorm3d,
|
25 |
+
torch.nn.LayerNorm,
|
26 |
+
torch.nn.LocalResponseNorm,
|
27 |
+
]
|
28 |
+
|
29 |
+
def register_norm_module(cls):
|
30 |
+
NORM_MODULES.append(cls)
|
31 |
+
|
32 |
+
return cls
|
33 |
+
|
34 |
+
|
35 |
+
def is_main_process():
|
36 |
+
rank = 0
|
37 |
+
if 'OMPI_COMM_WORLD_SIZE' in os.environ:
|
38 |
+
rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
|
39 |
+
|
40 |
+
return rank == 0
|
41 |
+
|
42 |
+
|
43 |
+
@torch.no_grad()
|
44 |
+
def analysis_model(model, dump_input, verbose=False):
|
45 |
+
model.eval()
|
46 |
+
flops = FlopCountAnalysis(model, dump_input)
|
47 |
+
total = flops.total()
|
48 |
+
model.train()
|
49 |
+
params_total = sum(p.numel() for p in model.parameters())
|
50 |
+
params_learned = sum(
|
51 |
+
p.numel() for p in model.parameters() if p.requires_grad
|
52 |
+
)
|
53 |
+
logger.info(f"flop count table:\n {flop_count_table(flops)}")
|
54 |
+
if verbose:
|
55 |
+
logger.info(f"flop count str:\n {flop_count_str(flops)}")
|
56 |
+
logger.info(f" Total flops: {total/1000/1000:.3f}M,")
|
57 |
+
logger.info(f" Total params: {params_total/1000/1000:.3f}M,")
|
58 |
+
logger.info(f" Learned params: {params_learned/1000/1000:.3f}M")
|
59 |
+
|
60 |
+
return total, flop_count_table(flops), flop_count_str(flops)
|
61 |
+
|
62 |
+
|
63 |
+
def load_config_dict_to_opt(opt, config_dict, splitter='.'):
|
64 |
+
"""
|
65 |
+
Load the key, value pairs from config_dict to opt, overriding existing values in opt
|
66 |
+
if there is any.
|
67 |
+
"""
|
68 |
+
if not isinstance(config_dict, dict):
|
69 |
+
raise TypeError("Config must be a Python dictionary")
|
70 |
+
for k, v in config_dict.items():
|
71 |
+
k_parts = k.split(splitter)
|
72 |
+
pointer = opt
|
73 |
+
for k_part in k_parts[:-1]:
|
74 |
+
if k_part not in pointer:
|
75 |
+
pointer[k_part] = {}
|
76 |
+
pointer = pointer[k_part]
|
77 |
+
assert isinstance(pointer, dict), "Overriding key needs to be inside a Python dict."
|
78 |
+
ori_value = pointer.get(k_parts[-1])
|
79 |
+
pointer[k_parts[-1]] = v
|
80 |
+
if ori_value:
|
81 |
+
print(f"Overrided {k} from {ori_value} to {pointer[k_parts[-1]]}")
|
82 |
+
|
83 |
+
|
84 |
+
def load_opt_from_config_file(conf_file):
|
85 |
+
"""
|
86 |
+
Load opt from the config file.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
conf_file: config file path
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
dict: a dictionary of opt settings
|
93 |
+
"""
|
94 |
+
opt = {}
|
95 |
+
with open(conf_file, encoding='utf-8') as f:
|
96 |
+
config_dict = yaml.safe_load(f)
|
97 |
+
load_config_dict_to_opt(opt, config_dict)
|
98 |
+
|
99 |
+
return opt
|
100 |
+
|
101 |
+
def cast_batch_to_dtype(batch, dtype):
|
102 |
+
"""
|
103 |
+
Cast the float32 tensors in a batch to a specified torch dtype.
|
104 |
+
It should be called before feeding the batch to the FP16 DeepSpeed model.
|
105 |
+
|
106 |
+
Args:
|
107 |
+
batch (torch.tensor or container of torch.tensor): input batch
|
108 |
+
Returns:
|
109 |
+
return_batch: same type as the input batch with internal float32 tensors casted to the specified dtype.
|
110 |
+
"""
|
111 |
+
if torch.is_tensor(batch):
|
112 |
+
if torch.is_floating_point(batch):
|
113 |
+
return_batch = batch.to(dtype)
|
114 |
+
else:
|
115 |
+
return_batch = batch
|
116 |
+
elif isinstance(batch, list):
|
117 |
+
return_batch = [cast_batch_to_dtype(t, dtype) for t in batch]
|
118 |
+
elif isinstance(batch, tuple):
|
119 |
+
return_batch = tuple(cast_batch_to_dtype(t, dtype) for t in batch)
|
120 |
+
elif isinstance(batch, dict):
|
121 |
+
return_batch = {}
|
122 |
+
for k in batch:
|
123 |
+
return_batch[k] = cast_batch_to_dtype(batch[k], dtype)
|
124 |
+
else:
|
125 |
+
logger.debug(f"Can not cast type {type(batch)} to {dtype}. Skipping it in the batch.")
|
126 |
+
return_batch = batch
|
127 |
+
|
128 |
+
return return_batch
|
129 |
+
|
130 |
+
|
131 |
+
def cast_batch_to_half(batch):
|
132 |
+
"""
|
133 |
+
Cast the float32 tensors in a batch to float16.
|
134 |
+
It should be called before feeding the batch to the FP16 DeepSpeed model.
|
135 |
+
|
136 |
+
Args:
|
137 |
+
batch (torch.tensor or container of torch.tensor): input batch
|
138 |
+
Returns:
|
139 |
+
return_batch: same type as the input batch with internal float32 tensors casted to float16
|
140 |
+
"""
|
141 |
+
return cast_batch_to_dtype(batch, torch.float16)
|
MedImageInsight/Utils/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .Utils import analysis_model
|
2 |
+
from .Utils import is_main_process
|
3 |
+
from .Utils import register_norm_module
|
4 |
+
from .Utils import NORM_MODULES
|
5 |
+
from .Utils import load_config_dict_to_opt
|
6 |
+
from .Utils import load_opt_from_config_file
|
7 |
+
from .Utils import cast_batch_to_half
|
MedImageInsight/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .UniCLModel import build_unicl_model
|
2 |
+
|
3 |
+
__all__ = [
|
4 |
+
'build_od_model',
|
5 |
+
'build_unicl_model',
|
6 |
+
'build_tokenizer_from_name',
|
7 |
+
'get_image_preprocess',
|
8 |
+
'build_unicl_matching_model'
|
9 |
+
]
|