Spaces:
Sleeping
Sleeping
HERIUN
commited on
Commit
•
591ba45
1
Parent(s):
6a07cb2
add models
Browse files- models/DocScanner/LICENSE.md +54 -0
- models/DocScanner/OCR_eval.py +78 -0
- models/DocScanner/README.md +96 -0
- models/DocScanner/__init__.py +0 -0
- models/DocScanner/__pycache__/__init__.cpython-38.pyc +0 -0
- models/DocScanner/__pycache__/__init__.cpython-39.pyc +0 -0
- models/DocScanner/__pycache__/extractor.cpython-38.pyc +0 -0
- models/DocScanner/__pycache__/extractor.cpython-39.pyc +0 -0
- models/DocScanner/__pycache__/inference.cpython-38.pyc +0 -0
- models/DocScanner/__pycache__/inference.cpython-39.pyc +0 -0
- models/DocScanner/__pycache__/model.cpython-38.pyc +0 -0
- models/DocScanner/__pycache__/model.cpython-39.pyc +0 -0
- models/DocScanner/__pycache__/seg.cpython-38.pyc +0 -0
- models/DocScanner/__pycache__/seg.cpython-39.pyc +0 -0
- models/DocScanner/__pycache__/update.cpython-38.pyc +0 -0
- models/DocScanner/__pycache__/update.cpython-39.pyc +0 -0
- models/DocScanner/eval.m +64 -0
- models/DocScanner/evalUnwarp.m +102 -0
- models/DocScanner/extractor.py +140 -0
- models/DocScanner/inference.py +65 -0
- models/DocScanner/model.py +104 -0
- models/DocScanner/ocr_img.txt +62 -0
- models/DocScanner/requirements.txt +6 -0
- models/DocScanner/seg.py +576 -0
- models/DocScanner/update.py +119 -0
- models/DocTr-Plus/GeoTr.py +960 -0
- models/DocTr-Plus/LICENSE.md +54 -0
- models/DocTr-Plus/OCR_eval.py +121 -0
- models/DocTr-Plus/README.md +79 -0
- models/DocTr-Plus/__init__.py +0 -0
- models/DocTr-Plus/__pycache__/GeoTr.cpython-38.pyc +0 -0
- models/DocTr-Plus/__pycache__/GeoTr.cpython-39.pyc +0 -0
- models/DocTr-Plus/__pycache__/__init__.cpython-38.pyc +0 -0
- models/DocTr-Plus/__pycache__/__init__.cpython-39.pyc +0 -0
- models/DocTr-Plus/__pycache__/extractor.cpython-38.pyc +0 -0
- models/DocTr-Plus/__pycache__/extractor.cpython-39.pyc +0 -0
- models/DocTr-Plus/__pycache__/inference.cpython-38.pyc +0 -0
- models/DocTr-Plus/__pycache__/inference.cpython-39.pyc +0 -0
- models/DocTr-Plus/__pycache__/position_encoding.cpython-38.pyc +0 -0
- models/DocTr-Plus/__pycache__/position_encoding.cpython-39.pyc +0 -0
- models/DocTr-Plus/evalUnwarp.m +46 -0
- models/DocTr-Plus/extractor.py +117 -0
- models/DocTr-Plus/inference.py +51 -0
- models/DocTr-Plus/position_encoding.py +125 -0
- models/DocTr-Plus/pyimagesearch/__init__.py +0 -0
- models/DocTr-Plus/pyimagesearch/transform.py +64 -0
- models/DocTr-Plus/requirements.txt +7 -0
- models/DocTr-Plus/ssimm_ldm_eval.m +36 -0
- models/Document-Image-Unwarping-pytorch +1 -0
models/DocScanner/LICENSE.md
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# License
|
2 |
+
|
3 |
+
Copyright © Hao Feng 2024. All Rights Reserved.
|
4 |
+
|
5 |
+
## 1. Definitions
|
6 |
+
|
7 |
+
1.1 "Algorithm" refers to the deep learning algorithm contained in this repository, including all associated code, documentation, and data.
|
8 |
+
|
9 |
+
1.2 "Author" refers to Hao Feng, the creator and copyright holder of the Algorithm.
|
10 |
+
|
11 |
+
1.3 "Non-Commercial Use" means use for academic research, personal study, or non-profit projects, without any direct or indirect commercial advantage.
|
12 |
+
|
13 |
+
1.4 "Commercial Use" means any use intended for or directed toward commercial advantage or monetary compensation.
|
14 |
+
|
15 |
+
## 2. Grant of Rights
|
16 |
+
|
17 |
+
2.1 Non-Commercial Use: The Author hereby grants you a worldwide, royalty-free, non-exclusive license to use, copy, modify, and distribute the Algorithm for Non-Commercial Use, subject to the conditions in Section 3.
|
18 |
+
|
19 |
+
2.2 Commercial Use: Any Commercial Use of the Algorithm is strictly prohibited without explicit prior written permission from the Author.
|
20 |
+
|
21 |
+
## 3. Conditions
|
22 |
+
|
23 |
+
3.1 For Non-Commercial Use:
|
24 |
+
a) Attribution: You must give appropriate credit to the Author, provide a link to this license, and indicate if changes were made.
|
25 |
+
b) Share-Alike: If you modify, transform, or build upon the Algorithm, you must distribute your contributions under the same license as this one.
|
26 |
+
c) No additional restrictions: You may not apply legal terms or technological measures that legally restrict others from doing anything this license permits.
|
27 |
+
|
28 |
+
3.2 For Commercial Use:
|
29 |
+
a) Prior Contact: Before any Commercial Use, you must contact the Author at haof@mail.ustc.edu.cn and obtain explicit written permission.
|
30 |
+
b) Separate Agreement: Commercial Use terms will be stipulated in a separate commercial license agreement.
|
31 |
+
|
32 |
+
## 4. Disclaimer of Warranty
|
33 |
+
|
34 |
+
The Algorithm is provided "as is", without warranty of any kind, express or implied, including but not limited to the warranties of merchantability, fitness for a particular purpose, and non-infringement. In no event shall the Author be liable for any claim, damages, or other liability arising from, out of, or in connection with the Algorithm or the use or other dealings in the Algorithm.
|
35 |
+
|
36 |
+
## 5. Limitation of Liability
|
37 |
+
|
38 |
+
In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, shall the Author be liable to you for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this license or out of the use or inability to use the Algorithm.
|
39 |
+
|
40 |
+
## 6. Termination
|
41 |
+
|
42 |
+
6.1 This license and the rights granted hereunder will terminate automatically upon any breach by you of the terms of this license.
|
43 |
+
|
44 |
+
6.2 All sections which by their nature should survive the termination of this license shall survive such termination.
|
45 |
+
|
46 |
+
## 7. Miscellaneous
|
47 |
+
|
48 |
+
7.1 If any provision of this license is held to be unenforceable, such provision shall be reformed only to the extent necessary to make it enforceable.
|
49 |
+
|
50 |
+
7.2 This license represents the complete agreement concerning the subject matter hereof.
|
51 |
+
|
52 |
+
By using the Algorithm, you acknowledge that you have read this license, understand it, and agree to be bound by its terms and conditions. If you do not agree to the terms and conditions of this license, do not use, modify, or distribute the Algorithm.
|
53 |
+
|
54 |
+
For permissions beyond the scope of this license, please contact the Author at haof@mail.ustc.edu.cn.
|
models/DocScanner/OCR_eval.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def Levenshtein_Distance(str1, str2):
|
2 |
+
matrix = [[i + j for j in range(len(str2) + 1)] for i in range(len(str1) + 1)]
|
3 |
+
for i in range(1, len(str1) + 1):
|
4 |
+
for j in range(1, len(str2) + 1):
|
5 |
+
if str1[i - 1] == str2[j - 1]:
|
6 |
+
d = 0
|
7 |
+
else:
|
8 |
+
d = 1
|
9 |
+
matrix[i][j] = min(
|
10 |
+
matrix[i - 1][j] + 1, matrix[i][j - 1] + 1, matrix[i - 1][j - 1] + d
|
11 |
+
)
|
12 |
+
|
13 |
+
return matrix[len(str1)][len(str2)]
|
14 |
+
|
15 |
+
|
16 |
+
def cal_cer_ed(path_ours, tail="_rec"):
|
17 |
+
path_gt = "./GT/"
|
18 |
+
N = 66
|
19 |
+
cer1 = []
|
20 |
+
cer2 = []
|
21 |
+
ed1 = []
|
22 |
+
ed2 = []
|
23 |
+
check = [0 for _ in range(N + 1)]
|
24 |
+
lis = [
|
25 |
+
1,
|
26 |
+
2,
|
27 |
+
3,
|
28 |
+
4,
|
29 |
+
5,
|
30 |
+
6,
|
31 |
+
7,
|
32 |
+
9,
|
33 |
+
10,
|
34 |
+
21,
|
35 |
+
22,
|
36 |
+
23,
|
37 |
+
24,
|
38 |
+
27,
|
39 |
+
30,
|
40 |
+
31,
|
41 |
+
32,
|
42 |
+
36,
|
43 |
+
38,
|
44 |
+
40,
|
45 |
+
41,
|
46 |
+
44,
|
47 |
+
45,
|
48 |
+
46,
|
49 |
+
47,
|
50 |
+
48,
|
51 |
+
50,
|
52 |
+
51,
|
53 |
+
52,
|
54 |
+
53,
|
55 |
+
] # DocTr (Setting 1)
|
56 |
+
# lis=[1,9,10,12,19,20,21,22,23,24,30,31,32,34,35,36,37,38,39,40,44,45,46,47,49] # DewarpNet (Setting 2)
|
57 |
+
for i in range(1, N):
|
58 |
+
if i not in lis:
|
59 |
+
continue
|
60 |
+
gt = Image.open(path_gt + str(i) + ".png")
|
61 |
+
img1 = Image.open(path_ours + str(i) + "_1" + tail)
|
62 |
+
img2 = Image.open(path_ours + str(i) + "_2" + tail)
|
63 |
+
content_gt = pytesseract.image_to_string(gt)
|
64 |
+
content1 = pytesseract.image_to_string(img1)
|
65 |
+
content2 = pytesseract.image_to_string(img2)
|
66 |
+
l1 = Levenshtein_Distance(content_gt, content1)
|
67 |
+
l2 = Levenshtein_Distance(content_gt, content2)
|
68 |
+
ed1.append(l1)
|
69 |
+
ed2.append(l2)
|
70 |
+
cer1.append(l1 / len(content_gt))
|
71 |
+
cer2.append(l2 / len(content_gt))
|
72 |
+
check[i] = cer1[-1]
|
73 |
+
print("CER: ", (np.mean(cer1) + np.mean(cer2)) / 2.0)
|
74 |
+
print("ED: ", (np.mean(ed1) + np.mean(ed2)) / 2.0)
|
75 |
+
|
76 |
+
|
77 |
+
def evalu(path_ours, tail):
|
78 |
+
cal_cer_ed(path_ours, tail)
|
models/DocScanner/README.md
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
🔥 ***2024.4.28:*** **Good news! The code and pre-trained model of DocScanner are now released!**
|
2 |
+
|
3 |
+
🚀 **Good news! The [online demo](https://docai.doctrp.top:20443/) for DocScanner is now live, allowing for easy image upload and correction.**
|
4 |
+
|
5 |
+
🔥 **Good news! Our new work [DocTr++: Deep Unrestricted Document Image Rectification](https://github.com/fh2019ustc/DocTr-Plus) comes out, capable of rectifying various distorted document images in the wild.**
|
6 |
+
|
7 |
+
🔥 **Good news! A comprehensive list of [Awesome Document Image Rectification](https://github.com/fh2019ustc/Awesome-Document-Image-Rectification) methods is available.**
|
8 |
+
|
9 |
+
# DocScanner
|
10 |
+
|
11 |
+
<p>
|
12 |
+
<a href='https://drive.google.com/file/d/1mmCUj90rHyuO1SmpLt361youh-07Y0sD/view?usp=share_link' target="_blank"><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a>
|
13 |
+
<a href='https://docai.doctrp.top:20443/' target="_blank"><img src='https://img.shields.io/badge/Online-Demo-green'></a>
|
14 |
+
</p>
|
15 |
+
|
16 |
+
|
17 |
+
This is a PyTorch/GPU re-implementation of the paper [DocScanner: Robust Document Image Rectification with Progressive Learning](https://drive.google.com/file/d/1mmCUj90rHyuO1SmpLt361youh-07Y0sD/view?usp=share_link).
|
18 |
+
|
19 |
+
![image](https://user-images.githubusercontent.com/50725551/209266364-aee68a88-090d-4f21-919a-092f19570d86.png)
|
20 |
+
|
21 |
+
|
22 |
+
## 🚀 Demo [(Link)](https://docai.doctrp.top:20443/)
|
23 |
+
***Note***:The model version used in the demo corresponds to ***"DocScanner-L"*** as described in the paper.
|
24 |
+
1. Upload the distorted document image to be rectified in the left box.
|
25 |
+
2. Click the "Submit" button.
|
26 |
+
3. The rectified image will be displayed in the right box.
|
27 |
+
|
28 |
+
<img width="1534" alt="image" src="https://github.com/fh2019ustc/DocScanner/assets/50725551/9eca3f7d-1570-4246-a3db-0a1cf1eece2d">
|
29 |
+
|
30 |
+
### Examples
|
31 |
+
![image](https://user-images.githubusercontent.com/50725551/223947040-eac8389c-bed8-433d-b23b-679c926fba8f.png)
|
32 |
+
![image](https://user-images.githubusercontent.com/50725551/223946953-3a46d6a3-4361-41ef-bb5c-f235392e1f88.png)
|
33 |
+
|
34 |
+
|
35 |
+
## Training
|
36 |
+
- We train the **Document Localization Module** using the [Doc3D](https://github.com/fh2019ustc/doc3D-dataset) dataset. Besides, [DTD](https://www.robots.ox.ac.uk/~vgg/data/dtd/) dataset is exploited for background data enhancement.
|
37 |
+
- We train the **Progressive Rectification Module** using the [Doc3D](https://github.com/fh2019ustc/doc3D-dataset) dataset. Here we use the background-excluded document images for training.
|
38 |
+
|
39 |
+
## Inference
|
40 |
+
1. Put the [pre-trained DocScanner-L](https://drive.google.com/drive/folders/1W1_DJU8dfEh6FqDYqFQ7ypR38Z8c5r4D?usp=sharing) to `$ROOT/model_pretrained/`.
|
41 |
+
2. Put the distorted images in `$ROOT/distorted/`.
|
42 |
+
3. Run the script and the rectified images are saved in `$ROOT/rectified/` by default.
|
43 |
+
```
|
44 |
+
python inference.py
|
45 |
+
```
|
46 |
+
|
47 |
+
## Evaluation
|
48 |
+
- ***Important.*** In the [DocUNet Benchmark](https://www3.cs.stonybrook.edu/~cvl/docunet.html), the '64_1.png' and '64_2.png' distorted images are rotated by 180 degrees, which do not match the GT documents. It is ignored by most of the existing works. Before the evaluation, please make a check. Note that the performances in most of the existing work are computed with these two ***mistaken*** samples.
|
49 |
+
- For reproducing the following quantitative performance on the ***corrected*** [DocUNet Benchmark](https://www3.cs.stonybrook.edu/~cvl/docunet.html), please use the geometric rectified images available from [Google Drive](https://drive.google.com/drive/folders/1QBe26xJwIl38sWqK2ZE9ke5nu0Mpr4dW?usp=sharing). For the ***corrected*** performance of [other methods](https://github.com/fh2019ustc/Awesome-Document-Image-Rectification), please refer to the paper [DocScanner](https://arxiv.org/pdf/2110.14968v2.pdf).
|
50 |
+
- ***Image Metrics:*** We use the same evaluation code for MS-SSIM and LD as [DocUNet Benchmark](https://www3.cs.stonybrook.edu/~cvl/docunet.html) dataset based on Matlab 2019a. Please compare the scores according to your Matlab version. We provide our Matlab interface file at ```$ROOT/ssim_ld_eval.m```.
|
51 |
+
- ***OCR Metrics:*** The index of 30 documents (60 images) of [DocUNet Benchmark](https://www3.cs.stonybrook.edu/~cvl/docunet.html) used for our OCR evaluation is ```$ROOT/ocr_img.txt``` (*Setting 1*). Please refer to [DewarpNet](https://github.com/cvlab-stonybrook/DewarpNet) for the index of 25 documents (50 images) of [DocUNet Benchmark](https://www3.cs.stonybrook.edu/~cvl/docunet.html) used for their OCR evaluation (*Setting 2*). We provide the OCR evaluation code at ```$ROOT/OCR_eval.py```. The version of pytesseract is 0.3.8, and the version of [Tesseract](https://digi.bib.uni-mannheim.de/tesseract/) in Windows is recent 5.0.1.20220118. Note that in different operating systems, the calculated performance has slight differences.
|
52 |
+
- ***W_v and W_h Index:*** The layout results of [DocUNet Benchmark](https://www3.cs.stonybrook.edu/~cvl/docunet.html) is available at [Google Drive](https://drive.google.com/drive/folders/1PcfWIowjM0AVKhZrRwGChM-2VAcUwWrF?usp=sharing).
|
53 |
+
|
54 |
+
| Method | MS-SSIM | LD | Li-D | ED (*Setting 1*) | CER | ED (*Setting 2*) | CER | Para. (M) |
|
55 |
+
|:-----------------------:|:------------:|:-----------:| :-------:|:----------------:|:--------------:|:---------------------:|:--------------:|:--------------:|
|
56 |
+
| *DocScanner-T* | 0.5123 | 7.92 | 2.04 | 501.82 | 0.1823 | 809.46 | 0.2068 | 2.6 |
|
57 |
+
| *DocScanner-B* | 0.5134 | 7.62 | 1.88 | 434.11 | 0.1652 | 671.48 | 0.1789 | 5.2 |
|
58 |
+
| *DocScanner-L* | 0.5178 | 7.45 | 1.86 | 390.43 | 0.1486 | 632.34 | 0.1648 | 8.5 |
|
59 |
+
|
60 |
+
## Citation
|
61 |
+
Please cite the related works in your publications if it helps your research:
|
62 |
+
|
63 |
+
```
|
64 |
+
@inproceedings{feng2021doctr,
|
65 |
+
title={DocTr: Document Image Transformer for Geometric Unwarping and Illumination Correction},
|
66 |
+
author={Feng, Hao and Wang, Yuechen and Zhou, Wengang and Deng, Jiajun and Li, Houqiang},
|
67 |
+
booktitle={Proceedings of the 29th ACM International Conference on Multimedia},
|
68 |
+
pages={273--281},
|
69 |
+
year={2021}
|
70 |
+
}
|
71 |
+
```
|
72 |
+
|
73 |
+
```
|
74 |
+
@inproceedings{feng2022docgeonet,
|
75 |
+
title={Geometric Representation Learning for Document Image Rectification},
|
76 |
+
author={Feng, Hao and Zhou, Wengang and Deng, Jiajun and Wang, Yuechen and Li, Houqiang},
|
77 |
+
booktitle={Proceedings of the European Conference on Computer Vision},
|
78 |
+
year={2022}
|
79 |
+
}
|
80 |
+
```
|
81 |
+
|
82 |
+
```
|
83 |
+
@article{feng2021docscanner,
|
84 |
+
title={DocScanner: robust document image rectification with progressive learning},
|
85 |
+
author={Feng, Hao and Zhou, Wengang and Deng, Jiajun and Tian, Qi and Li, Houqiang},
|
86 |
+
journal={arXiv preprint arXiv:2110.14968},
|
87 |
+
year={2021}
|
88 |
+
}
|
89 |
+
```
|
90 |
+
|
91 |
+
## Acknowledgement
|
92 |
+
The codes are largely based on [DocUNet](https://www3.cs.stonybrook.edu/~cvl/docunet.html) and [DewarpNet](https://github.com/cvlab-stonybrook/DewarpNet). Thanks for their wonderful works.
|
93 |
+
|
94 |
+
## Contact
|
95 |
+
For commercial usage, please contact Professor Wengang Zhou ([zhwg@ustc.edu.cn](zhwg@ustc.edu.cn)) and Hao Feng ([haof@mail.ustc.edu.cn](haof@mail.ustc.edu.cn)).
|
96 |
+
|
models/DocScanner/__init__.py
ADDED
File without changes
|
models/DocScanner/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (154 Bytes). View file
|
|
models/DocScanner/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (154 Bytes). View file
|
|
models/DocScanner/__pycache__/extractor.cpython-38.pyc
ADDED
Binary file (3.88 kB). View file
|
|
models/DocScanner/__pycache__/extractor.cpython-39.pyc
ADDED
Binary file (3.86 kB). View file
|
|
models/DocScanner/__pycache__/inference.cpython-38.pyc
ADDED
Binary file (2.22 kB). View file
|
|
models/DocScanner/__pycache__/inference.cpython-39.pyc
ADDED
Binary file (2.2 kB). View file
|
|
models/DocScanner/__pycache__/model.cpython-38.pyc
ADDED
Binary file (3.37 kB). View file
|
|
models/DocScanner/__pycache__/model.cpython-39.pyc
ADDED
Binary file (3.37 kB). View file
|
|
models/DocScanner/__pycache__/seg.cpython-38.pyc
ADDED
Binary file (12 kB). View file
|
|
models/DocScanner/__pycache__/seg.cpython-39.pyc
ADDED
Binary file (12.1 kB). View file
|
|
models/DocScanner/__pycache__/update.cpython-38.pyc
ADDED
Binary file (4.3 kB). View file
|
|
models/DocScanner/__pycache__/update.cpython-39.pyc
ADDED
Binary file (4.27 kB). View file
|
|
models/DocScanner/eval.m
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
path_rec = "xxx"; % rectified image path
|
2 |
+
path_scan = './scan/'; % scan image path
|
3 |
+
label_path = './layout/'; % layout result path
|
4 |
+
|
5 |
+
tarea = 598400;
|
6 |
+
ms1 = 0;
|
7 |
+
ld1 = 0;
|
8 |
+
lid1 = 0;
|
9 |
+
ms2 = 0;
|
10 |
+
ld2 = 0;
|
11 |
+
lid2 = 0;
|
12 |
+
wv = 0;
|
13 |
+
wh = 0;
|
14 |
+
|
15 |
+
sprintf(path_rec)
|
16 |
+
for i=1:65
|
17 |
+
path_rec_1 = sprintf("%s%d%s", path_rec, i, '_1 copy_rec.png'); % rectified image path
|
18 |
+
path_rec_2 = sprintf("%s%d%s", path_rec, i, '_2 copy_rec.png'); % rectified image path
|
19 |
+
path_scan_new = sprintf("%s%d%s", path_scan, i, '.png'); % corresponding scan image path
|
20 |
+
bbox_i_path = sprintf("%s%d%s", label_path, i, '.txt'); % corresponding layout txt path
|
21 |
+
|
22 |
+
% imread and rgb2gray
|
23 |
+
A1 = imread(path_rec_1);
|
24 |
+
A2 = imread(path_rec_2);
|
25 |
+
|
26 |
+
% if i == 64
|
27 |
+
% A1 = rot90(A1,-2);
|
28 |
+
% A2 = rot90(A2,-2);
|
29 |
+
% end
|
30 |
+
|
31 |
+
ref = imread(path_scan_new);
|
32 |
+
A1 = rgb2gray(A1);
|
33 |
+
A2 = rgb2gray(A2);
|
34 |
+
ref = rgb2gray(ref);
|
35 |
+
bbox_i = read_txt(bbox_i_path);
|
36 |
+
bbox_i = bbox_i + 1; % python index starts from 0
|
37 |
+
|
38 |
+
% resize
|
39 |
+
b = sqrt(tarea/size(ref,1)/size(ref,2));
|
40 |
+
ref = imresize(ref,b);
|
41 |
+
A1 = imresize(A1,[size(ref,1),size(ref,2)]);
|
42 |
+
A2 = imresize(A2,[size(ref,1),size(ref,2)]);
|
43 |
+
scaled_bbox_i = bbox_i * b * 0.5;
|
44 |
+
scaled_bbox_i = round(scaled_bbox_i);
|
45 |
+
scaled_bbox_i = max(scaled_bbox_i, 1);
|
46 |
+
|
47 |
+
% calculate
|
48 |
+
[ms_1, ld_1, lid_1, W_v_1, W_h_1] = evalUnwarp(A1, ref, scaled_bbox_i);
|
49 |
+
[ms_2, ld_2, lid_2, W_v_2, W_h_2] = evalUnwarp(A2, ref, scaled_bbox_i);
|
50 |
+
ms1 = ms1 + ms_1;
|
51 |
+
ms2 = ms2 + ms_2;
|
52 |
+
ld1 = ld1 + ld_1;
|
53 |
+
ld2 = ld2 + ld_2;
|
54 |
+
lid1 = lid1 + lid_1;
|
55 |
+
lid2 = lid2 + lid_2;
|
56 |
+
wv = wv + W_v_1 + W_v_2;
|
57 |
+
wh = wh + W_h_1 + W_h_2;
|
58 |
+
end
|
59 |
+
|
60 |
+
ms = (ms1 + ms2) / 130 % MS-SSIM
|
61 |
+
ld = (ld1 + ld2) / 130 % local distortion
|
62 |
+
li_d = (lid1 + lid2) / 130 % line distortion
|
63 |
+
wv = wv / 130 % wv index
|
64 |
+
wh = wh / 130 % wh index
|
models/DocScanner/evalUnwarp.m
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
function [ms, ld, li_d, wv, wh] = evalUnwarp(A, ref, data)
|
2 |
+
%EVALUNWARP compute MSSSIM and LD between the unwarped image and the scan
|
3 |
+
% A: unwarped image
|
4 |
+
% ref: reference image, the scan image
|
5 |
+
% ms: returned MS-SSIM value
|
6 |
+
% ld: returned local distortion value
|
7 |
+
% Matlab image processing toolbox is necessary to compute ssim. The weights
|
8 |
+
% for multi-scale ssim is directly adopted from:
|
9 |
+
%
|
10 |
+
% Wang, Zhou, Eero P. Simoncelli, and Alan C. Bovik. "Multiscale structural
|
11 |
+
% similarity for image quality assessment." In Signals, Systems and Computers,
|
12 |
+
% 2004. Conference Record of the Thirty-Seventh Asilomar Conference on, 2003.
|
13 |
+
%
|
14 |
+
% Local distortion relies on the paper:
|
15 |
+
% Liu, Ce, Jenny Yuen, and Antonio Torralba. "Sift flow: Dense correspondence
|
16 |
+
% across scenes and its applications." In PAMI, 2010.
|
17 |
+
%
|
18 |
+
% and its implementation:
|
19 |
+
% https://people.csail.mit.edu/celiu/SIFTflow/
|
20 |
+
|
21 |
+
x = A;
|
22 |
+
y = ref;
|
23 |
+
|
24 |
+
im1=imresize(imfilter(y,fspecial('gaussian',7,1.),'same','replicate'),0.5,'bicubic');
|
25 |
+
im2=imresize(imfilter(x,fspecial('gaussian',7,1.),'same','replicate'),0.5,'bicubic');
|
26 |
+
|
27 |
+
im1=im2double(im1);
|
28 |
+
im2=im2double(im2);
|
29 |
+
|
30 |
+
cellsize=3;
|
31 |
+
gridspacing=1;
|
32 |
+
|
33 |
+
sift1 = mexDenseSIFT(im1,cellsize,gridspacing);
|
34 |
+
sift2 = mexDenseSIFT(im2,cellsize,gridspacing);
|
35 |
+
|
36 |
+
SIFTflowpara.alpha=2*255;
|
37 |
+
SIFTflowpara.d=40*255;
|
38 |
+
SIFTflowpara.gamma=0.005*255;
|
39 |
+
SIFTflowpara.nlevels=4;
|
40 |
+
SIFTflowpara.wsize=2;
|
41 |
+
SIFTflowpara.topwsize=10;
|
42 |
+
SIFTflowpara.nTopIterations = 60;
|
43 |
+
SIFTflowpara.nIterations= 30;
|
44 |
+
|
45 |
+
|
46 |
+
[vx,vy,~]=SIFTflowc2f(sift1,sift2,SIFTflowpara);
|
47 |
+
|
48 |
+
rows1p = size(im1,1);
|
49 |
+
cols1p = size(im1,2);
|
50 |
+
|
51 |
+
% Li-D
|
52 |
+
rowstd_sum = 0;
|
53 |
+
for i = 1:rows1p
|
54 |
+
rowstd = std(vy(i, :),1);
|
55 |
+
rowstd_sum = rowstd_sum + rowstd;
|
56 |
+
end
|
57 |
+
rowstd_mean = rowstd_sum / rows1p;
|
58 |
+
|
59 |
+
colstd_sum = 0;
|
60 |
+
for i = 1:cols1p
|
61 |
+
colstd = std(vx(:, i),1);
|
62 |
+
colstd_sum = colstd_sum + colstd;
|
63 |
+
end
|
64 |
+
colstd_mean = colstd_sum / cols1p;
|
65 |
+
|
66 |
+
li_d = (rowstd_mean + colstd_mean) / 2;
|
67 |
+
|
68 |
+
|
69 |
+
% LD
|
70 |
+
d = sqrt(vx.^2 + vy.^2);
|
71 |
+
ld = mean(d(:));
|
72 |
+
|
73 |
+
|
74 |
+
% MS-SSIM
|
75 |
+
wt = [0.0448 0.2856 0.3001 0.2363 0.1333];
|
76 |
+
ss = zeros(5, 1);
|
77 |
+
for s = 1 : 5
|
78 |
+
ss(s) = ssim(x, y);
|
79 |
+
x = impyramid(x, 'reduce');
|
80 |
+
y = impyramid(y, 'reduce');
|
81 |
+
end
|
82 |
+
ms = wt * ss;
|
83 |
+
|
84 |
+
|
85 |
+
% wv and wh
|
86 |
+
rowstd_sum = 0;
|
87 |
+
for i = 1:size(data, 1)
|
88 |
+
rowstd_top = std(vy(data(i,2), data(i,1):data(i,3)),1) / (data(i,3)-data(i,1));
|
89 |
+
rowstd_bot = std(vy(data(i,4), data(i,1):data(i,3)),1) / (data(i,3)-data(i,1));
|
90 |
+
rowstd_sum = rowstd_sum + rowstd_top + rowstd_bot;
|
91 |
+
end
|
92 |
+
wv = rowstd_sum / (2 * size(data, 1));
|
93 |
+
|
94 |
+
colstd_sum = 0;
|
95 |
+
for i = 1:size(data, 1)
|
96 |
+
colstd_left = std(vx(data(i,2):data(i,4), data(i,1)),1) / (data(i,4)- data(i,2));
|
97 |
+
colstd_right = std(vx(data(i,2):data(i,4), data(i,3)),1) / (data(i,4)- data(i,2));
|
98 |
+
colstd_sum = colstd_sum + colstd_left + colstd_right;
|
99 |
+
end
|
100 |
+
wh = colstd_sum / (2 * size(data, 1));
|
101 |
+
|
102 |
+
end
|
models/DocScanner/extractor.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
|
4 |
+
class ResidualBlock(nn.Module):
|
5 |
+
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
|
6 |
+
super(ResidualBlock, self).__init__()
|
7 |
+
|
8 |
+
self.conv1 = nn.Conv2d(
|
9 |
+
in_planes, planes, kernel_size=3, padding=1, stride=stride
|
10 |
+
)
|
11 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
|
12 |
+
self.relu = nn.ReLU(inplace=True)
|
13 |
+
|
14 |
+
num_groups = planes // 8
|
15 |
+
|
16 |
+
if norm_fn == "batch":
|
17 |
+
self.norm1 = nn.BatchNorm2d(planes)
|
18 |
+
self.norm2 = nn.BatchNorm2d(planes)
|
19 |
+
if not stride == 1:
|
20 |
+
self.norm3 = nn.BatchNorm2d(planes)
|
21 |
+
|
22 |
+
elif norm_fn == "instance":
|
23 |
+
self.norm1 = nn.InstanceNorm2d(planes)
|
24 |
+
self.norm2 = nn.InstanceNorm2d(planes)
|
25 |
+
if not stride == 1:
|
26 |
+
self.norm3 = nn.InstanceNorm2d(planes)
|
27 |
+
|
28 |
+
if stride == 1:
|
29 |
+
self.downsample = None
|
30 |
+
else:
|
31 |
+
self.downsample = nn.Sequential(
|
32 |
+
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
|
33 |
+
)
|
34 |
+
|
35 |
+
def forward(self, x):
|
36 |
+
y = x
|
37 |
+
y = self.relu(self.norm1(self.conv1(y)))
|
38 |
+
y = self.relu(self.norm2(self.conv2(y)))
|
39 |
+
|
40 |
+
if self.downsample is not None:
|
41 |
+
x = self.downsample(x)
|
42 |
+
|
43 |
+
return self.relu(x + y)
|
44 |
+
|
45 |
+
|
46 |
+
class BottleneckBlock(nn.Module):
|
47 |
+
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
|
48 |
+
super(BottleneckBlock, self).__init__()
|
49 |
+
|
50 |
+
self.conv1 = nn.Conv2d(in_planes, planes // 4, kernel_size=1, padding=0)
|
51 |
+
self.conv2 = nn.Conv2d(
|
52 |
+
planes // 4, planes // 4, kernel_size=3, padding=1, stride=stride
|
53 |
+
)
|
54 |
+
self.conv3 = nn.Conv2d(planes // 4, planes, kernel_size=1, padding=0)
|
55 |
+
self.relu = nn.ReLU(inplace=True)
|
56 |
+
|
57 |
+
if norm_fn == "batch":
|
58 |
+
self.norm1 = nn.BatchNorm2d(planes // 4)
|
59 |
+
self.norm2 = nn.BatchNorm2d(planes // 4)
|
60 |
+
self.norm3 = nn.BatchNorm2d(planes)
|
61 |
+
if not stride == 1:
|
62 |
+
self.norm4 = nn.BatchNorm2d(planes)
|
63 |
+
|
64 |
+
elif norm_fn == "instance":
|
65 |
+
self.norm1 = nn.InstanceNorm2d(planes // 4)
|
66 |
+
self.norm2 = nn.InstanceNorm2d(planes // 4)
|
67 |
+
self.norm3 = nn.InstanceNorm2d(planes)
|
68 |
+
if not stride == 1:
|
69 |
+
self.norm4 = nn.InstanceNorm2d(planes)
|
70 |
+
|
71 |
+
if stride == 1:
|
72 |
+
self.downsample = None
|
73 |
+
else:
|
74 |
+
self.downsample = nn.Sequential(
|
75 |
+
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4
|
76 |
+
)
|
77 |
+
|
78 |
+
def forward(self, x):
|
79 |
+
y = x
|
80 |
+
y = self.relu(self.norm1(self.conv1(y)))
|
81 |
+
y = self.relu(self.norm2(self.conv2(y)))
|
82 |
+
y = self.relu(self.norm3(self.conv3(y)))
|
83 |
+
|
84 |
+
if self.downsample is not None:
|
85 |
+
x = self.downsample(x)
|
86 |
+
|
87 |
+
return self.relu(x + y)
|
88 |
+
|
89 |
+
|
90 |
+
class BasicEncoder(nn.Module):
|
91 |
+
def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0):
|
92 |
+
super(BasicEncoder, self).__init__()
|
93 |
+
self.norm_fn = norm_fn
|
94 |
+
|
95 |
+
if self.norm_fn == "batch":
|
96 |
+
self.norm1 = nn.BatchNorm2d(64)
|
97 |
+
|
98 |
+
elif self.norm_fn == "instance":
|
99 |
+
self.norm1 = nn.InstanceNorm2d(64)
|
100 |
+
|
101 |
+
self.conv1 = nn.Conv2d(3, 80, kernel_size=7, stride=2, padding=3)
|
102 |
+
self.relu1 = nn.ReLU(inplace=True)
|
103 |
+
|
104 |
+
self.in_planes = 80
|
105 |
+
self.layer1 = self._make_layer(80, stride=1)
|
106 |
+
self.layer2 = self._make_layer(160, stride=2)
|
107 |
+
self.layer3 = self._make_layer(240, stride=2)
|
108 |
+
|
109 |
+
# output convolution
|
110 |
+
self.conv2 = nn.Conv2d(240, output_dim, kernel_size=1)
|
111 |
+
|
112 |
+
for m in self.modules():
|
113 |
+
if isinstance(m, nn.Conv2d):
|
114 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
115 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
116 |
+
if m.weight is not None:
|
117 |
+
nn.init.constant_(m.weight, 1)
|
118 |
+
if m.bias is not None:
|
119 |
+
nn.init.constant_(m.bias, 0)
|
120 |
+
|
121 |
+
def _make_layer(self, dim, stride=1):
|
122 |
+
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
123 |
+
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
124 |
+
layers = (layer1, layer2)
|
125 |
+
|
126 |
+
self.in_planes = dim
|
127 |
+
return nn.Sequential(*layers)
|
128 |
+
|
129 |
+
def forward(self, x):
|
130 |
+
x = self.conv1(x)
|
131 |
+
x = self.norm1(x)
|
132 |
+
x = self.relu1(x)
|
133 |
+
|
134 |
+
x = self.layer1(x)
|
135 |
+
x = self.layer2(x)
|
136 |
+
x = self.layer3(x)
|
137 |
+
|
138 |
+
x = self.conv2(x)
|
139 |
+
|
140 |
+
return x
|
models/DocScanner/inference.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
5 |
+
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
import os
|
9 |
+
import warnings
|
10 |
+
|
11 |
+
import cv2
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
import torch.nn.functional as F
|
16 |
+
from DocScanner.model import DocScanner
|
17 |
+
from DocScanner.seg import U2NETP
|
18 |
+
from PIL import Image
|
19 |
+
|
20 |
+
warnings.filterwarnings("ignore")
|
21 |
+
|
22 |
+
|
23 |
+
class Net(nn.Module):
|
24 |
+
def __init__(self):
|
25 |
+
super(Net, self).__init__()
|
26 |
+
self.msk = U2NETP(3, 1)
|
27 |
+
self.bm = DocScanner() # 矫正
|
28 |
+
|
29 |
+
def forward(self, x):
|
30 |
+
msk, _1, _2, _3, _4, _5, _6 = self.msk(x)
|
31 |
+
msk = (msk > 0.5).float()
|
32 |
+
x = msk * x
|
33 |
+
|
34 |
+
bm = self.bm(x, iters=12, test_mode=True)
|
35 |
+
bm = (2 * (bm / 286.8) - 1) * 0.99
|
36 |
+
|
37 |
+
return bm, msk
|
38 |
+
|
39 |
+
|
40 |
+
def reload_seg_model(model, path=""):
|
41 |
+
if not bool(path):
|
42 |
+
return model
|
43 |
+
else:
|
44 |
+
model_dict = model.state_dict()
|
45 |
+
pretrained_dict = torch.load(path, map_location="cuda:0")
|
46 |
+
pretrained_dict = {
|
47 |
+
k[6:]: v for k, v in pretrained_dict.items() if k[6:] in model_dict
|
48 |
+
}
|
49 |
+
model_dict.update(pretrained_dict)
|
50 |
+
model.load_state_dict(model_dict)
|
51 |
+
|
52 |
+
return model
|
53 |
+
|
54 |
+
|
55 |
+
def reload_rec_model(model, path=""):
|
56 |
+
if not bool(path):
|
57 |
+
return model
|
58 |
+
else:
|
59 |
+
model_dict = model.state_dict()
|
60 |
+
pretrained_dict = torch.load(path, map_location="cuda:0")
|
61 |
+
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
|
62 |
+
model_dict.update(pretrained_dict)
|
63 |
+
model.load_state_dict(model_dict)
|
64 |
+
|
65 |
+
return model
|
models/DocScanner/model.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from DocScanner.extractor import BasicEncoder
|
10 |
+
from DocScanner.update import BasicUpdateBlock
|
11 |
+
|
12 |
+
|
13 |
+
def bilinear_sampler(img, coords, mode="bilinear", mask=False):
|
14 |
+
"""Wrapper for grid_sample, uses pixel coordinates"""
|
15 |
+
H, W = img.shape[-2:]
|
16 |
+
xgrid, ygrid = coords.split([1, 1], dim=-1)
|
17 |
+
xgrid = 2 * xgrid / (W - 1) - 1
|
18 |
+
ygrid = 2 * ygrid / (H - 1) - 1
|
19 |
+
|
20 |
+
grid = torch.cat([xgrid, ygrid], dim=-1)
|
21 |
+
img = F.grid_sample(img, grid, align_corners=True)
|
22 |
+
if mask:
|
23 |
+
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
|
24 |
+
return img, mask.float()
|
25 |
+
|
26 |
+
return img
|
27 |
+
|
28 |
+
|
29 |
+
def coords_grid(batch, ht, wd):
|
30 |
+
coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
|
31 |
+
coords = torch.stack(coords[::-1], dim=0).float()
|
32 |
+
return coords[None].repeat(batch, 1, 1, 1)
|
33 |
+
|
34 |
+
|
35 |
+
class DocScanner(nn.Module):
|
36 |
+
def __init__(self):
|
37 |
+
super(DocScanner, self).__init__()
|
38 |
+
|
39 |
+
self.hidden_dim = hdim = 160
|
40 |
+
self.context_dim = 160
|
41 |
+
|
42 |
+
self.fnet = BasicEncoder(output_dim=320, norm_fn="instance")
|
43 |
+
self.update_block = BasicUpdateBlock(hidden_dim=hdim)
|
44 |
+
|
45 |
+
def freeze_bn(self):
|
46 |
+
for m in self.modules():
|
47 |
+
if isinstance(m, nn.BatchNorm2d):
|
48 |
+
m.eval()
|
49 |
+
|
50 |
+
def initialize_flow(self, img):
|
51 |
+
N, C, H, W = img.shape
|
52 |
+
coodslar = coords_grid(N, H, W).to(img.device)
|
53 |
+
coords0 = coords_grid(N, H // 8, W // 8).to(img.device)
|
54 |
+
coords1 = coords_grid(N, H // 8, W // 8).to(img.device)
|
55 |
+
|
56 |
+
return coodslar, coords0, coords1
|
57 |
+
|
58 |
+
def upsample_flow(self, flow, mask):
|
59 |
+
N, _, H, W = flow.shape
|
60 |
+
mask = mask.view(N, 1, 9, 8, 8, H, W)
|
61 |
+
mask = torch.softmax(mask, dim=2)
|
62 |
+
|
63 |
+
up_flow = F.unfold(8 * flow, [3, 3], padding=1)
|
64 |
+
up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
|
65 |
+
|
66 |
+
up_flow = torch.sum(mask * up_flow, dim=2)
|
67 |
+
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
|
68 |
+
|
69 |
+
return up_flow.reshape(N, 2, 8 * H, 8 * W)
|
70 |
+
|
71 |
+
def forward(self, image1, iters=12, flow_init=None, test_mode=False):
|
72 |
+
image1 = image1.contiguous()
|
73 |
+
|
74 |
+
fmap1 = self.fnet(image1)
|
75 |
+
|
76 |
+
warpfea = fmap1
|
77 |
+
|
78 |
+
net, inp = torch.split(fmap1, [160, 160], dim=1)
|
79 |
+
net = torch.tanh(net)
|
80 |
+
inp = torch.relu(inp)
|
81 |
+
|
82 |
+
coodslar, coords0, coords1 = self.initialize_flow(image1)
|
83 |
+
|
84 |
+
if flow_init is not None:
|
85 |
+
coords1 = coords1 + flow_init
|
86 |
+
|
87 |
+
flow_predictions = []
|
88 |
+
for itr in range(iters):
|
89 |
+
coords1 = coords1.detach()
|
90 |
+
flow = coords1 - coords0
|
91 |
+
|
92 |
+
net, up_mask, delta_flow = self.update_block(net, inp, warpfea, flow)
|
93 |
+
|
94 |
+
coords1 = coords1 + delta_flow
|
95 |
+
flow_up = self.upsample_flow(coords1 - coords0, up_mask)
|
96 |
+
bm_up = coodslar + flow_up
|
97 |
+
|
98 |
+
warpfea = bilinear_sampler(fmap1, coords1.permute(0, 2, 3, 1))
|
99 |
+
flow_predictions.append(bm_up)
|
100 |
+
|
101 |
+
if test_mode:
|
102 |
+
return bm_up
|
103 |
+
|
104 |
+
return flow_predictions
|
models/DocScanner/ocr_img.txt
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
The images for OCR evaluation of DocUNet Benchmark.
|
2 |
+
# Setting 1 (Setting from DocTr)
|
3 |
+
# Total 30 * 2 = 60 images.
|
4 |
+
./scan/1.png
|
5 |
+
./scan/2.png
|
6 |
+
./scan/3.png
|
7 |
+
./scan/4.png
|
8 |
+
./scan/5.png
|
9 |
+
./scan/6.png
|
10 |
+
./scan/7.png
|
11 |
+
./scan/9.png
|
12 |
+
./scan/10.png
|
13 |
+
./scan/21.png
|
14 |
+
./scan/22.png
|
15 |
+
./scan/23.png
|
16 |
+
./scan/24.png
|
17 |
+
./scan/27.png
|
18 |
+
./scan/30.png
|
19 |
+
./scan/31.png
|
20 |
+
./scan/32.png
|
21 |
+
./scan/36.png
|
22 |
+
./scan/38.png
|
23 |
+
./scan/40.png
|
24 |
+
./scan/41.png
|
25 |
+
./scan/44.png
|
26 |
+
./scan/45.png
|
27 |
+
./scan/46.png
|
28 |
+
./scan/47.png
|
29 |
+
./scan/48.png
|
30 |
+
./scan/50.png
|
31 |
+
./scan/51.png
|
32 |
+
./scan/52.png
|
33 |
+
./scan/53.png
|
34 |
+
|
35 |
+
# Setting 2 (Setting from DewarpNet)
|
36 |
+
# Link: https://github.com/cvlab-stonybrook/DewarpNet/blob/master/eval/ocr_eval/ocr_files.txt
|
37 |
+
# Total 25 * 2 = 50 images.
|
38 |
+
./scan/1.png
|
39 |
+
./scan/9.png
|
40 |
+
./scan/10.png
|
41 |
+
./scan/12.png
|
42 |
+
./scan/19.png
|
43 |
+
./scan/20.png
|
44 |
+
./scan/21.png
|
45 |
+
./scan/22.png
|
46 |
+
./scan/23.png
|
47 |
+
./scan/24.png
|
48 |
+
./scan/30.png
|
49 |
+
./scan/31.png
|
50 |
+
./scan/32.png
|
51 |
+
./scan/34.png
|
52 |
+
./scan/35.png
|
53 |
+
./scan/36.png
|
54 |
+
./scan/37.png
|
55 |
+
./scan/38.png
|
56 |
+
./scan/39.png
|
57 |
+
./scan/40.png
|
58 |
+
./scan/44.png
|
59 |
+
./scan/45.png
|
60 |
+
./scan/46.png
|
61 |
+
./scan/47.png
|
62 |
+
./scan/49.png
|
models/DocScanner/requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy==1.19.0
|
2 |
+
opencv_python==4.2.0.34
|
3 |
+
Pillow==9.4.0
|
4 |
+
scikit_image==0.17.2
|
5 |
+
skimage==0.0
|
6 |
+
torch==1.5.1+cu101
|
models/DocScanner/seg.py
ADDED
@@ -0,0 +1,576 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torchvision import models
|
6 |
+
|
7 |
+
|
8 |
+
class sobel_net(nn.Module):
|
9 |
+
def __init__(self):
|
10 |
+
super().__init__()
|
11 |
+
self.conv_opx = nn.Conv2d(1, 1, 3, bias=False)
|
12 |
+
self.conv_opy = nn.Conv2d(1, 1, 3, bias=False)
|
13 |
+
sobel_kernelx = np.array(
|
14 |
+
[[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype="float32"
|
15 |
+
).reshape((1, 1, 3, 3))
|
16 |
+
sobel_kernely = np.array(
|
17 |
+
[[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype="float32"
|
18 |
+
).reshape((1, 1, 3, 3))
|
19 |
+
self.conv_opx.weight.data = torch.from_numpy(sobel_kernelx)
|
20 |
+
self.conv_opy.weight.data = torch.from_numpy(sobel_kernely)
|
21 |
+
|
22 |
+
for p in self.parameters():
|
23 |
+
p.requires_grad = False
|
24 |
+
|
25 |
+
def forward(self, im): # input rgb
|
26 |
+
x = (
|
27 |
+
0.299 * im[:, 0, :, :] + 0.587 * im[:, 1, :, :] + 0.114 * im[:, 2, :, :]
|
28 |
+
).unsqueeze(
|
29 |
+
1
|
30 |
+
) # rgb2gray
|
31 |
+
gradx = self.conv_opx(x)
|
32 |
+
grady = self.conv_opy(x)
|
33 |
+
|
34 |
+
x = (gradx**2 + grady**2) ** 0.5
|
35 |
+
x = (x - x.min()) / (x.max() - x.min())
|
36 |
+
x = F.pad(x, (1, 1, 1, 1))
|
37 |
+
|
38 |
+
x = torch.cat([im, x], dim=1)
|
39 |
+
return x
|
40 |
+
|
41 |
+
|
42 |
+
class REBNCONV(nn.Module):
|
43 |
+
def __init__(self, in_ch=3, out_ch=3, dirate=1):
|
44 |
+
super(REBNCONV, self).__init__()
|
45 |
+
|
46 |
+
self.conv_s1 = nn.Conv2d(
|
47 |
+
in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate
|
48 |
+
)
|
49 |
+
self.bn_s1 = nn.BatchNorm2d(out_ch)
|
50 |
+
self.relu_s1 = nn.ReLU(inplace=True)
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
hx = x
|
54 |
+
xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
|
55 |
+
|
56 |
+
return xout
|
57 |
+
|
58 |
+
|
59 |
+
## upsample tensor 'src' to have the same spatial size with tensor 'tar'
|
60 |
+
def _upsample_like(src, tar):
|
61 |
+
src = F.interpolate(src, size=tar.shape[2:], mode="bilinear", align_corners=False)
|
62 |
+
|
63 |
+
return src
|
64 |
+
|
65 |
+
|
66 |
+
### RSU-7 ###
|
67 |
+
class RSU7(nn.Module): # UNet07DRES(nn.Module):
|
68 |
+
|
69 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
70 |
+
super(RSU7, self).__init__()
|
71 |
+
|
72 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
73 |
+
|
74 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
75 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
76 |
+
|
77 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
78 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
79 |
+
|
80 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
81 |
+
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
82 |
+
|
83 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
84 |
+
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
85 |
+
|
86 |
+
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
87 |
+
self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
88 |
+
|
89 |
+
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
90 |
+
|
91 |
+
self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
92 |
+
|
93 |
+
self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
94 |
+
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
95 |
+
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
96 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
97 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
98 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
99 |
+
|
100 |
+
def forward(self, x):
|
101 |
+
hx = x
|
102 |
+
hxin = self.rebnconvin(hx)
|
103 |
+
|
104 |
+
hx1 = self.rebnconv1(hxin)
|
105 |
+
hx = self.pool1(hx1)
|
106 |
+
|
107 |
+
hx2 = self.rebnconv2(hx)
|
108 |
+
hx = self.pool2(hx2)
|
109 |
+
|
110 |
+
hx3 = self.rebnconv3(hx)
|
111 |
+
hx = self.pool3(hx3)
|
112 |
+
|
113 |
+
hx4 = self.rebnconv4(hx)
|
114 |
+
hx = self.pool4(hx4)
|
115 |
+
|
116 |
+
hx5 = self.rebnconv5(hx)
|
117 |
+
hx = self.pool5(hx5)
|
118 |
+
|
119 |
+
hx6 = self.rebnconv6(hx)
|
120 |
+
|
121 |
+
hx7 = self.rebnconv7(hx6)
|
122 |
+
|
123 |
+
hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
|
124 |
+
hx6dup = _upsample_like(hx6d, hx5)
|
125 |
+
|
126 |
+
hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
|
127 |
+
hx5dup = _upsample_like(hx5d, hx4)
|
128 |
+
|
129 |
+
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
|
130 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
131 |
+
|
132 |
+
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
133 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
134 |
+
|
135 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
136 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
137 |
+
|
138 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
139 |
+
|
140 |
+
return hx1d + hxin
|
141 |
+
|
142 |
+
|
143 |
+
### RSU-6 ###
|
144 |
+
class RSU6(nn.Module): # UNet06DRES(nn.Module):
|
145 |
+
|
146 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
147 |
+
super(RSU6, self).__init__()
|
148 |
+
|
149 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
150 |
+
|
151 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
152 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
153 |
+
|
154 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
155 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
156 |
+
|
157 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
158 |
+
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
159 |
+
|
160 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
161 |
+
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
162 |
+
|
163 |
+
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
164 |
+
|
165 |
+
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
166 |
+
|
167 |
+
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
168 |
+
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
169 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
170 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
171 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
172 |
+
|
173 |
+
def forward(self, x):
|
174 |
+
hx = x
|
175 |
+
|
176 |
+
hxin = self.rebnconvin(hx)
|
177 |
+
|
178 |
+
hx1 = self.rebnconv1(hxin)
|
179 |
+
hx = self.pool1(hx1)
|
180 |
+
|
181 |
+
hx2 = self.rebnconv2(hx)
|
182 |
+
hx = self.pool2(hx2)
|
183 |
+
|
184 |
+
hx3 = self.rebnconv3(hx)
|
185 |
+
hx = self.pool3(hx3)
|
186 |
+
|
187 |
+
hx4 = self.rebnconv4(hx)
|
188 |
+
hx = self.pool4(hx4)
|
189 |
+
|
190 |
+
hx5 = self.rebnconv5(hx)
|
191 |
+
|
192 |
+
hx6 = self.rebnconv6(hx5)
|
193 |
+
|
194 |
+
hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
|
195 |
+
hx5dup = _upsample_like(hx5d, hx4)
|
196 |
+
|
197 |
+
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
|
198 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
199 |
+
|
200 |
+
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
201 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
202 |
+
|
203 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
204 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
205 |
+
|
206 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
207 |
+
|
208 |
+
return hx1d + hxin
|
209 |
+
|
210 |
+
|
211 |
+
### RSU-5 ###
|
212 |
+
class RSU5(nn.Module): # UNet05DRES(nn.Module):
|
213 |
+
|
214 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
215 |
+
super(RSU5, self).__init__()
|
216 |
+
|
217 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
218 |
+
|
219 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
220 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
221 |
+
|
222 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
223 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
224 |
+
|
225 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
226 |
+
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
227 |
+
|
228 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
229 |
+
|
230 |
+
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
231 |
+
|
232 |
+
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
233 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
234 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
235 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
236 |
+
|
237 |
+
def forward(self, x):
|
238 |
+
hx = x
|
239 |
+
|
240 |
+
hxin = self.rebnconvin(hx)
|
241 |
+
|
242 |
+
hx1 = self.rebnconv1(hxin)
|
243 |
+
hx = self.pool1(hx1)
|
244 |
+
|
245 |
+
hx2 = self.rebnconv2(hx)
|
246 |
+
hx = self.pool2(hx2)
|
247 |
+
|
248 |
+
hx3 = self.rebnconv3(hx)
|
249 |
+
hx = self.pool3(hx3)
|
250 |
+
|
251 |
+
hx4 = self.rebnconv4(hx)
|
252 |
+
|
253 |
+
hx5 = self.rebnconv5(hx4)
|
254 |
+
|
255 |
+
hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
|
256 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
257 |
+
|
258 |
+
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
259 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
260 |
+
|
261 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
262 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
263 |
+
|
264 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
265 |
+
|
266 |
+
return hx1d + hxin
|
267 |
+
|
268 |
+
|
269 |
+
### RSU-4 ###
|
270 |
+
class RSU4(nn.Module): # UNet04DRES(nn.Module):
|
271 |
+
|
272 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
273 |
+
super(RSU4, self).__init__()
|
274 |
+
|
275 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
276 |
+
|
277 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
278 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
279 |
+
|
280 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
281 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
282 |
+
|
283 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
284 |
+
|
285 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
286 |
+
|
287 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
288 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
289 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
290 |
+
|
291 |
+
def forward(self, x):
|
292 |
+
hx = x
|
293 |
+
|
294 |
+
hxin = self.rebnconvin(hx)
|
295 |
+
|
296 |
+
hx1 = self.rebnconv1(hxin)
|
297 |
+
hx = self.pool1(hx1)
|
298 |
+
|
299 |
+
hx2 = self.rebnconv2(hx)
|
300 |
+
hx = self.pool2(hx2)
|
301 |
+
|
302 |
+
hx3 = self.rebnconv3(hx)
|
303 |
+
|
304 |
+
hx4 = self.rebnconv4(hx3)
|
305 |
+
|
306 |
+
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
|
307 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
308 |
+
|
309 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
310 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
311 |
+
|
312 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
313 |
+
|
314 |
+
return hx1d + hxin
|
315 |
+
|
316 |
+
|
317 |
+
### RSU-4F ###
|
318 |
+
class RSU4F(nn.Module): # UNet04FRES(nn.Module):
|
319 |
+
|
320 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
321 |
+
super(RSU4F, self).__init__()
|
322 |
+
|
323 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
324 |
+
|
325 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
326 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
327 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
|
328 |
+
|
329 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
|
330 |
+
|
331 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
|
332 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
|
333 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
334 |
+
|
335 |
+
def forward(self, x):
|
336 |
+
hx = x
|
337 |
+
|
338 |
+
hxin = self.rebnconvin(hx)
|
339 |
+
|
340 |
+
hx1 = self.rebnconv1(hxin)
|
341 |
+
hx2 = self.rebnconv2(hx1)
|
342 |
+
hx3 = self.rebnconv3(hx2)
|
343 |
+
|
344 |
+
hx4 = self.rebnconv4(hx3)
|
345 |
+
|
346 |
+
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
|
347 |
+
hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
|
348 |
+
hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
|
349 |
+
|
350 |
+
return hx1d + hxin
|
351 |
+
|
352 |
+
|
353 |
+
##### U^2-Net ####
|
354 |
+
class U2NET(nn.Module):
|
355 |
+
|
356 |
+
def __init__(self, in_ch=3, out_ch=1):
|
357 |
+
super(U2NET, self).__init__()
|
358 |
+
self.edge = sobel_net()
|
359 |
+
|
360 |
+
self.stage1 = RSU7(in_ch, 32, 64)
|
361 |
+
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
362 |
+
|
363 |
+
self.stage2 = RSU6(64, 32, 128)
|
364 |
+
self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
365 |
+
|
366 |
+
self.stage3 = RSU5(128, 64, 256)
|
367 |
+
self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
368 |
+
|
369 |
+
self.stage4 = RSU4(256, 128, 512)
|
370 |
+
self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
371 |
+
|
372 |
+
self.stage5 = RSU4F(512, 256, 512)
|
373 |
+
self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
374 |
+
|
375 |
+
self.stage6 = RSU4F(512, 256, 512)
|
376 |
+
|
377 |
+
# decoder
|
378 |
+
self.stage5d = RSU4F(1024, 256, 512)
|
379 |
+
self.stage4d = RSU4(1024, 128, 256)
|
380 |
+
self.stage3d = RSU5(512, 64, 128)
|
381 |
+
self.stage2d = RSU6(256, 32, 64)
|
382 |
+
self.stage1d = RSU7(128, 16, 64)
|
383 |
+
|
384 |
+
self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
|
385 |
+
self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
|
386 |
+
self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
|
387 |
+
self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
|
388 |
+
self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
|
389 |
+
self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
|
390 |
+
|
391 |
+
self.outconv = nn.Conv2d(6, out_ch, 1)
|
392 |
+
|
393 |
+
def forward(self, x):
|
394 |
+
x = self.edge(x)
|
395 |
+
hx = x
|
396 |
+
|
397 |
+
# stage 1
|
398 |
+
hx1 = self.stage1(hx)
|
399 |
+
hx = self.pool12(hx1)
|
400 |
+
|
401 |
+
# stage 2
|
402 |
+
hx2 = self.stage2(hx)
|
403 |
+
hx = self.pool23(hx2)
|
404 |
+
|
405 |
+
# stage 3
|
406 |
+
hx3 = self.stage3(hx)
|
407 |
+
hx = self.pool34(hx3)
|
408 |
+
|
409 |
+
# stage 4
|
410 |
+
hx4 = self.stage4(hx)
|
411 |
+
hx = self.pool45(hx4)
|
412 |
+
|
413 |
+
# stage 5
|
414 |
+
hx5 = self.stage5(hx)
|
415 |
+
hx = self.pool56(hx5)
|
416 |
+
|
417 |
+
# stage 6
|
418 |
+
hx6 = self.stage6(hx)
|
419 |
+
hx6up = _upsample_like(hx6, hx5)
|
420 |
+
|
421 |
+
# -------------------- decoder --------------------
|
422 |
+
hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
|
423 |
+
hx5dup = _upsample_like(hx5d, hx4)
|
424 |
+
|
425 |
+
hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
|
426 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
427 |
+
|
428 |
+
hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
|
429 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
430 |
+
|
431 |
+
hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
|
432 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
433 |
+
|
434 |
+
hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
|
435 |
+
|
436 |
+
# side output
|
437 |
+
d1 = self.side1(hx1d)
|
438 |
+
|
439 |
+
d2 = self.side2(hx2d)
|
440 |
+
d2 = _upsample_like(d2, d1)
|
441 |
+
|
442 |
+
d3 = self.side3(hx3d)
|
443 |
+
d3 = _upsample_like(d3, d1)
|
444 |
+
|
445 |
+
d4 = self.side4(hx4d)
|
446 |
+
d4 = _upsample_like(d4, d1)
|
447 |
+
|
448 |
+
d5 = self.side5(hx5d)
|
449 |
+
d5 = _upsample_like(d5, d1)
|
450 |
+
|
451 |
+
d6 = self.side6(hx6)
|
452 |
+
d6 = _upsample_like(d6, d1)
|
453 |
+
|
454 |
+
d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
|
455 |
+
|
456 |
+
return (
|
457 |
+
torch.sigmoid(d0),
|
458 |
+
torch.sigmoid(d1),
|
459 |
+
torch.sigmoid(d2),
|
460 |
+
torch.sigmoid(d3),
|
461 |
+
torch.sigmoid(d4),
|
462 |
+
torch.sigmoid(d5),
|
463 |
+
torch.sigmoid(d6),
|
464 |
+
)
|
465 |
+
|
466 |
+
|
467 |
+
### U^2-Net small ###
|
468 |
+
class U2NETP(nn.Module):
|
469 |
+
|
470 |
+
def __init__(self, in_ch=3, out_ch=1):
|
471 |
+
super(U2NETP, self).__init__()
|
472 |
+
|
473 |
+
self.stage1 = RSU7(in_ch, 16, 64)
|
474 |
+
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
475 |
+
|
476 |
+
self.stage2 = RSU6(64, 16, 64)
|
477 |
+
self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
478 |
+
|
479 |
+
self.stage3 = RSU5(64, 16, 64)
|
480 |
+
self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
481 |
+
|
482 |
+
self.stage4 = RSU4(64, 16, 64)
|
483 |
+
self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
484 |
+
|
485 |
+
self.stage5 = RSU4F(64, 16, 64)
|
486 |
+
self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
487 |
+
|
488 |
+
self.stage6 = RSU4F(64, 16, 64)
|
489 |
+
|
490 |
+
# decoder
|
491 |
+
self.stage5d = RSU4F(128, 16, 64)
|
492 |
+
self.stage4d = RSU4(128, 16, 64)
|
493 |
+
self.stage3d = RSU5(128, 16, 64)
|
494 |
+
self.stage2d = RSU6(128, 16, 64)
|
495 |
+
self.stage1d = RSU7(128, 16, 64)
|
496 |
+
|
497 |
+
self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
|
498 |
+
self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
|
499 |
+
self.side3 = nn.Conv2d(64, out_ch, 3, padding=1)
|
500 |
+
self.side4 = nn.Conv2d(64, out_ch, 3, padding=1)
|
501 |
+
self.side5 = nn.Conv2d(64, out_ch, 3, padding=1)
|
502 |
+
self.side6 = nn.Conv2d(64, out_ch, 3, padding=1)
|
503 |
+
|
504 |
+
self.outconv = nn.Conv2d(6, out_ch, 1)
|
505 |
+
|
506 |
+
def forward(self, x):
|
507 |
+
hx = x
|
508 |
+
|
509 |
+
# stage 1
|
510 |
+
hx1 = self.stage1(hx)
|
511 |
+
hx = self.pool12(hx1)
|
512 |
+
|
513 |
+
# stage 2
|
514 |
+
hx2 = self.stage2(hx)
|
515 |
+
hx = self.pool23(hx2)
|
516 |
+
|
517 |
+
# stage 3
|
518 |
+
hx3 = self.stage3(hx)
|
519 |
+
hx = self.pool34(hx3)
|
520 |
+
|
521 |
+
# stage 4
|
522 |
+
hx4 = self.stage4(hx)
|
523 |
+
hx = self.pool45(hx4)
|
524 |
+
|
525 |
+
# stage 5
|
526 |
+
hx5 = self.stage5(hx)
|
527 |
+
hx = self.pool56(hx5)
|
528 |
+
|
529 |
+
# stage 6
|
530 |
+
hx6 = self.stage6(hx)
|
531 |
+
hx6up = _upsample_like(hx6, hx5)
|
532 |
+
|
533 |
+
# decoder
|
534 |
+
hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
|
535 |
+
hx5dup = _upsample_like(hx5d, hx4)
|
536 |
+
|
537 |
+
hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
|
538 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
539 |
+
|
540 |
+
hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
|
541 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
542 |
+
|
543 |
+
hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
|
544 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
545 |
+
|
546 |
+
hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
|
547 |
+
|
548 |
+
# side output
|
549 |
+
d1 = self.side1(hx1d)
|
550 |
+
|
551 |
+
d2 = self.side2(hx2d)
|
552 |
+
d2 = _upsample_like(d2, d1)
|
553 |
+
|
554 |
+
d3 = self.side3(hx3d)
|
555 |
+
d3 = _upsample_like(d3, d1)
|
556 |
+
|
557 |
+
d4 = self.side4(hx4d)
|
558 |
+
d4 = _upsample_like(d4, d1)
|
559 |
+
|
560 |
+
d5 = self.side5(hx5d)
|
561 |
+
d5 = _upsample_like(d5, d1)
|
562 |
+
|
563 |
+
d6 = self.side6(hx6)
|
564 |
+
d6 = _upsample_like(d6, d1)
|
565 |
+
|
566 |
+
d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
|
567 |
+
|
568 |
+
return (
|
569 |
+
torch.sigmoid(d0),
|
570 |
+
torch.sigmoid(d1),
|
571 |
+
torch.sigmoid(d2),
|
572 |
+
torch.sigmoid(d3),
|
573 |
+
torch.sigmoid(d4),
|
574 |
+
torch.sigmoid(d5),
|
575 |
+
torch.sigmoid(d6),
|
576 |
+
)
|
models/DocScanner/update.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class FlowHead(nn.Module):
|
7 |
+
def __init__(self, input_dim=128, hidden_dim=256):
|
8 |
+
super(FlowHead, self).__init__()
|
9 |
+
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
|
10 |
+
self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
|
11 |
+
self.relu = nn.ReLU(inplace=True)
|
12 |
+
|
13 |
+
def forward(self, x):
|
14 |
+
return self.conv2(self.relu(self.conv1(x)))
|
15 |
+
|
16 |
+
|
17 |
+
class ConvGRU(nn.Module):
|
18 |
+
def __init__(self, hidden_dim=128, input_dim=192 + 128):
|
19 |
+
super(ConvGRU, self).__init__()
|
20 |
+
self.convz = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
|
21 |
+
self.convr = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
|
22 |
+
self.convq = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
|
23 |
+
|
24 |
+
def forward(self, h, x):
|
25 |
+
hx = torch.cat([h, x], dim=1)
|
26 |
+
|
27 |
+
z = torch.sigmoid(self.convz(hx))
|
28 |
+
r = torch.sigmoid(self.convr(hx))
|
29 |
+
q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1)))
|
30 |
+
|
31 |
+
h = (1 - z) * h + z * q
|
32 |
+
return h
|
33 |
+
|
34 |
+
|
35 |
+
class SepConvGRU(nn.Module):
|
36 |
+
def __init__(self, hidden_dim=128, input_dim=192 + 128):
|
37 |
+
super(SepConvGRU, self).__init__()
|
38 |
+
self.convz1 = nn.Conv2d(
|
39 |
+
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
|
40 |
+
)
|
41 |
+
self.convr1 = nn.Conv2d(
|
42 |
+
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
|
43 |
+
)
|
44 |
+
self.convq1 = nn.Conv2d(
|
45 |
+
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
|
46 |
+
)
|
47 |
+
|
48 |
+
self.convz2 = nn.Conv2d(
|
49 |
+
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
|
50 |
+
)
|
51 |
+
self.convr2 = nn.Conv2d(
|
52 |
+
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
|
53 |
+
)
|
54 |
+
self.convq2 = nn.Conv2d(
|
55 |
+
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
|
56 |
+
)
|
57 |
+
|
58 |
+
def forward(self, h, x):
|
59 |
+
# horizontal
|
60 |
+
hx = torch.cat([h, x], dim=1)
|
61 |
+
z = torch.sigmoid(self.convz1(hx))
|
62 |
+
r = torch.sigmoid(self.convr1(hx))
|
63 |
+
q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1)))
|
64 |
+
h = (1 - z) * h + z * q
|
65 |
+
|
66 |
+
# vertical
|
67 |
+
hx = torch.cat([h, x], dim=1)
|
68 |
+
z = torch.sigmoid(self.convz2(hx))
|
69 |
+
r = torch.sigmoid(self.convr2(hx))
|
70 |
+
q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1)))
|
71 |
+
h = (1 - z) * h + z * q
|
72 |
+
|
73 |
+
return h
|
74 |
+
|
75 |
+
|
76 |
+
class BasicMotionEncoder(nn.Module):
|
77 |
+
def __init__(self):
|
78 |
+
super(BasicMotionEncoder, self).__init__()
|
79 |
+
self.convc1 = nn.Conv2d(320, 240, 1, padding=0)
|
80 |
+
self.convc2 = nn.Conv2d(240, 160, 3, padding=1)
|
81 |
+
self.convf1 = nn.Conv2d(2, 160, 7, padding=3)
|
82 |
+
self.convf2 = nn.Conv2d(160, 80, 3, padding=1)
|
83 |
+
self.conv = nn.Conv2d(160 + 80, 160 - 2, 3, padding=1)
|
84 |
+
|
85 |
+
def forward(self, flow, corr):
|
86 |
+
cor = F.relu(self.convc1(corr))
|
87 |
+
cor = F.relu(self.convc2(cor))
|
88 |
+
flo = F.relu(self.convf1(flow))
|
89 |
+
flo = F.relu(self.convf2(flo))
|
90 |
+
|
91 |
+
cor_flo = torch.cat([cor, flo], dim=1)
|
92 |
+
out = F.relu(self.conv(cor_flo))
|
93 |
+
return torch.cat([out, flow], dim=1)
|
94 |
+
|
95 |
+
|
96 |
+
class BasicUpdateBlock(nn.Module):
|
97 |
+
def __init__(self, hidden_dim=128):
|
98 |
+
super(BasicUpdateBlock, self).__init__()
|
99 |
+
self.encoder = BasicMotionEncoder()
|
100 |
+
self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=160 + 160)
|
101 |
+
self.flow_head = FlowHead(hidden_dim, hidden_dim=320)
|
102 |
+
|
103 |
+
self.mask = nn.Sequential(
|
104 |
+
nn.Conv2d(hidden_dim, 288, 3, padding=1),
|
105 |
+
nn.ReLU(inplace=True),
|
106 |
+
nn.Conv2d(288, 64 * 9, 1, padding=0),
|
107 |
+
)
|
108 |
+
|
109 |
+
def forward(self, net, inp, corr, flow):
|
110 |
+
motion_features = self.encoder(flow, corr)
|
111 |
+
inp = torch.cat([inp, motion_features], dim=1)
|
112 |
+
|
113 |
+
net = self.gru(net, inp)
|
114 |
+
|
115 |
+
delta_flow = self.flow_head(net)
|
116 |
+
|
117 |
+
mask = 0.25 * self.mask(net)
|
118 |
+
|
119 |
+
return net, mask, delta_flow
|
models/DocTr-Plus/GeoTr.py
ADDED
@@ -0,0 +1,960 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
5 |
+
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
import copy
|
9 |
+
from typing import Optional
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from torch import Tensor, nn
|
15 |
+
|
16 |
+
from .extractor import BasicEncoder
|
17 |
+
from .position_encoding import build_position_encoding
|
18 |
+
|
19 |
+
|
20 |
+
class attnLayer(nn.Module):
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
d_model,
|
24 |
+
nhead=8,
|
25 |
+
dim_feedforward=2048,
|
26 |
+
dropout=0.1,
|
27 |
+
activation="relu",
|
28 |
+
normalize_before=False,
|
29 |
+
):
|
30 |
+
super().__init__()
|
31 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
32 |
+
self.multihead_attn_list = nn.ModuleList(
|
33 |
+
[
|
34 |
+
copy.deepcopy(nn.MultiheadAttention(d_model, nhead, dropout=dropout))
|
35 |
+
for i in range(2)
|
36 |
+
]
|
37 |
+
)
|
38 |
+
# Implementation of Feedforward model
|
39 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
40 |
+
self.dropout = nn.Dropout(dropout)
|
41 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
42 |
+
|
43 |
+
self.norm1 = nn.LayerNorm(d_model)
|
44 |
+
self.norm2_list = nn.ModuleList(
|
45 |
+
[copy.deepcopy(nn.LayerNorm(d_model)) for i in range(2)]
|
46 |
+
)
|
47 |
+
|
48 |
+
self.norm3 = nn.LayerNorm(d_model)
|
49 |
+
self.dropout1 = nn.Dropout(dropout)
|
50 |
+
self.dropout2_list = nn.ModuleList(
|
51 |
+
[copy.deepcopy(nn.Dropout(dropout)) for i in range(2)]
|
52 |
+
)
|
53 |
+
self.dropout3 = nn.Dropout(dropout)
|
54 |
+
|
55 |
+
self.activation = _get_activation_fn(activation)
|
56 |
+
self.normalize_before = normalize_before
|
57 |
+
|
58 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
59 |
+
return tensor if pos is None else tensor + pos
|
60 |
+
|
61 |
+
def forward_post(
|
62 |
+
self,
|
63 |
+
tgt,
|
64 |
+
memory_list,
|
65 |
+
tgt_mask=None,
|
66 |
+
memory_mask=None,
|
67 |
+
tgt_key_padding_mask=None,
|
68 |
+
memory_key_padding_mask=None,
|
69 |
+
pos=None,
|
70 |
+
memory_pos=None,
|
71 |
+
):
|
72 |
+
q = k = self.with_pos_embed(tgt, pos)
|
73 |
+
tgt2 = self.self_attn(
|
74 |
+
q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
|
75 |
+
)[0]
|
76 |
+
tgt = tgt + self.dropout1(tgt2)
|
77 |
+
tgt = self.norm1(tgt)
|
78 |
+
for memory, multihead_attn, norm2, dropout2, m_pos in zip(
|
79 |
+
memory_list,
|
80 |
+
self.multihead_attn_list,
|
81 |
+
self.norm2_list,
|
82 |
+
self.dropout2_list,
|
83 |
+
memory_pos,
|
84 |
+
):
|
85 |
+
tgt2 = multihead_attn(
|
86 |
+
query=self.with_pos_embed(tgt, pos),
|
87 |
+
key=self.with_pos_embed(memory, m_pos),
|
88 |
+
value=memory,
|
89 |
+
attn_mask=memory_mask,
|
90 |
+
key_padding_mask=memory_key_padding_mask,
|
91 |
+
)[0]
|
92 |
+
tgt = tgt + dropout2(tgt2)
|
93 |
+
tgt = norm2(tgt)
|
94 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
95 |
+
tgt = tgt + self.dropout3(tgt2)
|
96 |
+
tgt = self.norm3(tgt)
|
97 |
+
return tgt
|
98 |
+
|
99 |
+
def forward_pre(
|
100 |
+
self,
|
101 |
+
tgt,
|
102 |
+
memory,
|
103 |
+
tgt_mask=None,
|
104 |
+
memory_mask=None,
|
105 |
+
tgt_key_padding_mask=None,
|
106 |
+
memory_key_padding_mask=None,
|
107 |
+
pos=None,
|
108 |
+
memory_pos=None,
|
109 |
+
):
|
110 |
+
tgt2 = self.norm1(tgt)
|
111 |
+
q = k = self.with_pos_embed(tgt2, pos)
|
112 |
+
tgt2 = self.self_attn(
|
113 |
+
q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
|
114 |
+
)[0]
|
115 |
+
tgt = tgt + self.dropout1(tgt2)
|
116 |
+
tgt2 = self.norm2(tgt)
|
117 |
+
tgt2 = self.multihead_attn(
|
118 |
+
query=self.with_pos_embed(tgt2, pos),
|
119 |
+
key=self.with_pos_embed(memory, memory_pos),
|
120 |
+
value=memory,
|
121 |
+
attn_mask=memory_mask,
|
122 |
+
key_padding_mask=memory_key_padding_mask,
|
123 |
+
)[0]
|
124 |
+
tgt = tgt + self.dropout2(tgt2)
|
125 |
+
tgt2 = self.norm3(tgt)
|
126 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
127 |
+
tgt = tgt + self.dropout3(tgt2)
|
128 |
+
return tgt
|
129 |
+
|
130 |
+
def forward(
|
131 |
+
self,
|
132 |
+
tgt,
|
133 |
+
memory_list,
|
134 |
+
tgt_mask=None,
|
135 |
+
memory_mask=None,
|
136 |
+
tgt_key_padding_mask=None,
|
137 |
+
memory_key_padding_mask=None,
|
138 |
+
pos=None,
|
139 |
+
memory_pos=None,
|
140 |
+
):
|
141 |
+
if self.normalize_before:
|
142 |
+
return self.forward_pre(
|
143 |
+
tgt,
|
144 |
+
memory_list,
|
145 |
+
tgt_mask,
|
146 |
+
memory_mask,
|
147 |
+
tgt_key_padding_mask,
|
148 |
+
memory_key_padding_mask,
|
149 |
+
pos,
|
150 |
+
memory_pos,
|
151 |
+
)
|
152 |
+
return self.forward_post(
|
153 |
+
tgt,
|
154 |
+
memory_list,
|
155 |
+
tgt_mask,
|
156 |
+
memory_mask,
|
157 |
+
tgt_key_padding_mask,
|
158 |
+
memory_key_padding_mask,
|
159 |
+
pos,
|
160 |
+
memory_pos,
|
161 |
+
)
|
162 |
+
|
163 |
+
|
164 |
+
def _get_clones(module, N):
|
165 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
166 |
+
|
167 |
+
|
168 |
+
def _get_activation_fn(activation):
|
169 |
+
"""Return an activation function given a string"""
|
170 |
+
if activation == "relu":
|
171 |
+
return F.relu
|
172 |
+
if activation == "gelu":
|
173 |
+
return F.gelu
|
174 |
+
if activation == "glu":
|
175 |
+
return F.glu
|
176 |
+
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
|
177 |
+
|
178 |
+
|
179 |
+
class TransDecoder(nn.Module):
|
180 |
+
def __init__(self, num_attn_layers, hidden_dim=128):
|
181 |
+
super(TransDecoder, self).__init__()
|
182 |
+
attn_layer = attnLayer(hidden_dim)
|
183 |
+
self.layers = _get_clones(attn_layer, num_attn_layers)
|
184 |
+
self.position_embedding = build_position_encoding(hidden_dim)
|
185 |
+
|
186 |
+
def forward(self, imgf, query_embed):
|
187 |
+
pos = self.position_embedding(
|
188 |
+
torch.ones(imgf.shape[0], imgf.shape[2], imgf.shape[3]).bool().cuda()
|
189 |
+
) # torch.Size([1, 128, 36, 36])
|
190 |
+
|
191 |
+
bs, c, h, w = imgf.shape
|
192 |
+
imgf = imgf.flatten(2).permute(2, 0, 1)
|
193 |
+
# query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
194 |
+
pos = pos.flatten(2).permute(2, 0, 1)
|
195 |
+
|
196 |
+
for layer in self.layers:
|
197 |
+
query_embed = layer(query_embed, [imgf], pos=pos, memory_pos=[pos, pos])
|
198 |
+
query_embed = query_embed.permute(1, 2, 0).reshape(bs, c, h, w)
|
199 |
+
|
200 |
+
return query_embed
|
201 |
+
|
202 |
+
|
203 |
+
class TransEncoder(nn.Module):
|
204 |
+
def __init__(self, num_attn_layers, hidden_dim=128):
|
205 |
+
super(TransEncoder, self).__init__()
|
206 |
+
attn_layer = attnLayer(hidden_dim)
|
207 |
+
self.layers = _get_clones(attn_layer, num_attn_layers)
|
208 |
+
self.position_embedding = build_position_encoding(hidden_dim)
|
209 |
+
|
210 |
+
def forward(self, imgf):
|
211 |
+
pos = self.position_embedding(
|
212 |
+
torch.ones(imgf.shape[0], imgf.shape[2], imgf.shape[3]).bool().cuda()
|
213 |
+
) # torch.Size([1, 128, 36, 36])
|
214 |
+
bs, c, h, w = imgf.shape
|
215 |
+
imgf = imgf.flatten(2).permute(2, 0, 1)
|
216 |
+
pos = pos.flatten(2).permute(2, 0, 1)
|
217 |
+
|
218 |
+
for layer in self.layers:
|
219 |
+
imgf = layer(imgf, [imgf], pos=pos, memory_pos=[pos, pos])
|
220 |
+
imgf = imgf.permute(1, 2, 0).reshape(bs, c, h, w)
|
221 |
+
|
222 |
+
return imgf
|
223 |
+
|
224 |
+
|
225 |
+
class FlowHead(nn.Module):
|
226 |
+
def __init__(self, input_dim=128, hidden_dim=256):
|
227 |
+
super(FlowHead, self).__init__()
|
228 |
+
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
|
229 |
+
self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
|
230 |
+
self.relu = nn.ReLU(inplace=True)
|
231 |
+
|
232 |
+
def forward(self, x):
|
233 |
+
return self.conv2(self.relu(self.conv1(x)))
|
234 |
+
|
235 |
+
|
236 |
+
class UpdateBlock(nn.Module):
|
237 |
+
def __init__(self, hidden_dim=128):
|
238 |
+
super(UpdateBlock, self).__init__()
|
239 |
+
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
|
240 |
+
self.mask = nn.Sequential(
|
241 |
+
nn.Conv2d(hidden_dim, 256, 3, padding=1),
|
242 |
+
nn.ReLU(inplace=True),
|
243 |
+
nn.Conv2d(256, 64 * 9, 1, padding=0),
|
244 |
+
)
|
245 |
+
|
246 |
+
def forward(self, imgf, coords1):
|
247 |
+
mask = 0.25 * self.mask(imgf) # scale mask to balence gradients
|
248 |
+
dflow = self.flow_head(imgf)
|
249 |
+
coords1 = coords1 + dflow
|
250 |
+
|
251 |
+
return mask, coords1
|
252 |
+
|
253 |
+
|
254 |
+
def coords_grid(batch, ht, wd):
|
255 |
+
coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
|
256 |
+
coords = torch.stack(coords[::-1], dim=0).float()
|
257 |
+
return coords[None].repeat(batch, 1, 1, 1)
|
258 |
+
|
259 |
+
|
260 |
+
def upflow8(flow, mode="bilinear"):
|
261 |
+
new_size = (8 * flow.shape[2], 8 * flow.shape[3])
|
262 |
+
return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
|
263 |
+
|
264 |
+
|
265 |
+
class OverlapPatchEmbed(nn.Module):
|
266 |
+
"""Image to Patch Embedding"""
|
267 |
+
|
268 |
+
def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
|
269 |
+
super().__init__()
|
270 |
+
img_size = to_2tuple(img_size)
|
271 |
+
patch_size = to_2tuple(patch_size)
|
272 |
+
|
273 |
+
self.img_size = img_size
|
274 |
+
self.patch_size = patch_size
|
275 |
+
self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
|
276 |
+
self.num_patches = self.H * self.W
|
277 |
+
self.proj = nn.Conv2d(
|
278 |
+
in_chans,
|
279 |
+
embed_dim,
|
280 |
+
kernel_size=patch_size,
|
281 |
+
stride=stride,
|
282 |
+
padding=(patch_size[0] // 2, patch_size[1] // 2),
|
283 |
+
)
|
284 |
+
self.norm = nn.LayerNorm(embed_dim)
|
285 |
+
|
286 |
+
self.apply(self._init_weights)
|
287 |
+
|
288 |
+
def _init_weights(self, m):
|
289 |
+
if isinstance(m, nn.Linear):
|
290 |
+
trunc_normal_(m.weight, std=0.02)
|
291 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
292 |
+
nn.init.constant_(m.bias, 0)
|
293 |
+
elif isinstance(m, nn.LayerNorm):
|
294 |
+
nn.init.constant_(m.bias, 0)
|
295 |
+
nn.init.constant_(m.weight, 1.0)
|
296 |
+
elif isinstance(m, nn.Conv2d):
|
297 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
298 |
+
fan_out //= m.groups
|
299 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
300 |
+
if m.bias is not None:
|
301 |
+
m.bias.data.zero_()
|
302 |
+
|
303 |
+
def forward(self, x):
|
304 |
+
x = self.proj(x)
|
305 |
+
_, _, H, W = x.shape
|
306 |
+
x = x.flatten(2).transpose(1, 2)
|
307 |
+
x = self.norm(x)
|
308 |
+
|
309 |
+
return x, H, W
|
310 |
+
|
311 |
+
|
312 |
+
class GeoTr(nn.Module):
|
313 |
+
def __init__(self):
|
314 |
+
super(GeoTr, self).__init__()
|
315 |
+
|
316 |
+
self.hidden_dim = hdim = 256
|
317 |
+
|
318 |
+
self.fnet = BasicEncoder(output_dim=hdim, norm_fn="instance")
|
319 |
+
|
320 |
+
self.encoder_block = ["encoder_block" + str(i) for i in range(3)]
|
321 |
+
for i in self.encoder_block:
|
322 |
+
self.__setattr__(i, TransEncoder(2, hidden_dim=hdim))
|
323 |
+
self.down_layer = ["down_layer" + str(i) for i in range(2)]
|
324 |
+
for i in self.down_layer:
|
325 |
+
self.__setattr__(i, nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1))
|
326 |
+
|
327 |
+
self.decoder_block = ["decoder_block" + str(i) for i in range(3)]
|
328 |
+
for i in self.decoder_block:
|
329 |
+
self.__setattr__(i, TransDecoder(2, hidden_dim=hdim))
|
330 |
+
self.up_layer = ["up_layer" + str(i) for i in range(2)]
|
331 |
+
for i in self.up_layer:
|
332 |
+
self.__setattr__(
|
333 |
+
i, nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
|
334 |
+
)
|
335 |
+
|
336 |
+
self.query_embed = nn.Embedding(81, self.hidden_dim)
|
337 |
+
|
338 |
+
self.update_block = UpdateBlock(self.hidden_dim)
|
339 |
+
|
340 |
+
def initialize_flow(self, img):
|
341 |
+
N, C, H, W = img.shape
|
342 |
+
coodslar = coords_grid(N, H, W).to(img.device)
|
343 |
+
coords0 = coords_grid(N, H // 8, W // 8).to(img.device)
|
344 |
+
coords1 = coords_grid(N, H // 8, W // 8).to(img.device)
|
345 |
+
|
346 |
+
return coodslar, coords0, coords1
|
347 |
+
|
348 |
+
def upsample_flow(self, flow, mask):
|
349 |
+
N, _, H, W = flow.shape
|
350 |
+
mask = mask.view(N, 1, 9, 8, 8, H, W)
|
351 |
+
mask = torch.softmax(mask, dim=2)
|
352 |
+
|
353 |
+
up_flow = F.unfold(8 * flow, [3, 3], padding=1)
|
354 |
+
up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
|
355 |
+
|
356 |
+
up_flow = torch.sum(mask * up_flow, dim=2)
|
357 |
+
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
|
358 |
+
|
359 |
+
return up_flow.reshape(N, 2, 8 * H, 8 * W)
|
360 |
+
|
361 |
+
def forward(self, image1):
|
362 |
+
fmap = self.fnet(image1)
|
363 |
+
fmap = torch.relu(fmap)
|
364 |
+
|
365 |
+
# fmap = self.TransEncoder(fmap)
|
366 |
+
fmap1 = self.__getattr__(self.encoder_block[0])(fmap)
|
367 |
+
fmap1d = self.__getattr__(self.down_layer[0])(fmap1)
|
368 |
+
fmap2 = self.__getattr__(self.encoder_block[1])(fmap1d)
|
369 |
+
fmap2d = self.__getattr__(self.down_layer[1])(fmap2)
|
370 |
+
fmap3 = self.__getattr__(self.encoder_block[2])(fmap2d)
|
371 |
+
|
372 |
+
query_embed0 = self.query_embed.weight.unsqueeze(1).repeat(1, fmap3.size(0), 1)
|
373 |
+
fmap3d_ = self.__getattr__(self.decoder_block[0])(fmap3, query_embed0)
|
374 |
+
fmap3du_ = (
|
375 |
+
self.__getattr__(self.up_layer[0])(fmap3d_).flatten(2).permute(2, 0, 1)
|
376 |
+
)
|
377 |
+
fmap2d_ = self.__getattr__(self.decoder_block[1])(fmap2, fmap3du_)
|
378 |
+
fmap2du_ = (
|
379 |
+
self.__getattr__(self.up_layer[1])(fmap2d_).flatten(2).permute(2, 0, 1)
|
380 |
+
)
|
381 |
+
fmap_out = self.__getattr__(self.decoder_block[2])(fmap1, fmap2du_)
|
382 |
+
|
383 |
+
# convex upsample baesd on fmap_out
|
384 |
+
coodslar, coords0, coords1 = self.initialize_flow(image1)
|
385 |
+
coords1 = coords1.detach()
|
386 |
+
mask, coords1 = self.update_block(fmap_out, coords1)
|
387 |
+
flow_up = self.upsample_flow(coords1 - coords0, mask)
|
388 |
+
bm_up = coodslar + flow_up
|
389 |
+
|
390 |
+
return bm_up
|
391 |
+
|
392 |
+
|
393 |
+
## upsample tensor 'src' to have the same spatial size with tensor 'tar'
|
394 |
+
def _upsample_like(src, tar):
|
395 |
+
src = F.interpolate(src, size=tar.shape[2:], mode="bilinear", align_corners=False)
|
396 |
+
|
397 |
+
return src
|
398 |
+
|
399 |
+
|
400 |
+
class REBNCONV(nn.Module):
|
401 |
+
def __init__(self, in_ch=3, out_ch=3, dirate=1):
|
402 |
+
super(REBNCONV, self).__init__()
|
403 |
+
|
404 |
+
self.conv_s1 = nn.Conv2d(
|
405 |
+
in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate
|
406 |
+
)
|
407 |
+
self.bn_s1 = nn.BatchNorm2d(out_ch)
|
408 |
+
self.relu_s1 = nn.ReLU(inplace=True)
|
409 |
+
|
410 |
+
def forward(self, x):
|
411 |
+
hx = x
|
412 |
+
xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
|
413 |
+
|
414 |
+
return xout
|
415 |
+
|
416 |
+
|
417 |
+
### RSU-4 ###
|
418 |
+
class RSU4(nn.Module): # UNet04DRES(nn.Module):
|
419 |
+
|
420 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
421 |
+
super(RSU4, self).__init__()
|
422 |
+
|
423 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
424 |
+
|
425 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
426 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
427 |
+
|
428 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
429 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
430 |
+
|
431 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
432 |
+
|
433 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
434 |
+
|
435 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
436 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
437 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
438 |
+
|
439 |
+
def forward(self, x):
|
440 |
+
hx = x
|
441 |
+
|
442 |
+
hxin = self.rebnconvin(hx)
|
443 |
+
|
444 |
+
hx1 = self.rebnconv1(hxin)
|
445 |
+
hx = self.pool1(hx1)
|
446 |
+
|
447 |
+
hx2 = self.rebnconv2(hx)
|
448 |
+
hx = self.pool2(hx2)
|
449 |
+
|
450 |
+
hx3 = self.rebnconv3(hx)
|
451 |
+
|
452 |
+
hx4 = self.rebnconv4(hx3)
|
453 |
+
|
454 |
+
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
|
455 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
456 |
+
|
457 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
458 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
459 |
+
|
460 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
461 |
+
|
462 |
+
return hx1d + hxin
|
463 |
+
|
464 |
+
|
465 |
+
### RSU-4F ###
|
466 |
+
class RSU4F(nn.Module): # UNet04FRES(nn.Module):
|
467 |
+
|
468 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
469 |
+
super(RSU4F, self).__init__()
|
470 |
+
|
471 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
472 |
+
|
473 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
474 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
475 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
|
476 |
+
|
477 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
|
478 |
+
|
479 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
|
480 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
|
481 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
482 |
+
|
483 |
+
def forward(self, x):
|
484 |
+
hx = x
|
485 |
+
|
486 |
+
hxin = self.rebnconvin(hx)
|
487 |
+
|
488 |
+
hx1 = self.rebnconv1(hxin)
|
489 |
+
hx2 = self.rebnconv2(hx1)
|
490 |
+
hx3 = self.rebnconv3(hx2)
|
491 |
+
|
492 |
+
hx4 = self.rebnconv4(hx3)
|
493 |
+
|
494 |
+
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
|
495 |
+
hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
|
496 |
+
hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
|
497 |
+
|
498 |
+
return hx1d + hxin
|
499 |
+
|
500 |
+
|
501 |
+
class sobel_net(nn.Module):
|
502 |
+
def __init__(self):
|
503 |
+
super().__init__()
|
504 |
+
self.conv_opx = nn.Conv2d(1, 1, 3, bias=False)
|
505 |
+
self.conv_opy = nn.Conv2d(1, 1, 3, bias=False)
|
506 |
+
sobel_kernelx = np.array(
|
507 |
+
[[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype="float32"
|
508 |
+
).reshape((1, 1, 3, 3))
|
509 |
+
sobel_kernely = np.array(
|
510 |
+
[[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype="float32"
|
511 |
+
).reshape((1, 1, 3, 3))
|
512 |
+
self.conv_opx.weight.data = torch.from_numpy(sobel_kernelx)
|
513 |
+
self.conv_opy.weight.data = torch.from_numpy(sobel_kernely)
|
514 |
+
|
515 |
+
for p in self.parameters():
|
516 |
+
p.requires_grad = False
|
517 |
+
|
518 |
+
def forward(self, im): # input rgb
|
519 |
+
x = (
|
520 |
+
0.299 * im[:, 0, :, :] + 0.587 * im[:, 1, :, :] + 0.114 * im[:, 2, :, :]
|
521 |
+
).unsqueeze(
|
522 |
+
1
|
523 |
+
) # rgb2gray
|
524 |
+
gradx = self.conv_opx(x)
|
525 |
+
grady = self.conv_opy(x)
|
526 |
+
|
527 |
+
x = (gradx**2 + grady**2) ** 0.5
|
528 |
+
x = (x - x.min()) / (x.max() - x.min())
|
529 |
+
x = F.pad(x, (1, 1, 1, 1))
|
530 |
+
|
531 |
+
x = torch.cat([im, x], dim=1)
|
532 |
+
return x
|
533 |
+
|
534 |
+
|
535 |
+
##### U^2-Net ####
|
536 |
+
class U2NET(nn.Module):
|
537 |
+
|
538 |
+
def __init__(self, in_ch=3, out_ch=1):
|
539 |
+
super(U2NET, self).__init__()
|
540 |
+
self.edge = sobel_net()
|
541 |
+
|
542 |
+
self.stage1 = RSU7(in_ch, 32, 64)
|
543 |
+
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
544 |
+
|
545 |
+
self.stage2 = RSU6(64, 32, 128)
|
546 |
+
self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
547 |
+
|
548 |
+
self.stage3 = RSU5(128, 64, 256)
|
549 |
+
self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
550 |
+
|
551 |
+
self.stage4 = RSU4(256, 128, 512)
|
552 |
+
self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
553 |
+
|
554 |
+
self.stage5 = RSU4F(512, 256, 512)
|
555 |
+
self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
556 |
+
|
557 |
+
self.stage6 = RSU4F(512, 256, 512)
|
558 |
+
|
559 |
+
# decoder
|
560 |
+
self.stage5d = RSU4F(1024, 256, 512)
|
561 |
+
self.stage4d = RSU4(1024, 128, 256)
|
562 |
+
self.stage3d = RSU5(512, 64, 128)
|
563 |
+
self.stage2d = RSU6(256, 32, 64)
|
564 |
+
self.stage1d = RSU7(128, 16, 64)
|
565 |
+
|
566 |
+
self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
|
567 |
+
self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
|
568 |
+
self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
|
569 |
+
self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
|
570 |
+
self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
|
571 |
+
self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
|
572 |
+
|
573 |
+
self.outconv = nn.Conv2d(6, out_ch, 1)
|
574 |
+
|
575 |
+
def forward(self, x):
|
576 |
+
x = self.edge(x)
|
577 |
+
hx = x
|
578 |
+
|
579 |
+
# stage 1
|
580 |
+
hx1 = self.stage1(hx)
|
581 |
+
hx = self.pool12(hx1)
|
582 |
+
|
583 |
+
# stage 2
|
584 |
+
hx2 = self.stage2(hx)
|
585 |
+
hx = self.pool23(hx2)
|
586 |
+
|
587 |
+
# stage 3
|
588 |
+
hx3 = self.stage3(hx)
|
589 |
+
hx = self.pool34(hx3)
|
590 |
+
|
591 |
+
# stage 4
|
592 |
+
hx4 = self.stage4(hx)
|
593 |
+
hx = self.pool45(hx4)
|
594 |
+
|
595 |
+
# stage 5
|
596 |
+
hx5 = self.stage5(hx)
|
597 |
+
hx = self.pool56(hx5)
|
598 |
+
|
599 |
+
# stage 6
|
600 |
+
hx6 = self.stage6(hx)
|
601 |
+
hx6up = _upsample_like(hx6, hx5)
|
602 |
+
|
603 |
+
# -------------------- decoder --------------------
|
604 |
+
hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
|
605 |
+
hx5dup = _upsample_like(hx5d, hx4)
|
606 |
+
|
607 |
+
hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
|
608 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
609 |
+
|
610 |
+
hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
|
611 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
612 |
+
|
613 |
+
hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
|
614 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
615 |
+
|
616 |
+
hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
|
617 |
+
|
618 |
+
# side output
|
619 |
+
d1 = self.side1(hx1d)
|
620 |
+
|
621 |
+
d2 = self.side2(hx2d)
|
622 |
+
d2 = _upsample_like(d2, d1)
|
623 |
+
|
624 |
+
d3 = self.side3(hx3d)
|
625 |
+
d3 = _upsample_like(d3, d1)
|
626 |
+
|
627 |
+
d4 = self.side4(hx4d)
|
628 |
+
d4 = _upsample_like(d4, d1)
|
629 |
+
|
630 |
+
d5 = self.side5(hx5d)
|
631 |
+
d5 = _upsample_like(d5, d1)
|
632 |
+
|
633 |
+
d6 = self.side6(hx6)
|
634 |
+
d6 = _upsample_like(d6, d1)
|
635 |
+
|
636 |
+
d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
|
637 |
+
|
638 |
+
return (
|
639 |
+
torch.sigmoid(d0),
|
640 |
+
torch.sigmoid(d1),
|
641 |
+
torch.sigmoid(d2),
|
642 |
+
torch.sigmoid(d3),
|
643 |
+
torch.sigmoid(d4),
|
644 |
+
torch.sigmoid(d5),
|
645 |
+
torch.sigmoid(d6),
|
646 |
+
)
|
647 |
+
|
648 |
+
|
649 |
+
### RSU-5 ###
|
650 |
+
class RSU5(nn.Module): # UNet05DRES(nn.Module):
|
651 |
+
|
652 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
653 |
+
super(RSU5, self).__init__()
|
654 |
+
|
655 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
656 |
+
|
657 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
658 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
659 |
+
|
660 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
661 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
662 |
+
|
663 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
664 |
+
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
665 |
+
|
666 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
667 |
+
|
668 |
+
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
669 |
+
|
670 |
+
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
671 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
672 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
673 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
674 |
+
|
675 |
+
def forward(self, x):
|
676 |
+
hx = x
|
677 |
+
|
678 |
+
hxin = self.rebnconvin(hx)
|
679 |
+
|
680 |
+
hx1 = self.rebnconv1(hxin)
|
681 |
+
hx = self.pool1(hx1)
|
682 |
+
|
683 |
+
hx2 = self.rebnconv2(hx)
|
684 |
+
hx = self.pool2(hx2)
|
685 |
+
|
686 |
+
hx3 = self.rebnconv3(hx)
|
687 |
+
hx = self.pool3(hx3)
|
688 |
+
|
689 |
+
hx4 = self.rebnconv4(hx)
|
690 |
+
|
691 |
+
hx5 = self.rebnconv5(hx4)
|
692 |
+
|
693 |
+
hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
|
694 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
695 |
+
|
696 |
+
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
697 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
698 |
+
|
699 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
700 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
701 |
+
|
702 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
703 |
+
|
704 |
+
return hx1d + hxin
|
705 |
+
|
706 |
+
|
707 |
+
### RSU-6 ###
|
708 |
+
class RSU6(nn.Module): # UNet06DRES(nn.Module):
|
709 |
+
|
710 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
711 |
+
super(RSU6, self).__init__()
|
712 |
+
|
713 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
714 |
+
|
715 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
716 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
717 |
+
|
718 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
719 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
720 |
+
|
721 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
722 |
+
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
723 |
+
|
724 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
725 |
+
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
726 |
+
|
727 |
+
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
728 |
+
|
729 |
+
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
730 |
+
|
731 |
+
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
732 |
+
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
733 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
734 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
735 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
736 |
+
|
737 |
+
def forward(self, x):
|
738 |
+
hx = x
|
739 |
+
|
740 |
+
hxin = self.rebnconvin(hx)
|
741 |
+
|
742 |
+
hx1 = self.rebnconv1(hxin)
|
743 |
+
hx = self.pool1(hx1)
|
744 |
+
|
745 |
+
hx2 = self.rebnconv2(hx)
|
746 |
+
hx = self.pool2(hx2)
|
747 |
+
|
748 |
+
hx3 = self.rebnconv3(hx)
|
749 |
+
hx = self.pool3(hx3)
|
750 |
+
|
751 |
+
hx4 = self.rebnconv4(hx)
|
752 |
+
hx = self.pool4(hx4)
|
753 |
+
|
754 |
+
hx5 = self.rebnconv5(hx)
|
755 |
+
|
756 |
+
hx6 = self.rebnconv6(hx5)
|
757 |
+
|
758 |
+
hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
|
759 |
+
hx5dup = _upsample_like(hx5d, hx4)
|
760 |
+
|
761 |
+
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
|
762 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
763 |
+
|
764 |
+
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
765 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
766 |
+
|
767 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
768 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
769 |
+
|
770 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
771 |
+
|
772 |
+
return hx1d + hxin
|
773 |
+
|
774 |
+
|
775 |
+
### RSU-7 ###
|
776 |
+
class RSU7(nn.Module): # UNet07DRES(nn.Module):
|
777 |
+
|
778 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
779 |
+
super(RSU7, self).__init__()
|
780 |
+
|
781 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
782 |
+
|
783 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
784 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
785 |
+
|
786 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
787 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
788 |
+
|
789 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
790 |
+
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
791 |
+
|
792 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
793 |
+
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
794 |
+
|
795 |
+
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
796 |
+
self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
797 |
+
|
798 |
+
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
799 |
+
|
800 |
+
self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
801 |
+
|
802 |
+
self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
803 |
+
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
804 |
+
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
805 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
806 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
807 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
808 |
+
|
809 |
+
def forward(self, x):
|
810 |
+
hx = x
|
811 |
+
hxin = self.rebnconvin(hx)
|
812 |
+
|
813 |
+
hx1 = self.rebnconv1(hxin)
|
814 |
+
hx = self.pool1(hx1)
|
815 |
+
|
816 |
+
hx2 = self.rebnconv2(hx)
|
817 |
+
hx = self.pool2(hx2)
|
818 |
+
|
819 |
+
hx3 = self.rebnconv3(hx)
|
820 |
+
hx = self.pool3(hx3)
|
821 |
+
|
822 |
+
hx4 = self.rebnconv4(hx)
|
823 |
+
hx = self.pool4(hx4)
|
824 |
+
|
825 |
+
hx5 = self.rebnconv5(hx)
|
826 |
+
hx = self.pool5(hx5)
|
827 |
+
|
828 |
+
hx6 = self.rebnconv6(hx)
|
829 |
+
|
830 |
+
hx7 = self.rebnconv7(hx6)
|
831 |
+
|
832 |
+
hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
|
833 |
+
hx6dup = _upsample_like(hx6d, hx5)
|
834 |
+
|
835 |
+
hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
|
836 |
+
hx5dup = _upsample_like(hx5d, hx4)
|
837 |
+
|
838 |
+
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
|
839 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
840 |
+
|
841 |
+
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
842 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
843 |
+
|
844 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
845 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
846 |
+
|
847 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
848 |
+
|
849 |
+
return hx1d + hxin
|
850 |
+
|
851 |
+
|
852 |
+
class U2NETP(nn.Module):
|
853 |
+
|
854 |
+
def __init__(self, in_ch=3, out_ch=1):
|
855 |
+
super(U2NETP, self).__init__()
|
856 |
+
|
857 |
+
self.stage1 = RSU7(in_ch, 16, 64)
|
858 |
+
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
859 |
+
|
860 |
+
self.stage2 = RSU6(64, 16, 64)
|
861 |
+
self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
862 |
+
|
863 |
+
self.stage3 = RSU5(64, 16, 64)
|
864 |
+
self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
865 |
+
|
866 |
+
self.stage4 = RSU4(64, 16, 64)
|
867 |
+
self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
868 |
+
|
869 |
+
self.stage5 = RSU4F(64, 16, 64)
|
870 |
+
self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
871 |
+
|
872 |
+
self.stage6 = RSU4F(64, 16, 64)
|
873 |
+
|
874 |
+
# decoder
|
875 |
+
self.stage5d = RSU4F(128, 16, 64)
|
876 |
+
self.stage4d = RSU4(128, 16, 64)
|
877 |
+
self.stage3d = RSU5(128, 16, 64)
|
878 |
+
self.stage2d = RSU6(128, 16, 64)
|
879 |
+
self.stage1d = RSU7(128, 16, 64)
|
880 |
+
|
881 |
+
self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
|
882 |
+
self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
|
883 |
+
self.side3 = nn.Conv2d(64, out_ch, 3, padding=1)
|
884 |
+
self.side4 = nn.Conv2d(64, out_ch, 3, padding=1)
|
885 |
+
self.side5 = nn.Conv2d(64, out_ch, 3, padding=1)
|
886 |
+
self.side6 = nn.Conv2d(64, out_ch, 3, padding=1)
|
887 |
+
|
888 |
+
self.outconv = nn.Conv2d(6, out_ch, 1)
|
889 |
+
|
890 |
+
def forward(self, x):
|
891 |
+
hx = x
|
892 |
+
|
893 |
+
# stage 1
|
894 |
+
hx1 = self.stage1(hx)
|
895 |
+
hx = self.pool12(hx1)
|
896 |
+
|
897 |
+
# stage 2
|
898 |
+
hx2 = self.stage2(hx)
|
899 |
+
hx = self.pool23(hx2)
|
900 |
+
|
901 |
+
# stage 3
|
902 |
+
hx3 = self.stage3(hx)
|
903 |
+
hx = self.pool34(hx3)
|
904 |
+
|
905 |
+
# stage 4
|
906 |
+
hx4 = self.stage4(hx)
|
907 |
+
hx = self.pool45(hx4)
|
908 |
+
|
909 |
+
# stage 5
|
910 |
+
hx5 = self.stage5(hx)
|
911 |
+
hx = self.pool56(hx5)
|
912 |
+
|
913 |
+
# stage 6
|
914 |
+
hx6 = self.stage6(hx)
|
915 |
+
hx6up = _upsample_like(hx6, hx5)
|
916 |
+
|
917 |
+
# decoder
|
918 |
+
hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
|
919 |
+
hx5dup = _upsample_like(hx5d, hx4)
|
920 |
+
|
921 |
+
hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
|
922 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
923 |
+
|
924 |
+
hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
|
925 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
926 |
+
|
927 |
+
hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
|
928 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
929 |
+
|
930 |
+
hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
|
931 |
+
|
932 |
+
# side output
|
933 |
+
d1 = self.side1(hx1d)
|
934 |
+
|
935 |
+
d2 = self.side2(hx2d)
|
936 |
+
d2 = _upsample_like(d2, d1)
|
937 |
+
|
938 |
+
d3 = self.side3(hx3d)
|
939 |
+
d3 = _upsample_like(d3, d1)
|
940 |
+
|
941 |
+
d4 = self.side4(hx4d)
|
942 |
+
d4 = _upsample_like(d4, d1)
|
943 |
+
|
944 |
+
d5 = self.side5(hx5d)
|
945 |
+
d5 = _upsample_like(d5, d1)
|
946 |
+
|
947 |
+
d6 = self.side6(hx6)
|
948 |
+
d6 = _upsample_like(d6, d1)
|
949 |
+
|
950 |
+
d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
|
951 |
+
|
952 |
+
return (
|
953 |
+
torch.sigmoid(d0),
|
954 |
+
torch.sigmoid(d1),
|
955 |
+
torch.sigmoid(d2),
|
956 |
+
torch.sigmoid(d3),
|
957 |
+
torch.sigmoid(d4),
|
958 |
+
torch.sigmoid(d5),
|
959 |
+
torch.sigmoid(d6),
|
960 |
+
)
|
models/DocTr-Plus/LICENSE.md
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# License
|
2 |
+
|
3 |
+
Copyright © Hao Feng 2024. All Rights Reserved.
|
4 |
+
|
5 |
+
## 1. Definitions
|
6 |
+
|
7 |
+
1.1 "Algorithm" refers to the deep learning algorithm contained in this repository, including all associated code, documentation, and data.
|
8 |
+
|
9 |
+
1.2 "Author" refers to Hao Feng, the creator and copyright holder of the Algorithm.
|
10 |
+
|
11 |
+
1.3 "Non-Commercial Use" means use for academic research, personal study, or non-profit projects, without any direct or indirect commercial advantage.
|
12 |
+
|
13 |
+
1.4 "Commercial Use" means any use intended for or directed toward commercial advantage or monetary compensation.
|
14 |
+
|
15 |
+
## 2. Grant of Rights
|
16 |
+
|
17 |
+
2.1 Non-Commercial Use: The Author hereby grants you a worldwide, royalty-free, non-exclusive license to use, copy, modify, and distribute the Algorithm for Non-Commercial Use, subject to the conditions in Section 3.
|
18 |
+
|
19 |
+
2.2 Commercial Use: Any Commercial Use of the Algorithm is strictly prohibited without explicit prior written permission from the Author.
|
20 |
+
|
21 |
+
## 3. Conditions
|
22 |
+
|
23 |
+
3.1 For Non-Commercial Use:
|
24 |
+
a) Attribution: You must give appropriate credit to the Author, provide a link to this license, and indicate if changes were made.
|
25 |
+
b) Share-Alike: If you modify, transform, or build upon the Algorithm, you must distribute your contributions under the same license as this one.
|
26 |
+
c) No additional restrictions: You may not apply legal terms or technological measures that legally restrict others from doing anything this license permits.
|
27 |
+
|
28 |
+
3.2 For Commercial Use:
|
29 |
+
a) Prior Contact: Before any Commercial Use, you must contact the Author at haof@mail.ustc.edu.cn and obtain explicit written permission.
|
30 |
+
b) Separate Agreement: Commercial Use terms will be stipulated in a separate commercial license agreement.
|
31 |
+
|
32 |
+
## 4. Disclaimer of Warranty
|
33 |
+
|
34 |
+
The Algorithm is provided "as is", without warranty of any kind, express or implied, including but not limited to the warranties of merchantability, fitness for a particular purpose, and non-infringement. In no event shall the Author be liable for any claim, damages, or other liability arising from, out of, or in connection with the Algorithm or the use or other dealings in the Algorithm.
|
35 |
+
|
36 |
+
## 5. Limitation of Liability
|
37 |
+
|
38 |
+
In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, shall the Author be liable to you for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this license or out of the use or inability to use the Algorithm.
|
39 |
+
|
40 |
+
## 6. Termination
|
41 |
+
|
42 |
+
6.1 This license and the rights granted hereunder will terminate automatically upon any breach by you of the terms of this license.
|
43 |
+
|
44 |
+
6.2 All sections which by their nature should survive the termination of this license shall survive such termination.
|
45 |
+
|
46 |
+
## 7. Miscellaneous
|
47 |
+
|
48 |
+
7.1 If any provision of this license is held to be unenforceable, such provision shall be reformed only to the extent necessary to make it enforceable.
|
49 |
+
|
50 |
+
7.2 This license represents the complete agreement concerning the subject matter hereof.
|
51 |
+
|
52 |
+
By using the Algorithm, you acknowledge that you have read this license, understand it, and agree to be bound by its terms and conditions. If you do not agree to the terms and conditions of this license, do not use, modify, or distribute the Algorithm.
|
53 |
+
|
54 |
+
For permissions beyond the scope of this license, please contact the Author at haof@mail.ustc.edu.cn.
|
models/DocTr-Plus/OCR_eval.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import pytesseract
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
pytesseract.get_tesseract_version()
|
8 |
+
|
9 |
+
|
10 |
+
def Levenshtein_Distance(str1, str2):
|
11 |
+
matrix = [[i + j for j in range(len(str2) + 1)] for i in range(len(str1) + 1)]
|
12 |
+
for i in range(1, len(str1) + 1):
|
13 |
+
for j in range(1, len(str2) + 1):
|
14 |
+
if str1[i - 1] == str2[j - 1]:
|
15 |
+
d = 0
|
16 |
+
else:
|
17 |
+
d = 1
|
18 |
+
matrix[i][j] = min(
|
19 |
+
matrix[i - 1][j] + 1, matrix[i][j - 1] + 1, matrix[i - 1][j - 1] + d
|
20 |
+
)
|
21 |
+
|
22 |
+
return matrix[len(str1)][len(str2)]
|
23 |
+
|
24 |
+
|
25 |
+
def cal_cer_ed(path_ours, tail="_rec"):
|
26 |
+
print(path_ours, "start")
|
27 |
+
print(f"started at {time.strftime('%H:%M:%S')}")
|
28 |
+
path_gt = "./scan/"
|
29 |
+
N = 196
|
30 |
+
cer1 = []
|
31 |
+
ed1 = []
|
32 |
+
check = [0 for _ in range(N + 1)]
|
33 |
+
# img index in UDIR test set for OCR evaluation
|
34 |
+
lis = [
|
35 |
+
2,
|
36 |
+
5,
|
37 |
+
17,
|
38 |
+
19,
|
39 |
+
20,
|
40 |
+
23,
|
41 |
+
31,
|
42 |
+
37,
|
43 |
+
38,
|
44 |
+
39,
|
45 |
+
40,
|
46 |
+
41,
|
47 |
+
43,
|
48 |
+
45,
|
49 |
+
47,
|
50 |
+
48,
|
51 |
+
51,
|
52 |
+
54,
|
53 |
+
57,
|
54 |
+
60,
|
55 |
+
61,
|
56 |
+
62,
|
57 |
+
64,
|
58 |
+
65,
|
59 |
+
67,
|
60 |
+
68,
|
61 |
+
70,
|
62 |
+
75,
|
63 |
+
76,
|
64 |
+
77,
|
65 |
+
78,
|
66 |
+
80,
|
67 |
+
81,
|
68 |
+
83,
|
69 |
+
84,
|
70 |
+
85,
|
71 |
+
87,
|
72 |
+
88,
|
73 |
+
90,
|
74 |
+
91,
|
75 |
+
93,
|
76 |
+
96,
|
77 |
+
99,
|
78 |
+
100,
|
79 |
+
101,
|
80 |
+
102,
|
81 |
+
103,
|
82 |
+
104,
|
83 |
+
105,
|
84 |
+
134,
|
85 |
+
137,
|
86 |
+
138,
|
87 |
+
140,
|
88 |
+
150,
|
89 |
+
151,
|
90 |
+
155,
|
91 |
+
158,
|
92 |
+
162,
|
93 |
+
163,
|
94 |
+
164,
|
95 |
+
165,
|
96 |
+
166,
|
97 |
+
169,
|
98 |
+
170,
|
99 |
+
172,
|
100 |
+
173,
|
101 |
+
175,
|
102 |
+
177,
|
103 |
+
178,
|
104 |
+
182,
|
105 |
+
]
|
106 |
+
for i in range(1, N):
|
107 |
+
if i not in lis:
|
108 |
+
continue
|
109 |
+
gt = Image.open(path_gt + str(i) + ".png")
|
110 |
+
img1 = Image.open(path_ours + str(i) + tail)
|
111 |
+
content_gt = pytesseract.image_to_string(gt)
|
112 |
+
content1 = pytesseract.image_to_string(img1)
|
113 |
+
l1 = Levenshtein_Distance(content_gt, content1)
|
114 |
+
ed1.append(l1)
|
115 |
+
cer1.append(l1 / len(content_gt))
|
116 |
+
check[i] = cer1[-1]
|
117 |
+
|
118 |
+
CER = np.mean(cer1)
|
119 |
+
ED = np.mean(ed1)
|
120 |
+
print(f"finished at {time.strftime('%H:%M:%S')}")
|
121 |
+
return [path_ours, CER, ED]
|
models/DocTr-Plus/README.md
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
🔥 **Good news! Our demo has already exceeded 20,000 calls!**
|
2 |
+
|
3 |
+
🔥 **Good news! Our work has been accepted by IEEE Transactions on Multimedia.**
|
4 |
+
|
5 |
+
🚀 **Exciting update! We have created a demo for our paper, showcasing the generic rectification capabilities of our method. [Check it out here!](https://doctrp.docscanner.top/)**
|
6 |
+
|
7 |
+
🔥 **Good news! Our new work exhibits state-of-the-art performances on the [DocUNet Benchmark](https://www3.cs.stonybrook.edu/~cvl/docunet.html) dataset:
|
8 |
+
[DocScanner: Robust Document Image Rectification with Progressive Learning](https://drive.google.com/file/d/1mmCUj90rHyuO1SmpLt361youh-07Y0sD/view?usp=share_link)** with [Repo](https://github.com/fh2019ustc/DocScanner).
|
9 |
+
|
10 |
+
🔥 **Good news! A comprehensive list of [Awesome Document Image Rectification](https://github.com/fh2019ustc/Awesome-Document-Image-Rectification) methods is available.**
|
11 |
+
|
12 |
+
# DocTr++
|
13 |
+
|
14 |
+
<p>
|
15 |
+
<a href='https://project.doctrp.top/' target="_blank"><img src='https://img.shields.io/badge/Project-Page-Green'></a>
|
16 |
+
<a href='https://arxiv.org/abs/2304.08796' target="_blank"><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a>
|
17 |
+
<a href='https://demo.doctrp.top/' target="_blank"><img src='https://img.shields.io/badge/Online-Demo-green'></a>
|
18 |
+
</p>
|
19 |
+
|
20 |
+
|
21 |
+
![Demo](assets/github_demo.png)
|
22 |
+
![Demo](assets/github_demo_v2.png)
|
23 |
+
> **[DocTr++: Deep Unrestricted Document Image Rectification](https://arxiv.org/abs/2304.08796)**
|
24 |
+
|
25 |
+
> DocTr++ is an enhanced version of the original [DocTr: Document Image Transformer for Geometric Unwarping and Illumination Correction](https://github.com/fh2019ustc/DocTr), aiming to rectify various distorted document images in the wild,
|
26 |
+
whether or not the document is fully present in the image.
|
27 |
+
|
28 |
+
Any questions or discussions are welcomed!
|
29 |
+
|
30 |
+
|
31 |
+
## 🚀 Demo [(Link)](https://demo.doctrp.top/)
|
32 |
+
1. Upload the distorted document image to be rectified in the left box.
|
33 |
+
2. Click the "Submit" button.
|
34 |
+
3. The rectified image will be displayed in the right box.
|
35 |
+
4. Our demo environment is based on a CPU infrastructure, and due to image transmission over the network, some display latency may be experienced.
|
36 |
+
|
37 |
+
[![Alt text](https://user-images.githubusercontent.com/50725551/232952015-15508ad6-e38c-475b-bf9e-91cb74bc5fea.png)](https://demo.doctrp.top/)
|
38 |
+
|
39 |
+
|
40 |
+
|
41 |
+
## Inference
|
42 |
+
1. Put the pretrained model to `$ROOT/model_pretrained/`.
|
43 |
+
2. Put the distorted images in `$ROOT/distorted/`.
|
44 |
+
3. Run the script and the rectified images are saved in `$ROOT/rectified/` by default.
|
45 |
+
```
|
46 |
+
python inference.py
|
47 |
+
```
|
48 |
+
|
49 |
+
## Evaluation
|
50 |
+
- ***Image Metrics:*** We propose the metrics MS-SSIM-M and LD-M, different from that for [DocUNet Benchmark](https://www3.cs.stonybrook.edu/~cvl/docunet.html) dataset. We use Matlab 2019a. Please compare the scores according to your Matlab version. We provide our Matlab interface file at ```$ROOT/ssim_ld_eval.m```.
|
51 |
+
- ***OCR Metrics:*** The index of 70 document (70 images) in [UDIR test set](https://drive.google.com/drive/folders/15rknyt7XE2k6jrxaTc_n5dzXIdCukJLh?usp=share_link) used for our OCR evaluation is provided in ```$ROOT/ocr_eval.py```.
|
52 |
+
The version of pytesseract is 0.3.8, and the version of [Tesseract](https://digi.bib.uni-mannheim.de/tesseract/) in Windows is recent 5.0.1.20220118.
|
53 |
+
Note that in different operating systems, the calculated performance has slight differences.
|
54 |
+
|
55 |
+
## Citation
|
56 |
+
|
57 |
+
If you find this code useful for your research, please use the following BibTeX entry.
|
58 |
+
|
59 |
+
```
|
60 |
+
@inproceedings{feng2021doctr,
|
61 |
+
title={DocTr: Document Image Transformer for Geometric Unwarping and Illumination Correction},
|
62 |
+
author={Feng, Hao and Wang, Yuechen and Zhou, Wengang and Deng, Jiajun and Li, Houqiang},
|
63 |
+
booktitle={Proceedings of the 29th ACM International Conference on Multimedia},
|
64 |
+
pages={273--281},
|
65 |
+
year={2021}
|
66 |
+
}
|
67 |
+
```
|
68 |
+
|
69 |
+
```
|
70 |
+
@article{feng2023doctrp,
|
71 |
+
title={Deep Unrestricted Document Image Rectification},
|
72 |
+
author={Feng, Hao and Liu, Shaokai and Deng, Jiajun and Zhou, Wengang and Li, Houqiang},
|
73 |
+
journal={IEEE Transactions on Multimedia},
|
74 |
+
year={2023}
|
75 |
+
}
|
76 |
+
```
|
77 |
+
|
78 |
+
## Contact
|
79 |
+
For commercial usage, please contact Professor Wengang Zhou ([zhwg@ustc.edu.cn](zhwg@ustc.edu.cn)) and Hao Feng ([haof@mail.ustc.edu.cn](haof@mail.ustc.edu.cn)).
|
models/DocTr-Plus/__init__.py
ADDED
File without changes
|
models/DocTr-Plus/__pycache__/GeoTr.cpython-38.pyc
ADDED
Binary file (23.5 kB). View file
|
|
models/DocTr-Plus/__pycache__/GeoTr.cpython-39.pyc
ADDED
Binary file (23.5 kB). View file
|
|
models/DocTr-Plus/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (154 Bytes). View file
|
|
models/DocTr-Plus/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (154 Bytes). View file
|
|
models/DocTr-Plus/__pycache__/extractor.cpython-38.pyc
ADDED
Binary file (3.13 kB). View file
|
|
models/DocTr-Plus/__pycache__/extractor.cpython-39.pyc
ADDED
Binary file (3.12 kB). View file
|
|
models/DocTr-Plus/__pycache__/inference.cpython-38.pyc
ADDED
Binary file (1.61 kB). View file
|
|
models/DocTr-Plus/__pycache__/inference.cpython-39.pyc
ADDED
Binary file (1.61 kB). View file
|
|
models/DocTr-Plus/__pycache__/position_encoding.cpython-38.pyc
ADDED
Binary file (4.38 kB). View file
|
|
models/DocTr-Plus/__pycache__/position_encoding.cpython-39.pyc
ADDED
Binary file (4.34 kB). View file
|
|
models/DocTr-Plus/evalUnwarp.m
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
function [ms, ld] = evalUnwarp(A, ref, ref_msk)
|
2 |
+
|
3 |
+
x = A;
|
4 |
+
y = ref;
|
5 |
+
z = ref_msk;
|
6 |
+
|
7 |
+
im1=imresize(imfilter(x,fspecial('gaussian',7,1.),'same','replicate'),0.5,'bicubic');
|
8 |
+
im2=imresize(imfilter(y,fspecial('gaussian',7,1.),'same','replicate'),0.5,'bicubic');
|
9 |
+
im3=imresize(imfilter(z,fspecial('gaussian',7,1.),'same','replicate'),0.5,'bicubic');
|
10 |
+
|
11 |
+
im1=im2double(im1);
|
12 |
+
im2=im2double(im2);
|
13 |
+
im3=im2double(im3);
|
14 |
+
|
15 |
+
cellsize=3;
|
16 |
+
gridspacing=1;
|
17 |
+
|
18 |
+
sift1 = mexDenseSIFT(im1,cellsize,gridspacing);
|
19 |
+
sift2 = mexDenseSIFT(im2,cellsize,gridspacing);
|
20 |
+
|
21 |
+
SIFTflowpara.alpha=2*255;
|
22 |
+
SIFTflowpara.d=40*255;
|
23 |
+
SIFTflowpara.gamma=0.005*255;
|
24 |
+
SIFTflowpara.nlevels=4;
|
25 |
+
SIFTflowpara.wsize=2;
|
26 |
+
SIFTflowpara.topwsize=10;
|
27 |
+
SIFTflowpara.nTopIterations = 60;
|
28 |
+
SIFTflowpara.nIterations= 30;
|
29 |
+
|
30 |
+
|
31 |
+
[vx,vy,~]=SIFTflowc2f(sift1,sift2,SIFTflowpara);
|
32 |
+
|
33 |
+
d = sqrt(vx.^2 + vy.^2);
|
34 |
+
mskk = (im3==0);
|
35 |
+
ld = mean(d(~mskk));
|
36 |
+
|
37 |
+
wt = [0.0448 0.2856 0.3001 0.2363 0.1333];
|
38 |
+
ss = zeros(5, 1);
|
39 |
+
for s = 1 : 5
|
40 |
+
ss(s) = ssim(x, z);
|
41 |
+
x = impyramid(x, 'reduce');
|
42 |
+
z = impyramid(z, 'reduce');
|
43 |
+
end
|
44 |
+
ms = wt * ss;
|
45 |
+
|
46 |
+
end
|
models/DocTr-Plus/extractor.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class ResidualBlock(nn.Module):
|
7 |
+
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
|
8 |
+
super(ResidualBlock, self).__init__()
|
9 |
+
|
10 |
+
self.conv1 = nn.Conv2d(
|
11 |
+
in_planes, planes, kernel_size=3, padding=1, stride=stride
|
12 |
+
)
|
13 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
|
14 |
+
self.relu = nn.ReLU(inplace=True)
|
15 |
+
|
16 |
+
num_groups = planes // 8
|
17 |
+
|
18 |
+
if norm_fn == "group":
|
19 |
+
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
20 |
+
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
21 |
+
if not stride == 1:
|
22 |
+
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
23 |
+
|
24 |
+
elif norm_fn == "batch":
|
25 |
+
self.norm1 = nn.BatchNorm2d(planes)
|
26 |
+
self.norm2 = nn.BatchNorm2d(planes)
|
27 |
+
if not stride == 1:
|
28 |
+
self.norm3 = nn.BatchNorm2d(planes)
|
29 |
+
|
30 |
+
elif norm_fn == "instance":
|
31 |
+
self.norm1 = nn.InstanceNorm2d(planes)
|
32 |
+
self.norm2 = nn.InstanceNorm2d(planes)
|
33 |
+
if not stride == 1:
|
34 |
+
self.norm3 = nn.InstanceNorm2d(planes)
|
35 |
+
|
36 |
+
elif norm_fn == "none":
|
37 |
+
self.norm1 = nn.Sequential()
|
38 |
+
self.norm2 = nn.Sequential()
|
39 |
+
if not stride == 1:
|
40 |
+
self.norm3 = nn.Sequential()
|
41 |
+
|
42 |
+
if stride == 1:
|
43 |
+
self.downsample = None
|
44 |
+
|
45 |
+
else:
|
46 |
+
self.downsample = nn.Sequential(
|
47 |
+
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
|
48 |
+
)
|
49 |
+
|
50 |
+
def forward(self, x):
|
51 |
+
y = x
|
52 |
+
y = self.relu(self.norm1(self.conv1(y)))
|
53 |
+
y = self.relu(self.norm2(self.conv2(y)))
|
54 |
+
|
55 |
+
if self.downsample is not None:
|
56 |
+
x = self.downsample(x)
|
57 |
+
|
58 |
+
return self.relu(x + y)
|
59 |
+
|
60 |
+
|
61 |
+
class BasicEncoder(nn.Module):
|
62 |
+
def __init__(self, output_dim=128, norm_fn="batch"):
|
63 |
+
super(BasicEncoder, self).__init__()
|
64 |
+
self.norm_fn = norm_fn
|
65 |
+
|
66 |
+
if self.norm_fn == "group":
|
67 |
+
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
|
68 |
+
|
69 |
+
elif self.norm_fn == "batch":
|
70 |
+
self.norm1 = nn.BatchNorm2d(64)
|
71 |
+
|
72 |
+
elif self.norm_fn == "instance":
|
73 |
+
self.norm1 = nn.InstanceNorm2d(64)
|
74 |
+
|
75 |
+
elif self.norm_fn == "none":
|
76 |
+
self.norm1 = nn.Sequential()
|
77 |
+
|
78 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
|
79 |
+
self.relu1 = nn.ReLU(inplace=True)
|
80 |
+
|
81 |
+
self.in_planes = 64
|
82 |
+
self.layer1 = self._make_layer(64, stride=1)
|
83 |
+
self.layer2 = self._make_layer(128, stride=2)
|
84 |
+
self.layer3 = self._make_layer(192, stride=2)
|
85 |
+
|
86 |
+
# output convolution
|
87 |
+
self.conv2 = nn.Conv2d(192, output_dim, kernel_size=1)
|
88 |
+
|
89 |
+
for m in self.modules():
|
90 |
+
if isinstance(m, nn.Conv2d):
|
91 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
92 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
93 |
+
if m.weight is not None:
|
94 |
+
nn.init.constant_(m.weight, 1)
|
95 |
+
if m.bias is not None:
|
96 |
+
nn.init.constant_(m.bias, 0)
|
97 |
+
|
98 |
+
def _make_layer(self, dim, stride=1):
|
99 |
+
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
100 |
+
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
101 |
+
layers = (layer1, layer2)
|
102 |
+
|
103 |
+
self.in_planes = dim
|
104 |
+
return nn.Sequential(*layers)
|
105 |
+
|
106 |
+
def forward(self, x):
|
107 |
+
x = self.conv1(x)
|
108 |
+
x = self.norm1(x)
|
109 |
+
x = self.relu1(x)
|
110 |
+
|
111 |
+
x = self.layer1(x)
|
112 |
+
x = self.layer2(x)
|
113 |
+
x = self.layer3(x)
|
114 |
+
|
115 |
+
x = self.conv2(x)
|
116 |
+
|
117 |
+
return x
|
models/DocTr-Plus/inference.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
import glob
|
8 |
+
import os
|
9 |
+
import warnings
|
10 |
+
|
11 |
+
import cv2
|
12 |
+
import numpy as np
|
13 |
+
import skimage.io as io
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
import torch.nn.functional as F
|
17 |
+
from PIL import Image
|
18 |
+
|
19 |
+
from .GeoTr import U2NETP, GeoTr
|
20 |
+
|
21 |
+
warnings.filterwarnings("ignore")
|
22 |
+
|
23 |
+
|
24 |
+
class GeoTrP(nn.Module):
|
25 |
+
def __init__(self):
|
26 |
+
super(GeoTrP, self).__init__()
|
27 |
+
self.GeoTr = GeoTr()
|
28 |
+
|
29 |
+
def forward(self, x):
|
30 |
+
bm = self.GeoTr(x) # [0]
|
31 |
+
bm = 2 * (bm / 288) - 1
|
32 |
+
|
33 |
+
bm = (bm + 1) / 2 * 2560
|
34 |
+
|
35 |
+
bm = F.interpolate(bm, size=(2560, 2560), mode="bilinear", align_corners=True)
|
36 |
+
|
37 |
+
return bm
|
38 |
+
|
39 |
+
|
40 |
+
def reload_model(model, path=""):
|
41 |
+
if not bool(path):
|
42 |
+
return model
|
43 |
+
else:
|
44 |
+
model_dict = model.state_dict()
|
45 |
+
pretrained_dict = torch.load(path, map_location="cuda:0")
|
46 |
+
print(len(pretrained_dict.keys()))
|
47 |
+
print(len(pretrained_dict.keys()))
|
48 |
+
model_dict.update(pretrained_dict)
|
49 |
+
model.load_state_dict(model_dict)
|
50 |
+
|
51 |
+
return model
|
models/DocTr-Plus/position_encoding.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
"""
|
3 |
+
Various positional encodings for the transformer.
|
4 |
+
"""
|
5 |
+
import math
|
6 |
+
from typing import List, Optional
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from torch import Tensor, nn
|
10 |
+
|
11 |
+
|
12 |
+
class NestedTensor(object):
|
13 |
+
def __init__(self, tensors, mask: Optional[Tensor]):
|
14 |
+
self.tensors = tensors
|
15 |
+
self.mask = mask
|
16 |
+
|
17 |
+
def to(self, device):
|
18 |
+
# type: (Device) -> NestedTensor # noqa
|
19 |
+
cast_tensor = self.tensors.to(device)
|
20 |
+
mask = self.mask
|
21 |
+
if mask is not None:
|
22 |
+
assert mask is not None
|
23 |
+
cast_mask = mask.to(device)
|
24 |
+
else:
|
25 |
+
cast_mask = None
|
26 |
+
return NestedTensor(cast_tensor, cast_mask)
|
27 |
+
|
28 |
+
def decompose(self):
|
29 |
+
return self.tensors, self.mask
|
30 |
+
|
31 |
+
def __repr__(self):
|
32 |
+
return str(self.tensors)
|
33 |
+
|
34 |
+
|
35 |
+
class PositionEmbeddingSine(nn.Module):
|
36 |
+
"""
|
37 |
+
This is a more standard version of the position embedding, very similar to the one
|
38 |
+
used by the Attention is all you need paper, generalized to work on images.
|
39 |
+
"""
|
40 |
+
|
41 |
+
def __init__(
|
42 |
+
self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
|
43 |
+
):
|
44 |
+
super().__init__()
|
45 |
+
self.num_pos_feats = num_pos_feats
|
46 |
+
self.temperature = temperature
|
47 |
+
self.normalize = normalize
|
48 |
+
if scale is not None and normalize is False:
|
49 |
+
raise ValueError("normalize should be True if scale is passed")
|
50 |
+
if scale is None:
|
51 |
+
scale = 2 * math.pi
|
52 |
+
self.scale = scale
|
53 |
+
|
54 |
+
def forward(self, mask):
|
55 |
+
assert mask is not None
|
56 |
+
y_embed = mask.cumsum(1, dtype=torch.float32)
|
57 |
+
x_embed = mask.cumsum(2, dtype=torch.float32)
|
58 |
+
if self.normalize:
|
59 |
+
eps = 1e-6
|
60 |
+
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
61 |
+
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
62 |
+
|
63 |
+
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32).cuda()
|
64 |
+
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
65 |
+
|
66 |
+
pos_x = x_embed[:, :, :, None] / dim_t
|
67 |
+
pos_y = y_embed[:, :, :, None] / dim_t
|
68 |
+
pos_x = torch.stack(
|
69 |
+
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
|
70 |
+
).flatten(3)
|
71 |
+
pos_y = torch.stack(
|
72 |
+
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
|
73 |
+
).flatten(3)
|
74 |
+
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
75 |
+
# print(pos.shape)
|
76 |
+
return pos
|
77 |
+
|
78 |
+
|
79 |
+
class PositionEmbeddingLearned(nn.Module):
|
80 |
+
"""
|
81 |
+
Absolute pos embedding, learned.
|
82 |
+
"""
|
83 |
+
|
84 |
+
def __init__(self, num_pos_feats=256):
|
85 |
+
super().__init__()
|
86 |
+
self.row_embed = nn.Embedding(50, num_pos_feats)
|
87 |
+
self.col_embed = nn.Embedding(50, num_pos_feats)
|
88 |
+
self.reset_parameters()
|
89 |
+
|
90 |
+
def reset_parameters(self):
|
91 |
+
nn.init.uniform_(self.row_embed.weight)
|
92 |
+
nn.init.uniform_(self.col_embed.weight)
|
93 |
+
|
94 |
+
def forward(self, tensor_list: NestedTensor):
|
95 |
+
x = tensor_list.tensors
|
96 |
+
h, w = x.shape[-2:]
|
97 |
+
i = torch.arange(w, device=x.device)
|
98 |
+
j = torch.arange(h, device=x.device)
|
99 |
+
x_emb = self.col_embed(i)
|
100 |
+
y_emb = self.row_embed(j)
|
101 |
+
pos = (
|
102 |
+
torch.cat(
|
103 |
+
[
|
104 |
+
x_emb.unsqueeze(0).repeat(h, 1, 1),
|
105 |
+
y_emb.unsqueeze(1).repeat(1, w, 1),
|
106 |
+
],
|
107 |
+
dim=-1,
|
108 |
+
)
|
109 |
+
.permute(2, 0, 1)
|
110 |
+
.unsqueeze(0)
|
111 |
+
.repeat(x.shape[0], 1, 1, 1)
|
112 |
+
)
|
113 |
+
return pos
|
114 |
+
|
115 |
+
|
116 |
+
def build_position_encoding(hidden_dim=512, position_embedding="sine"):
|
117 |
+
N_steps = hidden_dim // 2
|
118 |
+
if position_embedding in ("v2", "sine"):
|
119 |
+
position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
|
120 |
+
elif position_embedding in ("v3", "learned"):
|
121 |
+
position_embedding = PositionEmbeddingLearned(N_steps)
|
122 |
+
else:
|
123 |
+
raise ValueError(f"not supported {position_embedding}")
|
124 |
+
|
125 |
+
return position_embedding
|
models/DocTr-Plus/pyimagesearch/__init__.py
ADDED
File without changes
|
models/DocTr-Plus/pyimagesearch/transform.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
def order_points(pts):
|
6 |
+
# initialize a list of coordinates that will be ordered
|
7 |
+
# such that the first entry in the list is the top-left,
|
8 |
+
# the second entry is the top-right, the third is the
|
9 |
+
# bottom-right, and the fourth is the bottom-left
|
10 |
+
rect = np.zeros((4, 2), dtype="float32")
|
11 |
+
|
12 |
+
# the top-left point will have the smallest sum, whereas
|
13 |
+
# the bottom-right point will have the largest sum
|
14 |
+
s = pts.sum(axis=1)
|
15 |
+
rect[0] = pts[np.argmin(s)]
|
16 |
+
rect[2] = pts[np.argmax(s)]
|
17 |
+
|
18 |
+
# now, compute the difference between the points, the
|
19 |
+
# top-right point will have the smallest difference,
|
20 |
+
# whereas the bottom-left will have the largest difference
|
21 |
+
diff = np.diff(pts, axis=1)
|
22 |
+
rect[1] = pts[np.argmin(diff)]
|
23 |
+
rect[3] = pts[np.argmax(diff)]
|
24 |
+
|
25 |
+
# return the ordered coordinates
|
26 |
+
return rect
|
27 |
+
|
28 |
+
|
29 |
+
def four_point_transform(image, pts):
|
30 |
+
# obtain a consistent order of the points and unpack them
|
31 |
+
# individually
|
32 |
+
rect = order_points(pts)
|
33 |
+
(tl, tr, br, bl) = rect
|
34 |
+
|
35 |
+
# compute the width of the new image, which will be the
|
36 |
+
# maximum distance between bottom-right and bottom-left
|
37 |
+
# x-coordiates or the top-right and top-left x-coordinates
|
38 |
+
widthA = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2))
|
39 |
+
widthB = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2))
|
40 |
+
maxWidth = max(int(widthA), int(widthB))
|
41 |
+
|
42 |
+
# compute the height of the new image, which will be the
|
43 |
+
# maximum distance between the top-right and bottom-right
|
44 |
+
# y-coordinates or the top-left and bottom-left y-coordinates
|
45 |
+
heightA = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2))
|
46 |
+
heightB = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2))
|
47 |
+
maxHeight = max(int(heightA), int(heightB))
|
48 |
+
|
49 |
+
# now that we have the dimensions of the new image, construct
|
50 |
+
# the set of destination points to obtain a "birds eye view",
|
51 |
+
# (i.e. top-down view) of the image, again specifying points
|
52 |
+
# in the top-left, top-right, bottom-right, and bottom-left
|
53 |
+
# order
|
54 |
+
dst = np.array(
|
55 |
+
[[0, 0], [maxWidth - 1, 0], [maxWidth - 1, maxHeight - 1], [0, maxHeight - 1]],
|
56 |
+
dtype="float32",
|
57 |
+
)
|
58 |
+
|
59 |
+
# compute the perspective transform matrix and then apply it
|
60 |
+
M = cv2.getPerspectiveTransform(rect, dst)
|
61 |
+
warped = cv2.warpPerspective(image, M, (maxWidth, maxHeight))
|
62 |
+
|
63 |
+
# return the warped image
|
64 |
+
return warped
|
models/DocTr-Plus/requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy==1.19.0
|
2 |
+
opencv_python==4.2.0.34
|
3 |
+
Pillow==9.4.0
|
4 |
+
scikit_image==0.17.2
|
5 |
+
skimage==0.0
|
6 |
+
thop==0.1.1.post2209072238
|
7 |
+
torch==1.5.1+cu101
|
models/DocTr-Plus/ssimm_ldm_eval.m
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
path_rec = 'xxx'; % rectified image path
|
2 |
+
path_scan = './UDIR/gt/'; % scan image path
|
3 |
+
|
4 |
+
tarea=598400;
|
5 |
+
ms=0;
|
6 |
+
ld=0;
|
7 |
+
|
8 |
+
for i=1:195
|
9 |
+
path_rec_1 = sprintf("%s%d%s", path_rec, i, '.png'); % rectified image path
|
10 |
+
path_scan_new = sprintf("%s%d%s", path_scan, i, '.png'); % corresponding scan image path
|
11 |
+
|
12 |
+
% imread and rgb2gray
|
13 |
+
A1 = imread(path_rec_1);
|
14 |
+
ref = imread(path_scan_new);
|
15 |
+
A1 = rgb2gray(A1);
|
16 |
+
ref = rgb2gray(ref);
|
17 |
+
|
18 |
+
% resize
|
19 |
+
b = sqrt(tarea/size(ref,1)/size(ref,2));
|
20 |
+
ref = imresize(ref,b);
|
21 |
+
ref_msk = ref;
|
22 |
+
A1 = imresize(A1,[size(ref,1),size(ref,2)]);
|
23 |
+
|
24 |
+
% mask the gt image
|
25 |
+
m1 = A1 == 0;
|
26 |
+
ref_msk(m1) = 0;
|
27 |
+
|
28 |
+
% calculate
|
29 |
+
[ms_1,ld_1] = evalUnwarp(A1, ref, ref_msk);
|
30 |
+
ms = ms + ms_1;
|
31 |
+
ld = ld + ld_1;
|
32 |
+
|
33 |
+
end
|
34 |
+
|
35 |
+
ms_m = ms / 195
|
36 |
+
ld_m = ld / 195
|
models/Document-Image-Unwarping-pytorch
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Subproject commit 92b29172b981d132f7b31e767505524f8cc7af7a
|