Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Upload 26 files
Browse files- Dockerfile +27 -0
- LICENSE +201 -0
- README.md +27 -8
- angle_calculation/angle_model.py +444 -0
- angle_calculation/classic.py +349 -0
- angle_calculation/envelope_correction.py +34 -0
- angle_calculation/granum_utils.py +80 -0
- angle_calculation/image_transforms.py +34 -0
- angle_calculation/sampling.py +142 -0
- app.py +602 -0
- grana_detection/mmwrapper.py +42 -0
- model.py +629 -0
- period_calculation/config.py +19 -0
- period_calculation/data_reader.py +861 -0
- period_calculation/image_transforms.py +79 -0
- period_calculation/models/abstract_model.py +61 -0
- period_calculation/models/gauss_model.py +237 -0
- period_calculation/period_measurer.py +54 -0
- requirements.txt +11 -0
- settings.py +1 -0
- styles.css +47 -0
- weights/AS_square_v16.ckpt +3 -0
- weights/model_weights_detector.pt +3 -0
- weights/period_measurer_weights-1.298_real_full-fa12970.ckpt +3 -0
- weights/yolo/20240604_yolov8_segm_ABRCR1_all_train4_best.pt +3 -0
- weights/yolo/current_yolo.pt +3 -0
Dockerfile
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.9 as build
|
| 2 |
+
|
| 3 |
+
RUN apt-get update && apt-get install ffmpeg libsm6 libxext6 -y
|
| 4 |
+
|
| 5 |
+
# Set up a new user named "user" with user ID 1000
|
| 6 |
+
RUN useradd -m -u 1000 user
|
| 7 |
+
|
| 8 |
+
# Switch to the "user" user
|
| 9 |
+
USER user
|
| 10 |
+
|
| 11 |
+
# Set home to the user's home directory
|
| 12 |
+
ENV HOME=/home/user \
|
| 13 |
+
PATH=/home/user/.local/bin:$PATH
|
| 14 |
+
|
| 15 |
+
# Set the working directory to the user's home directory
|
| 16 |
+
WORKDIR $HOME/app
|
| 17 |
+
|
| 18 |
+
COPY requirements.txt .
|
| 19 |
+
RUN pip install --no-cache-dir -r ./requirements.txt --extra-index-url https://download.pytorch.org/whl/cpu
|
| 20 |
+
# RUN mim install mmengine
|
| 21 |
+
# RUN mim install "mmcv==2.1.0" & mim install "mmdet==3.3.0"
|
| 22 |
+
|
| 23 |
+
FROM build as final
|
| 24 |
+
|
| 25 |
+
COPY --chown=user . .
|
| 26 |
+
|
| 27 |
+
CMD python app.py
|
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
README.md
CHANGED
|
@@ -1,13 +1,32 @@
|
|
| 1 |
---
|
| 2 |
title: GRANA
|
| 3 |
-
emoji: 🐨
|
| 4 |
-
colorFrom: indigo
|
| 5 |
-
colorTo: indigo
|
| 6 |
-
sdk: gradio
|
| 7 |
-
sdk_version: 5.4.0
|
| 8 |
app_file: app.py
|
| 9 |
-
|
| 10 |
-
|
| 11 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
title: GRANA
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
app_file: app.py
|
| 4 |
+
sdk: gradio
|
| 5 |
+
sdk_version: 4.44.0
|
| 6 |
---
|
| 7 |
+
# GRANA
|
| 8 |
+
<img src="https://img.shields.io/badge/Python-3.9-blue"/>
|
| 9 |
+
<a href="www.chloroplast.pl/GRANA"><img src="https://img.shields.io/badge/GRANA-Website-green" /></a>
|
| 10 |
+
<a href="https://huggingface.co/spaces/chloroplast/GRANA"><img src="https://img.shields.io/badge/GRANA-Demo-green" /></a>
|
| 11 |
+
<img src="https://img.shields.io/badge/Gradio-4.44.0-darkgreen"/>
|
| 12 |
+
|
| 13 |
+
GRANA (**G**raphical **R**ecognition and **A**nalysis of **N**anostructural **A**ssemblies)
|
| 14 |
+
is an an AI-enhanced, user-friendly
|
| 15 |
+
software tool that recognizes grana on thylakoid network electron micrographs
|
| 16 |
+
and generates a complex set of their structural parameters measurements.
|
| 17 |
+
|
| 18 |
+
## Website
|
| 19 |
+
More information about GRANA, including **example dataset**, can be found at [GRANA website](https://www.chloroplast.pl/grana).
|
| 20 |
+
|
| 21 |
+
## Demo
|
| 22 |
+
Demo version of GRANA is available at [Hugging Face Spaces](https://huggingface.co/spaces/chloroplast/GRANA).
|
| 23 |
+
Using demo version, you can analyze up to 5 images at once.
|
| 24 |
+
|
| 25 |
+
## Running as Docker container
|
| 26 |
+
The recommended way to run GRANA is to use Docker container.
|
| 27 |
|
| 28 |
+
To run the container, use the following command:
|
| 29 |
+
```bash
|
| 30 |
+
docker run -p 7860:7860 mbuk/grana_measure:v0.5.4
|
| 31 |
+
```
|
| 32 |
+
After running the command, you can access the GRANA interface at `http://localhost:7860`.
|
angle_calculation/angle_model.py
ADDED
|
@@ -0,0 +1,444 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import pytorch_lightning as pl
|
| 4 |
+
import timm
|
| 5 |
+
|
| 6 |
+
# from hydra.utils import instantiate
|
| 7 |
+
from scipy.stats import circmean, circstd
|
| 8 |
+
from scipy import ndimage
|
| 9 |
+
from skimage.transform import resize
|
| 10 |
+
|
| 11 |
+
from sampling import get_crop_batch
|
| 12 |
+
from granum_utils import get_circle_mask
|
| 13 |
+
import image_transforms
|
| 14 |
+
from envelope_correction import calculate_best_angle_from_mask
|
| 15 |
+
## loss
|
| 16 |
+
|
| 17 |
+
class ConfidenceScaler:
|
| 18 |
+
def __init__(self, data: np.ndarray):
|
| 19 |
+
self.data = data
|
| 20 |
+
self.data.sort()
|
| 21 |
+
def __call__(self, x):
|
| 22 |
+
return np.searchsorted(self.data,x) / len(self.data)
|
| 23 |
+
|
| 24 |
+
class PatchedPredictor:
|
| 25 |
+
def __init__(self,
|
| 26 |
+
model,
|
| 27 |
+
crop_size=96,
|
| 28 |
+
normalization=dict(mean=0,std=1),
|
| 29 |
+
n_samples=32,
|
| 30 |
+
mask=None,# 'circle', None
|
| 31 |
+
filter_outliers=True,
|
| 32 |
+
apply_radon=False, # apply Radon transform
|
| 33 |
+
radon_size=(128,128), # (int, int) reshape radon transformed image to this shape,
|
| 34 |
+
angle_confidence_threshold=0,
|
| 35 |
+
use_envelope_correction=True
|
| 36 |
+
):
|
| 37 |
+
self.model = model
|
| 38 |
+
self.crop_size = crop_size
|
| 39 |
+
self.normalization = normalization
|
| 40 |
+
self.n_samples = n_samples
|
| 41 |
+
if mask not in [None, 'circle']:
|
| 42 |
+
raise ValueError(f'unknown mask {mask}')
|
| 43 |
+
self.mask = mask
|
| 44 |
+
self.filter_outliers = filter_outliers
|
| 45 |
+
|
| 46 |
+
self.apply_radon = apply_radon
|
| 47 |
+
self.radon_size = radon_size
|
| 48 |
+
|
| 49 |
+
self.angle_confidence_threshold = angle_confidence_threshold
|
| 50 |
+
self.use_envelope_correction = use_envelope_correction
|
| 51 |
+
|
| 52 |
+
@torch.no_grad()
|
| 53 |
+
def __call__(self, img: np.ndarray, mask: np.ndarray):
|
| 54 |
+
pl.seed_everything(44)
|
| 55 |
+
# get crops with different scales and rotation
|
| 56 |
+
crops, angles_tta, scales_tta = get_crop_batch(
|
| 57 |
+
img, mask,
|
| 58 |
+
crop_size=self.crop_size,
|
| 59 |
+
samples_per_scale=self.n_samples,
|
| 60 |
+
use_variance_threshold=True
|
| 61 |
+
)
|
| 62 |
+
if len(crops) == 0:
|
| 63 |
+
return dict(
|
| 64 |
+
est_angle=np.nan,
|
| 65 |
+
est_angle_confidence=0.,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
# preprocess batch (normalize, mask, transform)
|
| 69 |
+
batch = self._preprocess_batch(crops)
|
| 70 |
+
|
| 71 |
+
# predict for batch - we don't use period and lumen anymore
|
| 72 |
+
preds_direction, preds_period, preds_lumen_width = self.model(batch)
|
| 73 |
+
# # convert to numpy
|
| 74 |
+
# preds_direction = preds_direction.numpy()
|
| 75 |
+
# preds_period = preds_period.numpy()
|
| 76 |
+
# preds_lumen_width = preds_lumen_width.numpy()
|
| 77 |
+
|
| 78 |
+
# aggregate angles
|
| 79 |
+
est_angles = (preds_direction - angles_tta) % 180
|
| 80 |
+
est_angle = circmean(est_angles, low=-90, high=90) + 90
|
| 81 |
+
est_angle_std = circstd(est_angles, low=-90, high=90)
|
| 82 |
+
est_angle_confidence = self._std_to_confidence(est_angle_std, 10) # confidence 0.5 for std =10 degrees
|
| 83 |
+
|
| 84 |
+
if est_angle_confidence < self.angle_confidence_threshold:
|
| 85 |
+
est_angle = np.nan
|
| 86 |
+
est_angle_confidence = 0.
|
| 87 |
+
|
| 88 |
+
if self.use_envelope_correction and (not np.isnan(est_angle)):
|
| 89 |
+
angle_correction = -calculate_best_angle_from_mask(
|
| 90 |
+
ndimage.rotate(mask, -est_angle, reshape=True, order=0)
|
| 91 |
+
)
|
| 92 |
+
est_angle += angle_correction
|
| 93 |
+
|
| 94 |
+
return dict(
|
| 95 |
+
est_angle=est_angle,
|
| 96 |
+
est_angle_confidence=est_angle_confidence,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
def _apply_radon(self, batch): # may reauire circle mask
|
| 100 |
+
crops_radon = image_transforms.batched_radon(batch.numpy())
|
| 101 |
+
crops_radon = np.transpose(resize(np.transpose(crops_radon, (1, 2, 0)), self.radon_size), (2, 0, 1))
|
| 102 |
+
return torch.tensor(crops_radon)
|
| 103 |
+
|
| 104 |
+
def _preprocess_batch(self, batch):
|
| 105 |
+
if self.mask == 'circle':
|
| 106 |
+
mask = get_circle_mask(batch.shape[1])
|
| 107 |
+
batch[:,mask] = 0
|
| 108 |
+
if self.apply_radon:
|
| 109 |
+
batch = self._apply_radon(batch)
|
| 110 |
+
batch = ((batch/255) - self.normalization['mean'])/self.normalization['std']
|
| 111 |
+
return batch.unsqueeze(1) # add channel dimension
|
| 112 |
+
|
| 113 |
+
def _filter_outliers(self, x, qmin=0.25, qmax=0.75):
|
| 114 |
+
x_min, x_max = np.quantile(x, [qmin, qmax])
|
| 115 |
+
return x[(x>=x_min) & (x<=x_max)]
|
| 116 |
+
|
| 117 |
+
def _std_to_confidence(self, x, x_thr, y_thr=0.5):
|
| 118 |
+
"""transform [0, inf] to [1,0], such that f(x_thr)=y_thr"""
|
| 119 |
+
return 1 / (1+x*(1-y_thr)/(x_thr*y_thr))
|
| 120 |
+
|
| 121 |
+
class CosineLoss(torch.nn.Module):
|
| 122 |
+
def __init__(self, p=1, degrees=False, scale=1):
|
| 123 |
+
super().__init__()
|
| 124 |
+
self.p = p
|
| 125 |
+
self.degrees = degrees
|
| 126 |
+
self.scale = scale
|
| 127 |
+
def forward(self, x, y):
|
| 128 |
+
if self.degrees:
|
| 129 |
+
x = torch.deg2rad(x)
|
| 130 |
+
y = torch.deg2rad(y)
|
| 131 |
+
return torch.mean((1-torch.cos(x-y))**self.p) * self.scale
|
| 132 |
+
|
| 133 |
+
## model
|
| 134 |
+
class AngleParser2d(torch.nn.Module):
|
| 135 |
+
def __init__(self, angle_range=180):
|
| 136 |
+
super().__init__()
|
| 137 |
+
self.angle_range = angle_range
|
| 138 |
+
def forward(self, batch):
|
| 139 |
+
# r = torch.linalg.norm(batch, dim=1)
|
| 140 |
+
preds_y_proj = torch.sigmoid(batch[:,0]) - 0.5
|
| 141 |
+
preds_x_proj = torch.sigmoid(batch[:,1]) - 0.5
|
| 142 |
+
preds_direction = self.angle_range/360.*torch.rad2deg(torch.arctan2(preds_y_proj, preds_x_proj))
|
| 143 |
+
return preds_direction
|
| 144 |
+
|
| 145 |
+
class AngleRegularizer(torch.nn.Module):
|
| 146 |
+
def __init__(self, strength=1.0, scale=1.0, p=2):
|
| 147 |
+
super().__init__()
|
| 148 |
+
self.strength = strength
|
| 149 |
+
self.scale = scale
|
| 150 |
+
self.p = p
|
| 151 |
+
def forward(self, batch):
|
| 152 |
+
r = torch.linalg.norm(batch, dim=1)
|
| 153 |
+
return self.strength * torch.norm(r - self.scale, p=self.p)
|
| 154 |
+
|
| 155 |
+
class AngleRegularizerLog(torch.nn.Module):
|
| 156 |
+
def __init__(self, strength=1.0, scale=1.0, p=2):
|
| 157 |
+
super().__init__()
|
| 158 |
+
self.strength = strength
|
| 159 |
+
self.scale = scale
|
| 160 |
+
self.p = p
|
| 161 |
+
def forward(self, batch):
|
| 162 |
+
r = torch.linalg.norm(batch, dim=1)
|
| 163 |
+
return self.strength * torch.norm(torch.log(r/self.scale), p=self.p)
|
| 164 |
+
|
| 165 |
+
class StripsModel(pl.LightningModule):
|
| 166 |
+
def __init__(self,
|
| 167 |
+
model_name = 'resnet18',
|
| 168 |
+
lr=0.001,
|
| 169 |
+
optimizer_hparams=dict(),
|
| 170 |
+
lr_hparams=dict(classname='MultiStepLR', kwargs=dict(milestones=[100, 150], gamma=0.1)),
|
| 171 |
+
loss_hparams=dict(rotation_weight=10., lumen_fraction_weight=50.),
|
| 172 |
+
angle_hparams=dict(angle_range=180.),
|
| 173 |
+
regularizer_hparams=None,
|
| 174 |
+
sigmoid_smoother=10.
|
| 175 |
+
):
|
| 176 |
+
super().__init__()
|
| 177 |
+
# Exports the hyperparameters to a YAML file, and create "self.hparams" namespace
|
| 178 |
+
self.save_hyperparameters()
|
| 179 |
+
# Create model - implemented in non-abstract classes
|
| 180 |
+
self.model = timm.create_model(model_name, in_chans=1, num_classes=4) #2 + self.hparams.angle_hparams['ndim'])
|
| 181 |
+
self.angle_parser = AngleParser2d(**self.hparams.angle_hparams)
|
| 182 |
+
self.regularizer = self._get_regularizer(self.hparams.regularizer_hparams)
|
| 183 |
+
self.losses = {
|
| 184 |
+
'direction': CosineLoss(2., True),
|
| 185 |
+
'period': torch.nn.functional.mse_loss,
|
| 186 |
+
'lumen_fraction': torch.nn.functional.mse_loss
|
| 187 |
+
}
|
| 188 |
+
self.losses_weights = {
|
| 189 |
+
'direction': self.hparams.loss_hparams['rotation_weight'],
|
| 190 |
+
'period': 1,
|
| 191 |
+
'lumen_fraction': self.hparams.loss_hparams['lumen_fraction_weight'],
|
| 192 |
+
'regularization': self.hparams.loss_hparams.get('regularization_weight', 0.)
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
def _get_regularizer(self, regularizer_params):
|
| 196 |
+
if regularizer_params is None:
|
| 197 |
+
return None
|
| 198 |
+
else:
|
| 199 |
+
return instantiate(regularizer_params)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def forward(self, x, return_raw=False):
|
| 203 |
+
"""get predictions from image batch"""
|
| 204 |
+
preds = self.model(x) # preds: logit angle_sin, logit angle_cos, period, logit lumen fraction or logit angle, period, logit lumen fraction
|
| 205 |
+
preds_direction = self.angle_parser(preds)
|
| 206 |
+
preds_period = preds[:,-2]
|
| 207 |
+
preds_lumen_fraction = torch.sigmoid(preds[:,-1]*self.hparams.sigmoid_smoother) #lumen fraction is between 0 and 1, so we take sigmoid fo this
|
| 208 |
+
|
| 209 |
+
outputs = [preds_direction, preds_period, preds_lumen_fraction]
|
| 210 |
+
if return_raw:
|
| 211 |
+
outputs.append(preds)
|
| 212 |
+
|
| 213 |
+
return tuple(outputs)
|
| 214 |
+
|
| 215 |
+
def configure_optimizers(self):
|
| 216 |
+
# AdamW is Adam with a correct implementation of weight decay (see here
|
| 217 |
+
# for details: https://arxiv.org/pdf/1711.05101.pdf)
|
| 218 |
+
optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.lr, **self.hparams.optimizer_hparams)
|
| 219 |
+
# scheduler = getattr(torch.optim.lr_scheduler, self.hparams.lr_hparams['classname'])(optimizer, **self.hparams.lr_hparams['kwargs'])
|
| 220 |
+
scheduler = instantiate({**self.hparams.lr_hparams, '_partial_': True})(optimizer)
|
| 221 |
+
return [optimizer], [scheduler]
|
| 222 |
+
|
| 223 |
+
def process_batch_supervised(self, batch):
|
| 224 |
+
"""get predictions, losses and mean errors (MAE)"""
|
| 225 |
+
|
| 226 |
+
# get predictions
|
| 227 |
+
preds = {}
|
| 228 |
+
preds['direction'], preds['period'], preds['lumen_fraction'], preds_raw = self.forward(batch['image'], return_raw=True) # preds: angle, period, lumen fraction, raw preds
|
| 229 |
+
|
| 230 |
+
# calculate losses
|
| 231 |
+
losses = {
|
| 232 |
+
'direction': self.losses['direction'](2*batch['direction'], 2*preds['direction']),
|
| 233 |
+
'period': self.losses['period'](batch['period'], preds['period']),
|
| 234 |
+
'lumen_fraction': self.losses['lumen_fraction'](batch['lumen_fraction'], preds['lumen_fraction']),
|
| 235 |
+
}
|
| 236 |
+
if self.regularizer is not None:
|
| 237 |
+
losses['regularization'] = self.regularizer(preds_raw[:,:2])
|
| 238 |
+
|
| 239 |
+
losses['final'] = \
|
| 240 |
+
losses['direction']*self.losses_weights['direction'] + \
|
| 241 |
+
losses['period']*self.losses_weights['period'] + \
|
| 242 |
+
losses['lumen_fraction']*self.losses_weights['lumen_fraction'] + \
|
| 243 |
+
losses.get('regularization', 0.)*self.losses_weights.get('regularization', 0.)
|
| 244 |
+
|
| 245 |
+
# calculate mean errors
|
| 246 |
+
period_difference = np.mean(abs(
|
| 247 |
+
batch['period'].detach().cpu().numpy() - \
|
| 248 |
+
preds['period'].detach().cpu().numpy()
|
| 249 |
+
))
|
| 250 |
+
|
| 251 |
+
a1 = batch['direction'].detach().cpu().numpy()
|
| 252 |
+
a2 = preds['direction'].detach().cpu().numpy()
|
| 253 |
+
angle_difference = np.mean(0.5*np.degrees(np.arccos(np.cos(2*np.radians(a2-a1)))))
|
| 254 |
+
|
| 255 |
+
lumen_fraction_difference = np.mean(abs(preds['lumen_fraction'].detach().cpu().numpy()-batch['lumen_fraction'].detach().cpu().numpy()))
|
| 256 |
+
|
| 257 |
+
mae = {
|
| 258 |
+
'period': period_difference,
|
| 259 |
+
'direction': angle_difference,
|
| 260 |
+
'lumen_fraction': lumen_fraction_difference
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
return preds, losses, mae
|
| 264 |
+
|
| 265 |
+
def log_all(self, losses, mae, prefix=''):
|
| 266 |
+
self.log(f"{prefix}angle_loss", losses['direction'].item())
|
| 267 |
+
self.log(f"{prefix}period_loss", losses['period'].item())
|
| 268 |
+
self.log(f"{prefix}lumen_fraction_loss", losses['lumen_fraction'].item())
|
| 269 |
+
self.log(f"{prefix}period_difference", mae['period'])
|
| 270 |
+
self.log(f"{prefix}angle_difference", mae['direction'])
|
| 271 |
+
self.log(f"{prefix}lumen_fraction_difference", mae['lumen_fraction'])
|
| 272 |
+
self.log(f"{prefix}loss", losses['final'])
|
| 273 |
+
if 'regularization' in losses:
|
| 274 |
+
self.log(f"{prefix}regularization_loss", losses['regularization'].item())
|
| 275 |
+
|
| 276 |
+
def training_step(self, batch, batch_idx):
|
| 277 |
+
# "batch" is the output of the training data loader.
|
| 278 |
+
preds, losses, mae = self.process_batch_supervised(batch)
|
| 279 |
+
self.log_all(losses, mae, prefix='train_')
|
| 280 |
+
|
| 281 |
+
return losses['final']
|
| 282 |
+
|
| 283 |
+
def validation_step(self, batch, batch_idx):
|
| 284 |
+
preds, losses, mae = self.process_batch_supervised(batch)
|
| 285 |
+
self.log_all(losses, mae, prefix='val_')
|
| 286 |
+
|
| 287 |
+
def test_step(self, batch, batch_idx):
|
| 288 |
+
preds, losses, mae = self.process_batch_supervised(batch)
|
| 289 |
+
self.log_all(losses, mae, prefix='test_')
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
class StripsModelLumenWidth(pl.LightningModule):
|
| 293 |
+
def __init__(self,
|
| 294 |
+
model_name = 'resnet18',
|
| 295 |
+
lr=0.001,
|
| 296 |
+
optimizer_hparams=dict(),
|
| 297 |
+
lr_hparams=dict(classname='MultiStepLR', kwargs=dict(milestones=[100, 150], gamma=0.1)),
|
| 298 |
+
loss_hparams=dict(rotation_weight=10., lumen_width_weight=50.),
|
| 299 |
+
angle_hparams=dict(angle_range=180.),
|
| 300 |
+
regularizer_hparams=None,
|
| 301 |
+
sigmoid_smoother=10.
|
| 302 |
+
):
|
| 303 |
+
super().__init__()
|
| 304 |
+
# Exports the hyperparameters to a YAML file, and create "self.hparams" namespace
|
| 305 |
+
self.save_hyperparameters()
|
| 306 |
+
# Create model - implemented in non-abstract classes
|
| 307 |
+
self.model = timm.create_model(model_name, in_chans=1, num_classes=4) #2 + self.hparams.angle_hparams['ndim'])
|
| 308 |
+
self.angle_parser = AngleParser2d(**self.hparams.angle_hparams)
|
| 309 |
+
self.regularizer = self._get_regularizer(self.hparams.regularizer_hparams)
|
| 310 |
+
self.losses = {
|
| 311 |
+
'direction': CosineLoss(2., True),
|
| 312 |
+
'period': torch.nn.functional.mse_loss,
|
| 313 |
+
'lumen_width': torch.nn.functional.mse_loss
|
| 314 |
+
}
|
| 315 |
+
self.losses_weights = {
|
| 316 |
+
'direction': self.hparams.loss_hparams['rotation_weight'],
|
| 317 |
+
'period': 1,
|
| 318 |
+
'lumen_width': self.hparams.loss_hparams['lumen_width_weight'],
|
| 319 |
+
'regularization': self.hparams.loss_hparams.get('regularization_weight', 0.)
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
def _get_regularizer(self, regularizer_params):
|
| 323 |
+
if regularizer_params is None:
|
| 324 |
+
return None
|
| 325 |
+
else:
|
| 326 |
+
return instantiate(regularizer_params)
|
| 327 |
+
|
| 328 |
+
def forward(self, x, return_raw=False):
|
| 329 |
+
"""get predictions from image batch"""
|
| 330 |
+
preds = self.model(x) # preds: logit angle_sin, logit angle_cos, period, logit lumen fraction or logit angle, period, logit lumen fraction
|
| 331 |
+
preds_direction = self.angle_parser(preds)
|
| 332 |
+
preds_period = preds[:,-2]
|
| 333 |
+
preds_lumen_width = preds[:,-1] #lumen fraction is between 0 and 1, so we take sigmoid fo this
|
| 334 |
+
|
| 335 |
+
outputs = [preds_direction, preds_period, preds_lumen_width]
|
| 336 |
+
if return_raw:
|
| 337 |
+
outputs.append(preds)
|
| 338 |
+
|
| 339 |
+
return tuple(outputs)
|
| 340 |
+
|
| 341 |
+
def configure_optimizers(self):
|
| 342 |
+
# AdamW is Adam with a correct implementation of weight decay (see here
|
| 343 |
+
# for details: https://arxiv.org/pdf/1711.05101.pdf)
|
| 344 |
+
optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.lr, **self.hparams.optimizer_hparams)
|
| 345 |
+
# scheduler = getattr(torch.optim.lr_scheduler, self.hparams.lr_hparams['classname'])(optimizer, **self.hparams.lr_hparams['kwargs'])
|
| 346 |
+
scheduler = instantiate({**self.hparams.lr_hparams, '_partial_': True})(optimizer)
|
| 347 |
+
return [optimizer], [scheduler]
|
| 348 |
+
|
| 349 |
+
def process_batch_supervised(self, batch):
|
| 350 |
+
"""get predictions, losses and mean errors (MAE)"""
|
| 351 |
+
|
| 352 |
+
# get predictions
|
| 353 |
+
preds = {}
|
| 354 |
+
preds['direction'], preds['period'], preds['lumen_width'], preds_raw = self.forward(batch['image'], return_raw=True) # preds: angle, period, lumen fraction, raw preds
|
| 355 |
+
|
| 356 |
+
# calculate losses
|
| 357 |
+
losses = {
|
| 358 |
+
'direction': self.losses['direction'](2*batch['direction'], 2*preds['direction']),
|
| 359 |
+
'period': self.losses['period'](batch['period'], preds['period']),
|
| 360 |
+
'lumen_width': self.losses['lumen_width'](batch['lumen_width'], preds['lumen_width']),
|
| 361 |
+
}
|
| 362 |
+
if self.regularizer is not None:
|
| 363 |
+
losses['regularization'] = self.regularizer(preds_raw[:,:2])
|
| 364 |
+
|
| 365 |
+
losses['final'] = \
|
| 366 |
+
losses['direction']*self.losses_weights['direction'] + \
|
| 367 |
+
losses['period']*self.losses_weights['period'] + \
|
| 368 |
+
losses['lumen_width']*self.losses_weights['lumen_width'] + \
|
| 369 |
+
losses.get('regularization', 0.)*self.losses_weights.get('regularization', 0.)
|
| 370 |
+
|
| 371 |
+
# calculate mean errors
|
| 372 |
+
period_difference = np.mean(abs(
|
| 373 |
+
batch['period'].detach().cpu().numpy() - \
|
| 374 |
+
preds['period'].detach().cpu().numpy()
|
| 375 |
+
))
|
| 376 |
+
|
| 377 |
+
a1 = batch['direction'].detach().cpu().numpy()
|
| 378 |
+
a2 = preds['direction'].detach().cpu().numpy()
|
| 379 |
+
angle_difference = np.mean(0.5*np.degrees(np.arccos(np.cos(2*np.radians(a2-a1)))))
|
| 380 |
+
|
| 381 |
+
lumen_width_difference = np.mean(abs(preds['lumen_width'].detach().cpu().numpy()-batch['lumen_width'].detach().cpu().numpy()))
|
| 382 |
+
|
| 383 |
+
lumen_fraction_pred = preds['lumen_width'].detach().cpu().numpy()/preds['period'].detach().cpu().numpy()
|
| 384 |
+
lumen_fraction_gt = batch['lumen_width'].detach().cpu().numpy()/batch['period'].detach().cpu().numpy()
|
| 385 |
+
lumen_fraction_difference = np.mean(abs(lumen_fraction_pred-lumen_fraction_gt))
|
| 386 |
+
|
| 387 |
+
mae = {
|
| 388 |
+
'period': period_difference,
|
| 389 |
+
'direction': angle_difference,
|
| 390 |
+
'lumen_width': lumen_width_difference,
|
| 391 |
+
'lumen_fraction': lumen_fraction_difference
|
| 392 |
+
}
|
| 393 |
+
|
| 394 |
+
return preds, losses, mae
|
| 395 |
+
|
| 396 |
+
def log_all(self, losses, mae, prefix=''):
|
| 397 |
+
for k, v in losses.items():
|
| 398 |
+
self.log(f'{prefix}{k}_loss', v.item() if isinstance(v, torch.Tensor) else v)
|
| 399 |
+
for k, v in mae.items():
|
| 400 |
+
self.log(f'{prefix}{k}_difference', v.item() if isinstance(v, torch.Tensor) else v)
|
| 401 |
+
|
| 402 |
+
def training_step(self, batch, batch_idx):
|
| 403 |
+
# "batch" is the output of the training data loader.
|
| 404 |
+
preds, losses, mae = self.process_batch_supervised(batch)
|
| 405 |
+
self.log_all(losses, mae, prefix='train_')
|
| 406 |
+
|
| 407 |
+
return losses['final']
|
| 408 |
+
|
| 409 |
+
def validation_step(self, batch, batch_idx):
|
| 410 |
+
preds, losses, mae = self.process_batch_supervised(batch)
|
| 411 |
+
self.log_all(losses, mae, prefix='val_')
|
| 412 |
+
|
| 413 |
+
def test_step(self, batch, batch_idx):
|
| 414 |
+
preds, losses, mae = self.process_batch_supervised(batch)
|
| 415 |
+
self.log_all(losses, mae, prefix='test_')
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
# class StripsModel(StripsModelGeneral):
|
| 420 |
+
# def __init__(self, model_name, *args, **kwargs):
|
| 421 |
+
# super().__init__( *args, **kwargs)
|
| 422 |
+
# self.model = timm.create_model(model_name, in_chans=1, num_classes=4)
|
| 423 |
+
# def forward(self, x):
|
| 424 |
+
# """get predictions from image batch"""
|
| 425 |
+
# preds = self.model(x) # preds: logit angle_sin, logit angle_cos, period, logit lumen fraction
|
| 426 |
+
# preds_sin = 1. - 2*torch.sigmoid(preds[:,0])
|
| 427 |
+
# preds_cos = 1. - 2*torch.sigmoid(preds[:,1])
|
| 428 |
+
# preds_direction = 0.5*torch.rad2deg(torch.arctan2(preds_sin, preds_cos))
|
| 429 |
+
# preds_period = preds[:,2]
|
| 430 |
+
# preds_lumen_fraction = torch.sigmoid(preds[:,3]) #lumen fraction is between 0 and 1, so we take sigmoid fo this
|
| 431 |
+
# return preds_direction, preds_period, preds_lumen_fraction
|
| 432 |
+
|
| 433 |
+
# class StripsModelAngle1(StripsModelGeneral):
|
| 434 |
+
# def __init__(self, model_name, *args, **kwargs):
|
| 435 |
+
# super().__init__( *args, **kwargs)
|
| 436 |
+
# self.model = timm.create_model(model_name, in_chans=1, num_classes=3)
|
| 437 |
+
# def forward(self, x):
|
| 438 |
+
# """get predictions from image batch"""
|
| 439 |
+
# preds = self.model(x) # preds: logit angle_sin, logit angle
|
| 440 |
+
# preds_direction = torch.pi * torch.sigmoid(preds[:,0])
|
| 441 |
+
# preds_period = preds[:,1]
|
| 442 |
+
# preds_lumen_fraction = torch.sigmoid(preds[:,2]) #lumen fraction is between 0 and 1, so we take sigmoid fo this
|
| 443 |
+
# return preds_direction, preds_period, preds_lumen_fraction
|
| 444 |
+
|
angle_calculation/classic.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from scipy import signal
|
| 3 |
+
from scipy import ndimage
|
| 4 |
+
from scipy.fftpack import next_fast_len
|
| 5 |
+
from skimage.transform import rotate
|
| 6 |
+
from skimage._shared.utils import convert_to_float
|
| 7 |
+
from skimage.transform import warp
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
import cv2
|
| 10 |
+
from copy import deepcopy
|
| 11 |
+
|
| 12 |
+
def get_directional_std(image, theta=None,*, preserve_range=False):
|
| 13 |
+
|
| 14 |
+
if image.ndim != 2:
|
| 15 |
+
raise ValueError('The input image must be 2-D')
|
| 16 |
+
if theta is None:
|
| 17 |
+
theta = np.arange(180)
|
| 18 |
+
|
| 19 |
+
image = convert_to_float(image.copy(), preserve_range) #TODO: needed?
|
| 20 |
+
|
| 21 |
+
shape_min = min(image.shape)
|
| 22 |
+
img_shape = np.array(image.shape)
|
| 23 |
+
|
| 24 |
+
# Crop image to make it square
|
| 25 |
+
slices = tuple(slice(int(np.ceil(excess / 2)),
|
| 26 |
+
int(np.ceil(excess / 2) + shape_min))
|
| 27 |
+
if excess > 0 else slice(None)
|
| 28 |
+
for excess in (img_shape - shape_min))
|
| 29 |
+
image = image[slices]
|
| 30 |
+
shape_min = min(image.shape)
|
| 31 |
+
img_shape = np.array(image.shape)
|
| 32 |
+
|
| 33 |
+
radius = shape_min // 2
|
| 34 |
+
coords = np.array(np.ogrid[:image.shape[0], :image.shape[1]],
|
| 35 |
+
dtype=object)
|
| 36 |
+
dist = ((coords - img_shape // 2) ** 2).sum(0)
|
| 37 |
+
outside_reconstruction_circle = dist > radius ** 2
|
| 38 |
+
image[outside_reconstruction_circle] = 0
|
| 39 |
+
|
| 40 |
+
valid_square_slice = slice(int(np.ceil(radius*(1-1/np.sqrt(2)))), int(np.ceil(radius*(1+1/np.sqrt(2)))) )
|
| 41 |
+
|
| 42 |
+
# padded_image is always square
|
| 43 |
+
if image.shape[0] != image.shape[1]:
|
| 44 |
+
raise ValueError('padded_image must be a square')
|
| 45 |
+
center = image.shape[0] // 2
|
| 46 |
+
result = np.zeros(len(theta))
|
| 47 |
+
|
| 48 |
+
for i, angle in enumerate(np.deg2rad(theta)):
|
| 49 |
+
cos_a, sin_a = np.cos(angle), np.sin(angle)
|
| 50 |
+
R = np.array([[cos_a, sin_a, -center * (cos_a + sin_a - 1)],
|
| 51 |
+
[-sin_a, cos_a, -center * (cos_a - sin_a - 1)],
|
| 52 |
+
[0, 0, 1]])
|
| 53 |
+
rotated = warp(image, R, clip=False)
|
| 54 |
+
result[i] = rotated[valid_square_slice, valid_square_slice].std(axis=0).mean()
|
| 55 |
+
return result
|
| 56 |
+
|
| 57 |
+
def acf2d(x, nlags=None):
|
| 58 |
+
xo = x - x.mean(axis=0)
|
| 59 |
+
n = len(x)
|
| 60 |
+
if nlags is None:
|
| 61 |
+
nlags = n -1
|
| 62 |
+
lag_len = nlags
|
| 63 |
+
|
| 64 |
+
xi = np.arange(1, n + 1)
|
| 65 |
+
d = np.expand_dims(np.hstack((xi, xi[:-1][::-1])),1)
|
| 66 |
+
|
| 67 |
+
nobs = len(xo)
|
| 68 |
+
n = next_fast_len(2 * nobs + 1)
|
| 69 |
+
Frf = np.fft.fft(xo, n=n, axis=0)
|
| 70 |
+
|
| 71 |
+
acov = np.fft.ifft(Frf * np.conjugate(Frf), axis=0)[:nobs] / d[nobs - 1 :]
|
| 72 |
+
acov = acov.real
|
| 73 |
+
ac = acov[: nlags + 1] / acov[:1]
|
| 74 |
+
return ac
|
| 75 |
+
|
| 76 |
+
def get_period(acf_table, n_samples=50):
|
| 77 |
+
#TODO: use peak heights to select best candidates. use std to eliminate outliers
|
| 78 |
+
period_candidates = []
|
| 79 |
+
period_candidates_hights = []
|
| 80 |
+
for i in np.random.randint(0, acf_table.shape[1], min(acf_table.shape[1], n_samples)):
|
| 81 |
+
peaks = signal.find_peaks(acf_table[:,i])[0]
|
| 82 |
+
if len(peaks) == 0:
|
| 83 |
+
continue
|
| 84 |
+
peak_idx = peaks[0]
|
| 85 |
+
period_candidates.append(peak_idx)
|
| 86 |
+
period_candidates_hights.append(acf_table[peak_idx,i])
|
| 87 |
+
period_candidates = np.array(period_candidates)
|
| 88 |
+
period_candidates_hights = np.array(period_candidates_hights)
|
| 89 |
+
|
| 90 |
+
if len(period_candidates) == 0:
|
| 91 |
+
return np.nan, np.nan
|
| 92 |
+
elif len(period_candidates) == 1:
|
| 93 |
+
return period_candidates[0], np.nan
|
| 94 |
+
q1, q3 = np.quantile(period_candidates, [0.25, 0.75])
|
| 95 |
+
candidates_std = np.std(period_candidates[(period_candidates>=q1)&(period_candidates<=q3)])
|
| 96 |
+
# return period_candidates, period_candidates_hights
|
| 97 |
+
return np.median(period_candidates), candidates_std
|
| 98 |
+
|
| 99 |
+
def get_rotation_with_confidence(padded_image, blur_size=4, make_plots=True):
|
| 100 |
+
std_by_angle = get_directional_std(cv2.blur(padded_image, (blur_size,blur_size)))
|
| 101 |
+
rotation_angle = np.argmin(std_by_angle)
|
| 102 |
+
|
| 103 |
+
rotation_quality = 1 - np.min(std_by_angle)/np.median(std_by_angle)
|
| 104 |
+
if make_plots:
|
| 105 |
+
plt.plot(std_by_angle)
|
| 106 |
+
plt.axvline(rotation_angle, c='k')
|
| 107 |
+
plt.title(f'quality: {rotation_quality:0.2f}')
|
| 108 |
+
return rotation_angle, rotation_quality
|
| 109 |
+
|
| 110 |
+
def calculate_autocorrelation(oriented_img, blur_kernel=(7,1), make_plots=True):
|
| 111 |
+
autocorrelation = acf2d(cv2.blur(oriented_img.T, blur_kernel))
|
| 112 |
+
if make_plots:
|
| 113 |
+
fig, axs = plt.subplots(ncols=2, figsize=(12,6))
|
| 114 |
+
axs[0].imshow(autocorrelation)
|
| 115 |
+
axs[1].plot(autocorrelation.sum(axis=1))
|
| 116 |
+
return autocorrelation
|
| 117 |
+
|
| 118 |
+
def get_period_with_confidence(autocorrelation_tab, n_samples=30):
|
| 119 |
+
period, period_std = get_period(autocorrelation_tab, n_samples=n_samples)
|
| 120 |
+
if period_std == np.nan:
|
| 121 |
+
period_confidence = 0.001
|
| 122 |
+
else:
|
| 123 |
+
period_confidence = period/(period+2*period_std)
|
| 124 |
+
return period, period_confidence
|
| 125 |
+
|
| 126 |
+
def calculate_white_fraction(img, blur_size=4, make_plots=True): #TODO: add mask
|
| 127 |
+
blurred = cv2.blur(img, (blur_size, blur_size))
|
| 128 |
+
blurred_sum = blurred.sum(axis=0)
|
| 129 |
+
lower, upper = np.quantile(blurred_sum, [0.15, 0.85])
|
| 130 |
+
sign = blurred_sum > (lower+upper)/2
|
| 131 |
+
|
| 132 |
+
sign_change = sign[:-1] != sign[1:]
|
| 133 |
+
sign_change_indices = np.where(sign_change)[0]
|
| 134 |
+
|
| 135 |
+
if len(sign_change_indices) >= 2 + (sign[-1] == sign[0]):
|
| 136 |
+
cut_first = sign_change_indices[0]+1
|
| 137 |
+
|
| 138 |
+
if sign[-1] == sign[0]:
|
| 139 |
+
cut_last = sign_change_indices[-2]
|
| 140 |
+
else:
|
| 141 |
+
cut_last = sign_change_indices[-1]
|
| 142 |
+
|
| 143 |
+
white_fraction = np.mean(sign[cut_first:cut_last])
|
| 144 |
+
else:
|
| 145 |
+
white_fraction = np.nan
|
| 146 |
+
cut_first, cut_last = None, None
|
| 147 |
+
if make_plots:
|
| 148 |
+
fig, axs = plt.subplots(ncols=3, figsize=(16,6))
|
| 149 |
+
blurred_sum_normalized = blurred_sum - blurred_sum.min()
|
| 150 |
+
blurred_sum_normalized /= blurred_sum_normalized.max()
|
| 151 |
+
axs[0].plot(blurred_sum_normalized)
|
| 152 |
+
axs[0].plot(sign)
|
| 153 |
+
axs[1].plot(blurred_sum_normalized[cut_first:cut_last])
|
| 154 |
+
axs[1].plot(sign[cut_first:cut_last])
|
| 155 |
+
axs[2].imshow(img, cmap='gray')
|
| 156 |
+
for i, idx in enumerate(sign_change_indices):
|
| 157 |
+
plt.axvline(idx, c=['r', 'lime'][i%2])
|
| 158 |
+
fig.suptitle(f'fraction: {white_fraction:0.2f}')
|
| 159 |
+
|
| 160 |
+
return white_fraction
|
| 161 |
+
|
| 162 |
+
def process_img_crop(img, nm_per_px=1, make_plots=False, return_extra=False):
|
| 163 |
+
|
| 164 |
+
# image must be square
|
| 165 |
+
assert img.shape[0] == img.shape[1]
|
| 166 |
+
crop_size = img.shape[0]
|
| 167 |
+
|
| 168 |
+
# find orientation
|
| 169 |
+
rotation_angle, rotation_quality = get_rotation_with_confidence(img, blur_size=4, make_plots=make_plots)
|
| 170 |
+
|
| 171 |
+
# rotate and crop image
|
| 172 |
+
crop_margin = int((1 - 1/np.sqrt(2))*crop_size*0.5)
|
| 173 |
+
oriented_img = rotate(img, -rotation_angle)[2*crop_margin:-crop_margin, crop_margin:-crop_margin]
|
| 174 |
+
|
| 175 |
+
# calculate autocorrelation
|
| 176 |
+
autocorrelation = calculate_autocorrelation(oriented_img, blur_kernel=(7,1), make_plots=make_plots)
|
| 177 |
+
|
| 178 |
+
# find period
|
| 179 |
+
period, period_confidence = get_period_with_confidence(autocorrelation)
|
| 180 |
+
if make_plots:
|
| 181 |
+
print(f'period: {period}, confidence: {period_confidence}')
|
| 182 |
+
|
| 183 |
+
# find white fraction
|
| 184 |
+
white_fraction = calculate_white_fraction(oriented_img, make_plots=make_plots)
|
| 185 |
+
white_width = white_fraction*period
|
| 186 |
+
|
| 187 |
+
result = {
|
| 188 |
+
'direction': rotation_angle,
|
| 189 |
+
'direction confidence': rotation_quality,
|
| 190 |
+
'period': period*nm_per_px,
|
| 191 |
+
'period confidence': period_confidence,
|
| 192 |
+
'lumen width': white_width*nm_per_px
|
| 193 |
+
}
|
| 194 |
+
if return_extra:
|
| 195 |
+
result['extra'] = {
|
| 196 |
+
'autocorrelation': autocorrelation,
|
| 197 |
+
'oriented_img': oriented_img
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
return result
|
| 201 |
+
|
| 202 |
+
def get_top_k(a, k):
|
| 203 |
+
ind = np.argpartition(a, -k)[-k:]
|
| 204 |
+
return a[ind]
|
| 205 |
+
|
| 206 |
+
def get_crops(img, distance_map, crop_size, N_sample):
|
| 207 |
+
crop_r= np.sqrt(2)*crop_size / 2
|
| 208 |
+
possible_positions_y, possible_positions_x = np.where(distance_map >= crop_r)
|
| 209 |
+
no_edge_mask = (possible_positions_y>crop_r) & \
|
| 210 |
+
(possible_positions_x>crop_r) & \
|
| 211 |
+
(possible_positions_y<(distance_map.shape[0]-crop_r)) & \
|
| 212 |
+
(possible_positions_x<(distance_map.shape[1]-crop_r))
|
| 213 |
+
|
| 214 |
+
possible_positions_x = possible_positions_x[no_edge_mask]
|
| 215 |
+
possible_positions_y = possible_positions_y[no_edge_mask]
|
| 216 |
+
N_available = len(possible_positions_x)
|
| 217 |
+
positions_indices = np.random.choice(np.arange(N_available), min(N_sample, N_available), replace=False)
|
| 218 |
+
|
| 219 |
+
for idx in positions_indices:
|
| 220 |
+
yield img[possible_positions_y[idx]-crop_size//2:possible_positions_y[idx]+crop_size//2,possible_positions_x[idx]-crop_size//2:possible_positions_x[idx]+crop_size//2].copy()
|
| 221 |
+
|
| 222 |
+
def sliced_mean(x, slice_size):
|
| 223 |
+
cs_y = np.cumsum(x, axis=0)
|
| 224 |
+
cs_y = np.concatenate((np.zeros((1, cs_y.shape[1]), dtype=cs_y.dtype), cs_y), axis=0)
|
| 225 |
+
slices_y = (cs_y[slice_size:] - cs_y[:-slice_size])/slice_size
|
| 226 |
+
cs_xy = np.cumsum(slices_y, axis=1)
|
| 227 |
+
cs_xy = np.concatenate((np.zeros((cs_xy.shape[0], 1), dtype=cs_xy.dtype), cs_xy), axis=1)
|
| 228 |
+
slices_xy = (cs_xy[:,slice_size:] - cs_xy[:,:-slice_size])/slice_size
|
| 229 |
+
return slices_xy
|
| 230 |
+
|
| 231 |
+
def sliced_var(x, slice_size):
|
| 232 |
+
x = x.astype('float64')
|
| 233 |
+
return sliced_mean(x**2, slice_size) - sliced_mean(x, slice_size)**2
|
| 234 |
+
|
| 235 |
+
def select_samples(granum_image, granum_mask, crop_size=96, n_samples=64, granum_fraction_min=1.0, variance_p=2):
|
| 236 |
+
granum_occupancy = sliced_mean(granum_mask, crop_size)
|
| 237 |
+
possible_indices = np.stack(np.where(granum_occupancy >= granum_fraction_min), axis=1)
|
| 238 |
+
|
| 239 |
+
if variance_p == 0:
|
| 240 |
+
p = np.ones(len(possible_indices))
|
| 241 |
+
else:
|
| 242 |
+
variance_map = sliced_var(granum_image, crop_size)
|
| 243 |
+
p = variance_map[possible_indices[:,0], possible_indices[:,1]]**variance_p
|
| 244 |
+
p /= np.sum(p)
|
| 245 |
+
|
| 246 |
+
chosen_indices = np.random.choice(
|
| 247 |
+
np.arange(len(possible_indices)),
|
| 248 |
+
min(len(possible_indices), n_samples),
|
| 249 |
+
replace=False,
|
| 250 |
+
p = p
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
crops = []
|
| 254 |
+
for crop_idx, idx in enumerate(chosen_indices):
|
| 255 |
+
crops.append(
|
| 256 |
+
granum_image[
|
| 257 |
+
possible_indices[idx,0]:possible_indices[idx,0]+crop_size,
|
| 258 |
+
possible_indices[idx,1]:possible_indices[idx,1]+crop_size
|
| 259 |
+
]
|
| 260 |
+
)
|
| 261 |
+
return np.array(crops)
|
| 262 |
+
|
| 263 |
+
def calculate_distance_map(mask):
|
| 264 |
+
padded = np.pad(mask, pad_width=1, mode='constant', constant_values=False)
|
| 265 |
+
distance_map_padded = ndimage.distance_transform_edt(padded)
|
| 266 |
+
return distance_map_padded[1:-1,1:-1]
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def measure_object(
|
| 270 |
+
img, mask,
|
| 271 |
+
nm_per_px=1, n_tries = 3,
|
| 272 |
+
direction_thr_min = 0.07, direction_thr_enough = 0.1,
|
| 273 |
+
crop_size = 200,
|
| 274 |
+
**kwargs):
|
| 275 |
+
|
| 276 |
+
distance_map = calculate_distance_map(mask)
|
| 277 |
+
crop_size = min(crop_size, int(min(get_top_k(distance_map.flatten(), n_tries)*0.5**0.5)))
|
| 278 |
+
|
| 279 |
+
direction_confidence = 0
|
| 280 |
+
best_stripes_data = {}
|
| 281 |
+
for i, img_crop in enumerate(get_crops(img, distance_map, crop_size, n_tries)):
|
| 282 |
+
stripes_data = process_img_crop(img_crop, nm_per_px=nm_per_px)
|
| 283 |
+
if stripes_data['direction confidence'] >= direction_confidence:
|
| 284 |
+
best_stripes_data = deepcopy(stripes_data)
|
| 285 |
+
direction_confidence = stripes_data['direction confidence']
|
| 286 |
+
if direction_confidence > direction_thr_enough:
|
| 287 |
+
break
|
| 288 |
+
|
| 289 |
+
result = best_stripes_data
|
| 290 |
+
|
| 291 |
+
if direction_confidence >= direction_thr_min:
|
| 292 |
+
|
| 293 |
+
mask_oriented = rotate(mask, 90-result['direction'], resize=True).astype('bool')
|
| 294 |
+
idx_begin_x, idx_end_x = np.where(np.any(mask_oriented, axis=0))[0][np.array([0, -1])]
|
| 295 |
+
idx_begin_y, idx_end_y = np.where(np.any(mask_oriented, axis=1))[0][np.array([0, -1])]
|
| 296 |
+
result['mask_oriented'] = mask_oriented[idx_begin_y:idx_end_y, idx_begin_x:idx_end_x]
|
| 297 |
+
result['img_oriented'] = rotate(img, 90-result['direction'], resize=True)[idx_begin_y:idx_end_y, idx_begin_x:idx_end_x]
|
| 298 |
+
|
| 299 |
+
# measurements = measure_granum_shape(result['mask_oriented'], nm_per_px=nm_per_px, oriented=True)
|
| 300 |
+
# else:
|
| 301 |
+
# measurements = measure_granum_shape(mask, nm_per_px=nm_per_px, oriented=False)
|
| 302 |
+
|
| 303 |
+
# result.update(**measurements)
|
| 304 |
+
# N_layers = result['height'] / result['period']
|
| 305 |
+
# if np.isfinite(N_layers):
|
| 306 |
+
# N_layers = round(N_layers)
|
| 307 |
+
|
| 308 |
+
return result
|
| 309 |
+
|
| 310 |
+
# def measure_object(
|
| 311 |
+
# img, mask,
|
| 312 |
+
# nm_per_px=1, n_tries = 3,
|
| 313 |
+
# direction_thr_min = 0.07, direction_thr_enough = 0.1,
|
| 314 |
+
# crop_size = 200,
|
| 315 |
+
# **kwargs):
|
| 316 |
+
|
| 317 |
+
# distance_map = calculate_distance_map(mask)
|
| 318 |
+
# crop_size = min(crop_size, int((min(get_top_k(distance_map.flatten(), n_tries)*0.5)**0.5)))
|
| 319 |
+
|
| 320 |
+
# direction_confidence = 0
|
| 321 |
+
# best_stripes_data = {}
|
| 322 |
+
# for i, img_crop in enumerate(select_samples(img, mask, crop_size=crop_size, n_samples=n_tries)):
|
| 323 |
+
# stripes_data = process_img_crop(img_crop, nm_per_px=nm_per_px)
|
| 324 |
+
# if stripes_data['direction_confidence'] >= direction_confidence:
|
| 325 |
+
# best_stripes_data = deepcopy(stripes_data)
|
| 326 |
+
# direction_confidence = stripes_data['direction_confidence']
|
| 327 |
+
# if direction_confidence > direction_thr_enough:
|
| 328 |
+
# break
|
| 329 |
+
|
| 330 |
+
# result = best_stripes_data
|
| 331 |
+
|
| 332 |
+
# if direction_confidence >= direction_thr_min:
|
| 333 |
+
|
| 334 |
+
# mask_oriented = rotate(mask, 90-result['direction'], resize=True).astype('bool')
|
| 335 |
+
# idx_begin_x, idx_end_x = np.where(np.any(mask_oriented, axis=0))[0][np.array([0, -1])]
|
| 336 |
+
# idx_begin_y, idx_end_y = np.where(np.any(mask_oriented, axis=1))[0][np.array([0, -1])]
|
| 337 |
+
# result['mask_oriented'] = mask_oriented[idx_begin_y:idx_end_y, idx_begin_x:idx_end_x]
|
| 338 |
+
# result['img_oriented'] = rotate(img, 90-result['direction'], resize=True)[idx_begin_y:idx_end_y, idx_begin_x:idx_end_x]
|
| 339 |
+
|
| 340 |
+
# # measurements = measure_granum_shape(result['mask_oriented'], nm_per_px=nm_per_px, oriented=True)
|
| 341 |
+
# # else:
|
| 342 |
+
# # measurements = measure_granum_shape(mask, nm_per_px=nm_per_px, oriented=False)
|
| 343 |
+
|
| 344 |
+
# # result.update(**measurements)
|
| 345 |
+
# # N_layers = result['height'] / result['period']
|
| 346 |
+
# # if np.isfinite(N_layers):
|
| 347 |
+
# # N_layers = round(N_layers)
|
| 348 |
+
|
| 349 |
+
# return result #{**measurements, **best_stripes_data, 'N layers': N_layers}
|
angle_calculation/envelope_correction.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import scipy
|
| 3 |
+
|
| 4 |
+
def detect_boundaries(mask, axis):
|
| 5 |
+
# calculate the boundaries of the mask
|
| 6 |
+
#axis = 0 results in x_from, x_to
|
| 7 |
+
#axis = 1 results in y_from, y_to
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
sum = mask.sum(axis=axis)
|
| 11 |
+
|
| 12 |
+
ind_from = min(sum.nonzero()[0])
|
| 13 |
+
ind_to = max(sum.nonzero()[0])
|
| 14 |
+
return ind_from, ind_to
|
| 15 |
+
|
| 16 |
+
def area(mask):
|
| 17 |
+
x1, y1 = detect_boundaries(mask, 0)
|
| 18 |
+
a = y1 - x1
|
| 19 |
+
x2, y2 = detect_boundaries(mask, 1)
|
| 20 |
+
b = y2 - x2
|
| 21 |
+
|
| 22 |
+
return (a * b, x1, y1, x2, y2)
|
| 23 |
+
|
| 24 |
+
def calculate_best_angle_from_mask(mask, angles=np.arange(-10,10,0.5)):
|
| 25 |
+
areas = []
|
| 26 |
+
for angle in angles:
|
| 27 |
+
rotated_mask = scipy.ndimage.rotate(mask, angle, reshape=True, order = 0) # order = 0 is the nearest neighbor interpolation, so the mask is not interpolated
|
| 28 |
+
this_area = area(rotated_mask)
|
| 29 |
+
areas.append(this_area[0])
|
| 30 |
+
|
| 31 |
+
best_angle = angles[np.argmin(areas)]
|
| 32 |
+
return best_angle
|
| 33 |
+
|
| 34 |
+
|
angle_calculation/granum_utils.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 2 |
+
import numpy as np
|
| 3 |
+
from scipy import ndimage
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from typing import Any, List
|
| 7 |
+
from zipfile import ZipFile
|
| 8 |
+
|
| 9 |
+
def add_text(image: Image.Image, text: str, location=(0.5, 0.5), color='red', size=40) -> Image.Image:
|
| 10 |
+
draw = ImageDraw.Draw(image)
|
| 11 |
+
font = ImageFont.load_default(size=size)
|
| 12 |
+
draw.text((int(image.size[0]*location[0]), int(image.size[1]*location[1])), text, font=font, fill=color)
|
| 13 |
+
return image
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def select_unique_mask(mask):
|
| 17 |
+
"""if mask consists of multiple parts, select the largest"""
|
| 18 |
+
blobs = ndimage.label(mask)[0]
|
| 19 |
+
blob_labels, blob_sizes = np.unique(blobs, return_counts=True)
|
| 20 |
+
best_blob_label = blob_labels[1:][np.argmax(blob_sizes[1:])]
|
| 21 |
+
return blobs == best_blob_label
|
| 22 |
+
|
| 23 |
+
def object_slice(mask, margin=128):
|
| 24 |
+
rows = np.any(mask, axis=1)
|
| 25 |
+
cols = np.any(mask, axis=0)
|
| 26 |
+
row_min, row_max = np.where(rows)[0][[0, -1]]
|
| 27 |
+
col_min, col_max = np.where(cols)[0][[0, -1]]
|
| 28 |
+
|
| 29 |
+
# Create a slice object for the bounding box
|
| 30 |
+
bounding_box_slice = (
|
| 31 |
+
slice(max(0,row_min-margin), min(row_max + 1+margin, len(rows)+1)),
|
| 32 |
+
slice(max(0,col_min-margin), min(col_max + 1+margin, len(cols)+1))
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
return bounding_box_slice
|
| 36 |
+
|
| 37 |
+
def resize_to(image: Image.Image, s=4032) -> Image.Image:
|
| 38 |
+
w, h = image.size
|
| 39 |
+
longest_size = max(h, w)
|
| 40 |
+
|
| 41 |
+
resize_factor = longest_size / s
|
| 42 |
+
|
| 43 |
+
resized_image = image.resize((int(w/resize_factor), int(h/resize_factor)))
|
| 44 |
+
return resized_image
|
| 45 |
+
|
| 46 |
+
def rolling_mean(x, window):
|
| 47 |
+
cs = np.r_[0, np.cumsum(x)]
|
| 48 |
+
rolling_sum = cs[window:] - cs[:-window]
|
| 49 |
+
return rolling_sum/window
|
| 50 |
+
|
| 51 |
+
@dataclass
|
| 52 |
+
class Granum:
|
| 53 |
+
image: Any = None#Optional[np.ndarray] = None
|
| 54 |
+
mask: Any = None #Optional[np.ndarray] = None
|
| 55 |
+
scaler: Any = None
|
| 56 |
+
nm_per_px: float = float('nan')
|
| 57 |
+
detection_confidence: float = float('nan')
|
| 58 |
+
|
| 59 |
+
def zip_files(files: List[str], output_name: str) -> None:
|
| 60 |
+
with ZipFile(output_name, "w") as zipObj:
|
| 61 |
+
for file in files:
|
| 62 |
+
zipObj.write(file)
|
| 63 |
+
|
| 64 |
+
def filter_boundary_detections(masks, scaler=None):
|
| 65 |
+
last_index_right = -1 if scaler is None else masks.shape[1]-1-scaler.pad_right
|
| 66 |
+
last_index_bottom = -1 if scaler is None else masks.shape[2]-1-scaler.pad_bottom
|
| 67 |
+
doesnt_touch_boundary_mask = ~(np.any(masks[:,0,:] != 0, axis=1) | np.any(masks[:,last_index_right:,:] != 0, axis=(1,2)) | np.any(masks[:,:,0] != 0, axis=1) | np.any(masks[:,:,last_index_bottom:] != 0, axis=(1,2)))
|
| 68 |
+
return doesnt_touch_boundary_mask
|
| 69 |
+
|
| 70 |
+
def get_circle_mask(shape, r=None):
|
| 71 |
+
if isinstance(shape, int):
|
| 72 |
+
shape = (shape, shape)
|
| 73 |
+
if r is None:
|
| 74 |
+
r = min(shape)/2
|
| 75 |
+
X, Y = np.meshgrid(np.arange(shape[1]), np.arange(shape[0]))
|
| 76 |
+
center_x = shape[1] / 2 - 0.5
|
| 77 |
+
center_y = shape[0] / 2 - 0.5
|
| 78 |
+
|
| 79 |
+
mask = ((X-center_x)**2 + (Y-center_y)**2) >= r**2
|
| 80 |
+
return mask
|
angle_calculation/image_transforms.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import cv2
|
| 3 |
+
|
| 4 |
+
def batched_radon(image_batch):
|
| 5 |
+
batch_size, img_size = image_batch.shape[:2]
|
| 6 |
+
if batch_size > 512: # limit batch size to 512 because cv2.warpAffine fails for batch> 512
|
| 7 |
+
return np.concatenate([batched_radon(image_batch[i:i+512]) for i in range(0,batch_size,512)], axis=0)
|
| 8 |
+
theta = np.arange(180)
|
| 9 |
+
radon_image = np.zeros((image_batch.shape[0], img_size, len(theta)),
|
| 10 |
+
dtype='float32')
|
| 11 |
+
|
| 12 |
+
for i, angle in enumerate(theta):
|
| 13 |
+
M = cv2.getRotationMatrix2D(((img_size-1)/2.0,(img_size-1)/2.0),angle,1)
|
| 14 |
+
rotated = cv2.warpAffine(np.transpose(image_batch, (1, 2, 0)),M,(img_size,img_size))
|
| 15 |
+
if batch_size == 1: # cv2.warpAffine cancels batch dimension if equal to 1
|
| 16 |
+
rotated = rotated[:,:, np.newaxis]
|
| 17 |
+
rotated = np.transpose(rotated, (2, 0, 1))
|
| 18 |
+
rotated = rotated / np.array(255, dtype='float32')
|
| 19 |
+
radon_image[:, :, i] = rotated.sum(axis=1)
|
| 20 |
+
return radon_image
|
| 21 |
+
|
| 22 |
+
def get_center_crop_coords(height: int, width: int, crop_height: int, crop_width: int):
|
| 23 |
+
"""from https://github.com/albumentations-team/albumentations/blob/master/albumentations/augmentations/crops/functional.py"""
|
| 24 |
+
y1 = (height - crop_height) // 2
|
| 25 |
+
y2 = y1 + crop_height
|
| 26 |
+
x1 = (width - crop_width) // 2
|
| 27 |
+
x2 = x1 + crop_width
|
| 28 |
+
return x1, y1, x2, y2
|
| 29 |
+
|
| 30 |
+
def center_crop(img: np.ndarray, crop_height: int, crop_width: int):
|
| 31 |
+
height, width = img.shape[:2]
|
| 32 |
+
x1, y1, x2, y2 = get_center_crop_coords(height, width, crop_height, crop_width)
|
| 33 |
+
img = img[y1:y2, x1:x2]
|
| 34 |
+
return img
|
angle_calculation/sampling.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image
|
| 2 |
+
from typing import Tuple
|
| 3 |
+
import numpy as np
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
from scipy import ndimage
|
| 6 |
+
import torch
|
| 7 |
+
from torchvision.transforms import functional as tvf
|
| 8 |
+
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
def sliced_mean(x, slice_size):
|
| 12 |
+
cs_y = np.cumsum(x, axis=0)
|
| 13 |
+
cs_y = np.concatenate((np.zeros((1, cs_y.shape[1]), dtype=cs_y.dtype), cs_y), axis=0)
|
| 14 |
+
slices_y = (cs_y[slice_size:] - cs_y[:-slice_size])/slice_size
|
| 15 |
+
cs_xy = np.cumsum(slices_y, axis=1)
|
| 16 |
+
cs_xy = np.concatenate((np.zeros((cs_xy.shape[0], 1), dtype=cs_xy.dtype), cs_xy), axis=1)
|
| 17 |
+
slices_xy = (cs_xy[:,slice_size:] - cs_xy[:,:-slice_size])/slice_size
|
| 18 |
+
return slices_xy
|
| 19 |
+
|
| 20 |
+
def sliced_var(x, slice_size):
|
| 21 |
+
x = x.astype('float64')
|
| 22 |
+
return sliced_mean(x**2, slice_size) - sliced_mean(x, slice_size)**2
|
| 23 |
+
|
| 24 |
+
def calculate_local_variance(img, var_window):
|
| 25 |
+
"""return local variance map with the same size as input image"""
|
| 26 |
+
var = sliced_var(img, var_window)
|
| 27 |
+
|
| 28 |
+
left_pad = var_window // 2 -1
|
| 29 |
+
right_pad = var_window -1 - left_pad
|
| 30 |
+
var_padded = np.pad(
|
| 31 |
+
var,
|
| 32 |
+
pad_width=(
|
| 33 |
+
(left_pad,right_pad),
|
| 34 |
+
(left_pad,right_pad)
|
| 35 |
+
))
|
| 36 |
+
return var_padded
|
| 37 |
+
|
| 38 |
+
def get_crop_batch(img: np.ndarray, mask: np.ndarray, crop_size=96, crop_scales=np.geomspace(0.5, 2, 7), samples_per_scale=32, use_variance_threshold=False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 39 |
+
"""
|
| 40 |
+
Generate a batch of cropped images from an input image and corresponding mask, at various scales and rotations.
|
| 41 |
+
|
| 42 |
+
Parameters
|
| 43 |
+
----------
|
| 44 |
+
img : np.ndarray
|
| 45 |
+
The input image from which crops are generated.
|
| 46 |
+
mask : np.ndarray
|
| 47 |
+
The binary mask indicating the region of interest in the image.
|
| 48 |
+
crop_size : int, optional
|
| 49 |
+
The size of the square crop.
|
| 50 |
+
crop_scales : np.ndarray, optional
|
| 51 |
+
An array of scale factors to apply to the crop size.
|
| 52 |
+
samples_per_scale : int, optional
|
| 53 |
+
Number of samples to generate per scale factor.
|
| 54 |
+
use_variance_threshold : bool, optional
|
| 55 |
+
Flag to use variance thresholding for selecting crop locations.
|
| 56 |
+
|
| 57 |
+
Returns
|
| 58 |
+
-------
|
| 59 |
+
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
| 60 |
+
A tuple containing the tensor of crops, their rotation angles, and scale factors.
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
# pad
|
| 64 |
+
pad_size = int(np.ceil(0.5*crop_size*max(crop_scales)*(np.sqrt(2)-1)))
|
| 65 |
+
img_padded = np.pad(img, pad_size)
|
| 66 |
+
mask_padded = np.pad(mask, pad_size)
|
| 67 |
+
|
| 68 |
+
# distance map
|
| 69 |
+
distance_map_padded = ndimage.distance_transform_edt(mask_padded)
|
| 70 |
+
# TODO: adjust scales and samples_per_scale
|
| 71 |
+
|
| 72 |
+
if use_variance_threshold:
|
| 73 |
+
variance_window = min(crop_size//2, min(img.shape))
|
| 74 |
+
variance_map_padded = np.pad(calculate_local_variance(img, variance_window), pad_size)
|
| 75 |
+
variance_median = np.ma.median(np.ma.masked_where(distance_map_padded<0.5*variance_window, variance_map_padded))
|
| 76 |
+
variance_mask = variance_map_padded >= variance_median
|
| 77 |
+
else:
|
| 78 |
+
variance_mask = np.ones_like(mask_padded)
|
| 79 |
+
|
| 80 |
+
# initilize output
|
| 81 |
+
crops_granum = []
|
| 82 |
+
angles_granum = []
|
| 83 |
+
scales_granum = []
|
| 84 |
+
# loop over scales
|
| 85 |
+
for scale in crop_scales:
|
| 86 |
+
half_crop_size_scaled = int(np.floor(scale*0.5*crop_size)) # half of crop size after scaling
|
| 87 |
+
crop_pad = int(np.ceil((np.sqrt(2) - 1)*half_crop_size_scaled)) # pad added in order to allow rotation
|
| 88 |
+
half_crop_size_external = half_crop_size_scaled + crop_pad # size of "external crop" which will be rotated
|
| 89 |
+
|
| 90 |
+
possible_indices = np.stack(np.where(variance_mask & (distance_map_padded >= 2*half_crop_size_scaled)), axis=1)
|
| 91 |
+
if len(possible_indices) == 0:
|
| 92 |
+
continue
|
| 93 |
+
chosen_indices = np.random.choice(np.arange(len(possible_indices)), min(len(possible_indices), samples_per_scale), replace=False)
|
| 94 |
+
|
| 95 |
+
crops = [
|
| 96 |
+
img_padded[y-half_crop_size_external:y+half_crop_size_external, x-half_crop_size_external:x+half_crop_size_external] for y, x in possible_indices[chosen_indices]
|
| 97 |
+
]
|
| 98 |
+
|
| 99 |
+
# rotate
|
| 100 |
+
rotation_angles = np.random.rand(len(crops))*180 - 90
|
| 101 |
+
crops = [
|
| 102 |
+
ndimage.rotate(crop, angle, reshape=False)[crop_pad:-crop_pad,crop_pad:-crop_pad] for crop, angle in zip(crops, rotation_angles)
|
| 103 |
+
]
|
| 104 |
+
# add to output
|
| 105 |
+
crops_granum.append(tvf.resize(torch.tensor(np.array(crops)), (crop_size,crop_size),antialias=True)) # resize crops to crop_size
|
| 106 |
+
angles_granum.extend(rotation_angles.tolist())
|
| 107 |
+
scales_granum.extend([scale]*len(crops))
|
| 108 |
+
|
| 109 |
+
if len(angles_granum) == 0:
|
| 110 |
+
return [], [], []
|
| 111 |
+
|
| 112 |
+
crops_granum = torch.concat(crops_granum)
|
| 113 |
+
angles_granum = torch.tensor(angles_granum, dtype=torch.float)
|
| 114 |
+
scales_granum = torch.tensor(scales_granum, dtype=torch.float)
|
| 115 |
+
|
| 116 |
+
return crops_granum, angles_granum, scales_granum
|
| 117 |
+
|
| 118 |
+
def get_crop_batch_from_path(img_path, mask_path=None, use_variance_threshold=False):
|
| 119 |
+
"""
|
| 120 |
+
Load an image and its mask from file paths and generate a batch of cropped images.
|
| 121 |
+
|
| 122 |
+
Parameters
|
| 123 |
+
----------
|
| 124 |
+
img_path : str
|
| 125 |
+
Path to the input image.
|
| 126 |
+
mask_path : str, optional
|
| 127 |
+
Path to the binary mask image. If None, assumes mask path by replacing image extension with '.npy'.
|
| 128 |
+
use_variance_threshold : bool, optional
|
| 129 |
+
Flag to use variance thresholding for selecting crop locations.
|
| 130 |
+
|
| 131 |
+
Returns
|
| 132 |
+
-------
|
| 133 |
+
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
| 134 |
+
A tuple containing the tensor of crops, their rotation angles, and scale factors, obtained from the specified image path.
|
| 135 |
+
"""
|
| 136 |
+
if mask_path is None:
|
| 137 |
+
mask_path = str(Path(img_path).with_suffix('.npy'))
|
| 138 |
+
mask = np.load(mask_path)
|
| 139 |
+
img = np.array(Image.open(img_path))[:,:,0]
|
| 140 |
+
|
| 141 |
+
return get_crop_batch(img, mask, use_variance_threshold=use_variance_threshold)
|
| 142 |
+
|
app.py
ADDED
|
@@ -0,0 +1,602 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import shutil
|
| 3 |
+
import time
|
| 4 |
+
import uuid
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
from decimal import Decimal
|
| 7 |
+
|
| 8 |
+
import gradio as gr
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
|
| 11 |
+
from settings import DEMO
|
| 12 |
+
|
| 13 |
+
plt.switch_backend("agg") # fix for "RuntimeError: main thread is not in main loop"
|
| 14 |
+
import numpy as np
|
| 15 |
+
import pandas as pd
|
| 16 |
+
from PIL import Image
|
| 17 |
+
|
| 18 |
+
from model import GranaAnalyser
|
| 19 |
+
|
| 20 |
+
ga = GranaAnalyser(
|
| 21 |
+
"weights/yolo/20240604_yolov8_segm_ABRCR1_all_train4_best.pt",
|
| 22 |
+
"weights/AS_square_v16.ckpt",
|
| 23 |
+
"weights/period_measurer_weights-1.298_real_full-fa12970.ckpt",
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def calc_ratio(pixels, nano):
|
| 28 |
+
"""
|
| 29 |
+
Calculates ratio of pixels to nanometers and returns as str to populate ratio_input
|
| 30 |
+
:param pixels:
|
| 31 |
+
:param nano:
|
| 32 |
+
:return:
|
| 33 |
+
"""
|
| 34 |
+
if not (pixels and nano):
|
| 35 |
+
pass
|
| 36 |
+
else:
|
| 37 |
+
res = pixels / nano
|
| 38 |
+
return res
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# https://jakevdp.github.io/PythonDataScienceHandbook/05.13-kernel-density-estimation.html
|
| 42 |
+
def KDE(dataset, h):
|
| 43 |
+
# the Kernel function
|
| 44 |
+
def K(x):
|
| 45 |
+
return np.exp(-(x ** 2) / 2) / np.sqrt(2 * np.pi)
|
| 46 |
+
|
| 47 |
+
n_samples = dataset.size
|
| 48 |
+
|
| 49 |
+
x_range = dataset # x-value range for plotting KDEs
|
| 50 |
+
|
| 51 |
+
total_sum = 0
|
| 52 |
+
# iterate over datapoints
|
| 53 |
+
for i, xi in enumerate(dataset):
|
| 54 |
+
total_sum += K((x_range - xi) / h)
|
| 55 |
+
|
| 56 |
+
y_range = total_sum / (h * n_samples)
|
| 57 |
+
|
| 58 |
+
return y_range
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def prepare_files_for_download(
|
| 62 |
+
dir_name,
|
| 63 |
+
grana_data,
|
| 64 |
+
aggregated_data,
|
| 65 |
+
detection_visualizations_dict,
|
| 66 |
+
images_grana_dict,
|
| 67 |
+
):
|
| 68 |
+
"""
|
| 69 |
+
Save and zip files for download
|
| 70 |
+
:param dir_name:
|
| 71 |
+
:param grana_data: DataFrame containing all grana measurements
|
| 72 |
+
:param aggregated_data: dict containing aggregated measurements
|
| 73 |
+
:return:
|
| 74 |
+
"""
|
| 75 |
+
dir_to_zip = f"{dir_name}/to_zip"
|
| 76 |
+
|
| 77 |
+
# raw data
|
| 78 |
+
grana_data_csv_path = f"{dir_to_zip}/grana_raw_data.csv"
|
| 79 |
+
grana_data.to_csv(grana_data_csv_path, index=False)
|
| 80 |
+
|
| 81 |
+
# aggregated measurements
|
| 82 |
+
aggregated_csv_path = f"{dir_to_zip}/grana_aggregated_data.csv"
|
| 83 |
+
aggregated_data.to_csv(aggregated_csv_path)
|
| 84 |
+
|
| 85 |
+
# annotated pictures
|
| 86 |
+
masked_images_dir = f"{dir_to_zip}/annotated_images"
|
| 87 |
+
os.makedirs(masked_images_dir)
|
| 88 |
+
for img_name, img in detection_visualizations_dict.items():
|
| 89 |
+
filename_split = img_name.split(".")
|
| 90 |
+
extension = filename_split[-1]
|
| 91 |
+
filename = ".".join(filename_split[:-1])
|
| 92 |
+
filename = f"{filename}_annotated.{extension}"
|
| 93 |
+
img.save(f"{masked_images_dir}/{filename}")
|
| 94 |
+
|
| 95 |
+
# single_grana images
|
| 96 |
+
grana_images_dir = f"{dir_to_zip}/single_grana_images"
|
| 97 |
+
os.makedirs(grana_images_dir)
|
| 98 |
+
org_images_dict = pd.Series(
|
| 99 |
+
grana_data["source image"].values, index=grana_data["granum ID"]
|
| 100 |
+
).to_dict()
|
| 101 |
+
for img_name, img in images_grana_dict.items():
|
| 102 |
+
org_filename = org_images_dict[img_name]
|
| 103 |
+
org_filename_split = org_filename.split(".")
|
| 104 |
+
org_filename_no_ext = ".".join(org_filename_split[:-1])
|
| 105 |
+
img_name_ext = f"{org_filename_no_ext}_granum_{str(img_name)}.png"
|
| 106 |
+
img.save(f"{grana_images_dir}/{img_name_ext}")
|
| 107 |
+
|
| 108 |
+
# zip all files
|
| 109 |
+
date_str = datetime.today().strftime("%Y-%m-%d")
|
| 110 |
+
zip_name = f"GRANA_results_{date_str}"
|
| 111 |
+
zip_path = f"{dir_name}/{zip_name}"
|
| 112 |
+
shutil.make_archive(zip_path, "zip", dir_to_zip)
|
| 113 |
+
|
| 114 |
+
# delete to_zip dir
|
| 115 |
+
zip_dir_path = os.path.join(os.getcwd(), dir_to_zip)
|
| 116 |
+
shutil.rmtree(zip_dir_path)
|
| 117 |
+
|
| 118 |
+
download_file_path = f"{zip_path}.zip"
|
| 119 |
+
return download_file_path
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def show_info_on_submit(s):
|
| 123 |
+
return (
|
| 124 |
+
gr.Button(interactive=False),
|
| 125 |
+
gr.Button(interactive=False),
|
| 126 |
+
gr.Row(visible=True),
|
| 127 |
+
gr.Row(visible=False),
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def load_css():
|
| 132 |
+
with open("styles.css", "r") as f:
|
| 133 |
+
css_content = f.read()
|
| 134 |
+
return css_content
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
primary_hue = gr.themes.Color(
|
| 138 |
+
c50="#e1f8ee",
|
| 139 |
+
c100="#b7efd5",
|
| 140 |
+
c200="#8de6bd",
|
| 141 |
+
c300="#63dda5",
|
| 142 |
+
c400="#39d48d",
|
| 143 |
+
c500="#27b373",
|
| 144 |
+
c600="#1e8958",
|
| 145 |
+
c700="#155f3d",
|
| 146 |
+
c800="#0c3522",
|
| 147 |
+
c900="#030b07",
|
| 148 |
+
c950="#000",
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
theme = gr.themes.Default(
|
| 153 |
+
primary_hue=primary_hue,
|
| 154 |
+
font=[gr.themes.GoogleFont("Ubuntu"), "ui-sans-serif", "system-ui", "sans-serif"],
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def draw_violin_plot(y, ylabel, title):
|
| 159 |
+
# only generate plot for 3 or more values
|
| 160 |
+
if y.count() < 3:
|
| 161 |
+
return None
|
| 162 |
+
|
| 163 |
+
# Colors
|
| 164 |
+
RED_DARK = "#850e00"
|
| 165 |
+
DARK_GREEN = "#0c3522"
|
| 166 |
+
BRIGHT_GREEN = "#8de6bd"
|
| 167 |
+
|
| 168 |
+
# Create jittered version of "x" (which is only 1)
|
| 169 |
+
x_jittered = []
|
| 170 |
+
kde = KDE(y, (y.max() - y.min()) / y.size / 2)
|
| 171 |
+
kde = kde / kde.max() * 0.2
|
| 172 |
+
for y_val in kde:
|
| 173 |
+
x_jittered.append(1 + np.random.uniform(-y_val, y_val, 1))
|
| 174 |
+
|
| 175 |
+
fig = plt.figure()
|
| 176 |
+
ax = fig.add_subplot(1, 1, 1)
|
| 177 |
+
ax.scatter(x=x_jittered, y=y, s=20, alpha=0.4, c=DARK_GREEN)
|
| 178 |
+
|
| 179 |
+
violins = ax.violinplot(
|
| 180 |
+
y,
|
| 181 |
+
widths=0.45,
|
| 182 |
+
bw_method="silverman",
|
| 183 |
+
showmeans=False,
|
| 184 |
+
showmedians=False,
|
| 185 |
+
showextrema=False,
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
# change violin color
|
| 189 |
+
for pc in violins["bodies"]:
|
| 190 |
+
pc.set_facecolor(BRIGHT_GREEN)
|
| 191 |
+
|
| 192 |
+
# add a boxplot to ax
|
| 193 |
+
# but make the whiskers length equal to 1 SD, i.e. in the proportion of the IQ range, but this length should start from the mean but be visible from the box boundary
|
| 194 |
+
lower = np.mean(y) - 1 * np.std(y)
|
| 195 |
+
upper = np.mean(y) + 1 * np.std(y)
|
| 196 |
+
|
| 197 |
+
medianprops = dict(linewidth=1, color="black", solid_capstyle="butt")
|
| 198 |
+
boxplot_stats = [
|
| 199 |
+
{
|
| 200 |
+
"med": np.median(y),
|
| 201 |
+
"q1": np.percentile(y, 25),
|
| 202 |
+
"q3": np.percentile(y, 75),
|
| 203 |
+
"whislo": lower,
|
| 204 |
+
"whishi": upper,
|
| 205 |
+
}
|
| 206 |
+
]
|
| 207 |
+
|
| 208 |
+
ax.bxp(
|
| 209 |
+
boxplot_stats, # data for the boxplot
|
| 210 |
+
showfliers=False, # do not show the outliers beyond the caps.
|
| 211 |
+
showcaps=True, # show the caps
|
| 212 |
+
medianprops=medianprops,
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
# Add mean value point
|
| 216 |
+
ax.scatter(1, y.mean(), s=30, color=RED_DARK, zorder=3)
|
| 217 |
+
|
| 218 |
+
ax.set_xticks([])
|
| 219 |
+
ax.set_ylabel(ylabel)
|
| 220 |
+
ax.set_title(title)
|
| 221 |
+
fig.tight_layout()
|
| 222 |
+
|
| 223 |
+
return fig
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def transform_aggregated_results_table(results_dict):
|
| 227 |
+
MEASUREMENT_HEADER = "measurement [unit]"
|
| 228 |
+
VALUE_HEADER = "value +-SD"
|
| 229 |
+
|
| 230 |
+
def get_value_str(value, std):
|
| 231 |
+
if np.isnan(value) or np.isnan(std):
|
| 232 |
+
return "-"
|
| 233 |
+
value_str = str(Decimal(str(value)).quantize(Decimal("0.01")))
|
| 234 |
+
std_str = str(Decimal(str(std)).quantize(Decimal("0.01")))
|
| 235 |
+
return f"{value_str} +-{std_str}"
|
| 236 |
+
|
| 237 |
+
def append_to_dict(new_key, old_val_key, old_sd_key):
|
| 238 |
+
aggregated_dict[MEASUREMENT_HEADER].append(new_key)
|
| 239 |
+
value_str = get_value_str(results_dict[old_val_key], results_dict[old_sd_key])
|
| 240 |
+
aggregated_dict[VALUE_HEADER].append(value_str)
|
| 241 |
+
|
| 242 |
+
aggregated_dict = {MEASUREMENT_HEADER: [], VALUE_HEADER: []}
|
| 243 |
+
|
| 244 |
+
# area
|
| 245 |
+
append_to_dict("area [nm^2]", "area nm^2", "area nm^2 std")
|
| 246 |
+
|
| 247 |
+
# perimeter
|
| 248 |
+
append_to_dict("perimeter [nm]", "perimeter nm", "perimeter nm std")
|
| 249 |
+
|
| 250 |
+
# diameter
|
| 251 |
+
append_to_dict("diameter [nm]", "diameter nm", "diameter nm std")
|
| 252 |
+
|
| 253 |
+
# height
|
| 254 |
+
append_to_dict("height [nm]", "height nm", "height nm std")
|
| 255 |
+
|
| 256 |
+
# number of layers
|
| 257 |
+
append_to_dict("number of thylakoids", "Number of layers", "Number of layers std")
|
| 258 |
+
|
| 259 |
+
# SRD
|
| 260 |
+
append_to_dict("SRD [nm]", "period nm", "period nm std")
|
| 261 |
+
|
| 262 |
+
# GSI
|
| 263 |
+
append_to_dict("GSI", "GSI", "GSI std")
|
| 264 |
+
|
| 265 |
+
# N grana
|
| 266 |
+
aggregated_dict[MEASUREMENT_HEADER].append("number of grana")
|
| 267 |
+
aggregated_dict[VALUE_HEADER].append(str(int(results_dict["N grana"])))
|
| 268 |
+
|
| 269 |
+
return aggregated_dict
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def rename_columns_in_results_table(results_table):
|
| 273 |
+
column_names = {
|
| 274 |
+
"Granum ID": "granum ID",
|
| 275 |
+
"File name": "source image",
|
| 276 |
+
"area nm^2": "area [nm^2]",
|
| 277 |
+
"perimeter nm": "perimeter [nm]",
|
| 278 |
+
"diameter nm": "diameter [nm]",
|
| 279 |
+
"height nm": "height [nm]",
|
| 280 |
+
"Number of layers": "number of thylakoids",
|
| 281 |
+
"period nm": "SRD [nm]",
|
| 282 |
+
"period SD nm": "SRD SD [nm]",
|
| 283 |
+
}
|
| 284 |
+
results_table = results_table.rename(columns=column_names)
|
| 285 |
+
return results_table
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
with gr.Blocks(css=load_css(), theme=theme) as demo:
|
| 289 |
+
|
| 290 |
+
svg = """
|
| 291 |
+
<svg id="Layer_1" data-name="Layer 1" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 30.73 33.38">
|
| 292 |
+
<defs>
|
| 293 |
+
<style>
|
| 294 |
+
.cls-1 {
|
| 295 |
+
fill: #27b373;
|
| 296 |
+
stroke-width: 0px;
|
| 297 |
+
}
|
| 298 |
+
</style>
|
| 299 |
+
</defs>
|
| 300 |
+
<path class="cls-1" d="M19.69,11.73h-3.22c-2.74,0-4.96,2.22-4.96,4.96h0c0,2.74,2.22,4.96,4.96,4.96h3.43c.56,0,1,.51.89,1.09-.08.43-.49.72-.92.72h-8.62c-.74,0-1.34-.6-1.34-1.34v-10.87c0-.74.6-1.34,1.34-1.34h13.44c2.73,0,4.95-2.22,4.95-4.95h0c0-2.75-2.22-4.97-4.96-4.97h-13.85C4.85,0,0,4.85,0,10.83v11.71c0,5.98,4.85,10.83,10.83,10.83h9.07c5.76,0,10.49-4.52,10.81-10.21.35-6.29-4.72-11.44-11.02-11.44ZM19.9,31.4h-9.07c-4.89,0-8.85-3.96-8.85-8.85v-11.71C1.98,5.95,5.95,1.98,10.83,1.98h13.81c1.64,0,2.97,1.33,2.97,2.97h0c0,1.65-1.33,2.97-2.96,2.97h-13.4c-1.83,0-3.32,1.49-3.32,3.32v10.87c0,1.83,1.49,3.32,3.32,3.32h8.56c1.51,0,2.83-1.12,2.97-2.62.16-1.72-1.2-3.16-2.88-3.16h-3.52c-1.64,0-2.97-1.33-2.97-2.97h0c0-1.64,1.33-2.97,2.97-2.97h3.34c4.83,0,8.9,3.81,9.01,8.64s-3.9,9.04-8.84,9.04Z"/>
|
| 301 |
+
<path class="cls-1" d="M19.9,29.41h-9.07c-3.79,0-6.87-3.07-6.87-6.87v-11.71c0-3.79,3.07-6.87,6.87-6.87h13.81c.55,0,.99.44.99.99h0c0,.55-.44.99-.99.99h-13.81c-2.7,0-4.88,2.19-4.88,4.88v11.71c0,2.7,2.19,4.88,4.88,4.88h8.94c2.64,0,4.91-2.05,5-4.7s-2.12-5.05-4.87-5.05h-3.52c-.55,0-.99-.44-.99-.99h0c0-.55.44-.99.99-.99h3.36c3.74,0,6.9,2.92,7.01,6.66.11,3.87-3.01,7.06-6.85,7.06Z"/>
|
| 302 |
+
</svg>
|
| 303 |
+
"""
|
| 304 |
+
|
| 305 |
+
gr.HTML(
|
| 306 |
+
f'<div class="header"><div id="header-logo">{svg}</div><div id="header-text">GRANA<div></div>'
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
with gr.Row(elem_classes="input-row"): # input
|
| 310 |
+
with gr.Column():
|
| 311 |
+
gr.HTML(
|
| 312 |
+
"<h1>1. Choose images to upload. All the images need to be of the same scale and experimental variant.</h1>"
|
| 313 |
+
)
|
| 314 |
+
img_input = gr.File(file_count="multiple")
|
| 315 |
+
|
| 316 |
+
gr.HTML("<h1>2. Set the scale of the images for the measurements.</h1>")
|
| 317 |
+
with gr.Row():
|
| 318 |
+
with gr.Column():
|
| 319 |
+
gr.HTML("Either provide pixel per nanometer ratio...")
|
| 320 |
+
ratio_input = gr.Number(
|
| 321 |
+
label="pixel per nm", precision=3, step=0.001
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
with gr.Column():
|
| 325 |
+
gr.HTML("...or length of the scale bar in pixels and nanometers.")
|
| 326 |
+
pixels_input = gr.Number(label="Length in pixels")
|
| 327 |
+
nano_input = gr.Number(label="Length in nanometers")
|
| 328 |
+
|
| 329 |
+
pixels_input.change(
|
| 330 |
+
calc_ratio,
|
| 331 |
+
inputs=[pixels_input, nano_input],
|
| 332 |
+
outputs=ratio_input,
|
| 333 |
+
)
|
| 334 |
+
nano_input.change(
|
| 335 |
+
calc_ratio,
|
| 336 |
+
inputs=[pixels_input, nano_input],
|
| 337 |
+
outputs=ratio_input,
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
with gr.Row():
|
| 341 |
+
clear_btn = gr.ClearButton(img_input, "Clear")
|
| 342 |
+
submit_btn = gr.Button("Submit", variant="primary")
|
| 343 |
+
|
| 344 |
+
with gr.Row(visible=False) as loading_row:
|
| 345 |
+
with gr.Column():
|
| 346 |
+
gr.HTML(
|
| 347 |
+
"<div class='processed-info'>Images are being processed. This may take a while...</div>"
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
with gr.Row(visible=False) as output_row:
|
| 351 |
+
with gr.Column():
|
| 352 |
+
gr.HTML(
|
| 353 |
+
'<div class="results-header">Results</div>'
|
| 354 |
+
"<p>Full results are a zip file containing:<p>"
|
| 355 |
+
"<ul>- grana_raw_data.csv: a table with full grana measurements,</ul>"
|
| 356 |
+
"<ul>- grana_aggregated_data.csv: a table with aggregated measurements,</ul>"
|
| 357 |
+
'<ul>- directory "annotated_images" with all submitted images with masks on detected grana,</ul>'
|
| 358 |
+
'<ul>- directory "single_grana_images" with images of all detected grana.</ul>'
|
| 359 |
+
"<p>Note that GRANA only stores the result files for 1 hour.</p>",
|
| 360 |
+
elem_classes="input-row",
|
| 361 |
+
)
|
| 362 |
+
with gr.Row(elem_classes="input-row"):
|
| 363 |
+
download_file_out = gr.DownloadButton(
|
| 364 |
+
label="Download results",
|
| 365 |
+
variant="primary",
|
| 366 |
+
elem_classes="margin-bottom",
|
| 367 |
+
)
|
| 368 |
+
with gr.Row():
|
| 369 |
+
gr.HTML(
|
| 370 |
+
'<h2 class="title">Annotated images</h2>'
|
| 371 |
+
"Gallery of uploaded images with masks of recognized grana structures. "
|
| 372 |
+
"Each granum mask is "
|
| 373 |
+
"labeled with its number. Note that only fully visible grana in the image are masked."
|
| 374 |
+
)
|
| 375 |
+
with gr.Row(elem_classes="margin-bottom"):
|
| 376 |
+
gallery_out = gr.Gallery(
|
| 377 |
+
columns=4,
|
| 378 |
+
rows=2,
|
| 379 |
+
object_fit="contain",
|
| 380 |
+
label="Detection visualizations",
|
| 381 |
+
show_download_button=False,
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
with gr.Row(elem_classes="input-row"):
|
| 385 |
+
gr.HTML(
|
| 386 |
+
'<h2 class="title">Aggregated results for all uploaded images</h2>'
|
| 387 |
+
)
|
| 388 |
+
with gr.Row(elem_classes=["input-row", "margin-bottom"]):
|
| 389 |
+
table_out = gr.Dataframe(label="Aggregated data")
|
| 390 |
+
|
| 391 |
+
with gr.Row():
|
| 392 |
+
gr.HTML(
|
| 393 |
+
'<h2 class="title">Violin graphs</h2>'
|
| 394 |
+
"These graphs present aggregated results for selected structural parameters. "
|
| 395 |
+
"The graph for each parameter is only generated if three or more values are available. "
|
| 396 |
+
"Each graph "
|
| 397 |
+
"displays individual data points, a box plot indicating the first and third quartiles, whiskers "
|
| 398 |
+
"marking the standard deviation (SD), the median value (horizontal line on the box plot), "
|
| 399 |
+
"the mean value (red dot), and a density plot where the width represents the frequency."
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
with gr.Row():
|
| 403 |
+
area_plot_out = gr.Plot(label="Area")
|
| 404 |
+
perimeter_plot_out = gr.Plot(label="Perimeter")
|
| 405 |
+
gsi_plot_out = gr.Plot(label="GSI")
|
| 406 |
+
|
| 407 |
+
with gr.Row(elem_classes="margin-bottom"):
|
| 408 |
+
diameter_plot_out = gr.Plot(label="Diameter")
|
| 409 |
+
height_plot_out = gr.Plot(label="Height")
|
| 410 |
+
srd_plot_out = gr.Plot(label="SRD")
|
| 411 |
+
|
| 412 |
+
with gr.Row():
|
| 413 |
+
gr.HTML(
|
| 414 |
+
'<h2 class="title">Recognized and rotated grana structures</h2>'
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
with gr.Row(elem_classes="margin-bottom"):
|
| 418 |
+
gallery_single_grana_out = gr.Gallery(
|
| 419 |
+
columns=4,
|
| 420 |
+
rows=2,
|
| 421 |
+
object_fit="contain",
|
| 422 |
+
label="Single grana images",
|
| 423 |
+
show_download_button=False,
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
with gr.Row():
|
| 427 |
+
gr.HTML(
|
| 428 |
+
'<h2 class="title">Full results</h2>'
|
| 429 |
+
"Note that structural parameters other than area and perimeter are only calculated for the grana "
|
| 430 |
+
"whose direction and/or SRD could be estimated."
|
| 431 |
+
)
|
| 432 |
+
with gr.Row():
|
| 433 |
+
table_full_out = gr.Dataframe(label="Full measurements data")
|
| 434 |
+
|
| 435 |
+
submit_btn.click(
|
| 436 |
+
show_info_on_submit,
|
| 437 |
+
inputs=[submit_btn],
|
| 438 |
+
outputs=[submit_btn, clear_btn, loading_row, output_row],
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
def enable_submit():
|
| 442 |
+
return (
|
| 443 |
+
gr.Button(interactive=True),
|
| 444 |
+
gr.Button(interactive=True),
|
| 445 |
+
gr.Row(visible=False),
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
def gradio_analize_image(images, scale):
|
| 449 |
+
"""
|
| 450 |
+
Model accepts following parameters:
|
| 451 |
+
:param images: list of images to be processed, in either tiff or png format
|
| 452 |
+
:param scale: float, nm to pixel ratio
|
| 453 |
+
|
| 454 |
+
Model returns the following objects:
|
| 455 |
+
- detection_visualizations: list of images with masks to be displayed as gallery and served to download
|
| 456 |
+
as zip of images
|
| 457 |
+
- grana_data: dataframe with measurements for each image to be served to download as a csv file
|
| 458 |
+
- images_grana: list of images with single grana to be served to download as zip of images
|
| 459 |
+
- aggregated_data: dataframe with aggregated measurements for all images to be displayed as table and served
|
| 460 |
+
to download as csv
|
| 461 |
+
"""
|
| 462 |
+
|
| 463 |
+
# validate that at least one image has been uploaded
|
| 464 |
+
if images is None or len(images) == 0:
|
| 465 |
+
raise gr.Error("Please upload at least one image")
|
| 466 |
+
|
| 467 |
+
# on demo instance, we limit the number of images to 5
|
| 468 |
+
if DEMO:
|
| 469 |
+
if len(images) > 5:
|
| 470 |
+
raise gr.Error("In demo version it is possible to analyze up to 5 images.")
|
| 471 |
+
|
| 472 |
+
# validate that scale has been provided correctly
|
| 473 |
+
if scale is None or scale == 0:
|
| 474 |
+
raise gr.Error("Please provide scale. Use dot as decimal separator")
|
| 475 |
+
|
| 476 |
+
# validate that all images are png or tiff
|
| 477 |
+
for image in images:
|
| 478 |
+
if not image.name.lower().endswith((".png", ".tif", ".jpg", ".jpeg")):
|
| 479 |
+
raise gr.Error("Only png, tiff, jpg ang jpeg images are supported")
|
| 480 |
+
|
| 481 |
+
# clean up previous results
|
| 482 |
+
# find all directories in current working directory that start with "results_"
|
| 483 |
+
# that were created more than 1 hour ago and delete them with all contents
|
| 484 |
+
for directory_name in os.listdir():
|
| 485 |
+
if directory_name.startswith("results_"):
|
| 486 |
+
dir_path = os.path.join(os.getcwd(), directory_name)
|
| 487 |
+
if os.path.isdir(dir_path):
|
| 488 |
+
if time.time() - os.path.getctime(dir_path) > 60 * 60:
|
| 489 |
+
shutil.rmtree(dir_path)
|
| 490 |
+
|
| 491 |
+
# create a directory for results
|
| 492 |
+
results_dir_name = "results_{uuid}".format(uuid=uuid.uuid4().hex)
|
| 493 |
+
os.makedirs(results_dir_name)
|
| 494 |
+
zip_dir_name = f"{results_dir_name}/to_zip"
|
| 495 |
+
os.makedirs(zip_dir_name)
|
| 496 |
+
|
| 497 |
+
# model takes a dict of images, so we need to convert input to list of PIL.PngImagePlugin.PngImageFile or
|
| 498 |
+
# PIL.TiffImagePlugin.TiffImageFile objects
|
| 499 |
+
images_dict = {
|
| 500 |
+
image.name.split("/")[-1]: Image.open(image.name)
|
| 501 |
+
for i, image in enumerate(images)
|
| 502 |
+
}
|
| 503 |
+
|
| 504 |
+
# model works here
|
| 505 |
+
(
|
| 506 |
+
detection_visualizations_dict,
|
| 507 |
+
grana_data,
|
| 508 |
+
images_grana_dict,
|
| 509 |
+
aggregated_data,
|
| 510 |
+
) = ga.predict(images_dict, scale)
|
| 511 |
+
detection_visualizations = list(detection_visualizations_dict.values())
|
| 512 |
+
images_grana = list(images_grana_dict.values())
|
| 513 |
+
|
| 514 |
+
# rearrange aggregated data to be displayed as table
|
| 515 |
+
aggregated_dict = transform_aggregated_results_table(aggregated_data)
|
| 516 |
+
aggregated_df_transposed = pd.DataFrame.from_dict(aggregated_dict)
|
| 517 |
+
|
| 518 |
+
# rename columns in full results
|
| 519 |
+
grana_data = rename_columns_in_results_table(grana_data)
|
| 520 |
+
|
| 521 |
+
# save files returned by model to disk so they can be retrieved for downloading
|
| 522 |
+
download_file_path = prepare_files_for_download(
|
| 523 |
+
results_dir_name,
|
| 524 |
+
grana_data,
|
| 525 |
+
aggregated_df_transposed,
|
| 526 |
+
detection_visualizations_dict,
|
| 527 |
+
images_grana_dict,
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
# generate plot
|
| 531 |
+
area_fig = draw_violin_plot(
|
| 532 |
+
grana_data["area [nm^2]"].dropna(),
|
| 533 |
+
"Granum area [nm^2]",
|
| 534 |
+
"Grana areas from all uploaded images",
|
| 535 |
+
)
|
| 536 |
+
perimeter_fig = draw_violin_plot(
|
| 537 |
+
grana_data["perimeter [nm]"].dropna(),
|
| 538 |
+
"Granum perimeter [nm]",
|
| 539 |
+
"Grana perimeters from all uploaded images",
|
| 540 |
+
)
|
| 541 |
+
gsi_fig = draw_violin_plot(
|
| 542 |
+
grana_data["GSI"].dropna(),
|
| 543 |
+
"GSI",
|
| 544 |
+
"GSI from all uploaded images",
|
| 545 |
+
)
|
| 546 |
+
diameter_fig = draw_violin_plot(
|
| 547 |
+
grana_data["diameter [nm]"].dropna(),
|
| 548 |
+
"Granum diameter [nm]",
|
| 549 |
+
"Grana diameters from all uploaded images",
|
| 550 |
+
)
|
| 551 |
+
height_fig = draw_violin_plot(
|
| 552 |
+
grana_data["height [nm]"].dropna(),
|
| 553 |
+
"Granum height [nm]",
|
| 554 |
+
"Grana heights from all uploaded images",
|
| 555 |
+
)
|
| 556 |
+
srd_fig = draw_violin_plot(
|
| 557 |
+
grana_data["SRD [nm]"].dropna(), "SRD [nm]", "SRD from all uploaded images"
|
| 558 |
+
)
|
| 559 |
+
|
| 560 |
+
return [
|
| 561 |
+
gr.Row(visible=True),
|
| 562 |
+
gr.Row(visible=True),
|
| 563 |
+
download_file_path,
|
| 564 |
+
detection_visualizations,
|
| 565 |
+
aggregated_df_transposed,
|
| 566 |
+
area_fig,
|
| 567 |
+
perimeter_fig,
|
| 568 |
+
gsi_fig,
|
| 569 |
+
diameter_fig,
|
| 570 |
+
height_fig,
|
| 571 |
+
srd_fig,
|
| 572 |
+
images_grana,
|
| 573 |
+
grana_data,
|
| 574 |
+
]
|
| 575 |
+
|
| 576 |
+
submit_btn.click(
|
| 577 |
+
fn=gradio_analize_image,
|
| 578 |
+
inputs=[
|
| 579 |
+
img_input,
|
| 580 |
+
ratio_input,
|
| 581 |
+
],
|
| 582 |
+
outputs=[
|
| 583 |
+
loading_row,
|
| 584 |
+
output_row,
|
| 585 |
+
# file_download_checkboxes,
|
| 586 |
+
download_file_out,
|
| 587 |
+
gallery_out,
|
| 588 |
+
table_out,
|
| 589 |
+
area_plot_out,
|
| 590 |
+
perimeter_plot_out,
|
| 591 |
+
gsi_plot_out,
|
| 592 |
+
diameter_plot_out,
|
| 593 |
+
height_plot_out,
|
| 594 |
+
srd_plot_out,
|
| 595 |
+
gallery_single_grana_out,
|
| 596 |
+
table_full_out,
|
| 597 |
+
],
|
| 598 |
+
).then(fn=enable_submit, inputs=[], outputs=[submit_btn, clear_btn, loading_row])
|
| 599 |
+
|
| 600 |
+
demo.launch(
|
| 601 |
+
share=False, debug=True, server_name="0.0.0.0", allowed_paths=["images/logo.svg"]
|
| 602 |
+
)
|
grana_detection/mmwrapper.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from typing import Union, Optional
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from mmdet.apis import DetInferencer
|
| 6 |
+
from ultralytics.engine.results import Results
|
| 7 |
+
import warnings
|
| 8 |
+
|
| 9 |
+
class MMDetector(DetInferencer):
|
| 10 |
+
def __call__(
|
| 11 |
+
self,
|
| 12 |
+
inputs,
|
| 13 |
+
) -> Results:
|
| 14 |
+
"""Call the inferencer as in DetInferencer but for single image.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
inputs (np.ndarray | str): Inputs for the inferencer.
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
Result: yolo-like result
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
ori_inputs = self._inputs_to_list(inputs)
|
| 24 |
+
|
| 25 |
+
data = list(self.preprocess(
|
| 26 |
+
ori_inputs, batch_size=1))[0][1]
|
| 27 |
+
|
| 28 |
+
preds = self.forward(data)[0]
|
| 29 |
+
|
| 30 |
+
yolo_result = Results(
|
| 31 |
+
orig_img=ori_inputs[0], path="", names=[""],
|
| 32 |
+
boxes=torch.cat((preds.pred_instances.bboxes, preds.pred_instances.scores.unsqueeze(-1), preds.pred_instances.labels.unsqueeze(-1)), dim=1),
|
| 33 |
+
masks=preds.pred_instances.masks
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
return yolo_result
|
| 37 |
+
|
| 38 |
+
def predict(self, source: Image.Image, conf=None):
|
| 39 |
+
"""yolo interface"""
|
| 40 |
+
if conf is not None:
|
| 41 |
+
warnings.warn(f"confidence value {conf} ignored")
|
| 42 |
+
return [self.__call__(np.array(source.convert("RGB")))]
|
model.py
ADDED
|
@@ -0,0 +1,629 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import itertools
|
| 2 |
+
import warnings
|
| 3 |
+
from io import BytesIO
|
| 4 |
+
from copy import deepcopy
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import List, Tuple, Dict, Optional, Any, Union
|
| 7 |
+
from dataclasses import dataclass, field
|
| 8 |
+
|
| 9 |
+
from PIL import Image
|
| 10 |
+
import numpy as np
|
| 11 |
+
import pandas as pd
|
| 12 |
+
import matplotlib.pyplot as plt
|
| 13 |
+
from scipy import ndimage
|
| 14 |
+
from skimage import measure
|
| 15 |
+
import torchvision.transforms.functional as tvf
|
| 16 |
+
from ultralytics import YOLO
|
| 17 |
+
import torch
|
| 18 |
+
import cv2
|
| 19 |
+
import gradio
|
| 20 |
+
|
| 21 |
+
import sys, os
|
| 22 |
+
sys.path.append(os.path.abspath('angle_calculation'))
|
| 23 |
+
# from classic import measure_object
|
| 24 |
+
from sampling import get_crop_batch
|
| 25 |
+
from angle_model import PatchedPredictor, StripsModelLumenWidth
|
| 26 |
+
|
| 27 |
+
from period_calculation.period_measurer import PeriodMeasurer
|
| 28 |
+
# from grana_detection.mmwrapper import MMDetector # mmdet installation in docker is problematic for now
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class Granum:
|
| 32 |
+
id: Optional[int] = None
|
| 33 |
+
image: Any = None
|
| 34 |
+
mask: Any = None
|
| 35 |
+
scaler: Any = None
|
| 36 |
+
nm_per_px: float = float('nan')
|
| 37 |
+
detection_confidence: float = float('nan')
|
| 38 |
+
img_oriented: Optional[np.ndarray] = None # oriented fragment of the image
|
| 39 |
+
mask_oriented: Optional[np.ndarray] = None # oriented fragment of the mask
|
| 40 |
+
measurements: dict = field(default_factory=dict) # dict with grana measurements
|
| 41 |
+
|
| 42 |
+
class ScalerPadder:
|
| 43 |
+
"""resize and pad image to specific range.
|
| 44 |
+
minimal_pad: obligatory padding, e.g. required for detector
|
| 45 |
+
"""
|
| 46 |
+
def __init__(self, target_size=1024, target_short_edge_min=640, minimal_pad=16, pad_to_multiply=32):
|
| 47 |
+
self.minimal_pad = minimal_pad
|
| 48 |
+
self.target_size = target_size - 2*self.minimal_pad # detection pad is necessary padding size
|
| 49 |
+
self.target_short_edge_min = target_short_edge_min - 2*self.minimal_pad
|
| 50 |
+
|
| 51 |
+
self.max_size_nm = 6000 # training images covers ~3100 nm
|
| 52 |
+
self.min_size_nm = 2400 # training images covers ~3100 nm
|
| 53 |
+
self.pad_to_multiply = pad_to_multiply
|
| 54 |
+
|
| 55 |
+
def transform(self, image: Image.Image, px_per_nm: float=1.298) -> Image.Image:
|
| 56 |
+
self.original_size = image.size
|
| 57 |
+
self.original_px_per_nm = px_per_nm
|
| 58 |
+
w, h = self.original_size
|
| 59 |
+
longest_size = max(h, w)
|
| 60 |
+
img_size_nm = longest_size / px_per_nm
|
| 61 |
+
if img_size_nm > self.max_size_nm:
|
| 62 |
+
error_message = f'too large image, image size: {img_size_nm:0.1f}nm, max allowed: {self.max_size_nm}nm'
|
| 63 |
+
# raise ValueError(error_message)
|
| 64 |
+
# warnings.warn(warning_message)
|
| 65 |
+
gradio.Warning(error_message)
|
| 66 |
+
# add_text(image, warning_message, location=(0.1, 0.1), color='blue', size=int(40*longest_size/self.target_size))
|
| 67 |
+
|
| 68 |
+
self.resize_factor = self.target_size / (max(self.min_size_nm, img_size_nm) * px_per_nm)
|
| 69 |
+
self.px_per_nm_transformed = px_per_nm * self.resize_factor
|
| 70 |
+
|
| 71 |
+
resized_image = resize_with_cv2(image, (int(h*self.resize_factor), int(w*self.resize_factor)))
|
| 72 |
+
|
| 73 |
+
if w >= h:
|
| 74 |
+
pad_w = self.target_size-resized_image.size[0]
|
| 75 |
+
pad_h = max(0, self.target_short_edge_min-resized_image.size[1])
|
| 76 |
+
else:
|
| 77 |
+
pad_w = max(0, self.target_short_edge_min-resized_image.size[0])
|
| 78 |
+
pad_h = self.target_size-resized_image.size[1]
|
| 79 |
+
|
| 80 |
+
# apply minimal padding
|
| 81 |
+
pad_w += 2*self.minimal_pad
|
| 82 |
+
pad_h += 2*self.minimal_pad
|
| 83 |
+
|
| 84 |
+
# round to multiplication
|
| 85 |
+
pad_w += (self.pad_to_multiply - resized_image.size[0]%self.pad_to_multiply)%self.pad_to_multiply
|
| 86 |
+
pad_h += (self.pad_to_multiply - resized_image.size[1]%self.pad_to_multiply)%self.pad_to_multiply
|
| 87 |
+
|
| 88 |
+
self.pad_right = pad_w // 2
|
| 89 |
+
self.pad_left = pad_w - self.pad_right
|
| 90 |
+
|
| 91 |
+
self.pad_up = pad_h // 2
|
| 92 |
+
self.pad_bottom = pad_h - self.pad_up
|
| 93 |
+
|
| 94 |
+
padded_image = tvf.pad(resized_image, [self.pad_left,self.pad_up, self.pad_right, self.pad_bottom], padding_mode='reflect') # fill 114 as in YOLO
|
| 95 |
+
return padded_image
|
| 96 |
+
|
| 97 |
+
@property
|
| 98 |
+
def unpad_slice(self) -> Tuple[slice]:
|
| 99 |
+
return slice(self.pad_up,-self.pad_bottom if self.pad_bottom>0 else None), slice(self.pad_left,-self.pad_right if self.pad_right>0 else None)
|
| 100 |
+
|
| 101 |
+
def inverse_transform(self, image: Union[np.ndarray, Image.Image], output_size: Optional[Tuple[int]]=None, output_nm_per_px: Optional[float]=None, return_pil: bool=True) -> Image.Image:
|
| 102 |
+
if isinstance(image, Image.Image):
|
| 103 |
+
image = np.array(image)
|
| 104 |
+
# h, w = image.shape[:2]
|
| 105 |
+
# unpadded_image = image[self.pad_up:h-self.pad_bottom,self.pad_left:w-self.pad_right]
|
| 106 |
+
unapdded_image = image[self.unpad_slice]
|
| 107 |
+
|
| 108 |
+
if output_size is not None and output_nm_per_px is not None:
|
| 109 |
+
raise ValueError("one of output_size or output_nm_per_px must not be None")
|
| 110 |
+
elif output_nm_per_px is not None:
|
| 111 |
+
resize_factor = self.original_nm_per_px/output_nm_per_px
|
| 112 |
+
output_size = (int(self.original_size[0]*resize_factor), int(self.original_size[1]*resize_factor))
|
| 113 |
+
elif output_size is None:
|
| 114 |
+
output_size = self.original_size
|
| 115 |
+
resized_image = resize_with_cv2(unapdded_image, (output_size[1],output_size[0]), return_pil=return_pil) #Image.fromarray(unpadded_image).resize(self.original_size)
|
| 116 |
+
|
| 117 |
+
return resized_image
|
| 118 |
+
|
| 119 |
+
def close_contour(contour):
|
| 120 |
+
if not np.array_equal(contour[0], contour[-1]):
|
| 121 |
+
contour = np.vstack((contour, contour[0]))
|
| 122 |
+
return contour
|
| 123 |
+
|
| 124 |
+
def binary_mask_to_polygon(binary_mask, tol=0.01):
|
| 125 |
+
padded_binary_mask = np.pad(binary_mask, pad_width=1, mode='constant', constant_values=0)
|
| 126 |
+
contours = measure.find_contours(padded_binary_mask, 0.5)
|
| 127 |
+
# assert len(contours) == 1 #raise error if there are more than 1 contour
|
| 128 |
+
contour = contours[0]
|
| 129 |
+
contour -= 1 # correct for padding
|
| 130 |
+
contour = close_contour(contour)
|
| 131 |
+
|
| 132 |
+
polygon = measure.approximate_polygon(contour, tol)
|
| 133 |
+
|
| 134 |
+
polygon = np.flip(polygon, axis=1)
|
| 135 |
+
# after padding and subtracting 1 we may get -0.5 points in our polygon. Replace it with 0
|
| 136 |
+
polygon = np.where(polygon>=0, polygon, 0)
|
| 137 |
+
# segmentation = polygon.ravel().tolist()
|
| 138 |
+
|
| 139 |
+
return polygon
|
| 140 |
+
|
| 141 |
+
def measure_shape(binary_mask):
|
| 142 |
+
contour = binary_mask_to_polygon(binary_mask)
|
| 143 |
+
perimeter = np.sum(np.linalg.norm(contour[:-1] - contour[1:], axis=1))
|
| 144 |
+
area = np.sum(binary_mask)
|
| 145 |
+
|
| 146 |
+
return perimeter, area
|
| 147 |
+
|
| 148 |
+
def calculate_gsi(perimeter, height, area):
|
| 149 |
+
a = 0.5*(perimeter - 2*height)
|
| 150 |
+
return 1 - area/(a*height)
|
| 151 |
+
|
| 152 |
+
def object_slice(mask):
|
| 153 |
+
rows = np.any(mask, axis=1)
|
| 154 |
+
cols = np.any(mask, axis=0)
|
| 155 |
+
row_min, row_max = np.where(rows)[0][[0, -1]]
|
| 156 |
+
col_min, col_max = np.where(cols)[0][[0, -1]]
|
| 157 |
+
|
| 158 |
+
# Create a slice object for the bounding box
|
| 159 |
+
bounding_box_slice = (slice(row_min, row_max + 1), slice(col_min, col_max + 1))
|
| 160 |
+
|
| 161 |
+
return bounding_box_slice
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def figure_to_pil(fig):
|
| 165 |
+
buf = BytesIO()
|
| 166 |
+
fig.savefig(buf, format='png')
|
| 167 |
+
buf.seek(0)
|
| 168 |
+
|
| 169 |
+
# Load the image from the buffer as a PIL Image
|
| 170 |
+
image = deepcopy(Image.open(buf))
|
| 171 |
+
|
| 172 |
+
# Close the buffer
|
| 173 |
+
buf.close()
|
| 174 |
+
return image
|
| 175 |
+
|
| 176 |
+
def resize_to(image: Image.Image, s: int=4032, return_factor: bool =False) -> Image.Image:
|
| 177 |
+
w, h = image.size
|
| 178 |
+
longest_size = max(h, w)
|
| 179 |
+
|
| 180 |
+
resize_factor = longest_size / s
|
| 181 |
+
|
| 182 |
+
resized_image = image.resize((int(w/resize_factor), int(h/resize_factor)))
|
| 183 |
+
if return_factor:
|
| 184 |
+
return resized_image, resize_factor
|
| 185 |
+
return resized_image
|
| 186 |
+
|
| 187 |
+
def resize_with_cv2(image, shape, return_pil=True):
|
| 188 |
+
"""resize using cv2 with cv2.INTER_LINEAR - consistent with YOLO"""
|
| 189 |
+
h, w = shape
|
| 190 |
+
if isinstance(image, Image.Image):
|
| 191 |
+
image = np.array(image)
|
| 192 |
+
|
| 193 |
+
resized = cv2.resize(image, (w, h), interpolation=cv2.INTER_LINEAR)
|
| 194 |
+
if return_pil:
|
| 195 |
+
return Image.fromarray(resized)
|
| 196 |
+
else:
|
| 197 |
+
return resized
|
| 198 |
+
|
| 199 |
+
def select_unique_mask(mask):
|
| 200 |
+
"""if mask consists of multiple parts, select the largest"""
|
| 201 |
+
if not np.any(mask): # if mask is empty, return without change
|
| 202 |
+
return mask
|
| 203 |
+
blobs = ndimage.label(mask)[0]
|
| 204 |
+
blob_labels, blob_sizes = np.unique(blobs, return_counts=True)
|
| 205 |
+
best_blob_label = blob_labels[1:][np.argmax(blob_sizes[1:])]
|
| 206 |
+
return blobs == best_blob_label
|
| 207 |
+
|
| 208 |
+
def sliced_mean(x, slice_size):
|
| 209 |
+
cs_y = np.cumsum(x, axis=0)
|
| 210 |
+
cs_y = np.concatenate((np.zeros((1, cs_y.shape[1]), dtype=cs_y.dtype), cs_y), axis=0)
|
| 211 |
+
slices_y = (cs_y[slice_size:] - cs_y[:-slice_size])/slice_size
|
| 212 |
+
cs_xy = np.cumsum(slices_y, axis=1)
|
| 213 |
+
cs_xy = np.concatenate((np.zeros((cs_xy.shape[0], 1), dtype=cs_xy.dtype), cs_xy), axis=1)
|
| 214 |
+
slices_xy = (cs_xy[:,slice_size:] - cs_xy[:,:-slice_size])/slice_size
|
| 215 |
+
return slices_xy
|
| 216 |
+
|
| 217 |
+
def sliced_var(x, slice_size):
|
| 218 |
+
x = x.astype('float64')
|
| 219 |
+
return sliced_mean(x**2, slice_size) - sliced_mean(x, slice_size)**2
|
| 220 |
+
|
| 221 |
+
def calculate_distance_map(mask):
|
| 222 |
+
padded = np.pad(mask, pad_width=1, mode='constant', constant_values=False)
|
| 223 |
+
distance_map_padded = ndimage.distance_transform_edt(padded)
|
| 224 |
+
return distance_map_padded[1:-1,1:-1]
|
| 225 |
+
|
| 226 |
+
def select_samples(granum_image, granum_mask, crop_size=96, n_samples=64, granum_fraction_min=0.75, variance_p=0.):
|
| 227 |
+
granum_occupancy = sliced_mean(granum_mask, crop_size)
|
| 228 |
+
possible_indices = np.stack(np.where(granum_occupancy >= granum_fraction_min), axis=1)
|
| 229 |
+
|
| 230 |
+
if variance_p == 0:
|
| 231 |
+
p = np.ones(len(possible_indices))
|
| 232 |
+
else:
|
| 233 |
+
variance_map = sliced_var(granum_image, crop_size)
|
| 234 |
+
p = variance_map[possible_indices[:,0], possible_indices[:,1]]**variance_p
|
| 235 |
+
p /= np.sum(p)
|
| 236 |
+
|
| 237 |
+
chosen_indices = np.random.choice(
|
| 238 |
+
np.arange(len(possible_indices)),
|
| 239 |
+
min(len(possible_indices), n_samples),
|
| 240 |
+
replace=False,
|
| 241 |
+
p = p
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
crops = []
|
| 245 |
+
for crop_idx, idx in enumerate(chosen_indices):
|
| 246 |
+
crops.append(
|
| 247 |
+
granum_image[
|
| 248 |
+
possible_indices[idx,0]:possible_indices[idx,0]+crop_size,
|
| 249 |
+
possible_indices[idx,1]:possible_indices[idx,1]+crop_size
|
| 250 |
+
]
|
| 251 |
+
)
|
| 252 |
+
return np.array(crops)
|
| 253 |
+
|
| 254 |
+
def calculate_height(mask_oriented): #HACK
|
| 255 |
+
span = mask_oriented.shape[0] - np.argmax(mask_oriented[::-1], axis=0) - np.argmax(mask_oriented, axis=0)
|
| 256 |
+
return np.quantile(span, 0.8)
|
| 257 |
+
|
| 258 |
+
def calculate_diameter(mask_oriented):
|
| 259 |
+
"""returns mean diameter"""
|
| 260 |
+
# calculate 0.25 and 0.75 lines
|
| 261 |
+
vertical_mask = np.any(mask_oriented, axis=1)
|
| 262 |
+
upper_granum_bound = np.argmax(vertical_mask)
|
| 263 |
+
lower_granum_bound = mask_oriented.shape[0] - np.argmax(vertical_mask[::-1])
|
| 264 |
+
upper = round(0.75*upper_granum_bound + 0.25*lower_granum_bound)
|
| 265 |
+
lower = max(upper+1, round(0.25*upper_granum_bound + 0.75*lower_granum_bound))
|
| 266 |
+
valid_rows_slice = slice(upper, lower)
|
| 267 |
+
|
| 268 |
+
# calculate diameters
|
| 269 |
+
span = mask_oriented.shape[1] - np.argmax(mask_oriented[valid_rows_slice,::-1], axis=1) - np.argmax(mask_oriented[valid_rows_slice], axis=1)
|
| 270 |
+
return np.mean(span)
|
| 271 |
+
|
| 272 |
+
def robust_mean(x, q=0.1):
|
| 273 |
+
x_med = np.median(x)
|
| 274 |
+
deviations = abs(x- x_med)
|
| 275 |
+
if max(deviations) == 0:
|
| 276 |
+
mask = np.ones(len(x), dtype='bool')
|
| 277 |
+
else:
|
| 278 |
+
threshold = np.quantile(deviations, 1-q)
|
| 279 |
+
mask = x[deviations<= threshold]
|
| 280 |
+
|
| 281 |
+
return np.mean(x[mask])
|
| 282 |
+
|
| 283 |
+
def rotate_image_and_mask(image, mask, direction):
|
| 284 |
+
mask_oriented = ndimage.rotate(mask.astype('int'), -direction, reshape=True).astype('bool')
|
| 285 |
+
idx_begin_x, idx_end_x = np.where(np.any(mask_oriented, axis=0))[0][np.array([0, -1])]
|
| 286 |
+
idx_begin_y, idx_end_y = np.where(np.any(mask_oriented, axis=1))[0][np.array([0, -1])]
|
| 287 |
+
img_oriented = ndimage.rotate(image, -direction, reshape=True) #[idx_begin_y:idx_end_y, idx_begin_x:idx_end_x]
|
| 288 |
+
return img_oriented, mask_oriented
|
| 289 |
+
|
| 290 |
+
class GranaAnalyser:
|
| 291 |
+
def __init__(self, weights_detector: str, weights_orientation: str, weights_period: str, period_sd_threshold_nm: float=2.5) -> None:
|
| 292 |
+
"""
|
| 293 |
+
Initializes the GranaAnalyser with specified weights for detection and measuring.
|
| 294 |
+
|
| 295 |
+
This method loads the weights for the grana detection and measuring algorithms
|
| 296 |
+
from the specified file paths. It also loads mock data for visualization and
|
| 297 |
+
analysis purposes.
|
| 298 |
+
|
| 299 |
+
Parameters:
|
| 300 |
+
weights_detector (str): The file path to the weights file for the grana detection algorithm.
|
| 301 |
+
weights_orientation (str): The file path to the weights file for the grana orientation algorithm.
|
| 302 |
+
weights_period (str): The file path to the weights file for the grana period algorithm.
|
| 303 |
+
"""
|
| 304 |
+
self.detector = YOLO(weights_detector)
|
| 305 |
+
|
| 306 |
+
self.orienter = PatchedPredictor(
|
| 307 |
+
StripsModelLumenWidth.load_from_checkpoint(weights_orientation, map_location='cpu').eval(),
|
| 308 |
+
normalization = dict(mean=0.250, std=0.135),
|
| 309 |
+
n_samples=32,
|
| 310 |
+
mask=None,
|
| 311 |
+
crop_size=64,
|
| 312 |
+
angle_confidence_threshold=0.2
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
self.measurement_px_per_nm = 1/0.768 # image scale required for measurement
|
| 316 |
+
|
| 317 |
+
self.period_measurer = PeriodMeasurer(
|
| 318 |
+
weights_period,
|
| 319 |
+
px_per_nm=self.measurement_px_per_nm,
|
| 320 |
+
sd_threshold_nm=period_sd_threshold_nm,
|
| 321 |
+
period_threshold_nm_min=14, period_threshold_nm_max=30
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def get_grana_data(self, image, detections, scaler, border_margin=1, min_count=1) -> List[Granum]:
|
| 326 |
+
"""filter detections and create grana data"""
|
| 327 |
+
image_numpy = np.array(image)
|
| 328 |
+
if image_numpy.ndim == 3:
|
| 329 |
+
image_numpy = image_numpy[:,:,0]
|
| 330 |
+
|
| 331 |
+
mask_all = None
|
| 332 |
+
grana = []
|
| 333 |
+
for mask, confidence in zip(
|
| 334 |
+
detections.masks.data.cpu().numpy().astype('bool'),
|
| 335 |
+
detections.boxes.conf.cpu().numpy()
|
| 336 |
+
):
|
| 337 |
+
granum_mask = select_unique_mask(mask[scaler.unpad_slice])
|
| 338 |
+
# check if mask is empty after padding
|
| 339 |
+
if not np.any(granum_mask):
|
| 340 |
+
continue
|
| 341 |
+
granum_mask = ndimage.binary_fill_holes(granum_mask)
|
| 342 |
+
|
| 343 |
+
# check if touches boundary:
|
| 344 |
+
if (np.sum(granum_mask[:border_margin])>min_count) or \
|
| 345 |
+
(np.sum(granum_mask[-border_margin:])>min_count) or \
|
| 346 |
+
(np.sum(granum_mask[:,:border_margin])>min_count) or \
|
| 347 |
+
(np.sum(granum_mask[:,-border_margin:])>min_count):
|
| 348 |
+
|
| 349 |
+
continue
|
| 350 |
+
|
| 351 |
+
# check grana overlap
|
| 352 |
+
if mask_all is None:
|
| 353 |
+
mask_all = granum_mask
|
| 354 |
+
else:
|
| 355 |
+
intersection = mask_all & granum_mask
|
| 356 |
+
|
| 357 |
+
if intersection.sum() >= (granum_mask.sum() * 0.2):
|
| 358 |
+
continue
|
| 359 |
+
mask_all = mask_all | granum_mask
|
| 360 |
+
|
| 361 |
+
granum = Granum(
|
| 362 |
+
image = image,
|
| 363 |
+
mask = granum_mask,
|
| 364 |
+
scaler=scaler,
|
| 365 |
+
detection_confidence=float(confidence)
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
granum.image_numpy = image_numpy
|
| 369 |
+
grana.append(granum)
|
| 370 |
+
return grana
|
| 371 |
+
|
| 372 |
+
def measure_grana(self, grana: List[Granum], measurement_image: np.ndarray) -> List[Granum]:
|
| 373 |
+
"""measure grana: includes orientation detection, period detection and geometric measurements"""
|
| 374 |
+
for granum in grana:
|
| 375 |
+
measurement_mask = resize_with_cv2(granum.mask.astype(np.uint8), measurement_image.shape[:2], return_pil=False).astype('bool')
|
| 376 |
+
|
| 377 |
+
granum.bounding_box_slice = object_slice(measurement_mask)
|
| 378 |
+
granum.image_crop = measurement_image[granum.bounding_box_slice][:,:]
|
| 379 |
+
granum.mask_crop = measurement_mask[granum.bounding_box_slice]
|
| 380 |
+
|
| 381 |
+
# initialize measurements
|
| 382 |
+
granum.measurements = {}
|
| 383 |
+
|
| 384 |
+
# measure shape
|
| 385 |
+
granum.measurements['perimeter px'], granum.measurements['area px'] = measure_shape(granum.mask_crop)
|
| 386 |
+
|
| 387 |
+
# measrure orientation
|
| 388 |
+
orienter_predictions = self.orienter(granum.image_crop, granum.mask_crop)
|
| 389 |
+
granum.measurements['direction'] = orienter_predictions["est_angle"]
|
| 390 |
+
granum.measurements['direction confidence'] = orienter_predictions["est_angle_confidence"]
|
| 391 |
+
|
| 392 |
+
if not np.isnan(granum.measurements["direction"]):
|
| 393 |
+
img_oriented, mask_oriented = rotate_image_and_mask(granum.image_crop, granum.mask_crop, granum.measurements["direction"])
|
| 394 |
+
oriented_granum_slice = object_slice(mask_oriented)
|
| 395 |
+
granum.img_oriented = img_oriented[oriented_granum_slice]
|
| 396 |
+
granum.mask_oriented = mask_oriented[oriented_granum_slice]
|
| 397 |
+
granum.measurements['height px'] = calculate_height(granum.mask_oriented)
|
| 398 |
+
granum.measurements['GSI'] = calculate_gsi(
|
| 399 |
+
granum.measurements['perimeter px'],
|
| 400 |
+
granum.measurements['height px'],
|
| 401 |
+
granum.measurements['area px']
|
| 402 |
+
)
|
| 403 |
+
granum.measurements['diameter px'] = calculate_diameter(granum.mask_oriented)
|
| 404 |
+
|
| 405 |
+
oriented_granum_slice = object_slice(granum.mask_oriented)
|
| 406 |
+
granum.measurements["period nm"], granum.measurements["period SD nm"] = self.period_measurer(granum.img_oriented, granum.mask_oriented)
|
| 407 |
+
|
| 408 |
+
if not pd.isna(granum.measurements['period nm']):
|
| 409 |
+
granum.measurements['Number of layers'] = round(granum.measurements['height px']/ self.measurement_px_per_nm / granum.measurements['period nm'])
|
| 410 |
+
|
| 411 |
+
return grana
|
| 412 |
+
|
| 413 |
+
def extract_grana_data(self, grana: List[Granum]) -> pd.DataFrame:
|
| 414 |
+
"""collect and scale grana data"""
|
| 415 |
+
grana_data = []
|
| 416 |
+
for granum in grana:
|
| 417 |
+
granum_entry = {
|
| 418 |
+
'Granum ID': granum.id,
|
| 419 |
+
'detection confidence': granum.detection_confidence
|
| 420 |
+
}
|
| 421 |
+
# fill with None if absent:
|
| 422 |
+
for key in ['direction', 'Number of layers', 'GSI', 'period nm', 'period SD nm']:
|
| 423 |
+
granum_entry[key] = granum.measurements.get(key, None)
|
| 424 |
+
# scale linearly:
|
| 425 |
+
for key in ['height px', 'diameter px', 'perimeter px', 'perimeter px']:
|
| 426 |
+
granum_entry[f"{key[:-3]} nm"] = granum.measurements.get(key, np.nan) / self.measurement_px_per_nm
|
| 427 |
+
# scale quadratically
|
| 428 |
+
granum_entry['area nm^2'] = granum.measurements['area px'] / self.measurement_px_per_nm**2
|
| 429 |
+
|
| 430 |
+
grana_data.append(granum_entry)
|
| 431 |
+
|
| 432 |
+
return pd.DataFrame(grana_data)
|
| 433 |
+
|
| 434 |
+
def visualize_detections(self, grana: List[Granum], image: Image.Image) -> Image.Image:
|
| 435 |
+
visualization_longer_edge = 1024
|
| 436 |
+
scale = visualization_longer_edge/max(image.size)
|
| 437 |
+
visualization_size = (round(scale*image.size[0]), round(scale*image.size[1]))
|
| 438 |
+
visualization_image = np.array(image.resize(visualization_size).convert('RGB'))
|
| 439 |
+
|
| 440 |
+
if len(grana) > 0:
|
| 441 |
+
grana_mask = resize_with_cv2(
|
| 442 |
+
np.any(np.array([granum.mask for granum in grana]),axis=0).astype(np.uint8),
|
| 443 |
+
visualization_size[::-1],
|
| 444 |
+
return_pil=False
|
| 445 |
+
).astype('bool')
|
| 446 |
+
visualization_image[grana_mask]= (0.7*visualization_image[grana_mask] + 0.3*np.array([[[39, 179, 115]]])).astype(np.uint8)
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
for granum in grana:
|
| 450 |
+
scale = visualization_longer_edge/max(granum.mask.shape)
|
| 451 |
+
y, x = ndimage.center_of_mass(granum.mask)
|
| 452 |
+
cv2.putText(visualization_image, f'{granum.id}', org=(int(x*scale)-10, int(y*scale)+10), fontFace=cv2.FONT_HERSHEY_SIMPLEX , fontScale=1, color=(39, 179, 115),thickness = 2)
|
| 453 |
+
|
| 454 |
+
return Image.fromarray(visualization_image)
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
def generate_grana_images(self, grana: List[Granum], image_name: str ="") -> List[Image.Image]:
|
| 458 |
+
grana_images = {}
|
| 459 |
+
for granum in grana:
|
| 460 |
+
fig, ax = plt.subplots()
|
| 461 |
+
if granum.img_oriented is None:
|
| 462 |
+
image_to_plot = granum.image_crop
|
| 463 |
+
mask_to_plot = granum.mask_crop
|
| 464 |
+
extra_caption = " orientation and period unknown"
|
| 465 |
+
else:
|
| 466 |
+
image_to_plot = granum.img_oriented
|
| 467 |
+
mask_to_plot = granum.mask_oriented
|
| 468 |
+
extra_caption = ""
|
| 469 |
+
|
| 470 |
+
ax.imshow(0.5*255*(~mask_to_plot) +image_to_plot*(1-0.5*(~mask_to_plot)), cmap='gray', vmin=0, vmax=255)
|
| 471 |
+
ax.axis('off')
|
| 472 |
+
ax.set_title(f'[{granum.id}]{image_name}\n{extra_caption}')
|
| 473 |
+
granum_image = figure_to_pil(fig)
|
| 474 |
+
grana_images[granum.id] = granum_image
|
| 475 |
+
plt.close('all')
|
| 476 |
+
|
| 477 |
+
return grana_images
|
| 478 |
+
|
| 479 |
+
def format_data(self, grana_data: pd.DataFrame) -> pd.DataFrame:
|
| 480 |
+
rounding_roles = {'area nm^2': 0, 'perimeter nm': 1, 'diameter nm': 1, 'height nm': 1, 'period nm': 1, 'period SD nm': 2, 'GSI':2, 'direction': 1}
|
| 481 |
+
rounded_data = grana_data.round(rounding_roles)
|
| 482 |
+
columns_order = ['Granum ID', 'File name', 'area nm^2', 'perimeter nm', 'GSI','diameter nm', 'height nm', 'Number of layers','period nm', 'period SD nm', 'direction']
|
| 483 |
+
return rounded_data[columns_order]
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
def aggregate_data(self, grana_data: pd.DataFrame, confidence: Optional[float]=None) -> Dict:
|
| 487 |
+
if confidence is None:
|
| 488 |
+
filtered = grana_data
|
| 489 |
+
else:
|
| 490 |
+
filtered = grana_data.loc[grana_data['aggregated confidence'] >= confidence]
|
| 491 |
+
aggregation = filtered[['area nm^2', 'perimeter nm', 'diameter nm', 'height nm', 'Number of layers', 'period nm', 'GSI']].mean().to_dict()
|
| 492 |
+
aggregation_std = filtered[['area nm^2', 'perimeter nm', 'diameter nm', 'height nm', 'Number of layers', 'period nm', 'GSI']].std().to_dict()
|
| 493 |
+
aggregation_std = {f"{k} std": v for k, v in aggregation_std.items()}
|
| 494 |
+
aggregation_result = {**aggregation, **aggregation_std, 'N grana': len(filtered)}
|
| 495 |
+
return aggregation_result
|
| 496 |
+
|
| 497 |
+
def predict_on_single(self, image: Image.Image, scale: float, detection_confidence: float=0.25, granum_id_start=1, image_name: str = "") -> Tuple[List[Image.Image], pd.DataFrame, List[Image.Image]]:
|
| 498 |
+
"""
|
| 499 |
+
Predicts and aggregates data related to grana using a dictionary of images.
|
| 500 |
+
|
| 501 |
+
Parameters:
|
| 502 |
+
image (Image.Image): PIL Image object to be analyzed
|
| 503 |
+
scale (float): scale of the image: px per nm.
|
| 504 |
+
detection_confidence (float): The detection confidence threshold shape measurement
|
| 505 |
+
|
| 506 |
+
Returns:
|
| 507 |
+
Tuple[Image.Image, pandas.DataFrame, List[Image.Image]]:
|
| 508 |
+
A tuple containing:
|
| 509 |
+
- detection_visualization (Image.Image): PIL image representing
|
| 510 |
+
the detection visualizations.
|
| 511 |
+
- grana_data (pandas.DataFrame): A DataFrame containing the simulated granum data.
|
| 512 |
+
- grana_images (List[Image.Image]): A list of PIL images of the grana.
|
| 513 |
+
"""
|
| 514 |
+
# convert to grayscale
|
| 515 |
+
image = image.convert("L")
|
| 516 |
+
|
| 517 |
+
# detect
|
| 518 |
+
scaler = ScalerPadder(target_size=1024, target_short_edge_min=640)
|
| 519 |
+
scaled_image = scaler.transform(image, px_per_nm=scale)
|
| 520 |
+
detections = self.detector.predict(source=scaled_image, conf=detection_confidence)[0]
|
| 521 |
+
|
| 522 |
+
# get grana data
|
| 523 |
+
grana = self.get_grana_data(image, detections, scaler)
|
| 524 |
+
for granum_id, granum in enumerate(grana, start=granum_id_start):
|
| 525 |
+
granum.id = granum_id
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
# visualize detections
|
| 529 |
+
detection_visualization = self.visualize_detections(grana, image)
|
| 530 |
+
|
| 531 |
+
# measure grana
|
| 532 |
+
measurement_image_resize_factor = self.measurement_px_per_nm / scale
|
| 533 |
+
measurement_image_shape = (
|
| 534 |
+
int(image.size[1]*measurement_image_resize_factor),
|
| 535 |
+
int(image.size[0]*measurement_image_resize_factor)
|
| 536 |
+
)
|
| 537 |
+
measurement_image = resize_with_cv2( # numpy image in scale valid for measurement
|
| 538 |
+
image, measurement_image_shape, return_pil=False
|
| 539 |
+
)
|
| 540 |
+
grana = self.measure_grana(grana, measurement_image)
|
| 541 |
+
|
| 542 |
+
# pandas DataFrame
|
| 543 |
+
grana_data = self.extract_grana_data(grana)
|
| 544 |
+
|
| 545 |
+
# list of PIL images
|
| 546 |
+
grana_images = self.generate_grana_images(grana, image_name=image_name)
|
| 547 |
+
|
| 548 |
+
return detection_visualization, grana_data, grana_images
|
| 549 |
+
|
| 550 |
+
def predict(self, images: Dict[str, Image.Image], scale: float, detection_confidence: float=0.25, parameter_confidence: Optional[float]=None) -> Tuple[List[Image.Image], pd.DataFrame, List[Image.Image], Dict]:
|
| 551 |
+
"""
|
| 552 |
+
Predicts and aggregates data related to grana using a dictionary of images.
|
| 553 |
+
|
| 554 |
+
Parameters:
|
| 555 |
+
images (Dict[str, Image.Image]): A dictionary of PIL Image objects to be analyzed,
|
| 556 |
+
keyed by their names.
|
| 557 |
+
scale (float): scale of the image: px per nm
|
| 558 |
+
detection_confidence (float): The detection confidence threshold shape measurement
|
| 559 |
+
parameter_confidence (float): The confidence threshold used for data aggregation. Only
|
| 560 |
+
data with aggregated confidence above this threshold will
|
| 561 |
+
be considered.
|
| 562 |
+
|
| 563 |
+
Returns:
|
| 564 |
+
Tuple[List[Image.Image], pandas.DataFrame, List[Image.Image], Dict]:
|
| 565 |
+
A tuple containing:
|
| 566 |
+
- detection_visualizations (List[Image.Image]): A list of PIL images representing
|
| 567 |
+
the detection visualizations.
|
| 568 |
+
- grana_data (pandas.DataFrame): A DataFrame containing the simulated granum data.
|
| 569 |
+
- grana_images (List[Image.Image]): A list of PIL images of the grana.
|
| 570 |
+
- aggregated_data (Dict): A dictionary containing the aggregated data results.
|
| 571 |
+
"""
|
| 572 |
+
detection_visualizations_all = {}
|
| 573 |
+
grana_data_all = None
|
| 574 |
+
grana_images_all = {}
|
| 575 |
+
|
| 576 |
+
granum_id_start = 1
|
| 577 |
+
for image_name, image in images.items():
|
| 578 |
+
detection_visualization, grana_data, grana_images = self.predict_on_single(image, scale=scale, detection_confidence=detection_confidence, granum_id_start=granum_id_start, image_name=image_name)
|
| 579 |
+
granum_id_start += len(grana_data)
|
| 580 |
+
detection_visualizations_all[image_name] = detection_visualization
|
| 581 |
+
grana_images_all.update(grana_images)
|
| 582 |
+
|
| 583 |
+
grana_data['File name'] = image_name
|
| 584 |
+
if grana_data_all is None:
|
| 585 |
+
grana_data_all = grana_data
|
| 586 |
+
else:
|
| 587 |
+
# grana_data['Granum ID'] += len(grana_data_all)
|
| 588 |
+
grana_data_all = pd.concat([grana_data_all, grana_data])
|
| 589 |
+
|
| 590 |
+
# dict
|
| 591 |
+
# grana_data_all.to_csv('grana_data_all.csv', index=False)
|
| 592 |
+
aggregated_data = self.aggregate_data(grana_data_all, parameter_confidence)
|
| 593 |
+
|
| 594 |
+
formatted_grana_data = self.format_data(grana_data_all)
|
| 595 |
+
|
| 596 |
+
return detection_visualizations_all, formatted_grana_data, grana_images_all, aggregated_data
|
| 597 |
+
|
| 598 |
+
|
| 599 |
+
class GranaDetector(GranaAnalyser):
|
| 600 |
+
"""supplementary class for grana detection only
|
| 601 |
+
"""
|
| 602 |
+
def __init__(self, weights_detector: str, detector_config: Optional[str] = None, model_type="yolo") -> None:
|
| 603 |
+
|
| 604 |
+
if model_type == "yolo":
|
| 605 |
+
self.detector = YOLO(weights_detector)
|
| 606 |
+
elif model_type == "mmdetection":
|
| 607 |
+
self.detector = MMDetector(model=detector_config, weights=weights_detector)
|
| 608 |
+
else:
|
| 609 |
+
raise NotImplementedError()
|
| 610 |
+
|
| 611 |
+
def predict_on_single(self, image: Image.Image, scale: float, detection_confidence: float=0.25, granum_id_start=1, use_scaling=True, granum_border_margin=1, granum_border_min_count=1, scaler_sizes=(1024, 640)) -> List[Granum]:
|
| 612 |
+
# convert to grayscale
|
| 613 |
+
image = image.convert("L")
|
| 614 |
+
|
| 615 |
+
# detect
|
| 616 |
+
if use_scaling:
|
| 617 |
+
scaler = ScalerPadder(target_size=scaler_sizes[0], target_short_edge_min=scaler_sizes[1])
|
| 618 |
+
else:
|
| 619 |
+
#dummy scaler
|
| 620 |
+
scaler = ScalerPadder(target_size=max(image.size), target_short_edge_min=min(image.size), minimal_pad=0, pad_to_multiply=1)
|
| 621 |
+
scaled_image = scaler.transform(image, scale=scale)
|
| 622 |
+
detections = self.detector.predict(source=scaled_image, conf=detection_confidence)[0]
|
| 623 |
+
|
| 624 |
+
# get grana data
|
| 625 |
+
grana = self.get_grana_data(image, detections, scaler, border_margin=granum_border_margin, min_count=granum_border_min_count)
|
| 626 |
+
for i_granum, granum in enumerate(grana, start=1):
|
| 627 |
+
granum.id = i_granum
|
| 628 |
+
|
| 629 |
+
return grana
|
period_calculation/config.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
import albumentations as A
|
| 3 |
+
from albumentations.pytorch import ToTensorV2
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
transforms = [
|
| 7 |
+
A.Normalize(**{'mean': 0.2845, 'std': 0.1447}, max_pixel_value=1.0),
|
| 8 |
+
# Applies the formula (img - mean * max_pixel_value) / (std * max_pixel_value)
|
| 9 |
+
ToTensorV2()
|
| 10 |
+
]
|
| 11 |
+
|
| 12 |
+
model_config = {
|
| 13 |
+
'receptive_field_height': 220,
|
| 14 |
+
'receptive_field_width': 38,
|
| 15 |
+
'stride_height': 64,
|
| 16 |
+
'stride_width': 2,
|
| 17 |
+
'image_height': 476,
|
| 18 |
+
'image_width': 476}
|
| 19 |
+
|
period_calculation/data_reader.py
ADDED
|
@@ -0,0 +1,861 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
import skimage.io
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import torch
|
| 6 |
+
import scipy
|
| 7 |
+
from PIL import Image, ImageFilter, ImageChops
|
| 8 |
+
# from config import model_config
|
| 9 |
+
from period_calculation.config import model_config
|
| 10 |
+
|
| 11 |
+
# Function to add Gaussian noise
|
| 12 |
+
|
| 13 |
+
def add_microscope_noise(base_image_as_numpy, noise_intensity):
|
| 14 |
+
###### The code below is for adding noise to the image
|
| 15 |
+
# noise intensity is a number between 0 and 1
|
| 16 |
+
# --- priginal implementation was provided by Michał Bykowski
|
| 17 |
+
# --- and adapted
|
| 18 |
+
# This routine works with PIL images and numpy internally (changing formats as it goes)
|
| 19 |
+
# but the input and output are numpy arrays
|
| 20 |
+
|
| 21 |
+
def add_noise(image, mean=0, std_dev=50): # std_dev impacts the amount of noise
|
| 22 |
+
# Generating noise
|
| 23 |
+
noise = np.random.normal(mean, std_dev, (image.height, image.width))
|
| 24 |
+
# Adding noise to the image
|
| 25 |
+
noisy_image = np.array(image) + noise
|
| 26 |
+
# Ensuring the values remain within valid grayscale range
|
| 27 |
+
noisy_image = np.clip(noisy_image, 0, 255)
|
| 28 |
+
return Image.fromarray(noisy_image.astype('uint8'))
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
base_image = Image.fromarray(base_image_as_numpy)
|
| 32 |
+
gray_value = 128
|
| 33 |
+
gray = Image.new('L', base_image.size, color=gray_value)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
gray = add_noise(gray, std_dev=noise_intensity * 76)
|
| 37 |
+
gray = gray.filter(ImageFilter.GaussianBlur(radius=3))
|
| 38 |
+
gray = add_noise(gray, std_dev=noise_intensity * 23)
|
| 39 |
+
gray = gray.filter(ImageFilter.GaussianBlur(radius=2))
|
| 40 |
+
gray = add_noise(gray, std_dev=noise_intensity * 15)
|
| 41 |
+
|
| 42 |
+
# soft light works as in Photoshop
|
| 43 |
+
# Superimposes two images on top of each other using the Soft Light algorithm
|
| 44 |
+
result = ImageChops.soft_light(base_image, gray)
|
| 45 |
+
|
| 46 |
+
return np.array(result)
|
| 47 |
+
|
| 48 |
+
def detect_boundaries(mask, axis):
|
| 49 |
+
# calculate the boundaries of the mask
|
| 50 |
+
#axis = 0 results in x_from, x_to
|
| 51 |
+
#axis = 1 results in y_from, y_to
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
sum = mask.sum(axis=axis)
|
| 55 |
+
|
| 56 |
+
ind_from = min(sum.nonzero()[0])
|
| 57 |
+
ind_to = max(sum.nonzero()[0])
|
| 58 |
+
return ind_from, ind_to
|
| 59 |
+
|
| 60 |
+
def add_symmetric_filling_beyond_mask(img, mask):
|
| 61 |
+
for x in range(img.shape[1]):
|
| 62 |
+
if sum(mask[:, x]) != 0: #if there is at least one nonzero index
|
| 63 |
+
nonzero_indices = mask[:, x].nonzero()[0]
|
| 64 |
+
|
| 65 |
+
y_min = min(nonzero_indices)
|
| 66 |
+
y_max = max(nonzero_indices)
|
| 67 |
+
|
| 68 |
+
if y_max == y_min: #there is only one point
|
| 69 |
+
img[:, x] = img[y_min, x]
|
| 70 |
+
else:
|
| 71 |
+
next = y_min + 1
|
| 72 |
+
step = +1 # we start by going upwards
|
| 73 |
+
for y in reversed(range(y_min)):
|
| 74 |
+
img[y, x] = img[next, x]
|
| 75 |
+
if next == y_max or next == y_min: #we hit the boundaries - we reverse
|
| 76 |
+
step *= -1 #reverse direction
|
| 77 |
+
next += step
|
| 78 |
+
|
| 79 |
+
next = y_max - 1
|
| 80 |
+
step = -1 # we start by going downwards
|
| 81 |
+
for y in range(y_max + 1, img.shape[0]): #we hit the boundaries - we reverse
|
| 82 |
+
img[y, x] = img[next, x]
|
| 83 |
+
if next == y_max or next == y_min:
|
| 84 |
+
step *= -1 # reverse direction
|
| 85 |
+
next += step
|
| 86 |
+
return img
|
| 87 |
+
class AbstractDataset(torch.utils.data.Dataset):
|
| 88 |
+
|
| 89 |
+
def __init__(self,
|
| 90 |
+
model = None,
|
| 91 |
+
transforms=[],
|
| 92 |
+
#### distortions during training ####
|
| 93 |
+
hv_symmetry=True, # True or False
|
| 94 |
+
|
| 95 |
+
min_horizontal_subsampling = 50, # None to turn off; or minimal percentage of horizontal size of the image
|
| 96 |
+
min_vertical_subsampling = 70, # None to turn off; or minimal percentage of vertical size of the image
|
| 97 |
+
max_random_tilt = 3, # None to turn off; or maximum tilt in degrees
|
| 98 |
+
max_add_colors_to_histogram = 10, # 0 to turn off; or points of the histogram to be added
|
| 99 |
+
max_remove_colors_from_histogram = 30, # 0 to turn off; or points of the histogram to be removed
|
| 100 |
+
max_noise_intensity = 3.0, # 0.0 to turn off; or max intensity of the noise
|
| 101 |
+
|
| 102 |
+
gaussian_phase_transforms_epoch=None, # None to turn off; or number of the epoch when the gaussian phase starts
|
| 103 |
+
min_horizontal_subsampling_gaussian_phase = 30, # None to turn off; or minimal percentage of horizontal size of the image
|
| 104 |
+
min_vertical_subsampling_gaussian_phase = 70, # None to turn off; or minimal percentage of vertical size of the image
|
| 105 |
+
max_random_tilt_gaussian_phase = 2, # None to turn off; or maximum tilt in degrees
|
| 106 |
+
max_add_colors_to_histogram_gaussian_phase = 10, # 0 to turn off; or points of the histogram to be added
|
| 107 |
+
max_remove_colors_from_histogram_gaussian_phase = 60, # 0 to turn off; or points of the histogram to be removed
|
| 108 |
+
max_noise_intensity_gaussian_phase = 3.5, # 0.0 to turn off; or max intensity of the noise
|
| 109 |
+
|
| 110 |
+
#### controling variables ####
|
| 111 |
+
transform_level=2, # 0 - no transforms, 1 - only the basic transform, 2 - all transforms, -1 - subsampling for high images
|
| 112 |
+
retain_raw_images=False,
|
| 113 |
+
retain_masks=False):
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
self.model = model # we need that to check epoch number during training
|
| 117 |
+
|
| 118 |
+
self.hv_symmetry = hv_symmetry
|
| 119 |
+
|
| 120 |
+
self.min_horizontal_subsampling = min_horizontal_subsampling
|
| 121 |
+
self.min_vertical_subsampling = min_vertical_subsampling
|
| 122 |
+
self.max_random_tilt = max_random_tilt
|
| 123 |
+
self.max_add_colors_to_histogram = max_add_colors_to_histogram
|
| 124 |
+
self.max_remove_colors_from_histogram = max_remove_colors_from_histogram
|
| 125 |
+
self.max_noise_intensity = max_noise_intensity
|
| 126 |
+
|
| 127 |
+
self.gaussian_phase_transforms_epoch = gaussian_phase_transforms_epoch
|
| 128 |
+
self.min_horizontal_subsampling_gaussian_phase = min_horizontal_subsampling_gaussian_phase
|
| 129 |
+
self.min_vertical_subsampling_gaussian_phase = min_vertical_subsampling_gaussian_phase
|
| 130 |
+
self.max_random_tilt_gaussian_phase = max_random_tilt_gaussian_phase
|
| 131 |
+
self.max_add_colors_to_histogram_gaussian_phase = max_add_colors_to_histogram_gaussian_phase
|
| 132 |
+
self.max_remove_colors_from_histogram_gaussian_phase = max_remove_colors_from_histogram_gaussian_phase
|
| 133 |
+
self.max_noise_intensity_gaussian_phase = max_noise_intensity_gaussian_phase
|
| 134 |
+
|
| 135 |
+
self.image_height = model_config['image_height']
|
| 136 |
+
self.image_width = model_config['image_width']
|
| 137 |
+
|
| 138 |
+
self.transform_level = transform_level
|
| 139 |
+
self.retain_raw_images = retain_raw_images
|
| 140 |
+
self.retain_masks = retain_masks
|
| 141 |
+
self.transforms = transforms
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def get_image_and_mask(self, row):
|
| 145 |
+
raise NotImplementedError("Subclass needs to implement this method")
|
| 146 |
+
|
| 147 |
+
def load_and_transform_image_and_mask(self, row):
|
| 148 |
+
img, mask = self.get_image_and_mask(row)
|
| 149 |
+
|
| 150 |
+
angle = row['angle']
|
| 151 |
+
#check if gaussian phase is on
|
| 152 |
+
if self.gaussian_phase_transforms_epoch is not None and self.model.current_epoch >= self.gaussian_phase_transforms_epoch:
|
| 153 |
+
max_random_tilt = self.max_random_tilt_gaussian_phase
|
| 154 |
+
max_noise_intensity = self.max_noise_intensity_gaussian_phase
|
| 155 |
+
min_horizontal_subsampling = self.min_horizontal_subsampling_gaussian_phase
|
| 156 |
+
min_vertical_subsampling = self.min_vertical_subsampling_gaussian_phase
|
| 157 |
+
max_add_colors_to_histogram = self.max_add_colors_to_histogram_gaussian_phase
|
| 158 |
+
max_remove_colors_from_histogram = self.max_remove_colors_from_histogram_gaussian_phase
|
| 159 |
+
else:
|
| 160 |
+
max_random_tilt = self.max_random_tilt
|
| 161 |
+
max_noise_intensity = self.max_noise_intensity
|
| 162 |
+
min_horizontal_subsampling = self.min_horizontal_subsampling
|
| 163 |
+
min_vertical_subsampling = self.min_vertical_subsampling
|
| 164 |
+
max_add_colors_to_histogram = self.max_add_colors_to_histogram
|
| 165 |
+
max_remove_colors_from_histogram = self.max_remove_colors_from_histogram
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
if self.transform_level >= 2 and max_random_tilt is not None:
|
| 173 |
+
####### RANDOM TILT
|
| 174 |
+
angle += np.random.uniform(-max_random_tilt, max_random_tilt)
|
| 175 |
+
|
| 176 |
+
img = scipy.ndimage.rotate(img, 90 - angle, reshape=True, order=3) # HORIZONTAL POSITION
|
| 177 |
+
###the part of the image that is added after rotation is all black (0s)
|
| 178 |
+
mask = scipy.ndimage.rotate(mask, 90 - angle, reshape=True, order = 0) # HORIZONTAL POSITION
|
| 179 |
+
#order = 0 is the nearest neighbor interpolation, so the mask is not interpolated
|
| 180 |
+
|
| 181 |
+
############# CROP
|
| 182 |
+
x_from, x_to = detect_boundaries(mask, axis=0)
|
| 183 |
+
y_from, y_to = detect_boundaries(mask, axis=1)
|
| 184 |
+
|
| 185 |
+
#crop the image to the verical and horizontal limits.
|
| 186 |
+
img = img[y_from:(y_to + 1), x_from:(x_to + 1)]
|
| 187 |
+
mask = mask[y_from:(y_to + 1), x_from:(x_to + 1)]
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
img_raw = img.copy()
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
if self.transform_level >= 2:
|
| 194 |
+
########## ADDING NOISE
|
| 195 |
+
|
| 196 |
+
if max_noise_intensity > 0.0:
|
| 197 |
+
noise_intensity = np.random.random() * max_noise_intensity
|
| 198 |
+
noisy_img = add_microscope_noise(img, noise_intensity=noise_intensity)
|
| 199 |
+
img[mask] = noisy_img[mask]
|
| 200 |
+
|
| 201 |
+
if self.transform_level == -1:
|
| 202 |
+
#special case where we take at most 300 middle pixels from the image
|
| 203 |
+
# (vertical subsampling)
|
| 204 |
+
# to handle very latge images correctly
|
| 205 |
+
x_from, x_to = detect_boundaries(mask, axis=0)
|
| 206 |
+
y_from, y_to = detect_boundaries(mask, axis=1)
|
| 207 |
+
|
| 208 |
+
y_size = y_to - y_from + 1
|
| 209 |
+
|
| 210 |
+
random_size = 300 #not so random, ay?
|
| 211 |
+
|
| 212 |
+
if y_size > random_size:
|
| 213 |
+
random_start = y_size // 2 - random_size // 2
|
| 214 |
+
|
| 215 |
+
y_from = random_start
|
| 216 |
+
y_to = random_start + random_size - 1
|
| 217 |
+
|
| 218 |
+
img = img[y_from:(y_to + 1), x_from:(x_to + 1)]
|
| 219 |
+
mask = mask[y_from:(y_to + 1), x_from:(x_to + 1)]
|
| 220 |
+
|
| 221 |
+
# recrop the image if necessary
|
| 222 |
+
# -- even after only horizontal subsampling it may be necessary to recrop the image
|
| 223 |
+
|
| 224 |
+
x_from, x_to = detect_boundaries(mask, axis=0)
|
| 225 |
+
y_from, y_to = detect_boundaries(mask, axis=1)
|
| 226 |
+
|
| 227 |
+
img = img[y_from:(y_to + 1), x_from:(x_to + 1)]
|
| 228 |
+
mask = mask[y_from:(y_to + 1), x_from:(x_to + 1)]
|
| 229 |
+
|
| 230 |
+
if self.transform_level >= 1:
|
| 231 |
+
############## HORIZONTAL SUBSAMPLING
|
| 232 |
+
if min_horizontal_subsampling is not None:
|
| 233 |
+
x_size = x_to - x_from + 1
|
| 234 |
+
|
| 235 |
+
# add some random horizontal shift
|
| 236 |
+
random_size = np.random.randint(x_size * min_horizontal_subsampling / 100.0, x_size + 1)
|
| 237 |
+
random_start = np.random.randint(0, x_size - random_size + 1) + x_from
|
| 238 |
+
|
| 239 |
+
img = img[:, random_start:(random_start + random_size)]
|
| 240 |
+
mask = mask[:, random_start:(random_start + random_size)]
|
| 241 |
+
|
| 242 |
+
############ VERTICAL SUBSAMPLING
|
| 243 |
+
if min_vertical_subsampling is not None:
|
| 244 |
+
|
| 245 |
+
x_from, x_to = detect_boundaries(mask, axis=0)
|
| 246 |
+
y_from, y_to = detect_boundaries(mask, axis=1)
|
| 247 |
+
|
| 248 |
+
y_size = y_to - y_from + 1
|
| 249 |
+
|
| 250 |
+
random_size = np.random.randint(y_size * min_vertical_subsampling / 100.0, y_size + 1)
|
| 251 |
+
random_start = np.random.randint(0, y_size - random_size + 1) + y_from
|
| 252 |
+
|
| 253 |
+
y_from = random_start
|
| 254 |
+
y_to = random_start + random_size - 1
|
| 255 |
+
|
| 256 |
+
img = img[y_from:(y_to + 1), x_from:(x_to + 1)]
|
| 257 |
+
mask = mask[y_from:(y_to + 1), x_from:(x_to + 1)]
|
| 258 |
+
|
| 259 |
+
if min_horizontal_subsampling is not None or min_vertical_subsampling is not None:
|
| 260 |
+
#recrop the image if necessary
|
| 261 |
+
# -- even after only horizontal subsampling it may be necessary to recrop the image
|
| 262 |
+
|
| 263 |
+
x_from, x_to = detect_boundaries(mask, axis=0)
|
| 264 |
+
y_from, y_to = detect_boundaries(mask, axis=1)
|
| 265 |
+
|
| 266 |
+
img = img[y_from:(y_to + 1), x_from:(x_to + 1)]
|
| 267 |
+
mask = mask[y_from:(y_to + 1), x_from:(x_to + 1)]
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
######### ADD SYMMETRIC FILLING OF THE IMAGE BEYOND THE MASK
|
| 271 |
+
#img = add_symmetric_filling_beyond_mask(img, mask)
|
| 272 |
+
#This leaves holes in the image, so we will not use it
|
| 273 |
+
|
| 274 |
+
#plt.imshow(img)
|
| 275 |
+
#plt.show()
|
| 276 |
+
######### HORIZONTAL AND VERTICAL SYMMETRY.
|
| 277 |
+
# When superimposed, the result is 180 degree rotation
|
| 278 |
+
if self.transform_level >= 1 and self.hv_symmetry:
|
| 279 |
+
for axis in range(2):
|
| 280 |
+
if np.random.randint(0, 2) % 2 == 0:
|
| 281 |
+
img = np.flip(img, axis = axis)
|
| 282 |
+
mask = np.flip(mask, axis = axis)
|
| 283 |
+
#plt.imshow(img)
|
| 284 |
+
#plt.show()
|
| 285 |
+
|
| 286 |
+
if self.transform_level >= 2 and (max_add_colors_to_histogram > 0 or max_remove_colors_from_histogram > 0):
|
| 287 |
+
lower_bound = np.random.randint(-max_add_colors_to_histogram, max_remove_colors_from_histogram + 1)
|
| 288 |
+
upper_bound = np.random.randint(255 - max_remove_colors_from_histogram, 255 + max_add_colors_to_histogram + 1)
|
| 289 |
+
# first clip the values outstanding from the range (lower_bound -- upper_bound)
|
| 290 |
+
img[mask] = np.clip(img[mask], lower_bound, upper_bound)
|
| 291 |
+
# the range (lower_bound -- upper_bound) gets mapped to the range (0--255)
|
| 292 |
+
# but only in a portion of the image where mask = True
|
| 293 |
+
img[mask] = np.interp(img[mask], (lower_bound, upper_bound), (0, 255)).astype(np.uint8)
|
| 294 |
+
|
| 295 |
+
#### since preserve_range in skimage.transform.resize is set to False, the image
|
| 296 |
+
#### will be converted to float. Consult:
|
| 297 |
+
# https://scikit-image.org/docs/stable/api/skimage.transform.html#skimage.transform.resize
|
| 298 |
+
# https://scikit-image.org/docs/dev/user_guide/data_types.html
|
| 299 |
+
|
| 300 |
+
# In our case the image gets conveted to floats ranging 0-1
|
| 301 |
+
old_height = img.shape[0]
|
| 302 |
+
img = skimage.transform.resize(img, (self.image_height, self.image_width), order=3)
|
| 303 |
+
new_height = img.shape[0]
|
| 304 |
+
mask = skimage.transform.resize(mask, (self.image_height, self.image_width), order=0, preserve_range=True)
|
| 305 |
+
# order = 0 is the nearest neighbor interpolation, so the mask is not interpolated
|
| 306 |
+
|
| 307 |
+
scale_factor = new_height / old_height
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
#plt.imshow(img)
|
| 311 |
+
#plt.show()
|
| 312 |
+
#plt.imshow(mask)
|
| 313 |
+
#plt.show()
|
| 314 |
+
return img, mask, scale_factor, img_raw
|
| 315 |
+
|
| 316 |
+
def get_annotations_row(self, idx):
|
| 317 |
+
raise NotImplementedError("Subclass needs to implement this method")
|
| 318 |
+
|
| 319 |
+
def __getitem__(self, idx):
|
| 320 |
+
row = self.get_annotations_row(idx)
|
| 321 |
+
|
| 322 |
+
image, mask, scale_factor, image_raw = self.load_and_transform_image_and_mask(row)
|
| 323 |
+
|
| 324 |
+
image_data = {
|
| 325 |
+
'image': image,
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
for transform in self.transforms:
|
| 329 |
+
image_data = transform(**image_data)
|
| 330 |
+
# transform operates on image field ONLY of image_data, and returns a dictionary with the same keys
|
| 331 |
+
|
| 332 |
+
ret_dict = {
|
| 333 |
+
'image': image_data['image'],
|
| 334 |
+
'period_px': torch.tensor(row['period_nm'] * scale_factor * row['px_per_nm'], dtype=torch.float32),
|
| 335 |
+
'filename': row['granum_image'],
|
| 336 |
+
'px_per_nm': row['px_per_nm'],
|
| 337 |
+
'scale': scale_factor, # the scale factor is used to calculate the true period error
|
| 338 |
+
# (before scale) in losses and metrics
|
| 339 |
+
'neutral': -self.transforms[0].mean/self.transforms[0].std #value of 0 after the scale transform
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
if self.retain_raw_images:
|
| 343 |
+
ret_dict['image_raw'] = image_raw
|
| 344 |
+
|
| 345 |
+
if self.retain_masks:
|
| 346 |
+
ret_dict['mask'] = mask
|
| 347 |
+
|
| 348 |
+
return ret_dict
|
| 349 |
+
|
| 350 |
+
def __len__(self):
|
| 351 |
+
raise NotImplementedError("Subclass needs to implement this method")
|
| 352 |
+
|
| 353 |
+
class ImageDataset(AbstractDataset):
|
| 354 |
+
def __init__(self, annotations, data_dir: Path, *args, **kwargs):
|
| 355 |
+
super().__init__(*args, **kwargs)
|
| 356 |
+
self.data_dir = Path(data_dir)
|
| 357 |
+
|
| 358 |
+
self.id = 1
|
| 359 |
+
|
| 360 |
+
if isinstance(annotations, str):
|
| 361 |
+
annotations = data_dir / annotations #make it a Path object relative to data_dir
|
| 362 |
+
|
| 363 |
+
if isinstance(annotations, Path):
|
| 364 |
+
self.annotations = pd.read_csv(data_dir / annotations)
|
| 365 |
+
no_period = ['27_k7 [1]_4.png']
|
| 366 |
+
del_img = ['38_k42[1]_19.png', 'n6363_araLL_60kx_6 [1]_0.png', '27_hs8 [1]_5.png', '27_k7 [1]_20.png',
|
| 367 |
+
'F1_1_60kx_01 [1]_2.png']
|
| 368 |
+
self.annotations = self.annotations[~self.annotations['granum_image'].isin(no_period)]
|
| 369 |
+
self.annotations = self.annotations[~self.annotations['granum_image'].isin(del_img)]
|
| 370 |
+
else:
|
| 371 |
+
self.annotations = annotations
|
| 372 |
+
|
| 373 |
+
def get_image_and_mask(self, row):
|
| 374 |
+
filename = row['granum_image']
|
| 375 |
+
img_path = self.data_dir / filename
|
| 376 |
+
img_raw = skimage.io.imread(img_path)
|
| 377 |
+
|
| 378 |
+
img = img_raw[:, :, 0] # all three channels are equal, with the exception
|
| 379 |
+
# of the last channel which is the full blue (0,0,255) for outside the mask (so blue channel is 255, red and green are 0)
|
| 380 |
+
mask = (img_raw != (0, 0, 255)).any(axis=2)
|
| 381 |
+
return img, mask
|
| 382 |
+
|
| 383 |
+
def get_annotations_row(self, idx):
|
| 384 |
+
row = self.annotations.iloc[idx].to_dict()
|
| 385 |
+
row['idx'] = idx
|
| 386 |
+
return row
|
| 387 |
+
|
| 388 |
+
def __len__(self):
|
| 389 |
+
return len(self.annotations)
|
| 390 |
+
|
| 391 |
+
class ArtificialDataset(AbstractDataset):
|
| 392 |
+
def __init__(self,
|
| 393 |
+
min_period = 20,
|
| 394 |
+
max_period = 140,
|
| 395 |
+
white_fraction_min = 0.15,
|
| 396 |
+
white_fraction_max=0.45,
|
| 397 |
+
|
| 398 |
+
noise_min_sd = 0.0,
|
| 399 |
+
noise_max_sd = 100.0,
|
| 400 |
+
noise_max_sd_everywhere = 20.0, # 20.0
|
| 401 |
+
leftovers_max = 5,
|
| 402 |
+
|
| 403 |
+
get_real_masks_dataset = None, #None or instance of ImageDataset
|
| 404 |
+
*args, **kwargs):
|
| 405 |
+
super().__init__(*args, **kwargs)
|
| 406 |
+
self.id = 0
|
| 407 |
+
self.min_period = min_period
|
| 408 |
+
self.max_period = max_period
|
| 409 |
+
self.white_fraction_min = white_fraction_min
|
| 410 |
+
self.white_fraction_max = white_fraction_max
|
| 411 |
+
|
| 412 |
+
self.receptive_field_height = model_config['receptive_field_height']
|
| 413 |
+
self.stride_height = model_config['stride_height']
|
| 414 |
+
self.receptive_field_width = model_config['receptive_field_width']
|
| 415 |
+
self.stride_width = model_config['stride_width']
|
| 416 |
+
|
| 417 |
+
self.noise_min_sd = noise_min_sd
|
| 418 |
+
self.noise_max_sd = noise_max_sd
|
| 419 |
+
self.noise_max_sd_everywhere = noise_max_sd_everywhere
|
| 420 |
+
|
| 421 |
+
self.leftovers_max = leftovers_max
|
| 422 |
+
|
| 423 |
+
self.get_real_masks_dataset = get_real_masks_dataset
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
def get_image_and_mask(self, row):
|
| 427 |
+
# generate a rectangular image of black and white horizontal stripes
|
| 428 |
+
# with black stripes varying with white stripes
|
| 429 |
+
|
| 430 |
+
period_px = row['period_nm'] * row['px_per_nm']
|
| 431 |
+
# white occupying 5-20 % of a total period (white+black)
|
| 432 |
+
white_px = np.random.randint(period_px * self.white_fraction_min, period_px * self.white_fraction_max + 1)
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
# mask is rectangle of True values
|
| 436 |
+
img = np.zeros((self.image_height, self.image_width), dtype=np.uint8)
|
| 437 |
+
mask = np.ones((self.image_height, self.image_width), dtype=bool)
|
| 438 |
+
black_px = period_px - white_px
|
| 439 |
+
random_start = np.random.randint(0, period_px+1)
|
| 440 |
+
for i in range(self.image_height):
|
| 441 |
+
if (random_start+i) % (black_px + white_px) < black_px:
|
| 442 |
+
# sample width with random numbers from 0 to 101
|
| 443 |
+
img[i, :] = np.random.randint(0, 101, self.image_width)
|
| 444 |
+
else:
|
| 445 |
+
#sample width with random numbers from 156 to 255
|
| 446 |
+
img[i, :] = np.random.randint(156, 256, self.image_width)
|
| 447 |
+
|
| 448 |
+
if self.noise_max_sd_everywhere > self.noise_min_sd:
|
| 449 |
+
sd = np.random.uniform(self.noise_min_sd, self.noise_max_sd_everywhere)
|
| 450 |
+
noise = np.random.normal(0, sd, (self.image_height, self.image_width))
|
| 451 |
+
img = np.clip(img+noise.astype(img.dtype), 0, 255)
|
| 452 |
+
|
| 453 |
+
if self.noise_max_sd > self.noise_min_sd:
|
| 454 |
+
# there is also a metagrid in the image
|
| 455 |
+
# consisting of overlapping receptive fields of size 190x42
|
| 456 |
+
# with stride 64x4
|
| 457 |
+
# the metagrid is 5x102
|
| 458 |
+
overlapping_fields_count_height = (self.image_height - self.receptive_field_height) // self.stride_height + 1
|
| 459 |
+
overlapping_fields_count_width = (self.image_width - self.receptive_field_width) // self.stride_width + 1
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
sd = np.random.uniform(self.noise_min_sd, self.noise_max_sd)
|
| 463 |
+
noise = np.random.normal(0, sd, (self.image_height, self.image_width))
|
| 464 |
+
|
| 465 |
+
#there will be some left-over metagrid rectangles
|
| 466 |
+
leftovers_count = np.random.randint(1, self.leftovers_max)
|
| 467 |
+
for i in range(leftovers_count):
|
| 468 |
+
metagrid_row = np.random.randint(0, overlapping_fields_count_height)
|
| 469 |
+
metagrid_col = np.random.randint(0, overlapping_fields_count_width)
|
| 470 |
+
#zero-out the noise inside the selected metagrid
|
| 471 |
+
noise[metagrid_row * self.stride_height:metagrid_row * self.stride_height + self.receptive_field_height + 1, \
|
| 472 |
+
metagrid_col * self.stride_width :metagrid_col * self.stride_width + self.receptive_field_width + 1] = 0
|
| 473 |
+
|
| 474 |
+
#add noise to the image
|
| 475 |
+
img = np.clip(img+noise.astype(img.dtype), 0, 255)
|
| 476 |
+
|
| 477 |
+
if self.get_real_masks_dataset is not None:
|
| 478 |
+
ret_dict = self.get_real_masks_dataset.__getitem__(row['idx'] % len(self.get_real_masks_dataset))
|
| 479 |
+
mask = ret_dict['mask'] #this mask is already sized target height-by-width
|
| 480 |
+
|
| 481 |
+
img[mask == False] = 0
|
| 482 |
+
|
| 483 |
+
return img, mask
|
| 484 |
+
|
| 485 |
+
def get_annotations_row(self, idx):
|
| 486 |
+
return {'idx': idx,
|
| 487 |
+
'period_nm': np.random.randint(self.min_period, self.max_period),
|
| 488 |
+
'px_per_nm': 1.0,
|
| 489 |
+
'granum_image': 'artificial_%d.png' % idx,
|
| 490 |
+
'angle': 90}
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
def __len__(self):
|
| 494 |
+
return 237 # number of samples as in real data in the train set (70% of 339 is 237,3)
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
class AdHocDataset(AbstractDataset):
|
| 498 |
+
def __init__(self, images_masks_pxpernm: list[tuple[np.ndarray, np.ndarray, float]], *args, **kwargs):
|
| 499 |
+
super().__init__(*args, **kwargs)
|
| 500 |
+
self.data = images_masks_pxpernm
|
| 501 |
+
|
| 502 |
+
def __len__(self):
|
| 503 |
+
return len(self.data)
|
| 504 |
+
|
| 505 |
+
def __getitem__(self, idx):
|
| 506 |
+
image, mask, px_per_nm = self.data[idx]
|
| 507 |
+
|
| 508 |
+
image, mask, scale_factor, image_raw = self.load_and_transform_image_and_mask(image, mask)
|
| 509 |
+
|
| 510 |
+
image_data = {
|
| 511 |
+
'image': image,
|
| 512 |
+
}
|
| 513 |
+
|
| 514 |
+
for transform in self.transforms:
|
| 515 |
+
image_data = transform(**image_data)
|
| 516 |
+
# transform operates on image field ONLY of image_data, and returns a dictionary with the same keys
|
| 517 |
+
|
| 518 |
+
ret_dict = {
|
| 519 |
+
'image': image_data['image'],
|
| 520 |
+
'period_px': torch.tensor(0, dtype=torch.float32),
|
| 521 |
+
'filename': str(idx),
|
| 522 |
+
'px_per_nm': px_per_nm,
|
| 523 |
+
'scale': scale_factor, # the scale factor is used to calculate the true period error
|
| 524 |
+
# (before scale) in losses and metrics
|
| 525 |
+
'neutral': -self.transforms[0].mean/self.transforms[0].std #value of 0 after the scale transform
|
| 526 |
+
}
|
| 527 |
+
|
| 528 |
+
if self.retain_raw_images:
|
| 529 |
+
ret_dict['image_raw'] = image_raw
|
| 530 |
+
|
| 531 |
+
if self.retain_masks:
|
| 532 |
+
ret_dict['mask'] = mask
|
| 533 |
+
|
| 534 |
+
return ret_dict
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
def load_and_transform_image_and_mask(self, img, mask):
|
| 538 |
+
|
| 539 |
+
angle = 90
|
| 540 |
+
#check if gaussian phase is on
|
| 541 |
+
if self.gaussian_phase_transforms_epoch is not None and self.model.current_epoch >= self.gaussian_phase_transforms_epoch:
|
| 542 |
+
max_random_tilt = self.max_random_tilt_gaussian_phase
|
| 543 |
+
max_noise_intensity = self.max_noise_intensity_gaussian_phase
|
| 544 |
+
min_horizontal_subsampling = self.min_horizontal_subsampling_gaussian_phase
|
| 545 |
+
min_vertical_subsampling = self.min_vertical_subsampling_gaussian_phase
|
| 546 |
+
max_add_colors_to_histogram = self.max_add_colors_to_histogram_gaussian_phase
|
| 547 |
+
max_remove_colors_from_histogram = self.max_remove_colors_from_histogram_gaussian_phase
|
| 548 |
+
else:
|
| 549 |
+
max_random_tilt = self.max_random_tilt
|
| 550 |
+
max_noise_intensity = self.max_noise_intensity
|
| 551 |
+
min_horizontal_subsampling = self.min_horizontal_subsampling
|
| 552 |
+
min_vertical_subsampling = self.min_vertical_subsampling
|
| 553 |
+
max_add_colors_to_histogram = self.max_add_colors_to_histogram
|
| 554 |
+
max_remove_colors_from_histogram = self.max_remove_colors_from_histogram
|
| 555 |
+
|
| 556 |
+
|
| 557 |
+
if self.transform_level >= 2 and max_random_tilt is not None:
|
| 558 |
+
####### RANDOM TILT
|
| 559 |
+
angle += np.random.uniform(-max_random_tilt, max_random_tilt)
|
| 560 |
+
|
| 561 |
+
|
| 562 |
+
img = scipy.ndimage.rotate(img, 90 - angle, reshape=True, order=3) # HORIZONTAL POSITION
|
| 563 |
+
###the part of the image that is added after rotation is all black (0s)
|
| 564 |
+
mask = scipy.ndimage.rotate(mask, 90 - angle, reshape=True, order = 0) # HORIZONTAL POSITION
|
| 565 |
+
#order = 0 is the nearest neighbor interpolation, so the mask is not interpolated
|
| 566 |
+
|
| 567 |
+
############# CROP
|
| 568 |
+
x_from, x_to = detect_boundaries(mask, axis=0)
|
| 569 |
+
y_from, y_to = detect_boundaries(mask, axis=1)
|
| 570 |
+
|
| 571 |
+
#crop the image to the verical and horizontal limits.
|
| 572 |
+
img = img[y_from:(y_to + 1), x_from:(x_to + 1)]
|
| 573 |
+
mask = mask[y_from:(y_to + 1), x_from:(x_to + 1)]
|
| 574 |
+
|
| 575 |
+
|
| 576 |
+
img_raw = img.copy()
|
| 577 |
+
|
| 578 |
+
|
| 579 |
+
if self.transform_level >= 2:
|
| 580 |
+
########## ADDING NOISE
|
| 581 |
+
|
| 582 |
+
if max_noise_intensity > 0.0:
|
| 583 |
+
noise_intensity = np.random.random() * max_noise_intensity
|
| 584 |
+
noisy_img = add_microscope_noise(img, noise_intensity=noise_intensity)
|
| 585 |
+
img[mask] = noisy_img[mask]
|
| 586 |
+
|
| 587 |
+
if self.transform_level == -1:
|
| 588 |
+
#special case where we take at most 300 middle pixels from the image
|
| 589 |
+
# (vertical subsampling)
|
| 590 |
+
# to handle very latge images correctly
|
| 591 |
+
x_from, x_to = detect_boundaries(mask, axis=0)
|
| 592 |
+
y_from, y_to = detect_boundaries(mask, axis=1)
|
| 593 |
+
|
| 594 |
+
y_size = y_to - y_from + 1
|
| 595 |
+
|
| 596 |
+
random_size = 300 #not so random, ay?
|
| 597 |
+
|
| 598 |
+
if y_size > random_size:
|
| 599 |
+
random_start = y_size // 2 - random_size // 2
|
| 600 |
+
|
| 601 |
+
y_from = random_start
|
| 602 |
+
y_to = random_start + random_size - 1
|
| 603 |
+
|
| 604 |
+
img = img[y_from:(y_to + 1), x_from:(x_to + 1)]
|
| 605 |
+
mask = mask[y_from:(y_to + 1), x_from:(x_to + 1)]
|
| 606 |
+
|
| 607 |
+
# recrop the image if necessary
|
| 608 |
+
# -- even after only horizontal subsampling it may be necessary to recrop the image
|
| 609 |
+
|
| 610 |
+
x_from, x_to = detect_boundaries(mask, axis=0)
|
| 611 |
+
y_from, y_to = detect_boundaries(mask, axis=1)
|
| 612 |
+
|
| 613 |
+
img = img[y_from:(y_to + 1), x_from:(x_to + 1)]
|
| 614 |
+
mask = mask[y_from:(y_to + 1), x_from:(x_to + 1)]
|
| 615 |
+
|
| 616 |
+
if self.transform_level >= 1:
|
| 617 |
+
############## HORIZONTAL SUBSAMPLING
|
| 618 |
+
if min_horizontal_subsampling is not None:
|
| 619 |
+
x_size = x_to - x_from + 1
|
| 620 |
+
|
| 621 |
+
# add some random horizontal shift
|
| 622 |
+
random_size = np.random.randint(x_size * min_horizontal_subsampling / 100.0, x_size + 1)
|
| 623 |
+
random_start = np.random.randint(0, x_size - random_size + 1) + x_from
|
| 624 |
+
|
| 625 |
+
img = img[:, random_start:(random_start + random_size)]
|
| 626 |
+
mask = mask[:, random_start:(random_start + random_size)]
|
| 627 |
+
|
| 628 |
+
############ VERTICAL SUBSAMPLING
|
| 629 |
+
if min_vertical_subsampling is not None:
|
| 630 |
+
|
| 631 |
+
x_from, x_to = detect_boundaries(mask, axis=0)
|
| 632 |
+
y_from, y_to = detect_boundaries(mask, axis=1)
|
| 633 |
+
|
| 634 |
+
y_size = y_to - y_from + 1
|
| 635 |
+
|
| 636 |
+
random_size = np.random.randint(y_size * min_vertical_subsampling / 100.0, y_size + 1)
|
| 637 |
+
random_start = np.random.randint(0, y_size - random_size + 1) + y_from
|
| 638 |
+
|
| 639 |
+
y_from = random_start
|
| 640 |
+
y_to = random_start + random_size - 1
|
| 641 |
+
|
| 642 |
+
img = img[y_from:(y_to + 1), x_from:(x_to + 1)]
|
| 643 |
+
mask = mask[y_from:(y_to + 1), x_from:(x_to + 1)]
|
| 644 |
+
|
| 645 |
+
if min_horizontal_subsampling is not None or min_vertical_subsampling is not None:
|
| 646 |
+
#recrop the image if necessary
|
| 647 |
+
# -- even after only horizontal subsampling it may be necessary to recrop the image
|
| 648 |
+
|
| 649 |
+
x_from, x_to = detect_boundaries(mask, axis=0)
|
| 650 |
+
y_from, y_to = detect_boundaries(mask, axis=1)
|
| 651 |
+
|
| 652 |
+
img = img[y_from:(y_to + 1), x_from:(x_to + 1)]
|
| 653 |
+
mask = mask[y_from:(y_to + 1), x_from:(x_to + 1)]
|
| 654 |
+
|
| 655 |
+
|
| 656 |
+
######### ADD SYMMETRIC FILLING OF THE IMAGE BEYOND THE MASK
|
| 657 |
+
#img = add_symmetric_filling_beyond_mask(img, mask)
|
| 658 |
+
#This leaves holes in the image, so we will not use it
|
| 659 |
+
|
| 660 |
+
#plt.imshow(img)
|
| 661 |
+
#plt.show()
|
| 662 |
+
######### HORIZONTAL AND VERTICAL SYMMETRY.
|
| 663 |
+
# When superimposed, the result is 180 degree rotation
|
| 664 |
+
if self.transform_level >= 1 and self.hv_symmetry:
|
| 665 |
+
for axis in range(2):
|
| 666 |
+
if np.random.randint(0, 2) % 2 == 0:
|
| 667 |
+
img = np.flip(img, axis = axis)
|
| 668 |
+
mask = np.flip(mask, axis = axis)
|
| 669 |
+
#plt.imshow(img)
|
| 670 |
+
#plt.show()
|
| 671 |
+
|
| 672 |
+
if self.transform_level >= 2 and (max_add_colors_to_histogram > 0 or max_remove_colors_from_histogram > 0):
|
| 673 |
+
lower_bound = np.random.randint(-max_add_colors_to_histogram, max_remove_colors_from_histogram + 1)
|
| 674 |
+
upper_bound = np.random.randint(255 - max_remove_colors_from_histogram, 255 + max_add_colors_to_histogram + 1)
|
| 675 |
+
# first clip the values outstanding from the range (lower_bound -- upper_bound)
|
| 676 |
+
img[mask] = np.clip(img[mask], lower_bound, upper_bound)
|
| 677 |
+
# the range (lower_bound -- upper_bound) gets mapped to the range (0--255)
|
| 678 |
+
# but only in a portion of the image where mask = True
|
| 679 |
+
img[mask] = np.interp(img[mask], (lower_bound, upper_bound), (0, 255)).astype(np.uint8)
|
| 680 |
+
|
| 681 |
+
#### since preserve_range in skimage.transform.resize is set to False, the image
|
| 682 |
+
#### will be converted to float. Consult:
|
| 683 |
+
# https://scikit-image.org/docs/stable/api/skimage.transform.html#skimage.transform.resize
|
| 684 |
+
# https://scikit-image.org/docs/dev/user_guide/data_types.html
|
| 685 |
+
|
| 686 |
+
# In our case the image gets conveted to floats ranging 0-1
|
| 687 |
+
old_height = img.shape[0]
|
| 688 |
+
img = skimage.transform.resize(img, (self.image_height, self.image_width), order=3)
|
| 689 |
+
new_height = img.shape[0]
|
| 690 |
+
mask = skimage.transform.resize(mask, (self.image_height, self.image_width), order=0, preserve_range=True)
|
| 691 |
+
# order = 0 is the nearest neighbor interpolation, so the mask is not interpolated
|
| 692 |
+
|
| 693 |
+
scale_factor = new_height / old_height
|
| 694 |
+
|
| 695 |
+
|
| 696 |
+
#plt.imshow(img)
|
| 697 |
+
#plt.show()
|
| 698 |
+
#plt.imshow(mask)
|
| 699 |
+
#plt.show()
|
| 700 |
+
return img, mask, scale_factor, img_raw
|
| 701 |
+
|
| 702 |
+
|
| 703 |
+
class AdHocDataset2(AbstractDataset):
|
| 704 |
+
def __init__(self, images_masks_pxpernm: list[tuple[np.ndarray, np.ndarray, float]], *args, **kwargs):
|
| 705 |
+
super().__init__(*args, **kwargs)
|
| 706 |
+
self.data = images_masks_pxpernm
|
| 707 |
+
|
| 708 |
+
def __len__(self):
|
| 709 |
+
return len(self.data)
|
| 710 |
+
|
| 711 |
+
def __getitem__(self, idx):
|
| 712 |
+
image, mask, px_per_nm = self.data[idx]
|
| 713 |
+
|
| 714 |
+
image, mask, scale_factor, image_raw = self.load_and_transform_image_and_mask(image, mask)
|
| 715 |
+
|
| 716 |
+
image_data = {
|
| 717 |
+
'image': image,
|
| 718 |
+
}
|
| 719 |
+
|
| 720 |
+
for transform in self.transforms:
|
| 721 |
+
image_data = transform(**image_data)
|
| 722 |
+
# transform operates on image field ONLY of image_data, and returns a dictionary with the same keys
|
| 723 |
+
|
| 724 |
+
ret_dict = {
|
| 725 |
+
'image': image_data['image'],
|
| 726 |
+
'scale': scale_factor, # the scale factor is used to calculate the true period error
|
| 727 |
+
# (before scale) in losses and metrics
|
| 728 |
+
'neutral': -self.transforms[0].mean/self.transforms[0].std #value of 0 after the scale transform
|
| 729 |
+
}
|
| 730 |
+
|
| 731 |
+
return ret_dict
|
| 732 |
+
|
| 733 |
+
|
| 734 |
+
def load_and_transform_image_and_mask(self, img, mask):
|
| 735 |
+
|
| 736 |
+
img_raw = img.copy()
|
| 737 |
+
|
| 738 |
+
if self.transform_level == -1:
|
| 739 |
+
#special case where we take at most 300 middle pixels from the image
|
| 740 |
+
# (vertical subsampling)
|
| 741 |
+
# to handle very latge images correctly
|
| 742 |
+
x_from, x_to = detect_boundaries(mask, axis=0)
|
| 743 |
+
y_from, y_to = detect_boundaries(mask, axis=1)
|
| 744 |
+
|
| 745 |
+
y_size = y_to - y_from + 1
|
| 746 |
+
|
| 747 |
+
max_size = 300
|
| 748 |
+
|
| 749 |
+
if y_size > max_size:
|
| 750 |
+
random_start = y_size // 2 - max_size // 2
|
| 751 |
+
|
| 752 |
+
y_from = random_start
|
| 753 |
+
y_to = random_start + max_size - 1
|
| 754 |
+
|
| 755 |
+
img = img[y_from:(y_to + 1), x_from:(x_to + 1)]
|
| 756 |
+
mask = mask[y_from:(y_to + 1), x_from:(x_to + 1)]
|
| 757 |
+
|
| 758 |
+
# recrop the image if necessary
|
| 759 |
+
# -- even after only horizontal subsampling it may be necessary to recrop the image
|
| 760 |
+
|
| 761 |
+
x_from, x_to = detect_boundaries(mask, axis=0)
|
| 762 |
+
y_from, y_to = detect_boundaries(mask, axis=1)
|
| 763 |
+
|
| 764 |
+
img = img[y_from:(y_to + 1), x_from:(x_to + 1)]
|
| 765 |
+
mask = mask[y_from:(y_to + 1), x_from:(x_to + 1)]
|
| 766 |
+
|
| 767 |
+
|
| 768 |
+
#### since preserve_range in skimage.transform.resize is set to False, the image
|
| 769 |
+
#### will be converted to float. Consult:
|
| 770 |
+
# https://scikit-image.org/docs/stable/api/skimage.transform.html#skimage.transform.resize
|
| 771 |
+
# https://scikit-image.org/docs/dev/user_guide/data_types.html
|
| 772 |
+
|
| 773 |
+
# In our case the image gets conveted to floats ranging 0-1
|
| 774 |
+
old_height = img.shape[0]
|
| 775 |
+
img = skimage.transform.resize(img, (self.image_height, self.image_width), order=3)
|
| 776 |
+
new_height = img.shape[0]
|
| 777 |
+
mask = skimage.transform.resize(mask, (self.image_height, self.image_width), order=0, preserve_range=True)
|
| 778 |
+
# order = 0 is the nearest neighbor interpolation, so the mask is not interpolated
|
| 779 |
+
|
| 780 |
+
scale_factor = new_height / old_height
|
| 781 |
+
|
| 782 |
+
return img, mask, scale_factor, img_raw
|
| 783 |
+
|
| 784 |
+
class AdHocDataset3(AbstractDataset):
|
| 785 |
+
def __init__(self, images_and_masks: list[tuple[np.ndarray, np.ndarray]], *args, **kwargs):
|
| 786 |
+
super().__init__(*args, **kwargs)
|
| 787 |
+
self.data = images_and_masks
|
| 788 |
+
|
| 789 |
+
def __len__(self):
|
| 790 |
+
return len(self.data)
|
| 791 |
+
|
| 792 |
+
def __getitem__(self, idx):
|
| 793 |
+
image, mask = self.data[idx]
|
| 794 |
+
|
| 795 |
+
image, mask, scale_factor = self.load_and_transform_image_and_mask(image, mask)
|
| 796 |
+
|
| 797 |
+
image_data = {
|
| 798 |
+
'image': image,
|
| 799 |
+
}
|
| 800 |
+
|
| 801 |
+
for transform in self.transforms:
|
| 802 |
+
image_data = transform(**image_data)
|
| 803 |
+
# transform operates on image field ONLY of image_data, and returns a dictionary with the same keys
|
| 804 |
+
|
| 805 |
+
ret_dict = {
|
| 806 |
+
'image': image_data['image'],
|
| 807 |
+
'scale': scale_factor, # the scale factor is used to calculate the true period error
|
| 808 |
+
# (before scale) in losses and metrics
|
| 809 |
+
#value of 0 after the scale transform
|
| 810 |
+
}
|
| 811 |
+
|
| 812 |
+
return ret_dict
|
| 813 |
+
|
| 814 |
+
|
| 815 |
+
def load_and_transform_image_and_mask(self, img, mask):
|
| 816 |
+
|
| 817 |
+
if self.transform_level == -1:
|
| 818 |
+
#special case where we take at most 300 middle pixels from the image
|
| 819 |
+
# (vertical subsampling)
|
| 820 |
+
# to handle very latge images correctly
|
| 821 |
+
x_from, x_to = detect_boundaries(mask, axis=0)
|
| 822 |
+
y_from, y_to = detect_boundaries(mask, axis=1)
|
| 823 |
+
|
| 824 |
+
y_size = y_to - y_from + 1
|
| 825 |
+
|
| 826 |
+
max_size = 300
|
| 827 |
+
|
| 828 |
+
if y_size > max_size:
|
| 829 |
+
random_start = y_size // 2 - max_size // 2
|
| 830 |
+
|
| 831 |
+
y_from = random_start
|
| 832 |
+
y_to = random_start + max_size - 1
|
| 833 |
+
|
| 834 |
+
img = img[y_from:(y_to + 1), x_from:(x_to + 1)]
|
| 835 |
+
mask = mask[y_from:(y_to + 1), x_from:(x_to + 1)]
|
| 836 |
+
|
| 837 |
+
# recrop the image if necessary
|
| 838 |
+
# -- even after only horizontal subsampling it may be necessary to recrop the image
|
| 839 |
+
|
| 840 |
+
x_from, x_to = detect_boundaries(mask, axis=0)
|
| 841 |
+
y_from, y_to = detect_boundaries(mask, axis=1)
|
| 842 |
+
|
| 843 |
+
img = img[y_from:(y_to + 1), x_from:(x_to + 1)]
|
| 844 |
+
mask = mask[y_from:(y_to + 1), x_from:(x_to + 1)]
|
| 845 |
+
|
| 846 |
+
|
| 847 |
+
#### since preserve_range in skimage.transform.resize is set to False, the image
|
| 848 |
+
#### will be converted to float. Consult:
|
| 849 |
+
# https://scikit-image.org/docs/stable/api/skimage.transform.html#skimage.transform.resize
|
| 850 |
+
# https://scikit-image.org/docs/dev/user_guide/data_types.html
|
| 851 |
+
|
| 852 |
+
# In our case the image gets conveted to floats ranging 0-1
|
| 853 |
+
old_height = img.shape[0]
|
| 854 |
+
img = skimage.transform.resize(img, (self.image_height, self.image_width), order=3)
|
| 855 |
+
new_height = img.shape[0]
|
| 856 |
+
mask = skimage.transform.resize(mask, (self.image_height, self.image_width), order=0, preserve_range=True)
|
| 857 |
+
# order = 0 is the nearest neighbor interpolation, so the mask is not interpolated
|
| 858 |
+
|
| 859 |
+
scale_factor = new_height / old_height
|
| 860 |
+
|
| 861 |
+
return img, mask, scale_factor
|
period_calculation/image_transforms.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import cv2
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def np_batched_radon(image_batch):
|
| 7 |
+
#image_batch: torch tensor, batch_size x 1 x img_size x img_size
|
| 8 |
+
# squeeze order #1 and transform to numpy
|
| 9 |
+
|
| 10 |
+
image_batch = image_batch.squeeze(1).cpu().numpy()
|
| 11 |
+
|
| 12 |
+
batch_size, img_size = image_batch.shape[:2]
|
| 13 |
+
if batch_size > 512: # limit batch size to 512 because cv2.warpAffine fails for batch> 512
|
| 14 |
+
return np.concatenate([np_batched_radon(image_batch[i:i+512]) for i in range(0,batch_size,512)], axis=0)
|
| 15 |
+
theta = np.arange(180)
|
| 16 |
+
radon_image = np.zeros((image_batch.shape[0], img_size, len(theta)),
|
| 17 |
+
dtype='float32')
|
| 18 |
+
|
| 19 |
+
for i, angle in enumerate(theta):
|
| 20 |
+
M = cv2.getRotationMatrix2D(((img_size-1)/2.0,(img_size-1)/2.0),angle,1)
|
| 21 |
+
rotated = cv2.warpAffine(np.transpose(image_batch, (1, 2, 0)),M,(img_size,img_size))
|
| 22 |
+
|
| 23 |
+
#plt.imshow(rotated[:,:,0])
|
| 24 |
+
#plt.show()
|
| 25 |
+
|
| 26 |
+
if batch_size == 1: # cv2.warpAffine cancels batch dimension if equal to 1
|
| 27 |
+
rotated = rotated[:,:, np.newaxis]
|
| 28 |
+
rotated = np.transpose(rotated, (2, 0, 1)) / 224.0
|
| 29 |
+
#rotated = rotated / np.array(255, dtype='float32')
|
| 30 |
+
radon_image[:, :, i] = rotated.sum(axis=1)
|
| 31 |
+
|
| 32 |
+
#plot the image
|
| 33 |
+
|
| 34 |
+
# plt.imshow(radon_image[0])
|
| 35 |
+
# plt.show()
|
| 36 |
+
|
| 37 |
+
return radon_image
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def torch_batched_radon(image_batch, neutral_value):
|
| 41 |
+
#image_batch: batch_size x 1 x img_size x img_size
|
| 42 |
+
#np_batched_radon(image_batch - neutral_value)
|
| 43 |
+
|
| 44 |
+
image_batch = image_batch - neutral_value # so the 0 value is neutral
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
batch_size = image_batch.shape[0]
|
| 48 |
+
img_size = image_batch.shape[2]
|
| 49 |
+
|
| 50 |
+
theta = np.arange(180) # we don't need torch here, we will evaluate individual angles below
|
| 51 |
+
|
| 52 |
+
radon_image = torch.zeros((batch_size, 1, img_size, len(theta)), dtype=torch.float, device=image_batch.device)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
for i, angle in enumerate(theta):
|
| 56 |
+
#M = cv2.getRotationMatrix2D(((img_size-1)/2.0,(img_size-1)/2.0),angle,1)
|
| 57 |
+
#calculate the same rotation matrix but with torch:
|
| 58 |
+
M = torch.tensor(cv2.getRotationMatrix2D(((img_size-1)/2.0,(img_size-1)/2.0),angle,1)).to(image_batch.device, dtype=torch.float32)
|
| 59 |
+
angle = torch.tensor((angle+90)/180.0*np.pi)
|
| 60 |
+
M1 = torch.tensor([[torch.sin(angle), torch.cos(angle), 0],
|
| 61 |
+
[torch.cos(angle), -torch.sin(angle), 0]]).to(image_batch.device, dtype=torch.float32)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# we need to add a batch dimension to the rotation matrix
|
| 65 |
+
M1 = M1.repeat(batch_size, 1, 1)
|
| 66 |
+
|
| 67 |
+
grid = torch.nn.functional.affine_grid(M1, image_batch.shape, align_corners=False)
|
| 68 |
+
rotated = torch.nn.functional.grid_sample(image_batch, grid, mode='bilinear', padding_mode='zeros', align_corners=False)
|
| 69 |
+
rotated = rotated.squeeze(1)
|
| 70 |
+
|
| 71 |
+
#plt.imshow(rotated[0].cpu().numpy())
|
| 72 |
+
#plt.show()
|
| 73 |
+
|
| 74 |
+
radon_image[:, 0, :, i] = rotated.sum(axis=1) / 224.0 + neutral_value
|
| 75 |
+
|
| 76 |
+
#plt.imshow(radon_image[0, 0].cpu().numpy())
|
| 77 |
+
#plt.show()
|
| 78 |
+
|
| 79 |
+
return radon_image
|
period_calculation/models/abstract_model.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytorch_lightning as pl
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
from hydra.utils import instantiate
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class AbstractModel(pl.LightningModule):
|
| 8 |
+
def __init__(self,
|
| 9 |
+
lr=0.001,
|
| 10 |
+
optimizer_hparams=dict(),
|
| 11 |
+
scheduler=dict(classname='MultiStepLR', kwargs=dict(milestones=[100, 150], gamma=0.1))
|
| 12 |
+
):
|
| 13 |
+
super().__init__()
|
| 14 |
+
# Exports the hyperparameters to a YAML file, and create "self.hparams" namespace
|
| 15 |
+
self.save_hyperparameters()
|
| 16 |
+
|
| 17 |
+
def forward(self, x):
|
| 18 |
+
raise NotImplementedError("Subclass needs to implement this method")
|
| 19 |
+
|
| 20 |
+
def configure_optimizers(self):
|
| 21 |
+
# AdamW is Adam with a correct implementation of weight decay (see here
|
| 22 |
+
# for details: https://arxiv.org/pdf/1711.05101.pdf)
|
| 23 |
+
print("configuring the optimizer and lr scheduler with learning rate=%.5f"%self.hparams.lr)
|
| 24 |
+
optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.lr, **self.hparams.optimizer_hparams)
|
| 25 |
+
# scheduler = getattr(torch.optim.lr_scheduler, self.hparams.lr_hparams['classname'])(optimizer, **self.hparams.lr_hparams['kwargs'])
|
| 26 |
+
if self.hparams.scheduler is not None:
|
| 27 |
+
scheduler = instantiate({**self.hparams.scheduler, '_partial_': True})(optimizer)
|
| 28 |
+
|
| 29 |
+
return [optimizer], [scheduler]
|
| 30 |
+
else:
|
| 31 |
+
return optimizer
|
| 32 |
+
|
| 33 |
+
def additional_losses(self):
|
| 34 |
+
"""get additional_losses"""
|
| 35 |
+
return torch.zeros((1))
|
| 36 |
+
|
| 37 |
+
def process_batch_supervised(self, batch):
|
| 38 |
+
"""get predictions, losses and mean errors (MAE)"""
|
| 39 |
+
raise NotImplementedError("Subclass needs to implement this method")
|
| 40 |
+
|
| 41 |
+
def log_all(self, losses, metrics, prefix=''):
|
| 42 |
+
for k, v in losses.items():
|
| 43 |
+
self.log(f'{prefix}{k}_loss', v.item() if isinstance(v, torch.Tensor) else v)
|
| 44 |
+
|
| 45 |
+
for k, v in metrics.items():
|
| 46 |
+
self.log(f'{prefix}{k}', v.item() if isinstance(v, torch.Tensor) else v)
|
| 47 |
+
|
| 48 |
+
def training_step(self, batch, batch_idx):
|
| 49 |
+
# "batch" is the output of the training data loader.
|
| 50 |
+
preds, losses, metrics = self.process_batch_supervised(batch)
|
| 51 |
+
self.log_all(losses, metrics, prefix='train_')
|
| 52 |
+
|
| 53 |
+
return losses['final']
|
| 54 |
+
|
| 55 |
+
def validation_step(self, batch, batch_idx):
|
| 56 |
+
preds, losses, metrics = self.process_batch_supervised(batch)
|
| 57 |
+
self.log_all(losses, metrics, prefix='val_')
|
| 58 |
+
|
| 59 |
+
def test_step(self, batch, batch_idx):
|
| 60 |
+
preds, losses, metrics = self.process_batch_supervised(batch)
|
| 61 |
+
self.log_all(losses, metrics, prefix='test_')
|
period_calculation/models/gauss_model.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytorch_lightning as pl
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
from period_calculation.models.abstract_model import AbstractModel
|
| 6 |
+
|
| 7 |
+
from period_calculation.config import model_config # this is a dictionary with the model configuration
|
| 8 |
+
|
| 9 |
+
class GaussPeriodModel(AbstractModel):
|
| 10 |
+
def __init__(self,
|
| 11 |
+
*args, **kwargs
|
| 12 |
+
):
|
| 13 |
+
super().__init__(*args, **kwargs)
|
| 14 |
+
|
| 15 |
+
self.seq = torch.nn.Sequential(
|
| 16 |
+
torch.nn.Conv2d(1, 32, (3, 3), stride=(1, 1), padding=(0, 0)),
|
| 17 |
+
torch.nn.ReLU(),
|
| 18 |
+
torch.nn.Conv2d(32, 32, (3, 3), stride=(1, 1), padding=(0, 0)),
|
| 19 |
+
torch.nn.MaxPool2d((2, 2), stride=(2, 2)),
|
| 20 |
+
|
| 21 |
+
torch.nn.Conv2d(32, 32, (3, 3), stride=(1, 1), padding=(0, 0)),
|
| 22 |
+
torch.nn.ReLU(),
|
| 23 |
+
torch.nn.Conv2d(32, 32, (3, 3), stride=(1, 1), padding=(0, 0)),
|
| 24 |
+
torch.nn.MaxPool2d((2, 1), stride=(2, 1)),
|
| 25 |
+
|
| 26 |
+
torch.nn.Conv2d(32, 32, (3, 3), stride=(1, 1), padding=(0, 0)),
|
| 27 |
+
torch.nn.ReLU(),
|
| 28 |
+
torch.nn.Conv2d(32, 32, (3, 3), stride=(1, 1), padding=(0, 0)),
|
| 29 |
+
torch.nn.MaxPool2d((2, 1), stride=(2, 1)),
|
| 30 |
+
|
| 31 |
+
torch.nn.Conv2d(32, 32, (3, 3), stride=(1, 1), padding=(0, 0)),
|
| 32 |
+
torch.nn.ReLU(),
|
| 33 |
+
torch.nn.Conv2d(32, 32, (3, 3), stride=(1, 1), padding=(0, 0)),
|
| 34 |
+
torch.nn.MaxPool2d((2, 1), stride=(2, 1)),
|
| 35 |
+
|
| 36 |
+
torch.nn.Conv2d(32, 32, (3, 3), stride=(1, 1), padding=(0, 0)),
|
| 37 |
+
torch.nn.MaxPool2d((2, 1), stride=(2, 1)),
|
| 38 |
+
|
| 39 |
+
torch.nn.Conv2d(32, 32, (3, 3), stride=(1, 1), padding=(0, 0)),
|
| 40 |
+
torch.nn.MaxPool2d((2, 1), stride=(2, 1)),
|
| 41 |
+
|
| 42 |
+
torch.nn.Dropout(0.1)
|
| 43 |
+
)
|
| 44 |
+
self.query = torch.nn.Parameter(torch.empty(1, 2, 32)) #two heads only
|
| 45 |
+
torch.nn.init.xavier_normal_(self.query)
|
| 46 |
+
|
| 47 |
+
self.linear1 = torch.nn.Linear(64, 8)
|
| 48 |
+
self.linear2 = torch.nn.Linear(8, 1)
|
| 49 |
+
|
| 50 |
+
self.query_sd = torch.nn.Parameter(torch.empty(1, 2, 32))
|
| 51 |
+
torch.nn.init.xavier_normal_(self.query_sd)
|
| 52 |
+
|
| 53 |
+
self.linear_sd1 = torch.nn.Linear(64, 8)
|
| 54 |
+
self.linear_sd2 = torch.nn.Linear(8, 1)
|
| 55 |
+
self.relu = torch.nn.ReLU()
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def copy_network_trunk(self, model):
|
| 59 |
+
# https://discuss.pytorch.org/t/copy-weights-from-only-one-layer-of-one-model-to-another-model-with-different-structure/153419
|
| 60 |
+
with torch.no_grad():
|
| 61 |
+
for i, layer in enumerate(model.seq):
|
| 62 |
+
if i%2 == 0 and i!=20: #convolutional layers are the ones with even indexes with the exeption of the 20th (=dropout)
|
| 63 |
+
self.seq[i].weight.copy_(layer.weight)
|
| 64 |
+
self.seq[i].bias.copy_(layer.bias)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def copy_final_layers(self, model):
|
| 68 |
+
# https://discuss.pytorch.org/t/copy-weights-from-only-one-layer-of-one-model-to-another-model-with-different-structure/153419
|
| 69 |
+
|
| 70 |
+
with torch.no_grad():
|
| 71 |
+
self.linear1.weight.copy_(model.linear1.weight)
|
| 72 |
+
self.linear1.bias.copy_(model.linear1.bias)
|
| 73 |
+
|
| 74 |
+
self.linear2.weight.copy_(model.linear2.weight)
|
| 75 |
+
self.linear2.bias.copy_(model.linear2.bias)
|
| 76 |
+
|
| 77 |
+
self.query.copy_(model.query)
|
| 78 |
+
|
| 79 |
+
def duplicate_final_layers(self):
|
| 80 |
+
# https://discuss.pytorch.org/t/copy-weights-from-only-one-layer-of-one-model-to-another-model-with-different-structure/153419
|
| 81 |
+
|
| 82 |
+
with torch.no_grad():
|
| 83 |
+
self.linear_sd1.weight.copy_(self.linear1.weight)
|
| 84 |
+
self.linear_sd1.bias.copy_(self.linear1.bias)
|
| 85 |
+
|
| 86 |
+
self.linear_sd2.weight.copy_(self.linear2.weight/10)
|
| 87 |
+
self.linear_sd2.bias.copy_(self.linear2.bias/10)
|
| 88 |
+
|
| 89 |
+
self.query_sd.copy_(self.query)
|
| 90 |
+
|
| 91 |
+
def forward(self, x, neutral=None, return_raw=False):
|
| 92 |
+
#https://www.nature.com/articles/s41598-023-43852-x
|
| 93 |
+
|
| 94 |
+
# x is sized # batch x 1 x 476 x 476
|
| 95 |
+
|
| 96 |
+
preds = self.seq(x) # batch x 32 x 5 x 220
|
| 97 |
+
features = torch.flatten(preds, 2) # batch x 32 x 1100
|
| 98 |
+
|
| 99 |
+
# attention
|
| 100 |
+
energy = self.query @ features # batch x 2 x 1100
|
| 101 |
+
weights = torch.nn.functional.softmax(energy, 2) # batch x 2 x 1100
|
| 102 |
+
response = features @ weights.transpose(1, 2) # batch x 32 x 2
|
| 103 |
+
response = torch.flatten(response, 1) # batch x 64
|
| 104 |
+
|
| 105 |
+
preds = self.linear1(response) # batch x 8
|
| 106 |
+
preds = self.linear2(self.relu(preds)) # batch x 1
|
| 107 |
+
|
| 108 |
+
# attention sd
|
| 109 |
+
|
| 110 |
+
energy_sd = self.query_sd @ features # batch x 2 x 1100
|
| 111 |
+
weights_sd = torch.nn.functional.softmax(energy_sd, 2) # batch x 2 x 1100
|
| 112 |
+
response_sd = features @ weights_sd.transpose(1, 2) # batch x 32 x 2
|
| 113 |
+
response_sd = torch.flatten(response_sd, 1) # batch x 64
|
| 114 |
+
|
| 115 |
+
preds_sd = self.linear_sd1(response_sd) # batch x 8
|
| 116 |
+
preds_sd = self.linear_sd2(self.relu(preds_sd)) # batch x 1
|
| 117 |
+
|
| 118 |
+
outputs = [ model_config['receptive_field_height']/(preds[:,0]) , torch.exp(preds_sd[:,0]) ]
|
| 119 |
+
if return_raw:
|
| 120 |
+
outputs.append(preds)
|
| 121 |
+
outputs.append(preds_sd)
|
| 122 |
+
outputs.append(weights)
|
| 123 |
+
outputs.append(weights_sd)
|
| 124 |
+
|
| 125 |
+
return tuple(outputs)
|
| 126 |
+
|
| 127 |
+
def additional_losses(self):
|
| 128 |
+
"""get additional_losses"""
|
| 129 |
+
# additional (orthogonal) loss
|
| 130 |
+
# we multiply the two heads and later the MSE loss (towards zero) sums the result in L2 norm
|
| 131 |
+
# the idea is that the scalar product of two orthogonal vectors is zero
|
| 132 |
+
scalar_product = torch.cat((self.query[0, 0] * self.query[0, 1], self.query_sd[0, 0] * self.query_sd[0, 1]), dim=0)
|
| 133 |
+
orthogonal_loss = torch.nn.functional.mse_loss(scalar_product, torch.zeros_like(scalar_product))
|
| 134 |
+
return orthogonal_loss
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def process_batch_supervised(self, batch):
|
| 139 |
+
"""get predictions, losses and mean errors (metrics)"""
|
| 140 |
+
|
| 141 |
+
# get predictions
|
| 142 |
+
preds = {}
|
| 143 |
+
preds['period_px'], preds['sd'] = self.forward(batch['image'], batch['neutral'][0], return_raw=False) # preds: period, sd, orto, preds_raw
|
| 144 |
+
|
| 145 |
+
# https://johaupt.github.io/blog/NN_prediction_uncertainty.html
|
| 146 |
+
# calculate losses
|
| 147 |
+
mse_period_px = torch.nn.functional.mse_loss(batch['period_px'],
|
| 148 |
+
preds['period_px'])
|
| 149 |
+
|
| 150 |
+
gaussian_nll = torch.nn.functional.gaussian_nll_loss(batch['period_px'],
|
| 151 |
+
preds['period_px'],
|
| 152 |
+
(preds['sd']) ** 2)
|
| 153 |
+
|
| 154 |
+
orthogonal_weight = 0.1
|
| 155 |
+
orthogonal_loss = self.additional_losses()
|
| 156 |
+
length_of_the_first_phase = 0
|
| 157 |
+
if self.current_epoch < length_of_the_first_phase:
|
| 158 |
+
#transition from MSE to Gaussian Negative Log Likelihood with sin/cos over first epochs
|
| 159 |
+
angle = torch.tensor((self.current_epoch) / (length_of_the_first_phase) * np.pi / 2)
|
| 160 |
+
total_loss = (gaussian_nll) * torch.sin(angle) + (mse_period_px) * torch.cos(angle) + orthogonal_weight * orthogonal_loss
|
| 161 |
+
else:
|
| 162 |
+
total_loss = gaussian_nll + orthogonal_weight * orthogonal_loss
|
| 163 |
+
|
| 164 |
+
losses = {
|
| 165 |
+
'gaussian_nll': gaussian_nll,
|
| 166 |
+
'mse_period_px': mse_period_px,
|
| 167 |
+
'orthogonal': orthogonal_loss,
|
| 168 |
+
'final': total_loss
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
# calculate mean errors
|
| 172 |
+
ground_truth_detached = batch['period_px'].detach().cpu().numpy()
|
| 173 |
+
print(ground_truth_detached)
|
| 174 |
+
mean_detached = preds['period_px'].detach().cpu().numpy()
|
| 175 |
+
print(mean_detached)
|
| 176 |
+
sd_detached = preds['sd'].detach().cpu().numpy()
|
| 177 |
+
print("==>", sd_detached)
|
| 178 |
+
px_per_nm_detached = batch['px_per_nm'].detach().cpu().numpy()
|
| 179 |
+
scale_detached = batch['scale'].detach().cpu().numpy()
|
| 180 |
+
|
| 181 |
+
period_px_difference = np.mean(abs(
|
| 182 |
+
ground_truth_detached - mean_detached
|
| 183 |
+
))
|
| 184 |
+
|
| 185 |
+
#initiate both with python array with 5 zeros
|
| 186 |
+
true_period_px_difference = [0.0] * 5
|
| 187 |
+
true_period_nm_difference = [0.0] * 5
|
| 188 |
+
|
| 189 |
+
for i, dist in enumerate([1.0, 2.0, 3.0, 4.0, 5.0]):
|
| 190 |
+
true_period_px_difference[i] = (np.sum(abs(
|
| 191 |
+
((ground_truth_detached - mean_detached) / scale_detached) * (sd_detached / scale_detached <dist))) \
|
| 192 |
+
/ np.sum(sd_detached / scale_detached < dist)) if np.sum(sd_detached / scale_detached < dist) > 0 else 0
|
| 193 |
+
|
| 194 |
+
for i, dist in enumerate([1.0, 2.0, 3.0, 4.0, 5.0]):
|
| 195 |
+
true_period_nm_difference[i] = (np.sum(abs(
|
| 196 |
+
((ground_truth_detached - mean_detached) / (scale_detached * px_per_nm_detached)) * (sd_detached / scale_detached <dist))) \
|
| 197 |
+
/ np.sum(sd_detached / scale_detached < dist)) if np.sum(sd_detached / scale_detached < dist) > 0 else 0
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
true_period_px_difference_all = np.mean(abs(
|
| 201 |
+
((ground_truth_detached - mean_detached) / scale_detached)
|
| 202 |
+
))
|
| 203 |
+
|
| 204 |
+
true_period_nm_difference_all = np.mean(abs(
|
| 205 |
+
((ground_truth_detached - mean_detached) / (scale_detached * px_per_nm_detached))
|
| 206 |
+
))
|
| 207 |
+
|
| 208 |
+
metrics = {
|
| 209 |
+
'period_px': period_px_difference,
|
| 210 |
+
'true_period_px_1': true_period_px_difference[0],
|
| 211 |
+
'true_period_px_2': true_period_px_difference[1],
|
| 212 |
+
'true_period_px_3': true_period_px_difference[2],
|
| 213 |
+
'true_period_px_4': true_period_px_difference[3],
|
| 214 |
+
'true_period_px_5': true_period_px_difference[4],
|
| 215 |
+
'true_period_px_all': true_period_px_difference_all,
|
| 216 |
+
|
| 217 |
+
'true_period_nm_1': true_period_nm_difference[0],
|
| 218 |
+
'true_period_nm_2': true_period_nm_difference[1],
|
| 219 |
+
'true_period_nm_3': true_period_nm_difference[2],
|
| 220 |
+
'true_period_nm_4': true_period_nm_difference[3],
|
| 221 |
+
'true_period_nm_5': true_period_nm_difference[4],
|
| 222 |
+
'true_period_nm_all': true_period_nm_difference_all,
|
| 223 |
+
|
| 224 |
+
'count_1': np.sum(sd_detached / scale_detached < 1.0),
|
| 225 |
+
'count_2': np.sum(sd_detached / scale_detached < 2.0),
|
| 226 |
+
'count_3': np.sum(sd_detached / scale_detached < 3.0),
|
| 227 |
+
'count_4': np.sum(sd_detached / scale_detached < 4.0),
|
| 228 |
+
'count_5': np.sum(sd_detached / scale_detached < 5.0),
|
| 229 |
+
|
| 230 |
+
'count_all': np.sum(sd_detached > 0.0),
|
| 231 |
+
'mean_sd': np.mean(sd_detached),
|
| 232 |
+
'sd_sd': np.std(sd_detached),
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
return preds, losses, metrics
|
| 236 |
+
|
| 237 |
+
|
period_calculation/period_measurer.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
import albumentations as A
|
| 4 |
+
from albumentations.pytorch import ToTensorV2
|
| 5 |
+
import skimage
|
| 6 |
+
import scipy
|
| 7 |
+
import numpy as np
|
| 8 |
+
from pytorch_lightning import seed_everything
|
| 9 |
+
|
| 10 |
+
from period_calculation.data_reader import AdHocDataset3
|
| 11 |
+
from period_calculation.models.gauss_model import GaussPeriodModel
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
transforms = [
|
| 15 |
+
A.Normalize(**{'mean': 0.2845, 'std': 0.1447}, max_pixel_value=1.0),
|
| 16 |
+
# Applies the formula (img - mean * max_pixel_value) / (std * max_pixel_value)
|
| 17 |
+
ToTensorV2()
|
| 18 |
+
]
|
| 19 |
+
|
| 20 |
+
class PeriodMeasurer:
|
| 21 |
+
"""returns period in pixels"""
|
| 22 |
+
def __init__(
|
| 23 |
+
self, weights_file, image_height=476, image_width=476,
|
| 24 |
+
px_per_nm = 1,
|
| 25 |
+
sd_threshold_nm=np.inf,
|
| 26 |
+
period_threshold_nm_min=0, period_threshold_nm_max=np.inf):
|
| 27 |
+
|
| 28 |
+
self.model = GaussPeriodModel.load_from_checkpoint(weights_file).to("cpu") #.eval()?
|
| 29 |
+
self.px_per_nm = px_per_nm
|
| 30 |
+
self.sd_threshold_nm = sd_threshold_nm
|
| 31 |
+
self.period_threshold_nm_min = period_threshold_nm_min
|
| 32 |
+
self.period_threshold_nm_max = period_threshold_nm_max
|
| 33 |
+
|
| 34 |
+
def __call__(self, img: np.ndarray, mask: np.ndarray) -> float:
|
| 35 |
+
seed_everything(44)
|
| 36 |
+
dataset = AdHocDataset3(
|
| 37 |
+
images_and_masks = [(img, mask)],
|
| 38 |
+
transform_level=-1,
|
| 39 |
+
retain_raw_images=False,
|
| 40 |
+
transforms=transforms
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
image_data = dataset[0]
|
| 44 |
+
with torch.no_grad():
|
| 45 |
+
y_hat, sd_hat = self.model(image_data["image"].unsqueeze(0), return_raw=False)
|
| 46 |
+
|
| 47 |
+
y_hat_nm = (y_hat/image_data["scale"]).item() / self.px_per_nm
|
| 48 |
+
sd_hat_nm = (sd_hat/image_data["scale"]).item() /self.px_per_nm
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
if (sd_hat_nm>self.sd_threshold_nm) or (y_hat_nm<self.period_threshold_nm_min) or (y_hat_nm>self.period_threshold_nm_max):
|
| 52 |
+
y_hat_nm = np.nan
|
| 53 |
+
|
| 54 |
+
return y_hat_nm, sd_hat_nm
|
requirements.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pandas==2.1.1
|
| 2 |
+
torch==2.1.0
|
| 3 |
+
torchvision==0.16
|
| 4 |
+
ultralytics==8.0.216
|
| 5 |
+
scikit-image==0.22.0
|
| 6 |
+
pytorch-lightning==2.1.2
|
| 7 |
+
timm==0.9.11
|
| 8 |
+
albumentations==1.4.10
|
| 9 |
+
hydra-core==1.3.2
|
| 10 |
+
gradio==4.44.0
|
| 11 |
+
albucore==0.0.16
|
settings.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
DEMO = False
|
styles.css
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.header {
|
| 2 |
+
display: flex;
|
| 3 |
+
padding: 30px;
|
| 4 |
+
text-align: center;
|
| 5 |
+
justify-content: center;
|
| 6 |
+
}
|
| 7 |
+
|
| 8 |
+
#header-text {
|
| 9 |
+
font-size: 50px;
|
| 10 |
+
line-height: 50px;
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
#header-logo {
|
| 14 |
+
width: 50px;
|
| 15 |
+
height: 50px;
|
| 16 |
+
margin-right: 10px;
|
| 17 |
+
/*background-image: url("file=images/logo.svg");*/
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
.input-row {
|
| 21 |
+
max-width: 900px;
|
| 22 |
+
margin: 0 auto;
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
.margin-bottom {
|
| 26 |
+
margin-bottom: 48px;
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
.results-header {
|
| 30 |
+
margin-top: 48px;
|
| 31 |
+
text-align: center;
|
| 32 |
+
font-size: 45px;
|
| 33 |
+
margin-bottom: 12px;
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
.processed-info {
|
| 37 |
+
display: flex;
|
| 38 |
+
padding: 30px;
|
| 39 |
+
text-align: center;
|
| 40 |
+
justify-content: center;
|
| 41 |
+
font-size: 26px;
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
.title {
|
| 45 |
+
margin-bottom: 8px!important;
|
| 46 |
+
font-size: 22px;
|
| 47 |
+
}
|
weights/AS_square_v16.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9c1d7d4e56f0ea28b34dd2457807e9266d0d5539cf5adfb0719fed10791f79c5
|
| 3 |
+
size 44771917
|
weights/model_weights_detector.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:37fe3d98789572cf147d5f0d1ea99d50a57e5c2028454c3825295e89b6350fd5
|
| 3 |
+
size 23926765
|
weights/period_measurer_weights-1.298_real_full-fa12970.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:90133b01545ae9d30eafcbdec480410e0f8b9fae3ba1aabb280450ce9589100a
|
| 3 |
+
size 350396
|
weights/yolo/20240604_yolov8_segm_ABRCR1_all_train4_best.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:038555e9b9f3900ec29c9c79d7b3d5b50aa3ca37b3e665ee7aa2394facc7e20e
|
| 3 |
+
size 23926765
|
weights/yolo/current_yolo.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:37fe3d98789572cf147d5f0d1ea99d50a57e5c2028454c3825295e89b6350fd5
|
| 3 |
+
size 23926765
|