HERIUN commited on
Commit
591ba45
1 Parent(s): 6a07cb2

add models

Browse files
Files changed (49) hide show
  1. models/DocScanner/LICENSE.md +54 -0
  2. models/DocScanner/OCR_eval.py +78 -0
  3. models/DocScanner/README.md +96 -0
  4. models/DocScanner/__init__.py +0 -0
  5. models/DocScanner/__pycache__/__init__.cpython-38.pyc +0 -0
  6. models/DocScanner/__pycache__/__init__.cpython-39.pyc +0 -0
  7. models/DocScanner/__pycache__/extractor.cpython-38.pyc +0 -0
  8. models/DocScanner/__pycache__/extractor.cpython-39.pyc +0 -0
  9. models/DocScanner/__pycache__/inference.cpython-38.pyc +0 -0
  10. models/DocScanner/__pycache__/inference.cpython-39.pyc +0 -0
  11. models/DocScanner/__pycache__/model.cpython-38.pyc +0 -0
  12. models/DocScanner/__pycache__/model.cpython-39.pyc +0 -0
  13. models/DocScanner/__pycache__/seg.cpython-38.pyc +0 -0
  14. models/DocScanner/__pycache__/seg.cpython-39.pyc +0 -0
  15. models/DocScanner/__pycache__/update.cpython-38.pyc +0 -0
  16. models/DocScanner/__pycache__/update.cpython-39.pyc +0 -0
  17. models/DocScanner/eval.m +64 -0
  18. models/DocScanner/evalUnwarp.m +102 -0
  19. models/DocScanner/extractor.py +140 -0
  20. models/DocScanner/inference.py +65 -0
  21. models/DocScanner/model.py +104 -0
  22. models/DocScanner/ocr_img.txt +62 -0
  23. models/DocScanner/requirements.txt +6 -0
  24. models/DocScanner/seg.py +576 -0
  25. models/DocScanner/update.py +119 -0
  26. models/DocTr-Plus/GeoTr.py +960 -0
  27. models/DocTr-Plus/LICENSE.md +54 -0
  28. models/DocTr-Plus/OCR_eval.py +121 -0
  29. models/DocTr-Plus/README.md +79 -0
  30. models/DocTr-Plus/__init__.py +0 -0
  31. models/DocTr-Plus/__pycache__/GeoTr.cpython-38.pyc +0 -0
  32. models/DocTr-Plus/__pycache__/GeoTr.cpython-39.pyc +0 -0
  33. models/DocTr-Plus/__pycache__/__init__.cpython-38.pyc +0 -0
  34. models/DocTr-Plus/__pycache__/__init__.cpython-39.pyc +0 -0
  35. models/DocTr-Plus/__pycache__/extractor.cpython-38.pyc +0 -0
  36. models/DocTr-Plus/__pycache__/extractor.cpython-39.pyc +0 -0
  37. models/DocTr-Plus/__pycache__/inference.cpython-38.pyc +0 -0
  38. models/DocTr-Plus/__pycache__/inference.cpython-39.pyc +0 -0
  39. models/DocTr-Plus/__pycache__/position_encoding.cpython-38.pyc +0 -0
  40. models/DocTr-Plus/__pycache__/position_encoding.cpython-39.pyc +0 -0
  41. models/DocTr-Plus/evalUnwarp.m +46 -0
  42. models/DocTr-Plus/extractor.py +117 -0
  43. models/DocTr-Plus/inference.py +51 -0
  44. models/DocTr-Plus/position_encoding.py +125 -0
  45. models/DocTr-Plus/pyimagesearch/__init__.py +0 -0
  46. models/DocTr-Plus/pyimagesearch/transform.py +64 -0
  47. models/DocTr-Plus/requirements.txt +7 -0
  48. models/DocTr-Plus/ssimm_ldm_eval.m +36 -0
  49. models/Document-Image-Unwarping-pytorch +1 -0
models/DocScanner/LICENSE.md ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # License
2
+
3
+ Copyright © Hao Feng 2024. All Rights Reserved.
4
+
5
+ ## 1. Definitions
6
+
7
+ 1.1 "Algorithm" refers to the deep learning algorithm contained in this repository, including all associated code, documentation, and data.
8
+
9
+ 1.2 "Author" refers to Hao Feng, the creator and copyright holder of the Algorithm.
10
+
11
+ 1.3 "Non-Commercial Use" means use for academic research, personal study, or non-profit projects, without any direct or indirect commercial advantage.
12
+
13
+ 1.4 "Commercial Use" means any use intended for or directed toward commercial advantage or monetary compensation.
14
+
15
+ ## 2. Grant of Rights
16
+
17
+ 2.1 Non-Commercial Use: The Author hereby grants you a worldwide, royalty-free, non-exclusive license to use, copy, modify, and distribute the Algorithm for Non-Commercial Use, subject to the conditions in Section 3.
18
+
19
+ 2.2 Commercial Use: Any Commercial Use of the Algorithm is strictly prohibited without explicit prior written permission from the Author.
20
+
21
+ ## 3. Conditions
22
+
23
+ 3.1 For Non-Commercial Use:
24
+ a) Attribution: You must give appropriate credit to the Author, provide a link to this license, and indicate if changes were made.
25
+ b) Share-Alike: If you modify, transform, or build upon the Algorithm, you must distribute your contributions under the same license as this one.
26
+ c) No additional restrictions: You may not apply legal terms or technological measures that legally restrict others from doing anything this license permits.
27
+
28
+ 3.2 For Commercial Use:
29
+ a) Prior Contact: Before any Commercial Use, you must contact the Author at haof@mail.ustc.edu.cn and obtain explicit written permission.
30
+ b) Separate Agreement: Commercial Use terms will be stipulated in a separate commercial license agreement.
31
+
32
+ ## 4. Disclaimer of Warranty
33
+
34
+ The Algorithm is provided "as is", without warranty of any kind, express or implied, including but not limited to the warranties of merchantability, fitness for a particular purpose, and non-infringement. In no event shall the Author be liable for any claim, damages, or other liability arising from, out of, or in connection with the Algorithm or the use or other dealings in the Algorithm.
35
+
36
+ ## 5. Limitation of Liability
37
+
38
+ In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, shall the Author be liable to you for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this license or out of the use or inability to use the Algorithm.
39
+
40
+ ## 6. Termination
41
+
42
+ 6.1 This license and the rights granted hereunder will terminate automatically upon any breach by you of the terms of this license.
43
+
44
+ 6.2 All sections which by their nature should survive the termination of this license shall survive such termination.
45
+
46
+ ## 7. Miscellaneous
47
+
48
+ 7.1 If any provision of this license is held to be unenforceable, such provision shall be reformed only to the extent necessary to make it enforceable.
49
+
50
+ 7.2 This license represents the complete agreement concerning the subject matter hereof.
51
+
52
+ By using the Algorithm, you acknowledge that you have read this license, understand it, and agree to be bound by its terms and conditions. If you do not agree to the terms and conditions of this license, do not use, modify, or distribute the Algorithm.
53
+
54
+ For permissions beyond the scope of this license, please contact the Author at haof@mail.ustc.edu.cn.
models/DocScanner/OCR_eval.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def Levenshtein_Distance(str1, str2):
2
+ matrix = [[i + j for j in range(len(str2) + 1)] for i in range(len(str1) + 1)]
3
+ for i in range(1, len(str1) + 1):
4
+ for j in range(1, len(str2) + 1):
5
+ if str1[i - 1] == str2[j - 1]:
6
+ d = 0
7
+ else:
8
+ d = 1
9
+ matrix[i][j] = min(
10
+ matrix[i - 1][j] + 1, matrix[i][j - 1] + 1, matrix[i - 1][j - 1] + d
11
+ )
12
+
13
+ return matrix[len(str1)][len(str2)]
14
+
15
+
16
+ def cal_cer_ed(path_ours, tail="_rec"):
17
+ path_gt = "./GT/"
18
+ N = 66
19
+ cer1 = []
20
+ cer2 = []
21
+ ed1 = []
22
+ ed2 = []
23
+ check = [0 for _ in range(N + 1)]
24
+ lis = [
25
+ 1,
26
+ 2,
27
+ 3,
28
+ 4,
29
+ 5,
30
+ 6,
31
+ 7,
32
+ 9,
33
+ 10,
34
+ 21,
35
+ 22,
36
+ 23,
37
+ 24,
38
+ 27,
39
+ 30,
40
+ 31,
41
+ 32,
42
+ 36,
43
+ 38,
44
+ 40,
45
+ 41,
46
+ 44,
47
+ 45,
48
+ 46,
49
+ 47,
50
+ 48,
51
+ 50,
52
+ 51,
53
+ 52,
54
+ 53,
55
+ ] # DocTr (Setting 1)
56
+ # lis=[1,9,10,12,19,20,21,22,23,24,30,31,32,34,35,36,37,38,39,40,44,45,46,47,49] # DewarpNet (Setting 2)
57
+ for i in range(1, N):
58
+ if i not in lis:
59
+ continue
60
+ gt = Image.open(path_gt + str(i) + ".png")
61
+ img1 = Image.open(path_ours + str(i) + "_1" + tail)
62
+ img2 = Image.open(path_ours + str(i) + "_2" + tail)
63
+ content_gt = pytesseract.image_to_string(gt)
64
+ content1 = pytesseract.image_to_string(img1)
65
+ content2 = pytesseract.image_to_string(img2)
66
+ l1 = Levenshtein_Distance(content_gt, content1)
67
+ l2 = Levenshtein_Distance(content_gt, content2)
68
+ ed1.append(l1)
69
+ ed2.append(l2)
70
+ cer1.append(l1 / len(content_gt))
71
+ cer2.append(l2 / len(content_gt))
72
+ check[i] = cer1[-1]
73
+ print("CER: ", (np.mean(cer1) + np.mean(cer2)) / 2.0)
74
+ print("ED: ", (np.mean(ed1) + np.mean(ed2)) / 2.0)
75
+
76
+
77
+ def evalu(path_ours, tail):
78
+ cal_cer_ed(path_ours, tail)
models/DocScanner/README.md ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 🔥 ***2024.4.28:*** **Good news! The code and pre-trained model of DocScanner are now released!**
2
+
3
+ 🚀 **Good news! The [online demo](https://docai.doctrp.top:20443/) for DocScanner is now live, allowing for easy image upload and correction.**
4
+
5
+ 🔥 **Good news! Our new work [DocTr++: Deep Unrestricted Document Image Rectification](https://github.com/fh2019ustc/DocTr-Plus) comes out, capable of rectifying various distorted document images in the wild.**
6
+
7
+ 🔥 **Good news! A comprehensive list of [Awesome Document Image Rectification](https://github.com/fh2019ustc/Awesome-Document-Image-Rectification) methods is available.**
8
+
9
+ # DocScanner
10
+
11
+ <p>
12
+ <a href='https://drive.google.com/file/d/1mmCUj90rHyuO1SmpLt361youh-07Y0sD/view?usp=share_link' target="_blank"><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a>
13
+ <a href='https://docai.doctrp.top:20443/' target="_blank"><img src='https://img.shields.io/badge/Online-Demo-green'></a>
14
+ </p>
15
+
16
+
17
+ This is a PyTorch/GPU re-implementation of the paper [DocScanner: Robust Document Image Rectification with Progressive Learning](https://drive.google.com/file/d/1mmCUj90rHyuO1SmpLt361youh-07Y0sD/view?usp=share_link).
18
+
19
+ ![image](https://user-images.githubusercontent.com/50725551/209266364-aee68a88-090d-4f21-919a-092f19570d86.png)
20
+
21
+
22
+ ## 🚀 Demo [(Link)](https://docai.doctrp.top:20443/)
23
+ ***Note***:The model version used in the demo corresponds to ***"DocScanner-L"*** as described in the paper.
24
+ 1. Upload the distorted document image to be rectified in the left box.
25
+ 2. Click the "Submit" button.
26
+ 3. The rectified image will be displayed in the right box.
27
+
28
+ <img width="1534" alt="image" src="https://github.com/fh2019ustc/DocScanner/assets/50725551/9eca3f7d-1570-4246-a3db-0a1cf1eece2d">
29
+
30
+ ### Examples
31
+ ![image](https://user-images.githubusercontent.com/50725551/223947040-eac8389c-bed8-433d-b23b-679c926fba8f.png)
32
+ ![image](https://user-images.githubusercontent.com/50725551/223946953-3a46d6a3-4361-41ef-bb5c-f235392e1f88.png)
33
+
34
+
35
+ ## Training
36
+ - We train the **Document Localization Module** using the [Doc3D](https://github.com/fh2019ustc/doc3D-dataset) dataset. Besides, [DTD](https://www.robots.ox.ac.uk/~vgg/data/dtd/) dataset is exploited for background data enhancement.
37
+ - We train the **Progressive Rectification Module** using the [Doc3D](https://github.com/fh2019ustc/doc3D-dataset) dataset. Here we use the background-excluded document images for training.
38
+
39
+ ## Inference
40
+ 1. Put the [pre-trained DocScanner-L](https://drive.google.com/drive/folders/1W1_DJU8dfEh6FqDYqFQ7ypR38Z8c5r4D?usp=sharing) to `$ROOT/model_pretrained/`.
41
+ 2. Put the distorted images in `$ROOT/distorted/`.
42
+ 3. Run the script and the rectified images are saved in `$ROOT/rectified/` by default.
43
+ ```
44
+ python inference.py
45
+ ```
46
+
47
+ ## Evaluation
48
+ - ***Important.*** In the [DocUNet Benchmark](https://www3.cs.stonybrook.edu/~cvl/docunet.html), the '64_1.png' and '64_2.png' distorted images are rotated by 180 degrees, which do not match the GT documents. It is ignored by most of the existing works. Before the evaluation, please make a check. Note that the performances in most of the existing work are computed with these two ***mistaken*** samples.
49
+ - For reproducing the following quantitative performance on the ***corrected*** [DocUNet Benchmark](https://www3.cs.stonybrook.edu/~cvl/docunet.html), please use the geometric rectified images available from [Google Drive](https://drive.google.com/drive/folders/1QBe26xJwIl38sWqK2ZE9ke5nu0Mpr4dW?usp=sharing). For the ***corrected*** performance of [other methods](https://github.com/fh2019ustc/Awesome-Document-Image-Rectification), please refer to the paper [DocScanner](https://arxiv.org/pdf/2110.14968v2.pdf).
50
+ - ***Image Metrics:*** We use the same evaluation code for MS-SSIM and LD as [DocUNet Benchmark](https://www3.cs.stonybrook.edu/~cvl/docunet.html) dataset based on Matlab 2019a. Please compare the scores according to your Matlab version. We provide our Matlab interface file at ```$ROOT/ssim_ld_eval.m```.
51
+ - ***OCR Metrics:*** The index of 30 documents (60 images) of [DocUNet Benchmark](https://www3.cs.stonybrook.edu/~cvl/docunet.html) used for our OCR evaluation is ```$ROOT/ocr_img.txt``` (*Setting 1*). Please refer to [DewarpNet](https://github.com/cvlab-stonybrook/DewarpNet) for the index of 25 documents (50 images) of [DocUNet Benchmark](https://www3.cs.stonybrook.edu/~cvl/docunet.html) used for their OCR evaluation (*Setting 2*). We provide the OCR evaluation code at ```$ROOT/OCR_eval.py```. The version of pytesseract is 0.3.8, and the version of [Tesseract](https://digi.bib.uni-mannheim.de/tesseract/) in Windows is recent 5.0.1.20220118. Note that in different operating systems, the calculated performance has slight differences.
52
+ - ***W_v and W_h Index:*** The layout results of [DocUNet Benchmark](https://www3.cs.stonybrook.edu/~cvl/docunet.html) is available at [Google Drive](https://drive.google.com/drive/folders/1PcfWIowjM0AVKhZrRwGChM-2VAcUwWrF?usp=sharing).
53
+
54
+ | Method | MS-SSIM | LD | Li-D | ED (*Setting 1*) | CER | ED (*Setting 2*) | CER | Para. (M) |
55
+ |:-----------------------:|:------------:|:-----------:| :-------:|:----------------:|:--------------:|:---------------------:|:--------------:|:--------------:|
56
+ | *DocScanner-T* | 0.5123 | 7.92 | 2.04 | 501.82 | 0.1823 | 809.46 | 0.2068 | 2.6 |
57
+ | *DocScanner-B* | 0.5134 | 7.62 | 1.88 | 434.11 | 0.1652 | 671.48 | 0.1789 | 5.2 |
58
+ | *DocScanner-L* | 0.5178 | 7.45 | 1.86 | 390.43 | 0.1486 | 632.34 | 0.1648 | 8.5 |
59
+
60
+ ## Citation
61
+ Please cite the related works in your publications if it helps your research:
62
+
63
+ ```
64
+ @inproceedings{feng2021doctr,
65
+ title={DocTr: Document Image Transformer for Geometric Unwarping and Illumination Correction},
66
+ author={Feng, Hao and Wang, Yuechen and Zhou, Wengang and Deng, Jiajun and Li, Houqiang},
67
+ booktitle={Proceedings of the 29th ACM International Conference on Multimedia},
68
+ pages={273--281},
69
+ year={2021}
70
+ }
71
+ ```
72
+
73
+ ```
74
+ @inproceedings{feng2022docgeonet,
75
+ title={Geometric Representation Learning for Document Image Rectification},
76
+ author={Feng, Hao and Zhou, Wengang and Deng, Jiajun and Wang, Yuechen and Li, Houqiang},
77
+ booktitle={Proceedings of the European Conference on Computer Vision},
78
+ year={2022}
79
+ }
80
+ ```
81
+
82
+ ```
83
+ @article{feng2021docscanner,
84
+ title={DocScanner: robust document image rectification with progressive learning},
85
+ author={Feng, Hao and Zhou, Wengang and Deng, Jiajun and Tian, Qi and Li, Houqiang},
86
+ journal={arXiv preprint arXiv:2110.14968},
87
+ year={2021}
88
+ }
89
+ ```
90
+
91
+ ## Acknowledgement
92
+ The codes are largely based on [DocUNet](https://www3.cs.stonybrook.edu/~cvl/docunet.html) and [DewarpNet](https://github.com/cvlab-stonybrook/DewarpNet). Thanks for their wonderful works.
93
+
94
+ ## Contact
95
+ For commercial usage, please contact Professor Wengang Zhou ([zhwg@ustc.edu.cn](zhwg@ustc.edu.cn)) and Hao Feng ([haof@mail.ustc.edu.cn](haof@mail.ustc.edu.cn)).
96
+
models/DocScanner/__init__.py ADDED
File without changes
models/DocScanner/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (154 Bytes). View file
 
models/DocScanner/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (154 Bytes). View file
 
models/DocScanner/__pycache__/extractor.cpython-38.pyc ADDED
Binary file (3.88 kB). View file
 
models/DocScanner/__pycache__/extractor.cpython-39.pyc ADDED
Binary file (3.86 kB). View file
 
models/DocScanner/__pycache__/inference.cpython-38.pyc ADDED
Binary file (2.22 kB). View file
 
models/DocScanner/__pycache__/inference.cpython-39.pyc ADDED
Binary file (2.2 kB). View file
 
models/DocScanner/__pycache__/model.cpython-38.pyc ADDED
Binary file (3.37 kB). View file
 
models/DocScanner/__pycache__/model.cpython-39.pyc ADDED
Binary file (3.37 kB). View file
 
models/DocScanner/__pycache__/seg.cpython-38.pyc ADDED
Binary file (12 kB). View file
 
models/DocScanner/__pycache__/seg.cpython-39.pyc ADDED
Binary file (12.1 kB). View file
 
models/DocScanner/__pycache__/update.cpython-38.pyc ADDED
Binary file (4.3 kB). View file
 
models/DocScanner/__pycache__/update.cpython-39.pyc ADDED
Binary file (4.27 kB). View file
 
models/DocScanner/eval.m ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ path_rec = "xxx"; % rectified image path
2
+ path_scan = './scan/'; % scan image path
3
+ label_path = './layout/'; % layout result path
4
+
5
+ tarea = 598400;
6
+ ms1 = 0;
7
+ ld1 = 0;
8
+ lid1 = 0;
9
+ ms2 = 0;
10
+ ld2 = 0;
11
+ lid2 = 0;
12
+ wv = 0;
13
+ wh = 0;
14
+
15
+ sprintf(path_rec)
16
+ for i=1:65
17
+ path_rec_1 = sprintf("%s%d%s", path_rec, i, '_1 copy_rec.png'); % rectified image path
18
+ path_rec_2 = sprintf("%s%d%s", path_rec, i, '_2 copy_rec.png'); % rectified image path
19
+ path_scan_new = sprintf("%s%d%s", path_scan, i, '.png'); % corresponding scan image path
20
+ bbox_i_path = sprintf("%s%d%s", label_path, i, '.txt'); % corresponding layout txt path
21
+
22
+ % imread and rgb2gray
23
+ A1 = imread(path_rec_1);
24
+ A2 = imread(path_rec_2);
25
+
26
+ % if i == 64
27
+ % A1 = rot90(A1,-2);
28
+ % A2 = rot90(A2,-2);
29
+ % end
30
+
31
+ ref = imread(path_scan_new);
32
+ A1 = rgb2gray(A1);
33
+ A2 = rgb2gray(A2);
34
+ ref = rgb2gray(ref);
35
+ bbox_i = read_txt(bbox_i_path);
36
+ bbox_i = bbox_i + 1; % python index starts from 0
37
+
38
+ % resize
39
+ b = sqrt(tarea/size(ref,1)/size(ref,2));
40
+ ref = imresize(ref,b);
41
+ A1 = imresize(A1,[size(ref,1),size(ref,2)]);
42
+ A2 = imresize(A2,[size(ref,1),size(ref,2)]);
43
+ scaled_bbox_i = bbox_i * b * 0.5;
44
+ scaled_bbox_i = round(scaled_bbox_i);
45
+ scaled_bbox_i = max(scaled_bbox_i, 1);
46
+
47
+ % calculate
48
+ [ms_1, ld_1, lid_1, W_v_1, W_h_1] = evalUnwarp(A1, ref, scaled_bbox_i);
49
+ [ms_2, ld_2, lid_2, W_v_2, W_h_2] = evalUnwarp(A2, ref, scaled_bbox_i);
50
+ ms1 = ms1 + ms_1;
51
+ ms2 = ms2 + ms_2;
52
+ ld1 = ld1 + ld_1;
53
+ ld2 = ld2 + ld_2;
54
+ lid1 = lid1 + lid_1;
55
+ lid2 = lid2 + lid_2;
56
+ wv = wv + W_v_1 + W_v_2;
57
+ wh = wh + W_h_1 + W_h_2;
58
+ end
59
+
60
+ ms = (ms1 + ms2) / 130 % MS-SSIM
61
+ ld = (ld1 + ld2) / 130 % local distortion
62
+ li_d = (lid1 + lid2) / 130 % line distortion
63
+ wv = wv / 130 % wv index
64
+ wh = wh / 130 % wh index
models/DocScanner/evalUnwarp.m ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ function [ms, ld, li_d, wv, wh] = evalUnwarp(A, ref, data)
2
+ %EVALUNWARP compute MSSSIM and LD between the unwarped image and the scan
3
+ % A: unwarped image
4
+ % ref: reference image, the scan image
5
+ % ms: returned MS-SSIM value
6
+ % ld: returned local distortion value
7
+ % Matlab image processing toolbox is necessary to compute ssim. The weights
8
+ % for multi-scale ssim is directly adopted from:
9
+ %
10
+ % Wang, Zhou, Eero P. Simoncelli, and Alan C. Bovik. "Multiscale structural
11
+ % similarity for image quality assessment." In Signals, Systems and Computers,
12
+ % 2004. Conference Record of the Thirty-Seventh Asilomar Conference on, 2003.
13
+ %
14
+ % Local distortion relies on the paper:
15
+ % Liu, Ce, Jenny Yuen, and Antonio Torralba. "Sift flow: Dense correspondence
16
+ % across scenes and its applications." In PAMI, 2010.
17
+ %
18
+ % and its implementation:
19
+ % https://people.csail.mit.edu/celiu/SIFTflow/
20
+
21
+ x = A;
22
+ y = ref;
23
+
24
+ im1=imresize(imfilter(y,fspecial('gaussian',7,1.),'same','replicate'),0.5,'bicubic');
25
+ im2=imresize(imfilter(x,fspecial('gaussian',7,1.),'same','replicate'),0.5,'bicubic');
26
+
27
+ im1=im2double(im1);
28
+ im2=im2double(im2);
29
+
30
+ cellsize=3;
31
+ gridspacing=1;
32
+
33
+ sift1 = mexDenseSIFT(im1,cellsize,gridspacing);
34
+ sift2 = mexDenseSIFT(im2,cellsize,gridspacing);
35
+
36
+ SIFTflowpara.alpha=2*255;
37
+ SIFTflowpara.d=40*255;
38
+ SIFTflowpara.gamma=0.005*255;
39
+ SIFTflowpara.nlevels=4;
40
+ SIFTflowpara.wsize=2;
41
+ SIFTflowpara.topwsize=10;
42
+ SIFTflowpara.nTopIterations = 60;
43
+ SIFTflowpara.nIterations= 30;
44
+
45
+
46
+ [vx,vy,~]=SIFTflowc2f(sift1,sift2,SIFTflowpara);
47
+
48
+ rows1p = size(im1,1);
49
+ cols1p = size(im1,2);
50
+
51
+ % Li-D
52
+ rowstd_sum = 0;
53
+ for i = 1:rows1p
54
+ rowstd = std(vy(i, :),1);
55
+ rowstd_sum = rowstd_sum + rowstd;
56
+ end
57
+ rowstd_mean = rowstd_sum / rows1p;
58
+
59
+ colstd_sum = 0;
60
+ for i = 1:cols1p
61
+ colstd = std(vx(:, i),1);
62
+ colstd_sum = colstd_sum + colstd;
63
+ end
64
+ colstd_mean = colstd_sum / cols1p;
65
+
66
+ li_d = (rowstd_mean + colstd_mean) / 2;
67
+
68
+
69
+ % LD
70
+ d = sqrt(vx.^2 + vy.^2);
71
+ ld = mean(d(:));
72
+
73
+
74
+ % MS-SSIM
75
+ wt = [0.0448 0.2856 0.3001 0.2363 0.1333];
76
+ ss = zeros(5, 1);
77
+ for s = 1 : 5
78
+ ss(s) = ssim(x, y);
79
+ x = impyramid(x, 'reduce');
80
+ y = impyramid(y, 'reduce');
81
+ end
82
+ ms = wt * ss;
83
+
84
+
85
+ % wv and wh
86
+ rowstd_sum = 0;
87
+ for i = 1:size(data, 1)
88
+ rowstd_top = std(vy(data(i,2), data(i,1):data(i,3)),1) / (data(i,3)-data(i,1));
89
+ rowstd_bot = std(vy(data(i,4), data(i,1):data(i,3)),1) / (data(i,3)-data(i,1));
90
+ rowstd_sum = rowstd_sum + rowstd_top + rowstd_bot;
91
+ end
92
+ wv = rowstd_sum / (2 * size(data, 1));
93
+
94
+ colstd_sum = 0;
95
+ for i = 1:size(data, 1)
96
+ colstd_left = std(vx(data(i,2):data(i,4), data(i,1)),1) / (data(i,4)- data(i,2));
97
+ colstd_right = std(vx(data(i,2):data(i,4), data(i,3)),1) / (data(i,4)- data(i,2));
98
+ colstd_sum = colstd_sum + colstd_left + colstd_right;
99
+ end
100
+ wh = colstd_sum / (2 * size(data, 1));
101
+
102
+ end
models/DocScanner/extractor.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ class ResidualBlock(nn.Module):
5
+ def __init__(self, in_planes, planes, norm_fn="group", stride=1):
6
+ super(ResidualBlock, self).__init__()
7
+
8
+ self.conv1 = nn.Conv2d(
9
+ in_planes, planes, kernel_size=3, padding=1, stride=stride
10
+ )
11
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
12
+ self.relu = nn.ReLU(inplace=True)
13
+
14
+ num_groups = planes // 8
15
+
16
+ if norm_fn == "batch":
17
+ self.norm1 = nn.BatchNorm2d(planes)
18
+ self.norm2 = nn.BatchNorm2d(planes)
19
+ if not stride == 1:
20
+ self.norm3 = nn.BatchNorm2d(planes)
21
+
22
+ elif norm_fn == "instance":
23
+ self.norm1 = nn.InstanceNorm2d(planes)
24
+ self.norm2 = nn.InstanceNorm2d(planes)
25
+ if not stride == 1:
26
+ self.norm3 = nn.InstanceNorm2d(planes)
27
+
28
+ if stride == 1:
29
+ self.downsample = None
30
+ else:
31
+ self.downsample = nn.Sequential(
32
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
33
+ )
34
+
35
+ def forward(self, x):
36
+ y = x
37
+ y = self.relu(self.norm1(self.conv1(y)))
38
+ y = self.relu(self.norm2(self.conv2(y)))
39
+
40
+ if self.downsample is not None:
41
+ x = self.downsample(x)
42
+
43
+ return self.relu(x + y)
44
+
45
+
46
+ class BottleneckBlock(nn.Module):
47
+ def __init__(self, in_planes, planes, norm_fn="group", stride=1):
48
+ super(BottleneckBlock, self).__init__()
49
+
50
+ self.conv1 = nn.Conv2d(in_planes, planes // 4, kernel_size=1, padding=0)
51
+ self.conv2 = nn.Conv2d(
52
+ planes // 4, planes // 4, kernel_size=3, padding=1, stride=stride
53
+ )
54
+ self.conv3 = nn.Conv2d(planes // 4, planes, kernel_size=1, padding=0)
55
+ self.relu = nn.ReLU(inplace=True)
56
+
57
+ if norm_fn == "batch":
58
+ self.norm1 = nn.BatchNorm2d(planes // 4)
59
+ self.norm2 = nn.BatchNorm2d(planes // 4)
60
+ self.norm3 = nn.BatchNorm2d(planes)
61
+ if not stride == 1:
62
+ self.norm4 = nn.BatchNorm2d(planes)
63
+
64
+ elif norm_fn == "instance":
65
+ self.norm1 = nn.InstanceNorm2d(planes // 4)
66
+ self.norm2 = nn.InstanceNorm2d(planes // 4)
67
+ self.norm3 = nn.InstanceNorm2d(planes)
68
+ if not stride == 1:
69
+ self.norm4 = nn.InstanceNorm2d(planes)
70
+
71
+ if stride == 1:
72
+ self.downsample = None
73
+ else:
74
+ self.downsample = nn.Sequential(
75
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4
76
+ )
77
+
78
+ def forward(self, x):
79
+ y = x
80
+ y = self.relu(self.norm1(self.conv1(y)))
81
+ y = self.relu(self.norm2(self.conv2(y)))
82
+ y = self.relu(self.norm3(self.conv3(y)))
83
+
84
+ if self.downsample is not None:
85
+ x = self.downsample(x)
86
+
87
+ return self.relu(x + y)
88
+
89
+
90
+ class BasicEncoder(nn.Module):
91
+ def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0):
92
+ super(BasicEncoder, self).__init__()
93
+ self.norm_fn = norm_fn
94
+
95
+ if self.norm_fn == "batch":
96
+ self.norm1 = nn.BatchNorm2d(64)
97
+
98
+ elif self.norm_fn == "instance":
99
+ self.norm1 = nn.InstanceNorm2d(64)
100
+
101
+ self.conv1 = nn.Conv2d(3, 80, kernel_size=7, stride=2, padding=3)
102
+ self.relu1 = nn.ReLU(inplace=True)
103
+
104
+ self.in_planes = 80
105
+ self.layer1 = self._make_layer(80, stride=1)
106
+ self.layer2 = self._make_layer(160, stride=2)
107
+ self.layer3 = self._make_layer(240, stride=2)
108
+
109
+ # output convolution
110
+ self.conv2 = nn.Conv2d(240, output_dim, kernel_size=1)
111
+
112
+ for m in self.modules():
113
+ if isinstance(m, nn.Conv2d):
114
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
115
+ elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
116
+ if m.weight is not None:
117
+ nn.init.constant_(m.weight, 1)
118
+ if m.bias is not None:
119
+ nn.init.constant_(m.bias, 0)
120
+
121
+ def _make_layer(self, dim, stride=1):
122
+ layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
123
+ layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
124
+ layers = (layer1, layer2)
125
+
126
+ self.in_planes = dim
127
+ return nn.Sequential(*layers)
128
+
129
+ def forward(self, x):
130
+ x = self.conv1(x)
131
+ x = self.norm1(x)
132
+ x = self.relu1(x)
133
+
134
+ x = self.layer1(x)
135
+ x = self.layer2(x)
136
+ x = self.layer3(x)
137
+
138
+ x = self.conv2(x)
139
+
140
+ return x
models/DocScanner/inference.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
5
+
6
+
7
+ import argparse
8
+ import os
9
+ import warnings
10
+
11
+ import cv2
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from DocScanner.model import DocScanner
17
+ from DocScanner.seg import U2NETP
18
+ from PIL import Image
19
+
20
+ warnings.filterwarnings("ignore")
21
+
22
+
23
+ class Net(nn.Module):
24
+ def __init__(self):
25
+ super(Net, self).__init__()
26
+ self.msk = U2NETP(3, 1)
27
+ self.bm = DocScanner() # 矫正
28
+
29
+ def forward(self, x):
30
+ msk, _1, _2, _3, _4, _5, _6 = self.msk(x)
31
+ msk = (msk > 0.5).float()
32
+ x = msk * x
33
+
34
+ bm = self.bm(x, iters=12, test_mode=True)
35
+ bm = (2 * (bm / 286.8) - 1) * 0.99
36
+
37
+ return bm, msk
38
+
39
+
40
+ def reload_seg_model(model, path=""):
41
+ if not bool(path):
42
+ return model
43
+ else:
44
+ model_dict = model.state_dict()
45
+ pretrained_dict = torch.load(path, map_location="cuda:0")
46
+ pretrained_dict = {
47
+ k[6:]: v for k, v in pretrained_dict.items() if k[6:] in model_dict
48
+ }
49
+ model_dict.update(pretrained_dict)
50
+ model.load_state_dict(model_dict)
51
+
52
+ return model
53
+
54
+
55
+ def reload_rec_model(model, path=""):
56
+ if not bool(path):
57
+ return model
58
+ else:
59
+ model_dict = model.state_dict()
60
+ pretrained_dict = torch.load(path, map_location="cuda:0")
61
+ pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
62
+ model_dict.update(pretrained_dict)
63
+ model.load_state_dict(model_dict)
64
+
65
+ return model
models/DocScanner/model.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from DocScanner.extractor import BasicEncoder
10
+ from DocScanner.update import BasicUpdateBlock
11
+
12
+
13
+ def bilinear_sampler(img, coords, mode="bilinear", mask=False):
14
+ """Wrapper for grid_sample, uses pixel coordinates"""
15
+ H, W = img.shape[-2:]
16
+ xgrid, ygrid = coords.split([1, 1], dim=-1)
17
+ xgrid = 2 * xgrid / (W - 1) - 1
18
+ ygrid = 2 * ygrid / (H - 1) - 1
19
+
20
+ grid = torch.cat([xgrid, ygrid], dim=-1)
21
+ img = F.grid_sample(img, grid, align_corners=True)
22
+ if mask:
23
+ mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
24
+ return img, mask.float()
25
+
26
+ return img
27
+
28
+
29
+ def coords_grid(batch, ht, wd):
30
+ coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
31
+ coords = torch.stack(coords[::-1], dim=0).float()
32
+ return coords[None].repeat(batch, 1, 1, 1)
33
+
34
+
35
+ class DocScanner(nn.Module):
36
+ def __init__(self):
37
+ super(DocScanner, self).__init__()
38
+
39
+ self.hidden_dim = hdim = 160
40
+ self.context_dim = 160
41
+
42
+ self.fnet = BasicEncoder(output_dim=320, norm_fn="instance")
43
+ self.update_block = BasicUpdateBlock(hidden_dim=hdim)
44
+
45
+ def freeze_bn(self):
46
+ for m in self.modules():
47
+ if isinstance(m, nn.BatchNorm2d):
48
+ m.eval()
49
+
50
+ def initialize_flow(self, img):
51
+ N, C, H, W = img.shape
52
+ coodslar = coords_grid(N, H, W).to(img.device)
53
+ coords0 = coords_grid(N, H // 8, W // 8).to(img.device)
54
+ coords1 = coords_grid(N, H // 8, W // 8).to(img.device)
55
+
56
+ return coodslar, coords0, coords1
57
+
58
+ def upsample_flow(self, flow, mask):
59
+ N, _, H, W = flow.shape
60
+ mask = mask.view(N, 1, 9, 8, 8, H, W)
61
+ mask = torch.softmax(mask, dim=2)
62
+
63
+ up_flow = F.unfold(8 * flow, [3, 3], padding=1)
64
+ up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
65
+
66
+ up_flow = torch.sum(mask * up_flow, dim=2)
67
+ up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
68
+
69
+ return up_flow.reshape(N, 2, 8 * H, 8 * W)
70
+
71
+ def forward(self, image1, iters=12, flow_init=None, test_mode=False):
72
+ image1 = image1.contiguous()
73
+
74
+ fmap1 = self.fnet(image1)
75
+
76
+ warpfea = fmap1
77
+
78
+ net, inp = torch.split(fmap1, [160, 160], dim=1)
79
+ net = torch.tanh(net)
80
+ inp = torch.relu(inp)
81
+
82
+ coodslar, coords0, coords1 = self.initialize_flow(image1)
83
+
84
+ if flow_init is not None:
85
+ coords1 = coords1 + flow_init
86
+
87
+ flow_predictions = []
88
+ for itr in range(iters):
89
+ coords1 = coords1.detach()
90
+ flow = coords1 - coords0
91
+
92
+ net, up_mask, delta_flow = self.update_block(net, inp, warpfea, flow)
93
+
94
+ coords1 = coords1 + delta_flow
95
+ flow_up = self.upsample_flow(coords1 - coords0, up_mask)
96
+ bm_up = coodslar + flow_up
97
+
98
+ warpfea = bilinear_sampler(fmap1, coords1.permute(0, 2, 3, 1))
99
+ flow_predictions.append(bm_up)
100
+
101
+ if test_mode:
102
+ return bm_up
103
+
104
+ return flow_predictions
models/DocScanner/ocr_img.txt ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ The images for OCR evaluation of DocUNet Benchmark.
2
+ # Setting 1 (Setting from DocTr)
3
+ # Total 30 * 2 = 60 images.
4
+ ./scan/1.png
5
+ ./scan/2.png
6
+ ./scan/3.png
7
+ ./scan/4.png
8
+ ./scan/5.png
9
+ ./scan/6.png
10
+ ./scan/7.png
11
+ ./scan/9.png
12
+ ./scan/10.png
13
+ ./scan/21.png
14
+ ./scan/22.png
15
+ ./scan/23.png
16
+ ./scan/24.png
17
+ ./scan/27.png
18
+ ./scan/30.png
19
+ ./scan/31.png
20
+ ./scan/32.png
21
+ ./scan/36.png
22
+ ./scan/38.png
23
+ ./scan/40.png
24
+ ./scan/41.png
25
+ ./scan/44.png
26
+ ./scan/45.png
27
+ ./scan/46.png
28
+ ./scan/47.png
29
+ ./scan/48.png
30
+ ./scan/50.png
31
+ ./scan/51.png
32
+ ./scan/52.png
33
+ ./scan/53.png
34
+
35
+ # Setting 2 (Setting from DewarpNet)
36
+ # Link: https://github.com/cvlab-stonybrook/DewarpNet/blob/master/eval/ocr_eval/ocr_files.txt
37
+ # Total 25 * 2 = 50 images.
38
+ ./scan/1.png
39
+ ./scan/9.png
40
+ ./scan/10.png
41
+ ./scan/12.png
42
+ ./scan/19.png
43
+ ./scan/20.png
44
+ ./scan/21.png
45
+ ./scan/22.png
46
+ ./scan/23.png
47
+ ./scan/24.png
48
+ ./scan/30.png
49
+ ./scan/31.png
50
+ ./scan/32.png
51
+ ./scan/34.png
52
+ ./scan/35.png
53
+ ./scan/36.png
54
+ ./scan/37.png
55
+ ./scan/38.png
56
+ ./scan/39.png
57
+ ./scan/40.png
58
+ ./scan/44.png
59
+ ./scan/45.png
60
+ ./scan/46.png
61
+ ./scan/47.png
62
+ ./scan/49.png
models/DocScanner/requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ numpy==1.19.0
2
+ opencv_python==4.2.0.34
3
+ Pillow==9.4.0
4
+ scikit_image==0.17.2
5
+ skimage==0.0
6
+ torch==1.5.1+cu101
models/DocScanner/seg.py ADDED
@@ -0,0 +1,576 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torchvision import models
6
+
7
+
8
+ class sobel_net(nn.Module):
9
+ def __init__(self):
10
+ super().__init__()
11
+ self.conv_opx = nn.Conv2d(1, 1, 3, bias=False)
12
+ self.conv_opy = nn.Conv2d(1, 1, 3, bias=False)
13
+ sobel_kernelx = np.array(
14
+ [[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype="float32"
15
+ ).reshape((1, 1, 3, 3))
16
+ sobel_kernely = np.array(
17
+ [[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype="float32"
18
+ ).reshape((1, 1, 3, 3))
19
+ self.conv_opx.weight.data = torch.from_numpy(sobel_kernelx)
20
+ self.conv_opy.weight.data = torch.from_numpy(sobel_kernely)
21
+
22
+ for p in self.parameters():
23
+ p.requires_grad = False
24
+
25
+ def forward(self, im): # input rgb
26
+ x = (
27
+ 0.299 * im[:, 0, :, :] + 0.587 * im[:, 1, :, :] + 0.114 * im[:, 2, :, :]
28
+ ).unsqueeze(
29
+ 1
30
+ ) # rgb2gray
31
+ gradx = self.conv_opx(x)
32
+ grady = self.conv_opy(x)
33
+
34
+ x = (gradx**2 + grady**2) ** 0.5
35
+ x = (x - x.min()) / (x.max() - x.min())
36
+ x = F.pad(x, (1, 1, 1, 1))
37
+
38
+ x = torch.cat([im, x], dim=1)
39
+ return x
40
+
41
+
42
+ class REBNCONV(nn.Module):
43
+ def __init__(self, in_ch=3, out_ch=3, dirate=1):
44
+ super(REBNCONV, self).__init__()
45
+
46
+ self.conv_s1 = nn.Conv2d(
47
+ in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate
48
+ )
49
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
50
+ self.relu_s1 = nn.ReLU(inplace=True)
51
+
52
+ def forward(self, x):
53
+ hx = x
54
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
55
+
56
+ return xout
57
+
58
+
59
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
60
+ def _upsample_like(src, tar):
61
+ src = F.interpolate(src, size=tar.shape[2:], mode="bilinear", align_corners=False)
62
+
63
+ return src
64
+
65
+
66
+ ### RSU-7 ###
67
+ class RSU7(nn.Module): # UNet07DRES(nn.Module):
68
+
69
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
70
+ super(RSU7, self).__init__()
71
+
72
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
73
+
74
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
75
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
76
+
77
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
78
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
79
+
80
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
81
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
82
+
83
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
84
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
85
+
86
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
87
+ self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
88
+
89
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
90
+
91
+ self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
92
+
93
+ self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
94
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
95
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
96
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
97
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
98
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
99
+
100
+ def forward(self, x):
101
+ hx = x
102
+ hxin = self.rebnconvin(hx)
103
+
104
+ hx1 = self.rebnconv1(hxin)
105
+ hx = self.pool1(hx1)
106
+
107
+ hx2 = self.rebnconv2(hx)
108
+ hx = self.pool2(hx2)
109
+
110
+ hx3 = self.rebnconv3(hx)
111
+ hx = self.pool3(hx3)
112
+
113
+ hx4 = self.rebnconv4(hx)
114
+ hx = self.pool4(hx4)
115
+
116
+ hx5 = self.rebnconv5(hx)
117
+ hx = self.pool5(hx5)
118
+
119
+ hx6 = self.rebnconv6(hx)
120
+
121
+ hx7 = self.rebnconv7(hx6)
122
+
123
+ hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
124
+ hx6dup = _upsample_like(hx6d, hx5)
125
+
126
+ hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
127
+ hx5dup = _upsample_like(hx5d, hx4)
128
+
129
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
130
+ hx4dup = _upsample_like(hx4d, hx3)
131
+
132
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
133
+ hx3dup = _upsample_like(hx3d, hx2)
134
+
135
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
136
+ hx2dup = _upsample_like(hx2d, hx1)
137
+
138
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
139
+
140
+ return hx1d + hxin
141
+
142
+
143
+ ### RSU-6 ###
144
+ class RSU6(nn.Module): # UNet06DRES(nn.Module):
145
+
146
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
147
+ super(RSU6, self).__init__()
148
+
149
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
150
+
151
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
152
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
153
+
154
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
155
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
156
+
157
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
158
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
159
+
160
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
161
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
162
+
163
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
164
+
165
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
166
+
167
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
168
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
169
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
170
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
171
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
172
+
173
+ def forward(self, x):
174
+ hx = x
175
+
176
+ hxin = self.rebnconvin(hx)
177
+
178
+ hx1 = self.rebnconv1(hxin)
179
+ hx = self.pool1(hx1)
180
+
181
+ hx2 = self.rebnconv2(hx)
182
+ hx = self.pool2(hx2)
183
+
184
+ hx3 = self.rebnconv3(hx)
185
+ hx = self.pool3(hx3)
186
+
187
+ hx4 = self.rebnconv4(hx)
188
+ hx = self.pool4(hx4)
189
+
190
+ hx5 = self.rebnconv5(hx)
191
+
192
+ hx6 = self.rebnconv6(hx5)
193
+
194
+ hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
195
+ hx5dup = _upsample_like(hx5d, hx4)
196
+
197
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
198
+ hx4dup = _upsample_like(hx4d, hx3)
199
+
200
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
201
+ hx3dup = _upsample_like(hx3d, hx2)
202
+
203
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
204
+ hx2dup = _upsample_like(hx2d, hx1)
205
+
206
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
207
+
208
+ return hx1d + hxin
209
+
210
+
211
+ ### RSU-5 ###
212
+ class RSU5(nn.Module): # UNet05DRES(nn.Module):
213
+
214
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
215
+ super(RSU5, self).__init__()
216
+
217
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
218
+
219
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
220
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
221
+
222
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
223
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
224
+
225
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
226
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
227
+
228
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
229
+
230
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
231
+
232
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
233
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
234
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
235
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
236
+
237
+ def forward(self, x):
238
+ hx = x
239
+
240
+ hxin = self.rebnconvin(hx)
241
+
242
+ hx1 = self.rebnconv1(hxin)
243
+ hx = self.pool1(hx1)
244
+
245
+ hx2 = self.rebnconv2(hx)
246
+ hx = self.pool2(hx2)
247
+
248
+ hx3 = self.rebnconv3(hx)
249
+ hx = self.pool3(hx3)
250
+
251
+ hx4 = self.rebnconv4(hx)
252
+
253
+ hx5 = self.rebnconv5(hx4)
254
+
255
+ hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
256
+ hx4dup = _upsample_like(hx4d, hx3)
257
+
258
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
259
+ hx3dup = _upsample_like(hx3d, hx2)
260
+
261
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
262
+ hx2dup = _upsample_like(hx2d, hx1)
263
+
264
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
265
+
266
+ return hx1d + hxin
267
+
268
+
269
+ ### RSU-4 ###
270
+ class RSU4(nn.Module): # UNet04DRES(nn.Module):
271
+
272
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
273
+ super(RSU4, self).__init__()
274
+
275
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
276
+
277
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
278
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
279
+
280
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
281
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
282
+
283
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
284
+
285
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
286
+
287
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
288
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
289
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
290
+
291
+ def forward(self, x):
292
+ hx = x
293
+
294
+ hxin = self.rebnconvin(hx)
295
+
296
+ hx1 = self.rebnconv1(hxin)
297
+ hx = self.pool1(hx1)
298
+
299
+ hx2 = self.rebnconv2(hx)
300
+ hx = self.pool2(hx2)
301
+
302
+ hx3 = self.rebnconv3(hx)
303
+
304
+ hx4 = self.rebnconv4(hx3)
305
+
306
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
307
+ hx3dup = _upsample_like(hx3d, hx2)
308
+
309
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
310
+ hx2dup = _upsample_like(hx2d, hx1)
311
+
312
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
313
+
314
+ return hx1d + hxin
315
+
316
+
317
+ ### RSU-4F ###
318
+ class RSU4F(nn.Module): # UNet04FRES(nn.Module):
319
+
320
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
321
+ super(RSU4F, self).__init__()
322
+
323
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
324
+
325
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
326
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
327
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
328
+
329
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
330
+
331
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
332
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
333
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
334
+
335
+ def forward(self, x):
336
+ hx = x
337
+
338
+ hxin = self.rebnconvin(hx)
339
+
340
+ hx1 = self.rebnconv1(hxin)
341
+ hx2 = self.rebnconv2(hx1)
342
+ hx3 = self.rebnconv3(hx2)
343
+
344
+ hx4 = self.rebnconv4(hx3)
345
+
346
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
347
+ hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
348
+ hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
349
+
350
+ return hx1d + hxin
351
+
352
+
353
+ ##### U^2-Net ####
354
+ class U2NET(nn.Module):
355
+
356
+ def __init__(self, in_ch=3, out_ch=1):
357
+ super(U2NET, self).__init__()
358
+ self.edge = sobel_net()
359
+
360
+ self.stage1 = RSU7(in_ch, 32, 64)
361
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
362
+
363
+ self.stage2 = RSU6(64, 32, 128)
364
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
365
+
366
+ self.stage3 = RSU5(128, 64, 256)
367
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
368
+
369
+ self.stage4 = RSU4(256, 128, 512)
370
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
371
+
372
+ self.stage5 = RSU4F(512, 256, 512)
373
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
374
+
375
+ self.stage6 = RSU4F(512, 256, 512)
376
+
377
+ # decoder
378
+ self.stage5d = RSU4F(1024, 256, 512)
379
+ self.stage4d = RSU4(1024, 128, 256)
380
+ self.stage3d = RSU5(512, 64, 128)
381
+ self.stage2d = RSU6(256, 32, 64)
382
+ self.stage1d = RSU7(128, 16, 64)
383
+
384
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
385
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
386
+ self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
387
+ self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
388
+ self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
389
+ self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
390
+
391
+ self.outconv = nn.Conv2d(6, out_ch, 1)
392
+
393
+ def forward(self, x):
394
+ x = self.edge(x)
395
+ hx = x
396
+
397
+ # stage 1
398
+ hx1 = self.stage1(hx)
399
+ hx = self.pool12(hx1)
400
+
401
+ # stage 2
402
+ hx2 = self.stage2(hx)
403
+ hx = self.pool23(hx2)
404
+
405
+ # stage 3
406
+ hx3 = self.stage3(hx)
407
+ hx = self.pool34(hx3)
408
+
409
+ # stage 4
410
+ hx4 = self.stage4(hx)
411
+ hx = self.pool45(hx4)
412
+
413
+ # stage 5
414
+ hx5 = self.stage5(hx)
415
+ hx = self.pool56(hx5)
416
+
417
+ # stage 6
418
+ hx6 = self.stage6(hx)
419
+ hx6up = _upsample_like(hx6, hx5)
420
+
421
+ # -------------------- decoder --------------------
422
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
423
+ hx5dup = _upsample_like(hx5d, hx4)
424
+
425
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
426
+ hx4dup = _upsample_like(hx4d, hx3)
427
+
428
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
429
+ hx3dup = _upsample_like(hx3d, hx2)
430
+
431
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
432
+ hx2dup = _upsample_like(hx2d, hx1)
433
+
434
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
435
+
436
+ # side output
437
+ d1 = self.side1(hx1d)
438
+
439
+ d2 = self.side2(hx2d)
440
+ d2 = _upsample_like(d2, d1)
441
+
442
+ d3 = self.side3(hx3d)
443
+ d3 = _upsample_like(d3, d1)
444
+
445
+ d4 = self.side4(hx4d)
446
+ d4 = _upsample_like(d4, d1)
447
+
448
+ d5 = self.side5(hx5d)
449
+ d5 = _upsample_like(d5, d1)
450
+
451
+ d6 = self.side6(hx6)
452
+ d6 = _upsample_like(d6, d1)
453
+
454
+ d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
455
+
456
+ return (
457
+ torch.sigmoid(d0),
458
+ torch.sigmoid(d1),
459
+ torch.sigmoid(d2),
460
+ torch.sigmoid(d3),
461
+ torch.sigmoid(d4),
462
+ torch.sigmoid(d5),
463
+ torch.sigmoid(d6),
464
+ )
465
+
466
+
467
+ ### U^2-Net small ###
468
+ class U2NETP(nn.Module):
469
+
470
+ def __init__(self, in_ch=3, out_ch=1):
471
+ super(U2NETP, self).__init__()
472
+
473
+ self.stage1 = RSU7(in_ch, 16, 64)
474
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
475
+
476
+ self.stage2 = RSU6(64, 16, 64)
477
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
478
+
479
+ self.stage3 = RSU5(64, 16, 64)
480
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
481
+
482
+ self.stage4 = RSU4(64, 16, 64)
483
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
484
+
485
+ self.stage5 = RSU4F(64, 16, 64)
486
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
487
+
488
+ self.stage6 = RSU4F(64, 16, 64)
489
+
490
+ # decoder
491
+ self.stage5d = RSU4F(128, 16, 64)
492
+ self.stage4d = RSU4(128, 16, 64)
493
+ self.stage3d = RSU5(128, 16, 64)
494
+ self.stage2d = RSU6(128, 16, 64)
495
+ self.stage1d = RSU7(128, 16, 64)
496
+
497
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
498
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
499
+ self.side3 = nn.Conv2d(64, out_ch, 3, padding=1)
500
+ self.side4 = nn.Conv2d(64, out_ch, 3, padding=1)
501
+ self.side5 = nn.Conv2d(64, out_ch, 3, padding=1)
502
+ self.side6 = nn.Conv2d(64, out_ch, 3, padding=1)
503
+
504
+ self.outconv = nn.Conv2d(6, out_ch, 1)
505
+
506
+ def forward(self, x):
507
+ hx = x
508
+
509
+ # stage 1
510
+ hx1 = self.stage1(hx)
511
+ hx = self.pool12(hx1)
512
+
513
+ # stage 2
514
+ hx2 = self.stage2(hx)
515
+ hx = self.pool23(hx2)
516
+
517
+ # stage 3
518
+ hx3 = self.stage3(hx)
519
+ hx = self.pool34(hx3)
520
+
521
+ # stage 4
522
+ hx4 = self.stage4(hx)
523
+ hx = self.pool45(hx4)
524
+
525
+ # stage 5
526
+ hx5 = self.stage5(hx)
527
+ hx = self.pool56(hx5)
528
+
529
+ # stage 6
530
+ hx6 = self.stage6(hx)
531
+ hx6up = _upsample_like(hx6, hx5)
532
+
533
+ # decoder
534
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
535
+ hx5dup = _upsample_like(hx5d, hx4)
536
+
537
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
538
+ hx4dup = _upsample_like(hx4d, hx3)
539
+
540
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
541
+ hx3dup = _upsample_like(hx3d, hx2)
542
+
543
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
544
+ hx2dup = _upsample_like(hx2d, hx1)
545
+
546
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
547
+
548
+ # side output
549
+ d1 = self.side1(hx1d)
550
+
551
+ d2 = self.side2(hx2d)
552
+ d2 = _upsample_like(d2, d1)
553
+
554
+ d3 = self.side3(hx3d)
555
+ d3 = _upsample_like(d3, d1)
556
+
557
+ d4 = self.side4(hx4d)
558
+ d4 = _upsample_like(d4, d1)
559
+
560
+ d5 = self.side5(hx5d)
561
+ d5 = _upsample_like(d5, d1)
562
+
563
+ d6 = self.side6(hx6)
564
+ d6 = _upsample_like(d6, d1)
565
+
566
+ d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
567
+
568
+ return (
569
+ torch.sigmoid(d0),
570
+ torch.sigmoid(d1),
571
+ torch.sigmoid(d2),
572
+ torch.sigmoid(d3),
573
+ torch.sigmoid(d4),
574
+ torch.sigmoid(d5),
575
+ torch.sigmoid(d6),
576
+ )
models/DocScanner/update.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class FlowHead(nn.Module):
7
+ def __init__(self, input_dim=128, hidden_dim=256):
8
+ super(FlowHead, self).__init__()
9
+ self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
10
+ self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
11
+ self.relu = nn.ReLU(inplace=True)
12
+
13
+ def forward(self, x):
14
+ return self.conv2(self.relu(self.conv1(x)))
15
+
16
+
17
+ class ConvGRU(nn.Module):
18
+ def __init__(self, hidden_dim=128, input_dim=192 + 128):
19
+ super(ConvGRU, self).__init__()
20
+ self.convz = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
21
+ self.convr = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
22
+ self.convq = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
23
+
24
+ def forward(self, h, x):
25
+ hx = torch.cat([h, x], dim=1)
26
+
27
+ z = torch.sigmoid(self.convz(hx))
28
+ r = torch.sigmoid(self.convr(hx))
29
+ q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1)))
30
+
31
+ h = (1 - z) * h + z * q
32
+ return h
33
+
34
+
35
+ class SepConvGRU(nn.Module):
36
+ def __init__(self, hidden_dim=128, input_dim=192 + 128):
37
+ super(SepConvGRU, self).__init__()
38
+ self.convz1 = nn.Conv2d(
39
+ hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
40
+ )
41
+ self.convr1 = nn.Conv2d(
42
+ hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
43
+ )
44
+ self.convq1 = nn.Conv2d(
45
+ hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
46
+ )
47
+
48
+ self.convz2 = nn.Conv2d(
49
+ hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
50
+ )
51
+ self.convr2 = nn.Conv2d(
52
+ hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
53
+ )
54
+ self.convq2 = nn.Conv2d(
55
+ hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
56
+ )
57
+
58
+ def forward(self, h, x):
59
+ # horizontal
60
+ hx = torch.cat([h, x], dim=1)
61
+ z = torch.sigmoid(self.convz1(hx))
62
+ r = torch.sigmoid(self.convr1(hx))
63
+ q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1)))
64
+ h = (1 - z) * h + z * q
65
+
66
+ # vertical
67
+ hx = torch.cat([h, x], dim=1)
68
+ z = torch.sigmoid(self.convz2(hx))
69
+ r = torch.sigmoid(self.convr2(hx))
70
+ q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1)))
71
+ h = (1 - z) * h + z * q
72
+
73
+ return h
74
+
75
+
76
+ class BasicMotionEncoder(nn.Module):
77
+ def __init__(self):
78
+ super(BasicMotionEncoder, self).__init__()
79
+ self.convc1 = nn.Conv2d(320, 240, 1, padding=0)
80
+ self.convc2 = nn.Conv2d(240, 160, 3, padding=1)
81
+ self.convf1 = nn.Conv2d(2, 160, 7, padding=3)
82
+ self.convf2 = nn.Conv2d(160, 80, 3, padding=1)
83
+ self.conv = nn.Conv2d(160 + 80, 160 - 2, 3, padding=1)
84
+
85
+ def forward(self, flow, corr):
86
+ cor = F.relu(self.convc1(corr))
87
+ cor = F.relu(self.convc2(cor))
88
+ flo = F.relu(self.convf1(flow))
89
+ flo = F.relu(self.convf2(flo))
90
+
91
+ cor_flo = torch.cat([cor, flo], dim=1)
92
+ out = F.relu(self.conv(cor_flo))
93
+ return torch.cat([out, flow], dim=1)
94
+
95
+
96
+ class BasicUpdateBlock(nn.Module):
97
+ def __init__(self, hidden_dim=128):
98
+ super(BasicUpdateBlock, self).__init__()
99
+ self.encoder = BasicMotionEncoder()
100
+ self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=160 + 160)
101
+ self.flow_head = FlowHead(hidden_dim, hidden_dim=320)
102
+
103
+ self.mask = nn.Sequential(
104
+ nn.Conv2d(hidden_dim, 288, 3, padding=1),
105
+ nn.ReLU(inplace=True),
106
+ nn.Conv2d(288, 64 * 9, 1, padding=0),
107
+ )
108
+
109
+ def forward(self, net, inp, corr, flow):
110
+ motion_features = self.encoder(flow, corr)
111
+ inp = torch.cat([inp, motion_features], dim=1)
112
+
113
+ net = self.gru(net, inp)
114
+
115
+ delta_flow = self.flow_head(net)
116
+
117
+ mask = 0.25 * self.mask(net)
118
+
119
+ return net, mask, delta_flow
models/DocTr-Plus/GeoTr.py ADDED
@@ -0,0 +1,960 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
5
+
6
+
7
+ import argparse
8
+ import copy
9
+ from typing import Optional
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from torch import Tensor, nn
15
+
16
+ from .extractor import BasicEncoder
17
+ from .position_encoding import build_position_encoding
18
+
19
+
20
+ class attnLayer(nn.Module):
21
+ def __init__(
22
+ self,
23
+ d_model,
24
+ nhead=8,
25
+ dim_feedforward=2048,
26
+ dropout=0.1,
27
+ activation="relu",
28
+ normalize_before=False,
29
+ ):
30
+ super().__init__()
31
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
32
+ self.multihead_attn_list = nn.ModuleList(
33
+ [
34
+ copy.deepcopy(nn.MultiheadAttention(d_model, nhead, dropout=dropout))
35
+ for i in range(2)
36
+ ]
37
+ )
38
+ # Implementation of Feedforward model
39
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
40
+ self.dropout = nn.Dropout(dropout)
41
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
42
+
43
+ self.norm1 = nn.LayerNorm(d_model)
44
+ self.norm2_list = nn.ModuleList(
45
+ [copy.deepcopy(nn.LayerNorm(d_model)) for i in range(2)]
46
+ )
47
+
48
+ self.norm3 = nn.LayerNorm(d_model)
49
+ self.dropout1 = nn.Dropout(dropout)
50
+ self.dropout2_list = nn.ModuleList(
51
+ [copy.deepcopy(nn.Dropout(dropout)) for i in range(2)]
52
+ )
53
+ self.dropout3 = nn.Dropout(dropout)
54
+
55
+ self.activation = _get_activation_fn(activation)
56
+ self.normalize_before = normalize_before
57
+
58
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
59
+ return tensor if pos is None else tensor + pos
60
+
61
+ def forward_post(
62
+ self,
63
+ tgt,
64
+ memory_list,
65
+ tgt_mask=None,
66
+ memory_mask=None,
67
+ tgt_key_padding_mask=None,
68
+ memory_key_padding_mask=None,
69
+ pos=None,
70
+ memory_pos=None,
71
+ ):
72
+ q = k = self.with_pos_embed(tgt, pos)
73
+ tgt2 = self.self_attn(
74
+ q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
75
+ )[0]
76
+ tgt = tgt + self.dropout1(tgt2)
77
+ tgt = self.norm1(tgt)
78
+ for memory, multihead_attn, norm2, dropout2, m_pos in zip(
79
+ memory_list,
80
+ self.multihead_attn_list,
81
+ self.norm2_list,
82
+ self.dropout2_list,
83
+ memory_pos,
84
+ ):
85
+ tgt2 = multihead_attn(
86
+ query=self.with_pos_embed(tgt, pos),
87
+ key=self.with_pos_embed(memory, m_pos),
88
+ value=memory,
89
+ attn_mask=memory_mask,
90
+ key_padding_mask=memory_key_padding_mask,
91
+ )[0]
92
+ tgt = tgt + dropout2(tgt2)
93
+ tgt = norm2(tgt)
94
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
95
+ tgt = tgt + self.dropout3(tgt2)
96
+ tgt = self.norm3(tgt)
97
+ return tgt
98
+
99
+ def forward_pre(
100
+ self,
101
+ tgt,
102
+ memory,
103
+ tgt_mask=None,
104
+ memory_mask=None,
105
+ tgt_key_padding_mask=None,
106
+ memory_key_padding_mask=None,
107
+ pos=None,
108
+ memory_pos=None,
109
+ ):
110
+ tgt2 = self.norm1(tgt)
111
+ q = k = self.with_pos_embed(tgt2, pos)
112
+ tgt2 = self.self_attn(
113
+ q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
114
+ )[0]
115
+ tgt = tgt + self.dropout1(tgt2)
116
+ tgt2 = self.norm2(tgt)
117
+ tgt2 = self.multihead_attn(
118
+ query=self.with_pos_embed(tgt2, pos),
119
+ key=self.with_pos_embed(memory, memory_pos),
120
+ value=memory,
121
+ attn_mask=memory_mask,
122
+ key_padding_mask=memory_key_padding_mask,
123
+ )[0]
124
+ tgt = tgt + self.dropout2(tgt2)
125
+ tgt2 = self.norm3(tgt)
126
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
127
+ tgt = tgt + self.dropout3(tgt2)
128
+ return tgt
129
+
130
+ def forward(
131
+ self,
132
+ tgt,
133
+ memory_list,
134
+ tgt_mask=None,
135
+ memory_mask=None,
136
+ tgt_key_padding_mask=None,
137
+ memory_key_padding_mask=None,
138
+ pos=None,
139
+ memory_pos=None,
140
+ ):
141
+ if self.normalize_before:
142
+ return self.forward_pre(
143
+ tgt,
144
+ memory_list,
145
+ tgt_mask,
146
+ memory_mask,
147
+ tgt_key_padding_mask,
148
+ memory_key_padding_mask,
149
+ pos,
150
+ memory_pos,
151
+ )
152
+ return self.forward_post(
153
+ tgt,
154
+ memory_list,
155
+ tgt_mask,
156
+ memory_mask,
157
+ tgt_key_padding_mask,
158
+ memory_key_padding_mask,
159
+ pos,
160
+ memory_pos,
161
+ )
162
+
163
+
164
+ def _get_clones(module, N):
165
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
166
+
167
+
168
+ def _get_activation_fn(activation):
169
+ """Return an activation function given a string"""
170
+ if activation == "relu":
171
+ return F.relu
172
+ if activation == "gelu":
173
+ return F.gelu
174
+ if activation == "glu":
175
+ return F.glu
176
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
177
+
178
+
179
+ class TransDecoder(nn.Module):
180
+ def __init__(self, num_attn_layers, hidden_dim=128):
181
+ super(TransDecoder, self).__init__()
182
+ attn_layer = attnLayer(hidden_dim)
183
+ self.layers = _get_clones(attn_layer, num_attn_layers)
184
+ self.position_embedding = build_position_encoding(hidden_dim)
185
+
186
+ def forward(self, imgf, query_embed):
187
+ pos = self.position_embedding(
188
+ torch.ones(imgf.shape[0], imgf.shape[2], imgf.shape[3]).bool().cuda()
189
+ ) # torch.Size([1, 128, 36, 36])
190
+
191
+ bs, c, h, w = imgf.shape
192
+ imgf = imgf.flatten(2).permute(2, 0, 1)
193
+ # query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
194
+ pos = pos.flatten(2).permute(2, 0, 1)
195
+
196
+ for layer in self.layers:
197
+ query_embed = layer(query_embed, [imgf], pos=pos, memory_pos=[pos, pos])
198
+ query_embed = query_embed.permute(1, 2, 0).reshape(bs, c, h, w)
199
+
200
+ return query_embed
201
+
202
+
203
+ class TransEncoder(nn.Module):
204
+ def __init__(self, num_attn_layers, hidden_dim=128):
205
+ super(TransEncoder, self).__init__()
206
+ attn_layer = attnLayer(hidden_dim)
207
+ self.layers = _get_clones(attn_layer, num_attn_layers)
208
+ self.position_embedding = build_position_encoding(hidden_dim)
209
+
210
+ def forward(self, imgf):
211
+ pos = self.position_embedding(
212
+ torch.ones(imgf.shape[0], imgf.shape[2], imgf.shape[3]).bool().cuda()
213
+ ) # torch.Size([1, 128, 36, 36])
214
+ bs, c, h, w = imgf.shape
215
+ imgf = imgf.flatten(2).permute(2, 0, 1)
216
+ pos = pos.flatten(2).permute(2, 0, 1)
217
+
218
+ for layer in self.layers:
219
+ imgf = layer(imgf, [imgf], pos=pos, memory_pos=[pos, pos])
220
+ imgf = imgf.permute(1, 2, 0).reshape(bs, c, h, w)
221
+
222
+ return imgf
223
+
224
+
225
+ class FlowHead(nn.Module):
226
+ def __init__(self, input_dim=128, hidden_dim=256):
227
+ super(FlowHead, self).__init__()
228
+ self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
229
+ self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
230
+ self.relu = nn.ReLU(inplace=True)
231
+
232
+ def forward(self, x):
233
+ return self.conv2(self.relu(self.conv1(x)))
234
+
235
+
236
+ class UpdateBlock(nn.Module):
237
+ def __init__(self, hidden_dim=128):
238
+ super(UpdateBlock, self).__init__()
239
+ self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
240
+ self.mask = nn.Sequential(
241
+ nn.Conv2d(hidden_dim, 256, 3, padding=1),
242
+ nn.ReLU(inplace=True),
243
+ nn.Conv2d(256, 64 * 9, 1, padding=0),
244
+ )
245
+
246
+ def forward(self, imgf, coords1):
247
+ mask = 0.25 * self.mask(imgf) # scale mask to balence gradients
248
+ dflow = self.flow_head(imgf)
249
+ coords1 = coords1 + dflow
250
+
251
+ return mask, coords1
252
+
253
+
254
+ def coords_grid(batch, ht, wd):
255
+ coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
256
+ coords = torch.stack(coords[::-1], dim=0).float()
257
+ return coords[None].repeat(batch, 1, 1, 1)
258
+
259
+
260
+ def upflow8(flow, mode="bilinear"):
261
+ new_size = (8 * flow.shape[2], 8 * flow.shape[3])
262
+ return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
263
+
264
+
265
+ class OverlapPatchEmbed(nn.Module):
266
+ """Image to Patch Embedding"""
267
+
268
+ def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
269
+ super().__init__()
270
+ img_size = to_2tuple(img_size)
271
+ patch_size = to_2tuple(patch_size)
272
+
273
+ self.img_size = img_size
274
+ self.patch_size = patch_size
275
+ self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
276
+ self.num_patches = self.H * self.W
277
+ self.proj = nn.Conv2d(
278
+ in_chans,
279
+ embed_dim,
280
+ kernel_size=patch_size,
281
+ stride=stride,
282
+ padding=(patch_size[0] // 2, patch_size[1] // 2),
283
+ )
284
+ self.norm = nn.LayerNorm(embed_dim)
285
+
286
+ self.apply(self._init_weights)
287
+
288
+ def _init_weights(self, m):
289
+ if isinstance(m, nn.Linear):
290
+ trunc_normal_(m.weight, std=0.02)
291
+ if isinstance(m, nn.Linear) and m.bias is not None:
292
+ nn.init.constant_(m.bias, 0)
293
+ elif isinstance(m, nn.LayerNorm):
294
+ nn.init.constant_(m.bias, 0)
295
+ nn.init.constant_(m.weight, 1.0)
296
+ elif isinstance(m, nn.Conv2d):
297
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
298
+ fan_out //= m.groups
299
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
300
+ if m.bias is not None:
301
+ m.bias.data.zero_()
302
+
303
+ def forward(self, x):
304
+ x = self.proj(x)
305
+ _, _, H, W = x.shape
306
+ x = x.flatten(2).transpose(1, 2)
307
+ x = self.norm(x)
308
+
309
+ return x, H, W
310
+
311
+
312
+ class GeoTr(nn.Module):
313
+ def __init__(self):
314
+ super(GeoTr, self).__init__()
315
+
316
+ self.hidden_dim = hdim = 256
317
+
318
+ self.fnet = BasicEncoder(output_dim=hdim, norm_fn="instance")
319
+
320
+ self.encoder_block = ["encoder_block" + str(i) for i in range(3)]
321
+ for i in self.encoder_block:
322
+ self.__setattr__(i, TransEncoder(2, hidden_dim=hdim))
323
+ self.down_layer = ["down_layer" + str(i) for i in range(2)]
324
+ for i in self.down_layer:
325
+ self.__setattr__(i, nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1))
326
+
327
+ self.decoder_block = ["decoder_block" + str(i) for i in range(3)]
328
+ for i in self.decoder_block:
329
+ self.__setattr__(i, TransDecoder(2, hidden_dim=hdim))
330
+ self.up_layer = ["up_layer" + str(i) for i in range(2)]
331
+ for i in self.up_layer:
332
+ self.__setattr__(
333
+ i, nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
334
+ )
335
+
336
+ self.query_embed = nn.Embedding(81, self.hidden_dim)
337
+
338
+ self.update_block = UpdateBlock(self.hidden_dim)
339
+
340
+ def initialize_flow(self, img):
341
+ N, C, H, W = img.shape
342
+ coodslar = coords_grid(N, H, W).to(img.device)
343
+ coords0 = coords_grid(N, H // 8, W // 8).to(img.device)
344
+ coords1 = coords_grid(N, H // 8, W // 8).to(img.device)
345
+
346
+ return coodslar, coords0, coords1
347
+
348
+ def upsample_flow(self, flow, mask):
349
+ N, _, H, W = flow.shape
350
+ mask = mask.view(N, 1, 9, 8, 8, H, W)
351
+ mask = torch.softmax(mask, dim=2)
352
+
353
+ up_flow = F.unfold(8 * flow, [3, 3], padding=1)
354
+ up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
355
+
356
+ up_flow = torch.sum(mask * up_flow, dim=2)
357
+ up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
358
+
359
+ return up_flow.reshape(N, 2, 8 * H, 8 * W)
360
+
361
+ def forward(self, image1):
362
+ fmap = self.fnet(image1)
363
+ fmap = torch.relu(fmap)
364
+
365
+ # fmap = self.TransEncoder(fmap)
366
+ fmap1 = self.__getattr__(self.encoder_block[0])(fmap)
367
+ fmap1d = self.__getattr__(self.down_layer[0])(fmap1)
368
+ fmap2 = self.__getattr__(self.encoder_block[1])(fmap1d)
369
+ fmap2d = self.__getattr__(self.down_layer[1])(fmap2)
370
+ fmap3 = self.__getattr__(self.encoder_block[2])(fmap2d)
371
+
372
+ query_embed0 = self.query_embed.weight.unsqueeze(1).repeat(1, fmap3.size(0), 1)
373
+ fmap3d_ = self.__getattr__(self.decoder_block[0])(fmap3, query_embed0)
374
+ fmap3du_ = (
375
+ self.__getattr__(self.up_layer[0])(fmap3d_).flatten(2).permute(2, 0, 1)
376
+ )
377
+ fmap2d_ = self.__getattr__(self.decoder_block[1])(fmap2, fmap3du_)
378
+ fmap2du_ = (
379
+ self.__getattr__(self.up_layer[1])(fmap2d_).flatten(2).permute(2, 0, 1)
380
+ )
381
+ fmap_out = self.__getattr__(self.decoder_block[2])(fmap1, fmap2du_)
382
+
383
+ # convex upsample baesd on fmap_out
384
+ coodslar, coords0, coords1 = self.initialize_flow(image1)
385
+ coords1 = coords1.detach()
386
+ mask, coords1 = self.update_block(fmap_out, coords1)
387
+ flow_up = self.upsample_flow(coords1 - coords0, mask)
388
+ bm_up = coodslar + flow_up
389
+
390
+ return bm_up
391
+
392
+
393
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
394
+ def _upsample_like(src, tar):
395
+ src = F.interpolate(src, size=tar.shape[2:], mode="bilinear", align_corners=False)
396
+
397
+ return src
398
+
399
+
400
+ class REBNCONV(nn.Module):
401
+ def __init__(self, in_ch=3, out_ch=3, dirate=1):
402
+ super(REBNCONV, self).__init__()
403
+
404
+ self.conv_s1 = nn.Conv2d(
405
+ in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate
406
+ )
407
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
408
+ self.relu_s1 = nn.ReLU(inplace=True)
409
+
410
+ def forward(self, x):
411
+ hx = x
412
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
413
+
414
+ return xout
415
+
416
+
417
+ ### RSU-4 ###
418
+ class RSU4(nn.Module): # UNet04DRES(nn.Module):
419
+
420
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
421
+ super(RSU4, self).__init__()
422
+
423
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
424
+
425
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
426
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
427
+
428
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
429
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
430
+
431
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
432
+
433
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
434
+
435
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
436
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
437
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
438
+
439
+ def forward(self, x):
440
+ hx = x
441
+
442
+ hxin = self.rebnconvin(hx)
443
+
444
+ hx1 = self.rebnconv1(hxin)
445
+ hx = self.pool1(hx1)
446
+
447
+ hx2 = self.rebnconv2(hx)
448
+ hx = self.pool2(hx2)
449
+
450
+ hx3 = self.rebnconv3(hx)
451
+
452
+ hx4 = self.rebnconv4(hx3)
453
+
454
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
455
+ hx3dup = _upsample_like(hx3d, hx2)
456
+
457
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
458
+ hx2dup = _upsample_like(hx2d, hx1)
459
+
460
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
461
+
462
+ return hx1d + hxin
463
+
464
+
465
+ ### RSU-4F ###
466
+ class RSU4F(nn.Module): # UNet04FRES(nn.Module):
467
+
468
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
469
+ super(RSU4F, self).__init__()
470
+
471
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
472
+
473
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
474
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
475
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
476
+
477
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
478
+
479
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
480
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
481
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
482
+
483
+ def forward(self, x):
484
+ hx = x
485
+
486
+ hxin = self.rebnconvin(hx)
487
+
488
+ hx1 = self.rebnconv1(hxin)
489
+ hx2 = self.rebnconv2(hx1)
490
+ hx3 = self.rebnconv3(hx2)
491
+
492
+ hx4 = self.rebnconv4(hx3)
493
+
494
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
495
+ hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
496
+ hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
497
+
498
+ return hx1d + hxin
499
+
500
+
501
+ class sobel_net(nn.Module):
502
+ def __init__(self):
503
+ super().__init__()
504
+ self.conv_opx = nn.Conv2d(1, 1, 3, bias=False)
505
+ self.conv_opy = nn.Conv2d(1, 1, 3, bias=False)
506
+ sobel_kernelx = np.array(
507
+ [[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype="float32"
508
+ ).reshape((1, 1, 3, 3))
509
+ sobel_kernely = np.array(
510
+ [[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype="float32"
511
+ ).reshape((1, 1, 3, 3))
512
+ self.conv_opx.weight.data = torch.from_numpy(sobel_kernelx)
513
+ self.conv_opy.weight.data = torch.from_numpy(sobel_kernely)
514
+
515
+ for p in self.parameters():
516
+ p.requires_grad = False
517
+
518
+ def forward(self, im): # input rgb
519
+ x = (
520
+ 0.299 * im[:, 0, :, :] + 0.587 * im[:, 1, :, :] + 0.114 * im[:, 2, :, :]
521
+ ).unsqueeze(
522
+ 1
523
+ ) # rgb2gray
524
+ gradx = self.conv_opx(x)
525
+ grady = self.conv_opy(x)
526
+
527
+ x = (gradx**2 + grady**2) ** 0.5
528
+ x = (x - x.min()) / (x.max() - x.min())
529
+ x = F.pad(x, (1, 1, 1, 1))
530
+
531
+ x = torch.cat([im, x], dim=1)
532
+ return x
533
+
534
+
535
+ ##### U^2-Net ####
536
+ class U2NET(nn.Module):
537
+
538
+ def __init__(self, in_ch=3, out_ch=1):
539
+ super(U2NET, self).__init__()
540
+ self.edge = sobel_net()
541
+
542
+ self.stage1 = RSU7(in_ch, 32, 64)
543
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
544
+
545
+ self.stage2 = RSU6(64, 32, 128)
546
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
547
+
548
+ self.stage3 = RSU5(128, 64, 256)
549
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
550
+
551
+ self.stage4 = RSU4(256, 128, 512)
552
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
553
+
554
+ self.stage5 = RSU4F(512, 256, 512)
555
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
556
+
557
+ self.stage6 = RSU4F(512, 256, 512)
558
+
559
+ # decoder
560
+ self.stage5d = RSU4F(1024, 256, 512)
561
+ self.stage4d = RSU4(1024, 128, 256)
562
+ self.stage3d = RSU5(512, 64, 128)
563
+ self.stage2d = RSU6(256, 32, 64)
564
+ self.stage1d = RSU7(128, 16, 64)
565
+
566
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
567
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
568
+ self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
569
+ self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
570
+ self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
571
+ self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
572
+
573
+ self.outconv = nn.Conv2d(6, out_ch, 1)
574
+
575
+ def forward(self, x):
576
+ x = self.edge(x)
577
+ hx = x
578
+
579
+ # stage 1
580
+ hx1 = self.stage1(hx)
581
+ hx = self.pool12(hx1)
582
+
583
+ # stage 2
584
+ hx2 = self.stage2(hx)
585
+ hx = self.pool23(hx2)
586
+
587
+ # stage 3
588
+ hx3 = self.stage3(hx)
589
+ hx = self.pool34(hx3)
590
+
591
+ # stage 4
592
+ hx4 = self.stage4(hx)
593
+ hx = self.pool45(hx4)
594
+
595
+ # stage 5
596
+ hx5 = self.stage5(hx)
597
+ hx = self.pool56(hx5)
598
+
599
+ # stage 6
600
+ hx6 = self.stage6(hx)
601
+ hx6up = _upsample_like(hx6, hx5)
602
+
603
+ # -------------------- decoder --------------------
604
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
605
+ hx5dup = _upsample_like(hx5d, hx4)
606
+
607
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
608
+ hx4dup = _upsample_like(hx4d, hx3)
609
+
610
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
611
+ hx3dup = _upsample_like(hx3d, hx2)
612
+
613
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
614
+ hx2dup = _upsample_like(hx2d, hx1)
615
+
616
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
617
+
618
+ # side output
619
+ d1 = self.side1(hx1d)
620
+
621
+ d2 = self.side2(hx2d)
622
+ d2 = _upsample_like(d2, d1)
623
+
624
+ d3 = self.side3(hx3d)
625
+ d3 = _upsample_like(d3, d1)
626
+
627
+ d4 = self.side4(hx4d)
628
+ d4 = _upsample_like(d4, d1)
629
+
630
+ d5 = self.side5(hx5d)
631
+ d5 = _upsample_like(d5, d1)
632
+
633
+ d6 = self.side6(hx6)
634
+ d6 = _upsample_like(d6, d1)
635
+
636
+ d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
637
+
638
+ return (
639
+ torch.sigmoid(d0),
640
+ torch.sigmoid(d1),
641
+ torch.sigmoid(d2),
642
+ torch.sigmoid(d3),
643
+ torch.sigmoid(d4),
644
+ torch.sigmoid(d5),
645
+ torch.sigmoid(d6),
646
+ )
647
+
648
+
649
+ ### RSU-5 ###
650
+ class RSU5(nn.Module): # UNet05DRES(nn.Module):
651
+
652
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
653
+ super(RSU5, self).__init__()
654
+
655
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
656
+
657
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
658
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
659
+
660
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
661
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
662
+
663
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
664
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
665
+
666
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
667
+
668
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
669
+
670
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
671
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
672
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
673
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
674
+
675
+ def forward(self, x):
676
+ hx = x
677
+
678
+ hxin = self.rebnconvin(hx)
679
+
680
+ hx1 = self.rebnconv1(hxin)
681
+ hx = self.pool1(hx1)
682
+
683
+ hx2 = self.rebnconv2(hx)
684
+ hx = self.pool2(hx2)
685
+
686
+ hx3 = self.rebnconv3(hx)
687
+ hx = self.pool3(hx3)
688
+
689
+ hx4 = self.rebnconv4(hx)
690
+
691
+ hx5 = self.rebnconv5(hx4)
692
+
693
+ hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
694
+ hx4dup = _upsample_like(hx4d, hx3)
695
+
696
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
697
+ hx3dup = _upsample_like(hx3d, hx2)
698
+
699
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
700
+ hx2dup = _upsample_like(hx2d, hx1)
701
+
702
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
703
+
704
+ return hx1d + hxin
705
+
706
+
707
+ ### RSU-6 ###
708
+ class RSU6(nn.Module): # UNet06DRES(nn.Module):
709
+
710
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
711
+ super(RSU6, self).__init__()
712
+
713
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
714
+
715
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
716
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
717
+
718
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
719
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
720
+
721
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
722
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
723
+
724
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
725
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
726
+
727
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
728
+
729
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
730
+
731
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
732
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
733
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
734
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
735
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
736
+
737
+ def forward(self, x):
738
+ hx = x
739
+
740
+ hxin = self.rebnconvin(hx)
741
+
742
+ hx1 = self.rebnconv1(hxin)
743
+ hx = self.pool1(hx1)
744
+
745
+ hx2 = self.rebnconv2(hx)
746
+ hx = self.pool2(hx2)
747
+
748
+ hx3 = self.rebnconv3(hx)
749
+ hx = self.pool3(hx3)
750
+
751
+ hx4 = self.rebnconv4(hx)
752
+ hx = self.pool4(hx4)
753
+
754
+ hx5 = self.rebnconv5(hx)
755
+
756
+ hx6 = self.rebnconv6(hx5)
757
+
758
+ hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
759
+ hx5dup = _upsample_like(hx5d, hx4)
760
+
761
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
762
+ hx4dup = _upsample_like(hx4d, hx3)
763
+
764
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
765
+ hx3dup = _upsample_like(hx3d, hx2)
766
+
767
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
768
+ hx2dup = _upsample_like(hx2d, hx1)
769
+
770
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
771
+
772
+ return hx1d + hxin
773
+
774
+
775
+ ### RSU-7 ###
776
+ class RSU7(nn.Module): # UNet07DRES(nn.Module):
777
+
778
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
779
+ super(RSU7, self).__init__()
780
+
781
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
782
+
783
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
784
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
785
+
786
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
787
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
788
+
789
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
790
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
791
+
792
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
793
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
794
+
795
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
796
+ self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
797
+
798
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
799
+
800
+ self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
801
+
802
+ self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
803
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
804
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
805
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
806
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
807
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
808
+
809
+ def forward(self, x):
810
+ hx = x
811
+ hxin = self.rebnconvin(hx)
812
+
813
+ hx1 = self.rebnconv1(hxin)
814
+ hx = self.pool1(hx1)
815
+
816
+ hx2 = self.rebnconv2(hx)
817
+ hx = self.pool2(hx2)
818
+
819
+ hx3 = self.rebnconv3(hx)
820
+ hx = self.pool3(hx3)
821
+
822
+ hx4 = self.rebnconv4(hx)
823
+ hx = self.pool4(hx4)
824
+
825
+ hx5 = self.rebnconv5(hx)
826
+ hx = self.pool5(hx5)
827
+
828
+ hx6 = self.rebnconv6(hx)
829
+
830
+ hx7 = self.rebnconv7(hx6)
831
+
832
+ hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
833
+ hx6dup = _upsample_like(hx6d, hx5)
834
+
835
+ hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
836
+ hx5dup = _upsample_like(hx5d, hx4)
837
+
838
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
839
+ hx4dup = _upsample_like(hx4d, hx3)
840
+
841
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
842
+ hx3dup = _upsample_like(hx3d, hx2)
843
+
844
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
845
+ hx2dup = _upsample_like(hx2d, hx1)
846
+
847
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
848
+
849
+ return hx1d + hxin
850
+
851
+
852
+ class U2NETP(nn.Module):
853
+
854
+ def __init__(self, in_ch=3, out_ch=1):
855
+ super(U2NETP, self).__init__()
856
+
857
+ self.stage1 = RSU7(in_ch, 16, 64)
858
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
859
+
860
+ self.stage2 = RSU6(64, 16, 64)
861
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
862
+
863
+ self.stage3 = RSU5(64, 16, 64)
864
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
865
+
866
+ self.stage4 = RSU4(64, 16, 64)
867
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
868
+
869
+ self.stage5 = RSU4F(64, 16, 64)
870
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
871
+
872
+ self.stage6 = RSU4F(64, 16, 64)
873
+
874
+ # decoder
875
+ self.stage5d = RSU4F(128, 16, 64)
876
+ self.stage4d = RSU4(128, 16, 64)
877
+ self.stage3d = RSU5(128, 16, 64)
878
+ self.stage2d = RSU6(128, 16, 64)
879
+ self.stage1d = RSU7(128, 16, 64)
880
+
881
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
882
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
883
+ self.side3 = nn.Conv2d(64, out_ch, 3, padding=1)
884
+ self.side4 = nn.Conv2d(64, out_ch, 3, padding=1)
885
+ self.side5 = nn.Conv2d(64, out_ch, 3, padding=1)
886
+ self.side6 = nn.Conv2d(64, out_ch, 3, padding=1)
887
+
888
+ self.outconv = nn.Conv2d(6, out_ch, 1)
889
+
890
+ def forward(self, x):
891
+ hx = x
892
+
893
+ # stage 1
894
+ hx1 = self.stage1(hx)
895
+ hx = self.pool12(hx1)
896
+
897
+ # stage 2
898
+ hx2 = self.stage2(hx)
899
+ hx = self.pool23(hx2)
900
+
901
+ # stage 3
902
+ hx3 = self.stage3(hx)
903
+ hx = self.pool34(hx3)
904
+
905
+ # stage 4
906
+ hx4 = self.stage4(hx)
907
+ hx = self.pool45(hx4)
908
+
909
+ # stage 5
910
+ hx5 = self.stage5(hx)
911
+ hx = self.pool56(hx5)
912
+
913
+ # stage 6
914
+ hx6 = self.stage6(hx)
915
+ hx6up = _upsample_like(hx6, hx5)
916
+
917
+ # decoder
918
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
919
+ hx5dup = _upsample_like(hx5d, hx4)
920
+
921
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
922
+ hx4dup = _upsample_like(hx4d, hx3)
923
+
924
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
925
+ hx3dup = _upsample_like(hx3d, hx2)
926
+
927
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
928
+ hx2dup = _upsample_like(hx2d, hx1)
929
+
930
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
931
+
932
+ # side output
933
+ d1 = self.side1(hx1d)
934
+
935
+ d2 = self.side2(hx2d)
936
+ d2 = _upsample_like(d2, d1)
937
+
938
+ d3 = self.side3(hx3d)
939
+ d3 = _upsample_like(d3, d1)
940
+
941
+ d4 = self.side4(hx4d)
942
+ d4 = _upsample_like(d4, d1)
943
+
944
+ d5 = self.side5(hx5d)
945
+ d5 = _upsample_like(d5, d1)
946
+
947
+ d6 = self.side6(hx6)
948
+ d6 = _upsample_like(d6, d1)
949
+
950
+ d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
951
+
952
+ return (
953
+ torch.sigmoid(d0),
954
+ torch.sigmoid(d1),
955
+ torch.sigmoid(d2),
956
+ torch.sigmoid(d3),
957
+ torch.sigmoid(d4),
958
+ torch.sigmoid(d5),
959
+ torch.sigmoid(d6),
960
+ )
models/DocTr-Plus/LICENSE.md ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # License
2
+
3
+ Copyright © Hao Feng 2024. All Rights Reserved.
4
+
5
+ ## 1. Definitions
6
+
7
+ 1.1 "Algorithm" refers to the deep learning algorithm contained in this repository, including all associated code, documentation, and data.
8
+
9
+ 1.2 "Author" refers to Hao Feng, the creator and copyright holder of the Algorithm.
10
+
11
+ 1.3 "Non-Commercial Use" means use for academic research, personal study, or non-profit projects, without any direct or indirect commercial advantage.
12
+
13
+ 1.4 "Commercial Use" means any use intended for or directed toward commercial advantage or monetary compensation.
14
+
15
+ ## 2. Grant of Rights
16
+
17
+ 2.1 Non-Commercial Use: The Author hereby grants you a worldwide, royalty-free, non-exclusive license to use, copy, modify, and distribute the Algorithm for Non-Commercial Use, subject to the conditions in Section 3.
18
+
19
+ 2.2 Commercial Use: Any Commercial Use of the Algorithm is strictly prohibited without explicit prior written permission from the Author.
20
+
21
+ ## 3. Conditions
22
+
23
+ 3.1 For Non-Commercial Use:
24
+ a) Attribution: You must give appropriate credit to the Author, provide a link to this license, and indicate if changes were made.
25
+ b) Share-Alike: If you modify, transform, or build upon the Algorithm, you must distribute your contributions under the same license as this one.
26
+ c) No additional restrictions: You may not apply legal terms or technological measures that legally restrict others from doing anything this license permits.
27
+
28
+ 3.2 For Commercial Use:
29
+ a) Prior Contact: Before any Commercial Use, you must contact the Author at haof@mail.ustc.edu.cn and obtain explicit written permission.
30
+ b) Separate Agreement: Commercial Use terms will be stipulated in a separate commercial license agreement.
31
+
32
+ ## 4. Disclaimer of Warranty
33
+
34
+ The Algorithm is provided "as is", without warranty of any kind, express or implied, including but not limited to the warranties of merchantability, fitness for a particular purpose, and non-infringement. In no event shall the Author be liable for any claim, damages, or other liability arising from, out of, or in connection with the Algorithm or the use or other dealings in the Algorithm.
35
+
36
+ ## 5. Limitation of Liability
37
+
38
+ In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, shall the Author be liable to you for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this license or out of the use or inability to use the Algorithm.
39
+
40
+ ## 6. Termination
41
+
42
+ 6.1 This license and the rights granted hereunder will terminate automatically upon any breach by you of the terms of this license.
43
+
44
+ 6.2 All sections which by their nature should survive the termination of this license shall survive such termination.
45
+
46
+ ## 7. Miscellaneous
47
+
48
+ 7.1 If any provision of this license is held to be unenforceable, such provision shall be reformed only to the extent necessary to make it enforceable.
49
+
50
+ 7.2 This license represents the complete agreement concerning the subject matter hereof.
51
+
52
+ By using the Algorithm, you acknowledge that you have read this license, understand it, and agree to be bound by its terms and conditions. If you do not agree to the terms and conditions of this license, do not use, modify, or distribute the Algorithm.
53
+
54
+ For permissions beyond the scope of this license, please contact the Author at haof@mail.ustc.edu.cn.
models/DocTr-Plus/OCR_eval.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import numpy as np
4
+ import pytesseract
5
+ from PIL import Image
6
+
7
+ pytesseract.get_tesseract_version()
8
+
9
+
10
+ def Levenshtein_Distance(str1, str2):
11
+ matrix = [[i + j for j in range(len(str2) + 1)] for i in range(len(str1) + 1)]
12
+ for i in range(1, len(str1) + 1):
13
+ for j in range(1, len(str2) + 1):
14
+ if str1[i - 1] == str2[j - 1]:
15
+ d = 0
16
+ else:
17
+ d = 1
18
+ matrix[i][j] = min(
19
+ matrix[i - 1][j] + 1, matrix[i][j - 1] + 1, matrix[i - 1][j - 1] + d
20
+ )
21
+
22
+ return matrix[len(str1)][len(str2)]
23
+
24
+
25
+ def cal_cer_ed(path_ours, tail="_rec"):
26
+ print(path_ours, "start")
27
+ print(f"started at {time.strftime('%H:%M:%S')}")
28
+ path_gt = "./scan/"
29
+ N = 196
30
+ cer1 = []
31
+ ed1 = []
32
+ check = [0 for _ in range(N + 1)]
33
+ # img index in UDIR test set for OCR evaluation
34
+ lis = [
35
+ 2,
36
+ 5,
37
+ 17,
38
+ 19,
39
+ 20,
40
+ 23,
41
+ 31,
42
+ 37,
43
+ 38,
44
+ 39,
45
+ 40,
46
+ 41,
47
+ 43,
48
+ 45,
49
+ 47,
50
+ 48,
51
+ 51,
52
+ 54,
53
+ 57,
54
+ 60,
55
+ 61,
56
+ 62,
57
+ 64,
58
+ 65,
59
+ 67,
60
+ 68,
61
+ 70,
62
+ 75,
63
+ 76,
64
+ 77,
65
+ 78,
66
+ 80,
67
+ 81,
68
+ 83,
69
+ 84,
70
+ 85,
71
+ 87,
72
+ 88,
73
+ 90,
74
+ 91,
75
+ 93,
76
+ 96,
77
+ 99,
78
+ 100,
79
+ 101,
80
+ 102,
81
+ 103,
82
+ 104,
83
+ 105,
84
+ 134,
85
+ 137,
86
+ 138,
87
+ 140,
88
+ 150,
89
+ 151,
90
+ 155,
91
+ 158,
92
+ 162,
93
+ 163,
94
+ 164,
95
+ 165,
96
+ 166,
97
+ 169,
98
+ 170,
99
+ 172,
100
+ 173,
101
+ 175,
102
+ 177,
103
+ 178,
104
+ 182,
105
+ ]
106
+ for i in range(1, N):
107
+ if i not in lis:
108
+ continue
109
+ gt = Image.open(path_gt + str(i) + ".png")
110
+ img1 = Image.open(path_ours + str(i) + tail)
111
+ content_gt = pytesseract.image_to_string(gt)
112
+ content1 = pytesseract.image_to_string(img1)
113
+ l1 = Levenshtein_Distance(content_gt, content1)
114
+ ed1.append(l1)
115
+ cer1.append(l1 / len(content_gt))
116
+ check[i] = cer1[-1]
117
+
118
+ CER = np.mean(cer1)
119
+ ED = np.mean(ed1)
120
+ print(f"finished at {time.strftime('%H:%M:%S')}")
121
+ return [path_ours, CER, ED]
models/DocTr-Plus/README.md ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 🔥 **Good news! Our demo has already exceeded 20,000 calls!**
2
+
3
+ 🔥 **Good news! Our work has been accepted by IEEE Transactions on Multimedia.**
4
+
5
+ 🚀 **Exciting update! We have created a demo for our paper, showcasing the generic rectification capabilities of our method. [Check it out here!](https://doctrp.docscanner.top/)**
6
+
7
+ 🔥 **Good news! Our new work exhibits state-of-the-art performances on the [DocUNet Benchmark](https://www3.cs.stonybrook.edu/~cvl/docunet.html) dataset:
8
+ [DocScanner: Robust Document Image Rectification with Progressive Learning](https://drive.google.com/file/d/1mmCUj90rHyuO1SmpLt361youh-07Y0sD/view?usp=share_link)** with [Repo](https://github.com/fh2019ustc/DocScanner).
9
+
10
+ 🔥 **Good news! A comprehensive list of [Awesome Document Image Rectification](https://github.com/fh2019ustc/Awesome-Document-Image-Rectification) methods is available.**
11
+
12
+ # DocTr++
13
+
14
+ <p>
15
+ <a href='https://project.doctrp.top/' target="_blank"><img src='https://img.shields.io/badge/Project-Page-Green'></a>
16
+ <a href='https://arxiv.org/abs/2304.08796' target="_blank"><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a>
17
+ <a href='https://demo.doctrp.top/' target="_blank"><img src='https://img.shields.io/badge/Online-Demo-green'></a>
18
+ </p>
19
+
20
+
21
+ ![Demo](assets/github_demo.png)
22
+ ![Demo](assets/github_demo_v2.png)
23
+ > **[DocTr++: Deep Unrestricted Document Image Rectification](https://arxiv.org/abs/2304.08796)**
24
+
25
+ > DocTr++ is an enhanced version of the original [DocTr: Document Image Transformer for Geometric Unwarping and Illumination Correction](https://github.com/fh2019ustc/DocTr), aiming to rectify various distorted document images in the wild,
26
+ whether or not the document is fully present in the image.
27
+
28
+ Any questions or discussions are welcomed!
29
+
30
+
31
+ ## 🚀 Demo [(Link)](https://demo.doctrp.top/)
32
+ 1. Upload the distorted document image to be rectified in the left box.
33
+ 2. Click the "Submit" button.
34
+ 3. The rectified image will be displayed in the right box.
35
+ 4. Our demo environment is based on a CPU infrastructure, and due to image transmission over the network, some display latency may be experienced.
36
+
37
+ [![Alt text](https://user-images.githubusercontent.com/50725551/232952015-15508ad6-e38c-475b-bf9e-91cb74bc5fea.png)](https://demo.doctrp.top/)
38
+
39
+
40
+
41
+ ## Inference
42
+ 1. Put the pretrained model to `$ROOT/model_pretrained/`.
43
+ 2. Put the distorted images in `$ROOT/distorted/`.
44
+ 3. Run the script and the rectified images are saved in `$ROOT/rectified/` by default.
45
+ ```
46
+ python inference.py
47
+ ```
48
+
49
+ ## Evaluation
50
+ - ***Image Metrics:*** We propose the metrics MS-SSIM-M and LD-M, different from that for [DocUNet Benchmark](https://www3.cs.stonybrook.edu/~cvl/docunet.html) dataset. We use Matlab 2019a. Please compare the scores according to your Matlab version. We provide our Matlab interface file at ```$ROOT/ssim_ld_eval.m```.
51
+ - ***OCR Metrics:*** The index of 70 document (70 images) in [UDIR test set](https://drive.google.com/drive/folders/15rknyt7XE2k6jrxaTc_n5dzXIdCukJLh?usp=share_link) used for our OCR evaluation is provided in ```$ROOT/ocr_eval.py```.
52
+ The version of pytesseract is 0.3.8, and the version of [Tesseract](https://digi.bib.uni-mannheim.de/tesseract/) in Windows is recent 5.0.1.20220118.
53
+ Note that in different operating systems, the calculated performance has slight differences.
54
+
55
+ ## Citation
56
+
57
+ If you find this code useful for your research, please use the following BibTeX entry.
58
+
59
+ ```
60
+ @inproceedings{feng2021doctr,
61
+ title={DocTr: Document Image Transformer for Geometric Unwarping and Illumination Correction},
62
+ author={Feng, Hao and Wang, Yuechen and Zhou, Wengang and Deng, Jiajun and Li, Houqiang},
63
+ booktitle={Proceedings of the 29th ACM International Conference on Multimedia},
64
+ pages={273--281},
65
+ year={2021}
66
+ }
67
+ ```
68
+
69
+ ```
70
+ @article{feng2023doctrp,
71
+ title={Deep Unrestricted Document Image Rectification},
72
+ author={Feng, Hao and Liu, Shaokai and Deng, Jiajun and Zhou, Wengang and Li, Houqiang},
73
+ journal={IEEE Transactions on Multimedia},
74
+ year={2023}
75
+ }
76
+ ```
77
+
78
+ ## Contact
79
+ For commercial usage, please contact Professor Wengang Zhou ([zhwg@ustc.edu.cn](zhwg@ustc.edu.cn)) and Hao Feng ([haof@mail.ustc.edu.cn](haof@mail.ustc.edu.cn)).
models/DocTr-Plus/__init__.py ADDED
File without changes
models/DocTr-Plus/__pycache__/GeoTr.cpython-38.pyc ADDED
Binary file (23.5 kB). View file
 
models/DocTr-Plus/__pycache__/GeoTr.cpython-39.pyc ADDED
Binary file (23.5 kB). View file
 
models/DocTr-Plus/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (154 Bytes). View file
 
models/DocTr-Plus/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (154 Bytes). View file
 
models/DocTr-Plus/__pycache__/extractor.cpython-38.pyc ADDED
Binary file (3.13 kB). View file
 
models/DocTr-Plus/__pycache__/extractor.cpython-39.pyc ADDED
Binary file (3.12 kB). View file
 
models/DocTr-Plus/__pycache__/inference.cpython-38.pyc ADDED
Binary file (1.61 kB). View file
 
models/DocTr-Plus/__pycache__/inference.cpython-39.pyc ADDED
Binary file (1.61 kB). View file
 
models/DocTr-Plus/__pycache__/position_encoding.cpython-38.pyc ADDED
Binary file (4.38 kB). View file
 
models/DocTr-Plus/__pycache__/position_encoding.cpython-39.pyc ADDED
Binary file (4.34 kB). View file
 
models/DocTr-Plus/evalUnwarp.m ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ function [ms, ld] = evalUnwarp(A, ref, ref_msk)
2
+
3
+ x = A;
4
+ y = ref;
5
+ z = ref_msk;
6
+
7
+ im1=imresize(imfilter(x,fspecial('gaussian',7,1.),'same','replicate'),0.5,'bicubic');
8
+ im2=imresize(imfilter(y,fspecial('gaussian',7,1.),'same','replicate'),0.5,'bicubic');
9
+ im3=imresize(imfilter(z,fspecial('gaussian',7,1.),'same','replicate'),0.5,'bicubic');
10
+
11
+ im1=im2double(im1);
12
+ im2=im2double(im2);
13
+ im3=im2double(im3);
14
+
15
+ cellsize=3;
16
+ gridspacing=1;
17
+
18
+ sift1 = mexDenseSIFT(im1,cellsize,gridspacing);
19
+ sift2 = mexDenseSIFT(im2,cellsize,gridspacing);
20
+
21
+ SIFTflowpara.alpha=2*255;
22
+ SIFTflowpara.d=40*255;
23
+ SIFTflowpara.gamma=0.005*255;
24
+ SIFTflowpara.nlevels=4;
25
+ SIFTflowpara.wsize=2;
26
+ SIFTflowpara.topwsize=10;
27
+ SIFTflowpara.nTopIterations = 60;
28
+ SIFTflowpara.nIterations= 30;
29
+
30
+
31
+ [vx,vy,~]=SIFTflowc2f(sift1,sift2,SIFTflowpara);
32
+
33
+ d = sqrt(vx.^2 + vy.^2);
34
+ mskk = (im3==0);
35
+ ld = mean(d(~mskk));
36
+
37
+ wt = [0.0448 0.2856 0.3001 0.2363 0.1333];
38
+ ss = zeros(5, 1);
39
+ for s = 1 : 5
40
+ ss(s) = ssim(x, z);
41
+ x = impyramid(x, 'reduce');
42
+ z = impyramid(z, 'reduce');
43
+ end
44
+ ms = wt * ss;
45
+
46
+ end
models/DocTr-Plus/extractor.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class ResidualBlock(nn.Module):
7
+ def __init__(self, in_planes, planes, norm_fn="group", stride=1):
8
+ super(ResidualBlock, self).__init__()
9
+
10
+ self.conv1 = nn.Conv2d(
11
+ in_planes, planes, kernel_size=3, padding=1, stride=stride
12
+ )
13
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
14
+ self.relu = nn.ReLU(inplace=True)
15
+
16
+ num_groups = planes // 8
17
+
18
+ if norm_fn == "group":
19
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
20
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
21
+ if not stride == 1:
22
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
23
+
24
+ elif norm_fn == "batch":
25
+ self.norm1 = nn.BatchNorm2d(planes)
26
+ self.norm2 = nn.BatchNorm2d(planes)
27
+ if not stride == 1:
28
+ self.norm3 = nn.BatchNorm2d(planes)
29
+
30
+ elif norm_fn == "instance":
31
+ self.norm1 = nn.InstanceNorm2d(planes)
32
+ self.norm2 = nn.InstanceNorm2d(planes)
33
+ if not stride == 1:
34
+ self.norm3 = nn.InstanceNorm2d(planes)
35
+
36
+ elif norm_fn == "none":
37
+ self.norm1 = nn.Sequential()
38
+ self.norm2 = nn.Sequential()
39
+ if not stride == 1:
40
+ self.norm3 = nn.Sequential()
41
+
42
+ if stride == 1:
43
+ self.downsample = None
44
+
45
+ else:
46
+ self.downsample = nn.Sequential(
47
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
48
+ )
49
+
50
+ def forward(self, x):
51
+ y = x
52
+ y = self.relu(self.norm1(self.conv1(y)))
53
+ y = self.relu(self.norm2(self.conv2(y)))
54
+
55
+ if self.downsample is not None:
56
+ x = self.downsample(x)
57
+
58
+ return self.relu(x + y)
59
+
60
+
61
+ class BasicEncoder(nn.Module):
62
+ def __init__(self, output_dim=128, norm_fn="batch"):
63
+ super(BasicEncoder, self).__init__()
64
+ self.norm_fn = norm_fn
65
+
66
+ if self.norm_fn == "group":
67
+ self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
68
+
69
+ elif self.norm_fn == "batch":
70
+ self.norm1 = nn.BatchNorm2d(64)
71
+
72
+ elif self.norm_fn == "instance":
73
+ self.norm1 = nn.InstanceNorm2d(64)
74
+
75
+ elif self.norm_fn == "none":
76
+ self.norm1 = nn.Sequential()
77
+
78
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
79
+ self.relu1 = nn.ReLU(inplace=True)
80
+
81
+ self.in_planes = 64
82
+ self.layer1 = self._make_layer(64, stride=1)
83
+ self.layer2 = self._make_layer(128, stride=2)
84
+ self.layer3 = self._make_layer(192, stride=2)
85
+
86
+ # output convolution
87
+ self.conv2 = nn.Conv2d(192, output_dim, kernel_size=1)
88
+
89
+ for m in self.modules():
90
+ if isinstance(m, nn.Conv2d):
91
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
92
+ elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
93
+ if m.weight is not None:
94
+ nn.init.constant_(m.weight, 1)
95
+ if m.bias is not None:
96
+ nn.init.constant_(m.bias, 0)
97
+
98
+ def _make_layer(self, dim, stride=1):
99
+ layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
100
+ layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
101
+ layers = (layer1, layer2)
102
+
103
+ self.in_planes = dim
104
+ return nn.Sequential(*layers)
105
+
106
+ def forward(self, x):
107
+ x = self.conv1(x)
108
+ x = self.norm1(x)
109
+ x = self.relu1(x)
110
+
111
+ x = self.layer1(x)
112
+ x = self.layer2(x)
113
+ x = self.layer3(x)
114
+
115
+ x = self.conv2(x)
116
+
117
+ return x
models/DocTr-Plus/inference.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
5
+
6
+ import argparse
7
+ import glob
8
+ import os
9
+ import warnings
10
+
11
+ import cv2
12
+ import numpy as np
13
+ import skimage.io as io
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from PIL import Image
18
+
19
+ from .GeoTr import U2NETP, GeoTr
20
+
21
+ warnings.filterwarnings("ignore")
22
+
23
+
24
+ class GeoTrP(nn.Module):
25
+ def __init__(self):
26
+ super(GeoTrP, self).__init__()
27
+ self.GeoTr = GeoTr()
28
+
29
+ def forward(self, x):
30
+ bm = self.GeoTr(x) # [0]
31
+ bm = 2 * (bm / 288) - 1
32
+
33
+ bm = (bm + 1) / 2 * 2560
34
+
35
+ bm = F.interpolate(bm, size=(2560, 2560), mode="bilinear", align_corners=True)
36
+
37
+ return bm
38
+
39
+
40
+ def reload_model(model, path=""):
41
+ if not bool(path):
42
+ return model
43
+ else:
44
+ model_dict = model.state_dict()
45
+ pretrained_dict = torch.load(path, map_location="cuda:0")
46
+ print(len(pretrained_dict.keys()))
47
+ print(len(pretrained_dict.keys()))
48
+ model_dict.update(pretrained_dict)
49
+ model.load_state_dict(model_dict)
50
+
51
+ return model
models/DocTr-Plus/position_encoding.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ Various positional encodings for the transformer.
4
+ """
5
+ import math
6
+ from typing import List, Optional
7
+
8
+ import torch
9
+ from torch import Tensor, nn
10
+
11
+
12
+ class NestedTensor(object):
13
+ def __init__(self, tensors, mask: Optional[Tensor]):
14
+ self.tensors = tensors
15
+ self.mask = mask
16
+
17
+ def to(self, device):
18
+ # type: (Device) -> NestedTensor # noqa
19
+ cast_tensor = self.tensors.to(device)
20
+ mask = self.mask
21
+ if mask is not None:
22
+ assert mask is not None
23
+ cast_mask = mask.to(device)
24
+ else:
25
+ cast_mask = None
26
+ return NestedTensor(cast_tensor, cast_mask)
27
+
28
+ def decompose(self):
29
+ return self.tensors, self.mask
30
+
31
+ def __repr__(self):
32
+ return str(self.tensors)
33
+
34
+
35
+ class PositionEmbeddingSine(nn.Module):
36
+ """
37
+ This is a more standard version of the position embedding, very similar to the one
38
+ used by the Attention is all you need paper, generalized to work on images.
39
+ """
40
+
41
+ def __init__(
42
+ self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
43
+ ):
44
+ super().__init__()
45
+ self.num_pos_feats = num_pos_feats
46
+ self.temperature = temperature
47
+ self.normalize = normalize
48
+ if scale is not None and normalize is False:
49
+ raise ValueError("normalize should be True if scale is passed")
50
+ if scale is None:
51
+ scale = 2 * math.pi
52
+ self.scale = scale
53
+
54
+ def forward(self, mask):
55
+ assert mask is not None
56
+ y_embed = mask.cumsum(1, dtype=torch.float32)
57
+ x_embed = mask.cumsum(2, dtype=torch.float32)
58
+ if self.normalize:
59
+ eps = 1e-6
60
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
61
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
62
+
63
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32).cuda()
64
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
65
+
66
+ pos_x = x_embed[:, :, :, None] / dim_t
67
+ pos_y = y_embed[:, :, :, None] / dim_t
68
+ pos_x = torch.stack(
69
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
70
+ ).flatten(3)
71
+ pos_y = torch.stack(
72
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
73
+ ).flatten(3)
74
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
75
+ # print(pos.shape)
76
+ return pos
77
+
78
+
79
+ class PositionEmbeddingLearned(nn.Module):
80
+ """
81
+ Absolute pos embedding, learned.
82
+ """
83
+
84
+ def __init__(self, num_pos_feats=256):
85
+ super().__init__()
86
+ self.row_embed = nn.Embedding(50, num_pos_feats)
87
+ self.col_embed = nn.Embedding(50, num_pos_feats)
88
+ self.reset_parameters()
89
+
90
+ def reset_parameters(self):
91
+ nn.init.uniform_(self.row_embed.weight)
92
+ nn.init.uniform_(self.col_embed.weight)
93
+
94
+ def forward(self, tensor_list: NestedTensor):
95
+ x = tensor_list.tensors
96
+ h, w = x.shape[-2:]
97
+ i = torch.arange(w, device=x.device)
98
+ j = torch.arange(h, device=x.device)
99
+ x_emb = self.col_embed(i)
100
+ y_emb = self.row_embed(j)
101
+ pos = (
102
+ torch.cat(
103
+ [
104
+ x_emb.unsqueeze(0).repeat(h, 1, 1),
105
+ y_emb.unsqueeze(1).repeat(1, w, 1),
106
+ ],
107
+ dim=-1,
108
+ )
109
+ .permute(2, 0, 1)
110
+ .unsqueeze(0)
111
+ .repeat(x.shape[0], 1, 1, 1)
112
+ )
113
+ return pos
114
+
115
+
116
+ def build_position_encoding(hidden_dim=512, position_embedding="sine"):
117
+ N_steps = hidden_dim // 2
118
+ if position_embedding in ("v2", "sine"):
119
+ position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
120
+ elif position_embedding in ("v3", "learned"):
121
+ position_embedding = PositionEmbeddingLearned(N_steps)
122
+ else:
123
+ raise ValueError(f"not supported {position_embedding}")
124
+
125
+ return position_embedding
models/DocTr-Plus/pyimagesearch/__init__.py ADDED
File without changes
models/DocTr-Plus/pyimagesearch/transform.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+
5
+ def order_points(pts):
6
+ # initialize a list of coordinates that will be ordered
7
+ # such that the first entry in the list is the top-left,
8
+ # the second entry is the top-right, the third is the
9
+ # bottom-right, and the fourth is the bottom-left
10
+ rect = np.zeros((4, 2), dtype="float32")
11
+
12
+ # the top-left point will have the smallest sum, whereas
13
+ # the bottom-right point will have the largest sum
14
+ s = pts.sum(axis=1)
15
+ rect[0] = pts[np.argmin(s)]
16
+ rect[2] = pts[np.argmax(s)]
17
+
18
+ # now, compute the difference between the points, the
19
+ # top-right point will have the smallest difference,
20
+ # whereas the bottom-left will have the largest difference
21
+ diff = np.diff(pts, axis=1)
22
+ rect[1] = pts[np.argmin(diff)]
23
+ rect[3] = pts[np.argmax(diff)]
24
+
25
+ # return the ordered coordinates
26
+ return rect
27
+
28
+
29
+ def four_point_transform(image, pts):
30
+ # obtain a consistent order of the points and unpack them
31
+ # individually
32
+ rect = order_points(pts)
33
+ (tl, tr, br, bl) = rect
34
+
35
+ # compute the width of the new image, which will be the
36
+ # maximum distance between bottom-right and bottom-left
37
+ # x-coordiates or the top-right and top-left x-coordinates
38
+ widthA = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2))
39
+ widthB = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2))
40
+ maxWidth = max(int(widthA), int(widthB))
41
+
42
+ # compute the height of the new image, which will be the
43
+ # maximum distance between the top-right and bottom-right
44
+ # y-coordinates or the top-left and bottom-left y-coordinates
45
+ heightA = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2))
46
+ heightB = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2))
47
+ maxHeight = max(int(heightA), int(heightB))
48
+
49
+ # now that we have the dimensions of the new image, construct
50
+ # the set of destination points to obtain a "birds eye view",
51
+ # (i.e. top-down view) of the image, again specifying points
52
+ # in the top-left, top-right, bottom-right, and bottom-left
53
+ # order
54
+ dst = np.array(
55
+ [[0, 0], [maxWidth - 1, 0], [maxWidth - 1, maxHeight - 1], [0, maxHeight - 1]],
56
+ dtype="float32",
57
+ )
58
+
59
+ # compute the perspective transform matrix and then apply it
60
+ M = cv2.getPerspectiveTransform(rect, dst)
61
+ warped = cv2.warpPerspective(image, M, (maxWidth, maxHeight))
62
+
63
+ # return the warped image
64
+ return warped
models/DocTr-Plus/requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ numpy==1.19.0
2
+ opencv_python==4.2.0.34
3
+ Pillow==9.4.0
4
+ scikit_image==0.17.2
5
+ skimage==0.0
6
+ thop==0.1.1.post2209072238
7
+ torch==1.5.1+cu101
models/DocTr-Plus/ssimm_ldm_eval.m ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ path_rec = 'xxx'; % rectified image path
2
+ path_scan = './UDIR/gt/'; % scan image path
3
+
4
+ tarea=598400;
5
+ ms=0;
6
+ ld=0;
7
+
8
+ for i=1:195
9
+ path_rec_1 = sprintf("%s%d%s", path_rec, i, '.png'); % rectified image path
10
+ path_scan_new = sprintf("%s%d%s", path_scan, i, '.png'); % corresponding scan image path
11
+
12
+ % imread and rgb2gray
13
+ A1 = imread(path_rec_1);
14
+ ref = imread(path_scan_new);
15
+ A1 = rgb2gray(A1);
16
+ ref = rgb2gray(ref);
17
+
18
+ % resize
19
+ b = sqrt(tarea/size(ref,1)/size(ref,2));
20
+ ref = imresize(ref,b);
21
+ ref_msk = ref;
22
+ A1 = imresize(A1,[size(ref,1),size(ref,2)]);
23
+
24
+ % mask the gt image
25
+ m1 = A1 == 0;
26
+ ref_msk(m1) = 0;
27
+
28
+ % calculate
29
+ [ms_1,ld_1] = evalUnwarp(A1, ref, ref_msk);
30
+ ms = ms + ms_1;
31
+ ld = ld + ld_1;
32
+
33
+ end
34
+
35
+ ms_m = ms / 195
36
+ ld_m = ld / 195
models/Document-Image-Unwarping-pytorch ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 92b29172b981d132f7b31e767505524f8cc7af7a