diff --git a/Data/test_samples/img/austin2_01179_258_289.png b/Data/test_samples/img/austin2_01179_258_289.png new file mode 100644 index 0000000000000000000000000000000000000000..604d58df4ddabac52abfbd4e3ea7ff90392b55bd Binary files /dev/null and b/Data/test_samples/img/austin2_01179_258_289.png differ diff --git a/Data/test_samples/img/austin3_00294_332_293.png b/Data/test_samples/img/austin3_00294_332_293.png new file mode 100644 index 0000000000000000000000000000000000000000..c5e089254a29720313baab44de2cbd2db97c5e9f Binary files /dev/null and b/Data/test_samples/img/austin3_00294_332_293.png differ diff --git a/Data/test_samples/img/austin3_00497_432_361.png b/Data/test_samples/img/austin3_00497_432_361.png new file mode 100644 index 0000000000000000000000000000000000000000..1d8c9363a30c0cb379ed22b2ac24252c4fcc49dd Binary files /dev/null and b/Data/test_samples/img/austin3_00497_432_361.png differ diff --git a/Data/test_samples/img/austin3_01452_334_335.png b/Data/test_samples/img/austin3_01452_334_335.png new file mode 100644 index 0000000000000000000000000000000000000000..b7ddfbfde123bd3a15047bb6700e06e7d5b8a143 Binary files /dev/null and b/Data/test_samples/img/austin3_01452_334_335.png differ diff --git a/Data/test_samples/img/austin4_00061_481_276.png b/Data/test_samples/img/austin4_00061_481_276.png new file mode 100644 index 0000000000000000000000000000000000000000..41148c854cfc7bcfb7fd88623488c39cbe8af6e1 Binary files /dev/null and b/Data/test_samples/img/austin4_00061_481_276.png differ diff --git a/Data/test_samples/img/austin4_00133_264_308.png b/Data/test_samples/img/austin4_00133_264_308.png new file mode 100644 index 0000000000000000000000000000000000000000..57e7e40e473f8a31ff7cb58352071b0ab1b0029d Binary files /dev/null and b/Data/test_samples/img/austin4_00133_264_308.png differ diff --git a/Data/test_samples/img/austin4_00163_356_367.png b/Data/test_samples/img/austin4_00163_356_367.png new file mode 100644 index 0000000000000000000000000000000000000000..a913dd02436e6a22496f9169a5ccd009ba8a1d31 Binary files /dev/null and b/Data/test_samples/img/austin4_00163_356_367.png differ diff --git a/Data/test_samples/img/austin4_00205_468_478.png b/Data/test_samples/img/austin4_00205_468_478.png new file mode 100644 index 0000000000000000000000000000000000000000..752e4d7cfe102c9e8692bffb795810a79a76abb1 Binary files /dev/null and b/Data/test_samples/img/austin4_00205_468_478.png differ diff --git a/Data/test_samples/img/austin7_00079_495_463.png b/Data/test_samples/img/austin7_00079_495_463.png new file mode 100644 index 0000000000000000000000000000000000000000..7cf747568f7b2fe8d386749e013de59514d426f4 Binary files /dev/null and b/Data/test_samples/img/austin7_00079_495_463.png differ diff --git a/Data/test_samples/img/austin7_00100_300_263.png b/Data/test_samples/img/austin7_00100_300_263.png new file mode 100644 index 0000000000000000000000000000000000000000..038df3f3a4b15d99020efdbfa84f7b3eeeb5c962 Binary files /dev/null and b/Data/test_samples/img/austin7_00100_300_263.png differ diff --git a/Data/test_samples/img/austin7_00166_430_319.png b/Data/test_samples/img/austin7_00166_430_319.png new file mode 100644 index 0000000000000000000000000000000000000000..bc64351619024ab0a08a49e58f04c3287227f3f9 Binary files /dev/null and b/Data/test_samples/img/austin7_00166_430_319.png differ diff --git a/Data/test_samples/label/austin2_01179_258_289.png b/Data/test_samples/label/austin2_01179_258_289.png new file mode 100644 index 0000000000000000000000000000000000000000..502f36179390c0fa7be584d4fec6270676d1cf3d Binary files /dev/null and b/Data/test_samples/label/austin2_01179_258_289.png differ diff --git a/Data/test_samples/label/austin3_00294_332_293.png b/Data/test_samples/label/austin3_00294_332_293.png new file mode 100644 index 0000000000000000000000000000000000000000..d484925c65b18a546e174d61ec82135b277f044a Binary files /dev/null and b/Data/test_samples/label/austin3_00294_332_293.png differ diff --git a/Data/test_samples/label/austin3_00497_432_361.png b/Data/test_samples/label/austin3_00497_432_361.png new file mode 100644 index 0000000000000000000000000000000000000000..f973edb64193ec01fd3f1d9eb24aa4f0d228df38 Binary files /dev/null and b/Data/test_samples/label/austin3_00497_432_361.png differ diff --git a/Data/test_samples/label/austin3_01452_334_335.png b/Data/test_samples/label/austin3_01452_334_335.png new file mode 100644 index 0000000000000000000000000000000000000000..c3ae963a85d644c28309f9eab16997f6b8f50881 Binary files /dev/null and b/Data/test_samples/label/austin3_01452_334_335.png differ diff --git a/Data/test_samples/label/austin4_00061_481_276.png b/Data/test_samples/label/austin4_00061_481_276.png new file mode 100644 index 0000000000000000000000000000000000000000..5f9e1504571104c2e87b9ef2a4b3a07e23a1fc7e Binary files /dev/null and b/Data/test_samples/label/austin4_00061_481_276.png differ diff --git a/Data/test_samples/label/austin4_00133_264_308.png b/Data/test_samples/label/austin4_00133_264_308.png new file mode 100644 index 0000000000000000000000000000000000000000..7ffc1813cdc925fc7236fcf1bc1f2be1cf4f53e9 Binary files /dev/null and b/Data/test_samples/label/austin4_00133_264_308.png differ diff --git a/Data/test_samples/label/austin4_00163_356_367.png b/Data/test_samples/label/austin4_00163_356_367.png new file mode 100644 index 0000000000000000000000000000000000000000..cf6b29fd2a0c027b07d3e68ea9706c0044f187b3 Binary files /dev/null and b/Data/test_samples/label/austin4_00163_356_367.png differ diff --git a/Data/test_samples/label/austin4_00205_468_478.png b/Data/test_samples/label/austin4_00205_468_478.png new file mode 100644 index 0000000000000000000000000000000000000000..6a62461711fd4e4d4cbd7215e62bac756e08fc76 Binary files /dev/null and b/Data/test_samples/label/austin4_00205_468_478.png differ diff --git a/Data/test_samples/label/austin7_00079_495_463.png b/Data/test_samples/label/austin7_00079_495_463.png new file mode 100644 index 0000000000000000000000000000000000000000..036313128207db4cbe4fcabce901bdb2b20b5e1d Binary files /dev/null and b/Data/test_samples/label/austin7_00079_495_463.png differ diff --git a/Data/test_samples/label/austin7_00100_300_263.png b/Data/test_samples/label/austin7_00100_300_263.png new file mode 100644 index 0000000000000000000000000000000000000000..a149c4f61bb4b3a31d5ad86a412b428ec536c7fd Binary files /dev/null and b/Data/test_samples/label/austin7_00100_300_263.png differ diff --git a/Data/test_samples/label/austin7_00166_430_319.png b/Data/test_samples/label/austin7_00166_430_319.png new file mode 100644 index 0000000000000000000000000000000000000000..22397b71540422d9890de6b3c43c6fc3cd50eb0c Binary files /dev/null and b/Data/test_samples/label/austin7_00166_430_319.png differ diff --git a/Data/train_samples/img/austin2_01179_258_289.png b/Data/train_samples/img/austin2_01179_258_289.png new file mode 100644 index 0000000000000000000000000000000000000000..604d58df4ddabac52abfbd4e3ea7ff90392b55bd Binary files /dev/null and b/Data/train_samples/img/austin2_01179_258_289.png differ diff --git a/Data/train_samples/img/austin3_00294_332_293.png b/Data/train_samples/img/austin3_00294_332_293.png new file mode 100644 index 0000000000000000000000000000000000000000..c5e089254a29720313baab44de2cbd2db97c5e9f Binary files /dev/null and b/Data/train_samples/img/austin3_00294_332_293.png differ diff --git a/Data/train_samples/img/austin3_00497_432_361.png b/Data/train_samples/img/austin3_00497_432_361.png new file mode 100644 index 0000000000000000000000000000000000000000..1d8c9363a30c0cb379ed22b2ac24252c4fcc49dd Binary files /dev/null and b/Data/train_samples/img/austin3_00497_432_361.png differ diff --git a/Data/train_samples/img/austin3_01452_334_335.png b/Data/train_samples/img/austin3_01452_334_335.png new file mode 100644 index 0000000000000000000000000000000000000000..b7ddfbfde123bd3a15047bb6700e06e7d5b8a143 Binary files /dev/null and b/Data/train_samples/img/austin3_01452_334_335.png differ diff --git a/Data/train_samples/img/austin4_00061_481_276.png b/Data/train_samples/img/austin4_00061_481_276.png new file mode 100644 index 0000000000000000000000000000000000000000..41148c854cfc7bcfb7fd88623488c39cbe8af6e1 Binary files /dev/null and b/Data/train_samples/img/austin4_00061_481_276.png differ diff --git a/Data/train_samples/img/austin4_00133_264_308.png b/Data/train_samples/img/austin4_00133_264_308.png new file mode 100644 index 0000000000000000000000000000000000000000..57e7e40e473f8a31ff7cb58352071b0ab1b0029d Binary files /dev/null and b/Data/train_samples/img/austin4_00133_264_308.png differ diff --git a/Data/train_samples/img/austin4_00163_356_367.png b/Data/train_samples/img/austin4_00163_356_367.png new file mode 100644 index 0000000000000000000000000000000000000000..a913dd02436e6a22496f9169a5ccd009ba8a1d31 Binary files /dev/null and b/Data/train_samples/img/austin4_00163_356_367.png differ diff --git a/Data/train_samples/img/austin4_00205_468_478.png b/Data/train_samples/img/austin4_00205_468_478.png new file mode 100644 index 0000000000000000000000000000000000000000..752e4d7cfe102c9e8692bffb795810a79a76abb1 Binary files /dev/null and b/Data/train_samples/img/austin4_00205_468_478.png differ diff --git a/Data/train_samples/img/austin7_00079_495_463.png b/Data/train_samples/img/austin7_00079_495_463.png new file mode 100644 index 0000000000000000000000000000000000000000..7cf747568f7b2fe8d386749e013de59514d426f4 Binary files /dev/null and b/Data/train_samples/img/austin7_00079_495_463.png differ diff --git a/Data/train_samples/img/austin7_00100_300_263.png b/Data/train_samples/img/austin7_00100_300_263.png new file mode 100644 index 0000000000000000000000000000000000000000..038df3f3a4b15d99020efdbfa84f7b3eeeb5c962 Binary files /dev/null and b/Data/train_samples/img/austin7_00100_300_263.png differ diff --git a/Data/train_samples/img/austin7_00166_430_319.png b/Data/train_samples/img/austin7_00166_430_319.png new file mode 100644 index 0000000000000000000000000000000000000000..bc64351619024ab0a08a49e58f04c3287227f3f9 Binary files /dev/null and b/Data/train_samples/img/austin7_00166_430_319.png differ diff --git a/Data/train_samples/label/austin2_01179_258_289.png b/Data/train_samples/label/austin2_01179_258_289.png new file mode 100644 index 0000000000000000000000000000000000000000..502f36179390c0fa7be584d4fec6270676d1cf3d Binary files /dev/null and b/Data/train_samples/label/austin2_01179_258_289.png differ diff --git a/Data/train_samples/label/austin3_00294_332_293.png b/Data/train_samples/label/austin3_00294_332_293.png new file mode 100644 index 0000000000000000000000000000000000000000..d484925c65b18a546e174d61ec82135b277f044a Binary files /dev/null and b/Data/train_samples/label/austin3_00294_332_293.png differ diff --git a/Data/train_samples/label/austin3_00497_432_361.png b/Data/train_samples/label/austin3_00497_432_361.png new file mode 100644 index 0000000000000000000000000000000000000000..f973edb64193ec01fd3f1d9eb24aa4f0d228df38 Binary files /dev/null and b/Data/train_samples/label/austin3_00497_432_361.png differ diff --git a/Data/train_samples/label/austin3_01452_334_335.png b/Data/train_samples/label/austin3_01452_334_335.png new file mode 100644 index 0000000000000000000000000000000000000000..c3ae963a85d644c28309f9eab16997f6b8f50881 Binary files /dev/null and b/Data/train_samples/label/austin3_01452_334_335.png differ diff --git a/Data/train_samples/label/austin4_00061_481_276.png b/Data/train_samples/label/austin4_00061_481_276.png new file mode 100644 index 0000000000000000000000000000000000000000..5f9e1504571104c2e87b9ef2a4b3a07e23a1fc7e Binary files /dev/null and b/Data/train_samples/label/austin4_00061_481_276.png differ diff --git a/Data/train_samples/label/austin4_00133_264_308.png b/Data/train_samples/label/austin4_00133_264_308.png new file mode 100644 index 0000000000000000000000000000000000000000..7ffc1813cdc925fc7236fcf1bc1f2be1cf4f53e9 Binary files /dev/null and b/Data/train_samples/label/austin4_00133_264_308.png differ diff --git a/Data/train_samples/label/austin4_00163_356_367.png b/Data/train_samples/label/austin4_00163_356_367.png new file mode 100644 index 0000000000000000000000000000000000000000..cf6b29fd2a0c027b07d3e68ea9706c0044f187b3 Binary files /dev/null and b/Data/train_samples/label/austin4_00163_356_367.png differ diff --git a/Data/train_samples/label/austin4_00205_468_478.png b/Data/train_samples/label/austin4_00205_468_478.png new file mode 100644 index 0000000000000000000000000000000000000000..6a62461711fd4e4d4cbd7215e62bac756e08fc76 Binary files /dev/null and b/Data/train_samples/label/austin4_00205_468_478.png differ diff --git a/Data/train_samples/label/austin7_00079_495_463.png b/Data/train_samples/label/austin7_00079_495_463.png new file mode 100644 index 0000000000000000000000000000000000000000..036313128207db4cbe4fcabce901bdb2b20b5e1d Binary files /dev/null and b/Data/train_samples/label/austin7_00079_495_463.png differ diff --git a/Data/train_samples/label/austin7_00100_300_263.png b/Data/train_samples/label/austin7_00100_300_263.png new file mode 100644 index 0000000000000000000000000000000000000000..a149c4f61bb4b3a31d5ad86a412b428ec536c7fd Binary files /dev/null and b/Data/train_samples/label/austin7_00100_300_263.png differ diff --git a/Data/train_samples/label/austin7_00166_430_319.png b/Data/train_samples/label/austin7_00166_430_319.png new file mode 100644 index 0000000000000000000000000000000000000000..22397b71540422d9890de6b3c43c6fc3cd50eb0c Binary files /dev/null and b/Data/train_samples/label/austin7_00166_430_319.png differ diff --git a/Data/val_samples/img/austin2_01179_258_289.png b/Data/val_samples/img/austin2_01179_258_289.png new file mode 100644 index 0000000000000000000000000000000000000000..604d58df4ddabac52abfbd4e3ea7ff90392b55bd Binary files /dev/null and b/Data/val_samples/img/austin2_01179_258_289.png differ diff --git a/Data/val_samples/img/austin3_00294_332_293.png b/Data/val_samples/img/austin3_00294_332_293.png new file mode 100644 index 0000000000000000000000000000000000000000..c5e089254a29720313baab44de2cbd2db97c5e9f Binary files /dev/null and b/Data/val_samples/img/austin3_00294_332_293.png differ diff --git a/Data/val_samples/img/austin3_00497_432_361.png b/Data/val_samples/img/austin3_00497_432_361.png new file mode 100644 index 0000000000000000000000000000000000000000..1d8c9363a30c0cb379ed22b2ac24252c4fcc49dd Binary files /dev/null and b/Data/val_samples/img/austin3_00497_432_361.png differ diff --git a/Data/val_samples/img/austin3_01452_334_335.png b/Data/val_samples/img/austin3_01452_334_335.png new file mode 100644 index 0000000000000000000000000000000000000000..b7ddfbfde123bd3a15047bb6700e06e7d5b8a143 Binary files /dev/null and b/Data/val_samples/img/austin3_01452_334_335.png differ diff --git a/Data/val_samples/img/austin4_00061_481_276.png b/Data/val_samples/img/austin4_00061_481_276.png new file mode 100644 index 0000000000000000000000000000000000000000..41148c854cfc7bcfb7fd88623488c39cbe8af6e1 Binary files /dev/null and b/Data/val_samples/img/austin4_00061_481_276.png differ diff --git a/Data/val_samples/img/austin4_00133_264_308.png b/Data/val_samples/img/austin4_00133_264_308.png new file mode 100644 index 0000000000000000000000000000000000000000..57e7e40e473f8a31ff7cb58352071b0ab1b0029d Binary files /dev/null and b/Data/val_samples/img/austin4_00133_264_308.png differ diff --git a/Data/val_samples/img/austin4_00163_356_367.png b/Data/val_samples/img/austin4_00163_356_367.png new file mode 100644 index 0000000000000000000000000000000000000000..a913dd02436e6a22496f9169a5ccd009ba8a1d31 Binary files /dev/null and b/Data/val_samples/img/austin4_00163_356_367.png differ diff --git a/Data/val_samples/img/austin4_00205_468_478.png b/Data/val_samples/img/austin4_00205_468_478.png new file mode 100644 index 0000000000000000000000000000000000000000..752e4d7cfe102c9e8692bffb795810a79a76abb1 Binary files /dev/null and b/Data/val_samples/img/austin4_00205_468_478.png differ diff --git a/Data/val_samples/img/austin7_00079_495_463.png b/Data/val_samples/img/austin7_00079_495_463.png new file mode 100644 index 0000000000000000000000000000000000000000..7cf747568f7b2fe8d386749e013de59514d426f4 Binary files /dev/null and b/Data/val_samples/img/austin7_00079_495_463.png differ diff --git a/Data/val_samples/img/austin7_00100_300_263.png b/Data/val_samples/img/austin7_00100_300_263.png new file mode 100644 index 0000000000000000000000000000000000000000..038df3f3a4b15d99020efdbfa84f7b3eeeb5c962 Binary files /dev/null and b/Data/val_samples/img/austin7_00100_300_263.png differ diff --git a/Data/val_samples/img/austin7_00166_430_319.png b/Data/val_samples/img/austin7_00166_430_319.png new file mode 100644 index 0000000000000000000000000000000000000000..bc64351619024ab0a08a49e58f04c3287227f3f9 Binary files /dev/null and b/Data/val_samples/img/austin7_00166_430_319.png differ diff --git a/Data/val_samples/label/austin2_01179_258_289.png b/Data/val_samples/label/austin2_01179_258_289.png new file mode 100644 index 0000000000000000000000000000000000000000..502f36179390c0fa7be584d4fec6270676d1cf3d Binary files /dev/null and b/Data/val_samples/label/austin2_01179_258_289.png differ diff --git a/Data/val_samples/label/austin3_00294_332_293.png b/Data/val_samples/label/austin3_00294_332_293.png new file mode 100644 index 0000000000000000000000000000000000000000..d484925c65b18a546e174d61ec82135b277f044a Binary files /dev/null and b/Data/val_samples/label/austin3_00294_332_293.png differ diff --git a/Data/val_samples/label/austin3_00497_432_361.png b/Data/val_samples/label/austin3_00497_432_361.png new file mode 100644 index 0000000000000000000000000000000000000000..f973edb64193ec01fd3f1d9eb24aa4f0d228df38 Binary files /dev/null and b/Data/val_samples/label/austin3_00497_432_361.png differ diff --git a/Data/val_samples/label/austin3_01452_334_335.png b/Data/val_samples/label/austin3_01452_334_335.png new file mode 100644 index 0000000000000000000000000000000000000000..c3ae963a85d644c28309f9eab16997f6b8f50881 Binary files /dev/null and b/Data/val_samples/label/austin3_01452_334_335.png differ diff --git a/Data/val_samples/label/austin4_00061_481_276.png b/Data/val_samples/label/austin4_00061_481_276.png new file mode 100644 index 0000000000000000000000000000000000000000..5f9e1504571104c2e87b9ef2a4b3a07e23a1fc7e Binary files /dev/null and b/Data/val_samples/label/austin4_00061_481_276.png differ diff --git a/Data/val_samples/label/austin4_00133_264_308.png b/Data/val_samples/label/austin4_00133_264_308.png new file mode 100644 index 0000000000000000000000000000000000000000..7ffc1813cdc925fc7236fcf1bc1f2be1cf4f53e9 Binary files /dev/null and b/Data/val_samples/label/austin4_00133_264_308.png differ diff --git a/Data/val_samples/label/austin4_00163_356_367.png b/Data/val_samples/label/austin4_00163_356_367.png new file mode 100644 index 0000000000000000000000000000000000000000..cf6b29fd2a0c027b07d3e68ea9706c0044f187b3 Binary files /dev/null and b/Data/val_samples/label/austin4_00163_356_367.png differ diff --git a/Data/val_samples/label/austin4_00205_468_478.png b/Data/val_samples/label/austin4_00205_468_478.png new file mode 100644 index 0000000000000000000000000000000000000000..6a62461711fd4e4d4cbd7215e62bac756e08fc76 Binary files /dev/null and b/Data/val_samples/label/austin4_00205_468_478.png differ diff --git a/Data/val_samples/label/austin7_00079_495_463.png b/Data/val_samples/label/austin7_00079_495_463.png new file mode 100644 index 0000000000000000000000000000000000000000..036313128207db4cbe4fcabce901bdb2b20b5e1d Binary files /dev/null and b/Data/val_samples/label/austin7_00079_495_463.png differ diff --git a/Data/val_samples/label/austin7_00100_300_263.png b/Data/val_samples/label/austin7_00100_300_263.png new file mode 100644 index 0000000000000000000000000000000000000000..a149c4f61bb4b3a31d5ad86a412b428ec536c7fd Binary files /dev/null and b/Data/val_samples/label/austin7_00100_300_263.png differ diff --git a/Data/val_samples/label/austin7_00166_430_319.png b/Data/val_samples/label/austin7_00166_430_319.png new file mode 100644 index 0000000000000000000000000000000000000000..22397b71540422d9890de6b3c43c6fc3cd50eb0c Binary files /dev/null and b/Data/val_samples/label/austin7_00166_430_319.png differ diff --git a/Models/BackBone/GetBackbone.py b/Models/BackBone/GetBackbone.py new file mode 100644 index 0000000000000000000000000000000000000000..2f20ad1df3879bacf2a38cfbc99017de9f11a345 --- /dev/null +++ b/Models/BackBone/GetBackbone.py @@ -0,0 +1,16 @@ +from .ResNet import * +from .VGGNet import * + +__all__ = ['get_backbone'] + + +def get_backbone(model_name='', pretrained=True, num_classes=None, **kwargs): + if 'res' in model_name: + model = get_resnet(model_name, pretrained=pretrained, num_classes=num_classes, **kwargs) + + elif 'vgg' in model_name: + model = get_vgg(model_name, pretrained=pretrained, num_classes=num_classes, **kwargs) + else: + raise NotImplementedError + return model + diff --git a/Models/BackBone/ResNet.py b/Models/BackBone/ResNet.py new file mode 100644 index 0000000000000000000000000000000000000000..8f7e1e7e36550b7c504faf9af15147f9386d9a59 --- /dev/null +++ b/Models/BackBone/ResNet.py @@ -0,0 +1,325 @@ +import torch +import torch.nn as nn +from torch.hub import load_state_dict_from_url + +__all__ = ['get_resnet', 'BasicBlock'] + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + # if dilation > 1: + # raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes, dilation=dilation) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, + groups=1, width_per_group=64, replace_stride_with_dilation=None, + norm_layer=None, out_keys=None, in_channels=3, **kwargs): + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + self.out_keys = out_keys + self.num_classes = num_classes + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(in_channels, self.inplanes, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, + dilate=replace_stride_with_dilation[1]) + if 'block5' in self.out_keys: + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, + dilate=replace_stride_with_dilation[2]) + if self.num_classes is not None: + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(512 * block.expansion, self.num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def forward(self, x): + endpoints = dict() + endpoints['block0'] = x + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + endpoints['block1'] = x + x = self.maxpool(x) + x = self.layer1(x) + endpoints['block2'] = x + x = self.layer2(x) + endpoints['block3'] = x + x = self.layer3(x) + endpoints['block4'] = x + if 'block5' in self.out_keys: + x = self.layer4(x) + endpoints['block5'] = x + + if self.num_classes is not None: + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.fc(x) + if self.out_keys is not None: + endpoints = {key: endpoints[key] for key in self.out_keys} + return x, endpoints + + +def _resnet(arch, block, layers, pretrained, progress, num_classes=1000, in_channels=3, out_keys=None, **kwargs): + model = ResNet(block, layers, num_classes, out_keys=out_keys, in_channels=in_channels, **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls[arch], + progress=progress) + if in_channels != 3: + keys = state_dict.keys() + keys = [x for x in keys if 'conv1.weight' in x] + for key in keys: + del state_dict[key] + if num_classes !=1000: + keys = state_dict.keys() + keys = [x for x in keys if 'fc' in x] + for key in keys: + del state_dict[key] + if 'block5' not in out_keys: + keys = state_dict.keys() + keys = [x for x in keys if 'layer4' in x] + for key in keys: + del state_dict[key] + model.load_state_dict(state_dict) + print('load resnet model...') + + return model + + +def _resnet18(name='resnet18', pretrained=True, progress=True, num_classes=1000, out_keys=None, **kwargs): + r"""ResNet-18 model from + `"Deep Residual Learning for Image Recognition" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet(name, BasicBlock, [2, 2, 2, 2], pretrained, progress, + num_classes=num_classes, out_keys=out_keys, **kwargs) + +def _resnet50(name='resnet50',pretrained=False, progress=True,num_classes=1000,out_keys=None, **kwargs): + r"""ResNet-50 model from + `"Deep Residual Learning for Image Recognition" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet(name, Bottleneck, [3, 4, 6, 3], pretrained, progress, + num_classes=num_classes,out_keys=out_keys, + **kwargs) + + +def _resnet101(name='resnet101',pretrained=False, progress=True, num_classes=1000,out_keys=None,**kwargs): + r"""ResNet-101 model from + `"Deep Residual Learning for Image Recognition" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet(name, Bottleneck, [3, 4, 23, 3], pretrained, progress, + num_classes=num_classes, out_keys=out_keys, + **kwargs) + + +def _resnet152(name='resnet152',pretrained=False, progress=True,num_classes=1000,out_keys=None,**kwargs): + r"""ResNet-152 model from + `"Deep Residual Learning for Image Recognition" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet(name, Bottleneck, [3, 8, 36, 3], pretrained, progress, + num_classes=num_classes, out_keys=out_keys, + **kwargs) + + +def get_resnet(model_name='resnet50', pretrained=True, progress=True, num_classes=1000, out_keys=None, in_channels=3, **kwargs): + ''' + Get resnet model with name. + :param name: resnet model name, optional values:[resnet18, reset50, resnet101, resnet152] + :param pretrained: If True, returns a model pre-trained on ImageNet + ''' + + if pretrained and num_classes != 1000: + print('warning: num_class is not equal to 1000, which will cause some parameters to fail to load!') + if pretrained and in_channels != 3: + print('warning: in_channels is not equal to 3, which will cause some parameters to fail to load!') + + if model_name == 'resnet18': + return _resnet18(name=model_name, pretrained=pretrained, progress=progress, + num_classes=num_classes, out_keys=out_keys, in_channels=in_channels, **kwargs) + elif model_name == 'resnet50': + return _resnet50(name=model_name, pretrained=pretrained, progress=progress, + num_classes=num_classes, out_keys=out_keys, in_channels=in_channels, **kwargs) + elif model_name == 'resnet101': + return _resnet101(name=model_name, pretrained=pretrained, progress=progress, + num_classes=num_classes, out_keys=out_keys, in_channels=in_channels, **kwargs) + elif model_name == 'resnet152': + return _resnet152(name=model_name, pretrained=pretrained, progress=progress, + num_classes=num_classes, out_keys=out_keys, in_channels=in_channels, **kwargs) + else: + raise NotImplementedError(r'''{0} is not an available values. \ + Please choose one of the available values in + [resnet18, reset50, resnet101, resnet152]'''.format(name)) + + +if __name__ == '__main__': + model = get_resnet('resnet18', pretrained=True, num_classes=None, in_channels=3, out_keys=['block4']) + x = torch.rand([2, 3, 256, 256]) + torch.save(model.state_dict(), 'res18nofc.pth') \ No newline at end of file diff --git a/Models/BackBone/VGGNet.py b/Models/BackBone/VGGNet.py new file mode 100644 index 0000000000000000000000000000000000000000..6ebc77847a4c739bce158ec9017eac1852ebc82b --- /dev/null +++ b/Models/BackBone/VGGNet.py @@ -0,0 +1,267 @@ +import torch +import torch.nn as nn +from torch.hub import load_state_dict_from_url +from typing import Union, List, Dict, Any, cast + + +__all__ = ['get_vgg'] + + +model_urls = { + 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', + 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', + 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', + 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', + 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', + 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', + 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', + 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', +} + + +class VGG(nn.Module): + + def __init__( + self, + num_classes, + out_keys, + output_make_layers, + init_weights: bool = True, + **kwargs + ) -> None: + super(VGG, self).__init__() + self.stage_id = output_make_layers[0] + self.features = output_make_layers[1] + self.num_classes = num_classes + self.out_keys = out_keys + if num_classes is not None: + self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) + self.classifier = nn.Sequential( + nn.Linear(512 * 7 * 7, 4096), + nn.ReLU(True), + nn.Dropout(), + nn.Linear(4096, 4096), + nn.ReLU(True), + nn.Dropout(), + nn.Linear(4096, num_classes), + ) + if init_weights: + self._initialize_weights() + + def forward(self, x: torch.Tensor): + out_blocks = dict() + stage = 0 + out_blocks['block%d' % stage] = x + + for idx, op in enumerate(self.features): + if idx in self.stage_id: + stage += 1 + x = op(x) + out_blocks['block%d' % stage] = x + continue + x = op(x) + + if self.num_classes is not None: + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.classifier(x) + if self.out_keys is not None: + out_blocks = {key: out_blocks[key] for key in self.out_keys} + return x, out_blocks + + def _initialize_weights(self) -> None: + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.constant_(m.bias, 0) + + +def make_layers(in_channels, out_keys, cfg: List[Union[str, int]], batch_norm: bool = False): + layer_list = [] + + idx = 0 + stage_ids = [] + for v in cfg: + if isinstance(v, int) and v in [1, 2, 3, 4, 5]: + if v > int(out_keys[-1].replace('block', '')): + break + continue + if v == 'M': + layer_list += [nn.MaxPool2d(kernel_size=2, stride=2)] + stage_ids += [idx] + idx += 1 + else: + v = cast(int, v) + conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) + if batch_norm: + layer_list += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] + idx += 3 + else: + layer_list += [conv2d, nn.ReLU(inplace=True)] + idx += 2 + in_channels = v + + return stage_ids, nn.Sequential(*layer_list) + + +cfgs: Dict[str, List[Union[str, int]]] = { + 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], + 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], + 'D': [1, 64, 64, 'M', 2, 128, 128, 'M', 3, 256, 256, 256, 'M', 4, 512, 512, 512, 'M', 5, 512, 512, 512, 'M'], + 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], +} + + +def _vgg(in_channels, num_classes, out_keys, arch: str, cfg: str, batch_norm: bool, pretrained: bool, progress: bool, **kwargs: Any) -> VGG: + if pretrained: + kwargs['init_weights'] = False + stage_id, ops = make_layers(in_channels, out_keys, cfgs[cfg], batch_norm=batch_norm) + model = VGG(num_classes, out_keys, (stage_id, ops), **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) + if in_channels != 3: + keys = state_dict.keys() + keys = [x for x in keys if 'features.0.' in x] + for key in keys: + del state_dict[key] + if num_classes != 1000: + keys = state_dict.keys() + keys = [x for x in keys if 'classifier' in x] + for key in keys: + del state_dict[key] + if 'block5' not in out_keys: + keys = list(state_dict.keys()) + for key in keys: + key_layer_id = int(key.split('.')[1]) + if key_layer_id >= stage_id[-1]: + del state_dict[key] + model.load_state_dict(state_dict) + return model + + +def vgg11(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: + r"""VGG 11-layer model (configuration "A") from + `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _vgg('vgg11', 'A', False, pretrained, progress, **kwargs) + + +def vgg11_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: + r"""VGG 11-layer model (configuration "A") with batch normalization + `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _vgg('vgg11_bn', 'A', True, pretrained, progress, **kwargs) + + +def vgg13(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: + r"""VGG 13-layer model (configuration "B") + `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _vgg('vgg13', 'B', False, pretrained, progress, **kwargs) + + +def vgg13_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: + r"""VGG 13-layer model (configuration "B") with batch normalization + `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _vgg('vgg13_bn', 'B', True, pretrained, progress, **kwargs) + + +def vgg16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: + r"""VGG 16-layer model (configuration "D") + `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs) + + +def vgg16_bn(in_channels, num_classes, out_keys, pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: + r"""VGG 16-layer model (configuration "D") with batch normalization + `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _vgg(in_channels, num_classes, out_keys,'vgg16_bn', 'D', True, pretrained, progress, **kwargs) + + +def vgg19(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: + r"""VGG 19-layer model (configuration "E") + `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _vgg('vgg19', 'E', False, pretrained, progress, **kwargs) + + +def vgg19_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: + r"""VGG 19-layer model (configuration 'E') with batch normalization + `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _vgg('vgg19_bn', 'E', True, pretrained, progress, **kwargs) + + +def get_vgg(name='vgg16_bn', pretrained=True, progress=True, num_classes=None, out_keys=None, in_channels=3, **kwargs): + + if pretrained and num_classes != 1000: + print('warning: num_class is not equal to 1000, which will cause some parameters to fail to load!') + if pretrained and in_channels != 3: + print('warning: in_channels is not equal to 3, which will cause some parameters to fail to load!') + + if name == 'vgg16_bn': + return vgg16_bn(in_channels=in_channels, num_classes=num_classes, + out_keys=out_keys, pretrained=pretrained, progress=progress, **kwargs) + + elif name == 'resnet50': + return _resnet50(name=name, pretrained=pretrained, progress=progress, + num_classes=num_classes, out_keys=out_keys, in_channels=in_channels, **kwargs) + elif name == 'resnet101': + return _resnet101(name=name, pretrained=pretrained, progress=progress, + num_classes=num_classes, out_keys=out_keys, in_channels=in_channels, **kwargs) + elif name == 'resnet152': + return _resnet152(name=name, pretrained=pretrained, progress=progress, + num_classes=num_classes, out_keys=out_keys, in_channels=in_channels, **kwargs) + else: + raise NotImplementedError(r'''{0} is not an available values. \ + Please choose one of the available values in + [resnet18, reset50, resnet101, resnet152]'''.format(name)) + + +if __name__ == '__main__': + model = get_vgg('vgg16_bn', pretrained=True, num_classes=None, in_channels=4, out_keys=['block3']) + x = torch.rand([2, 3, 512, 512]) + x = model(x) + torch.save(model.state_dict(), '../../vgg16bns4.pth') diff --git a/Models/BackBone/__init__.py b/Models/BackBone/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0da09945c4d8f63fd343aeaf9059c228e554b656 --- /dev/null +++ b/Models/BackBone/__init__.py @@ -0,0 +1 @@ +from Models.BackBone.GetBackbone import * \ No newline at end of file diff --git a/Pretrain/INRIA_ckpt_latest.pt b/Pretrain/INRIA_ckpt_latest.pt new file mode 100644 index 0000000000000000000000000000000000000000..fb6884a7d7509065c94d89fb1c43ad6247ac8866 --- /dev/null +++ b/Pretrain/INRIA_ckpt_latest.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:84fa5dcdadd00d24d37de13b7b95e8b83baacadf49e3fa985d8eb85d235401c2 +size 75749859 diff --git a/Pretrain/WHU_ckpt_latest.pt b/Pretrain/WHU_ckpt_latest.pt new file mode 100644 index 0000000000000000000000000000000000000000..7bffdaceaaf44cc835d093026abbc9c3a8a1df0f --- /dev/null +++ b/Pretrain/WHU_ckpt_latest.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6cec146101d49459752e03e36bd0508c6e5d6c9a1f658696568f0d56b1660d75 +size 75749859 diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..ae5850c7ccd23fe0869653555e8d4173c00ec51e --- /dev/null +++ b/README.md @@ -0,0 +1,36 @@ +# STTNet +Paper: Building Extraction from Remote Sensing Images with Sparse Token Transformers +1. Prepare Data + Prepare data for training, validation, and test phase. All images are with the resolution of $512 \times 512$. Please refer to the directory of **Data**. + + For larger images, you can patch the images with labels using **Tools/CutImgSegWithLabel.py**. +2. Get Data List + Please refer to **Tools/GetTrainValTestCSV.py** to get the train, val, and test csv files. +3. Get Imgs Infos + Please refer to **Tools/GetImgMeanStd.py** to get the mean value and standard deviation of the all image pixels in training set. +4. Modify Model Infos + Please modify the model information if you want, or keep the default configuration. +5. Run to Train + Train the model in **Main.py**. +6. [Optional] Run to Test + Test the model with checkpoint in **Test.py**. + + +We have provided pretrained models on INRIA and WHU Datasets. The pt models are in folder **Pretrain**. + +If you have any questions, please refer to [our paper](https://www.mdpi.com/2072-4292/13/21/4441) or contact with us by email. + +``` +@Article{rs13214441, +AUTHOR = {Chen, Keyan and Zou, Zhengxia and Shi, Zhenwei}, +TITLE = {Building Extraction from Remote Sensing Images with Sparse Token Transformers}, +JOURNAL = {Remote Sensing}, +VOLUME = {13}, +YEAR = {2021}, +NUMBER = {21}, +ARTICLE-NUMBER = {4441}, +URL = {https://www.mdpi.com/2072-4292/13/21/4441}, +ISSN = {2072-4292}, +DOI = {10.3390/rs13214441} +} +``` diff --git a/STTNet.py b/STTNet.py new file mode 100644 index 0000000000000000000000000000000000000000..7429626b9b4f8d8963f6f52041a55c5b4fff0e56 --- /dev/null +++ b/STTNet.py @@ -0,0 +1,323 @@ +from Models.BackBone import * +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat + + +class DoubleConv(nn.Module): + """(convolution => [BN] => ReLU) * 2""" + + def __init__(self, in_channels, out_channels, mid_channels=None): + super(DoubleConv, self).__init__() + if not mid_channels: + mid_channels = out_channels + self.double_conv = nn.Sequential( + nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), + nn.BatchNorm2d(mid_channels), + nn.LeakyReLU(inplace=True), + nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), + nn.BatchNorm2d(out_channels), + nn.LeakyReLU(inplace=True) + ) + + def forward(self, x): + return self.double_conv(x) + + +class BoTMultiHeadAttention(nn.Module): + def __init__(self, in_feature_dim, num_heads=8, dim_head=None, dropout_rate=0.): + super().__init__() + self.num_heads = num_heads + self.dim_head = dim_head or in_feature_dim // num_heads + self.scale = self.dim_head ** -0.5 + + inner_dim = self.dim_head * self.num_heads + self.weights_qkv = nn.ModuleList([ + nn.Linear(in_feature_dim, inner_dim, bias=False), + nn.Linear(in_feature_dim, inner_dim, bias=False), + nn.Linear(in_feature_dim, inner_dim, bias=False) + ]) + + self.out_layer = nn.Sequential( + nn.Linear(inner_dim, in_feature_dim), + nn.Dropout(dropout_rate) + ) + self.layer_norm = nn.LayerNorm(in_feature_dim) + + def forward(self, q_s, k_s=None, v_s=None, pos_emb=None): + if k_s is None and v_s is None: + k_s = v_s = q_s + elif v_s is None: + v_s = k_s + q, k, v = [self.weights_qkv[idx](x) for idx, x in enumerate([q_s, k_s, v_s])] + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.num_heads), [q, k, v]) + content_content_att = torch.einsum('b h i d, b h j d -> b h i j', q, k) * self.scale + if pos_emb is not None: + pos_emb = rearrange(pos_emb, 'b n (h d) -> b h n d', h=self.num_heads) + content_position_att = torch.einsum('b h i d, b h j d -> b h i j', q, pos_emb) * self.scale + att_mat = content_content_att + content_position_att + else: + att_mat = content_content_att + att_mat = att_mat.softmax(dim=-1) + + atted_x = torch.einsum('b h i j , b h j d -> b h i d', att_mat, v) + atted_x = rearrange(atted_x, 'b h n d -> b n (h d)') + atted_x = self.out_layer(atted_x) + out = self.layer_norm(atted_x + q_s) + return out + + +class STTNet(nn.Module): + def __init__(self, in_channel, n_classes, *args, **kwargs): + super(STTNet, self).__init__() + self.in_channel = in_channel + self.n_classes = n_classes + + # kwargs['backbone'] = res18, res50 or vgg16 + self.res_backbone = get_backbone( + model_name=kwargs['backbone'], num_classes=None, **kwargs + ) + + # kwargs['out_keys'] = ['block_4'] or ['block_5'] + self.last_block = kwargs['out_keys'][-1] + + if '18' in kwargs['backbone']: + # 512 256 128 64 32 16 + layer_channels = [64, 64, 128, 256, 512] + self.reduce_dim_in = 256 + self.reduce_dim_out = 256 // 4 + elif '50' in kwargs['backbone']: + layer_channels = [64, 256, 512, 1024, 2048] + self.reduce_dim_in = 1024 + self.reduce_dim_out = 1024 // 16 + elif '16' in kwargs['backbone']: + layer_channels = [64, 128, 256, 512, 512] + self.reduce_dim_in = 512 + self.reduce_dim_out = 512 // 8 + + self.f_map_size = 32 + + # kwargs['top_k_s'] = 64 + self.top_k_s = kwargs['top_k_s'] + # kwargs['top_k_c'] = 16 + self.top_k_c = kwargs['top_k_c'] + # kwargs['encoder_pos'] = True or False + self.encoder_pos = kwargs['encoder_pos'] + # kwargs['decoder_pos'] = True or False + self.decoder_pos = kwargs['decoder_pos'] + # kwargs['model_pattern'] = ['X', 'A', 'S', 'C'] means different features concatenation + self.model_pattern = kwargs['model_pattern'] + + self.cat_num = len(self.model_pattern) + if 'A' in self.model_pattern: + self.cat_num += 1 + + self.num_head_s = max(2, min(self.top_k_s // 8, 64)) + self.num_head_c = min(2, min(self.top_k_c // 4, 64)) + + self.reduce_channel_b5 = nn.Sequential( + nn.Conv2d(in_channels=self.reduce_dim_in, out_channels=self.reduce_dim_out, kernel_size=1), + nn.BatchNorm2d(self.reduce_dim_out), + nn.LeakyReLU() + ) + # position embedding + # if self.encoder_pos or self.decoder_pos: + self.spatial_embedding_h = nn.Parameter( + torch.randn(1, self.reduce_dim_out, self.f_map_size, 1), requires_grad=True) + self.spatial_embedding_w = nn.Parameter( + torch.randn(1, self.reduce_dim_out, 1, self.f_map_size), requires_grad=True) + self.channel_embedding = nn.Parameter( + torch.randn(1, self.reduce_dim_out, self.f_map_size ** 2), requires_grad=True) + # spatial attention ops + + self.get_s_probability = nn.Sequential( + nn.Conv2d(self.reduce_dim_out, self.reduce_dim_out // 4, kernel_size=3, padding=1), + nn.BatchNorm2d(self.reduce_dim_out // 4), + nn.LeakyReLU(inplace=True), + nn.Conv2d(self.reduce_dim_out // 4, 1, kernel_size=3, padding=1), + nn.Sigmoid() + ) + + # b5 spatial encoder and decoder + self.tf_encoder_spatial_b5 = BoTMultiHeadAttention( + in_feature_dim=self.reduce_dim_out, + num_heads=self.num_head_s + ) + self.tf_decoder_spatial_b5 = BoTMultiHeadAttention( + in_feature_dim=self.reduce_dim_out, + num_heads=self.num_head_s + ) + # channel attention ops + + self.get_c_probability = nn.Sequential( + nn.Conv2d(self.reduce_dim_out, self.reduce_dim_out // 8, kernel_size=self.f_map_size), + nn.BatchNorm2d(self.reduce_dim_out // 8), + nn.LeakyReLU(inplace=True), + nn.Conv2d(self.reduce_dim_out // 8, self.reduce_dim_out, kernel_size=1), + nn.Sigmoid() + ) + + # b5 channel encoder and decoder + self.tf_encoder_channel_b5 = BoTMultiHeadAttention( + in_feature_dim=self.f_map_size ** 2, + num_heads=self.num_head_c + ) + self.tf_decoder_channel_b5 = BoTMultiHeadAttention( + in_feature_dim=self.f_map_size ** 2, + num_heads=self.num_head_c + ) + self.before_predict_head_conv = nn.Sequential( + nn.Conv2d(in_channels=self.reduce_dim_out * self.cat_num, out_channels=self.reduce_dim_in, kernel_size=1), + nn.BatchNorm2d(self.reduce_dim_in), + nn.LeakyReLU() + ) + + if self.last_block == 'block5': + self.pre_pixel_shuffle = nn.PixelShuffle(2) + # 128, 256, 256 + self.pre_double_conv = DoubleConv( + in_channels=layer_channels[4] // 4, + out_channels=layer_channels[3], + mid_channels=layer_channels[3] + ) + + self.pixel_shuffle1 = nn.PixelShuffle(4) + # 16, 64, 64 + self.double_conv1 = DoubleConv( + in_channels=layer_channels[3] // 16, + out_channels=layer_channels[1], + mid_channels=layer_channels[3] // 4 + ) + # 4, 16, 16 + self.pixel_shuffle2 = nn.PixelShuffle(4) + self.double_conv2 = DoubleConv( + in_channels=layer_channels[1] // 16, + out_channels=layer_channels[1] // 4, + mid_channels=layer_channels[1] // 4 + ) + + last_channels = layer_channels[1] // 4 + # 16, 32 + # 32, 2 + if '18' in kwargs['backbone']: + scale_factor = 2 + else: + scale_factor = 1 + self.predict_head_out = nn.Sequential( + nn.Conv2d(in_channels=last_channels, out_channels=last_channels * scale_factor, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(last_channels * scale_factor), + nn.LeakyReLU(), + nn.Conv2d(in_channels=last_channels * scale_factor, out_channels=n_classes, kernel_size=3, stride=1, padding=1), + ) + + self.loss_att_branch = nn.Sequential( + nn.Conv2d(in_channels=self.reduce_dim_out * 2, out_channels=64, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(64), + nn.LeakyReLU(), + nn.Conv2d(in_channels=64, out_channels=n_classes, kernel_size=3, stride=1, padding=1), + ) + + def forward(self, x, *args, **kwargs): + x, endpoints = self.res_backbone(x) + + # reduce channel 512 to 128 + x_reduced_channel = self.reduce_channel_b5(x) # B 128 h w + + prob_s_map = self.get_s_probability(x_reduced_channel) + + prob_c_map = self.get_c_probability(x_reduced_channel) # B C 1 1 + x_att_s = x_reduced_channel * prob_s_map + x_att_c = x_reduced_channel * prob_c_map + + output_cat = [] + if 'X' in self.model_pattern: + output_cat.append(x_reduced_channel) + if 'A' in self.model_pattern: + output_cat.append(x_att_s) + output_cat.append(x_att_c) + + if 'S' in self.model_pattern: + # spatial pos embedding + prob_s_vector = rearrange(prob_s_map, 'b c h w -> b (h w) c') + x_vec_s = rearrange(x_reduced_channel, 'b c h w -> b (h w) c') + # get top k, k = 16 * 16 // 4 x_b5_reduced_channel_vector + _, indices_s = torch.topk(prob_s_vector, k=self.top_k_s, dim=1, sorted=False) # B K 1 + indices_s = repeat(indices_s, 'b k m -> b k (m c)', c=self.reduce_dim_out) + x_s_vec_topk = torch.gather(x_vec_s, 1, indices_s) # B K 128 + if self.encoder_pos or self.decoder_pos: + s_pos_embedding = self.spatial_embedding_h + self.spatial_embedding_w # 1 128 16 16 + s_pos_embedding = repeat(s_pos_embedding, 'm c h w -> (b m) c h w', b=x.size(0)) + s_pos_embedding_vec = rearrange(s_pos_embedding, 'b c h w -> b (h w) c') + s_pos_embedding_vec_topk = torch.gather(s_pos_embedding_vec, 1, indices_s) # B K 128 + + if self.encoder_pos is True: + pos_encoder = s_pos_embedding_vec_topk + else: + pos_encoder = None + + # b5 encoder and decoder op + tf_encoder_s_x = self.tf_encoder_spatial_b5( + q_s=x_s_vec_topk, k_s=None, v_s=None, pos_emb=pos_encoder + ) + if self.decoder_pos is True: + pos_decoder = s_pos_embedding_vec_topk + else: + pos_decoder = None + + tf_decoder_s_x = self.tf_decoder_spatial_b5( + q_s=x_vec_s, k_s=tf_encoder_s_x, v_s=None, + pos_emb=pos_decoder + ) # B (16*16) 128 + + # B 128 16 16 + tf_decoder_s_x = rearrange(tf_decoder_s_x, 'b (h w) c -> b c h w', h=self.f_map_size) + output_cat.append(tf_decoder_s_x) + + if 'C' in self.model_pattern: + # channel attention ops + prob_c_vec = rearrange(prob_c_map, 'b c h w -> b c (h w)') + x_vec_c = rearrange(x_reduced_channel, 'b c h w -> b c (h w)') + + # get top k, k = 128 // 4 = 32 + _, indices_c = torch.topk(prob_c_vec, k=self.top_k_c, dim=1, sorted=True) # b k 1 + indices_c = repeat(indices_c, 'b k m -> b k (m c)', c=self.f_map_size ** 2) + x_vec_c_topk = torch.gather(x_vec_c, 1, indices_c) # B K 256 + if self.encoder_pos or self.decoder_pos: + c_pos_embedding_vec = repeat(self.channel_embedding, 'm len c -> (m b) len c', b=x.size(0)) + c_pos_embedding_vec_topk = torch.gather(c_pos_embedding_vec, 1, indices_c) # B K 256 + + if self.encoder_pos is True: + pos_encoder = c_pos_embedding_vec_topk + else: + pos_encoder = None + # b5 encoder and decoder op + tf_encoder_c_x = self.tf_encoder_channel_b5( + q_s=x_vec_c_topk, k_s=None, v_s=None, + pos_emb=pos_encoder + ) + if self.decoder_pos is True: + pos_decoder = c_pos_embedding_vec_topk + else: + pos_decoder = None + tf_decoder_c_x = self.tf_decoder_channel_b5( + q_s=x_vec_c, k_s=tf_encoder_c_x, v_s=None, + pos_emb=pos_decoder + ) # B 128 (16*16) + + # B 128 16 16 + tf_decoder_c_x = rearrange(tf_decoder_c_x, 'b c (h w) -> b c h w', h=self.f_map_size) + output_cat.append(tf_decoder_c_x) + + x_cat = torch.cat(output_cat, dim=1) + x_cat = self.before_predict_head_conv(x_cat) + + x = self.double_conv1(self.pixel_shuffle1(x_cat)) + x = self.double_conv2(self.pixel_shuffle2(x)) + logits = self.predict_head_out(x) + + att_output = torch.cat([x_att_s, x_att_c], dim=1) + att_branch_output = self.loss_att_branch(att_output) + return logits, att_branch_output + + diff --git a/Test.py b/Test.py new file mode 100644 index 0000000000000000000000000000000000000000..a3d33174ef8d5888fee19f095823636e82db6b9a --- /dev/null +++ b/Test.py @@ -0,0 +1,126 @@ +import os +# Change the numbers when you want to test with specific gpus +# os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1, 2, 3' +import torch +from STTNet import STTNet +import torch.nn.functional as F +from Utils.Datasets import get_data_loader +from Utils.Utils import make_numpy_img, inv_normalize_img, encode_onehot_to_mask, get_metrics, Logger +import matplotlib.pyplot as plt +import numpy as np +from collections import OrderedDict + +if __name__ == '__main__': + model_infos = { + # vgg16_bn, resnet50, resnet18 + 'backbone': 'resnet50', + 'pretrained': True, + 'out_keys': ['block4'], + 'in_channel': 3, + 'n_classes': 2, + 'top_k_s': 64, + 'top_k_c': 16, + 'encoder_pos': True, + 'decoder_pos': True, + 'model_pattern': ['X', 'A', 'S', 'C'], + + 'log_path': 'Results', + 'NUM_WORKERS': 0, + # if you need the validation process. + 'IS_VAL': True, + 'VAL_BATCH_SIZE': 4, + 'VAL_DATASET': 'Tools/generate_dep_info/val_data.csv', + # if you need the test process. + 'IS_TEST': True, + 'TEST_DATASET': 'Tools/generate_dep_info/test_data.csv', + 'IMG_SIZE': [512, 512], + 'PHASE': 'seg', + + # INRIA Dataset + 'PRIOR_MEAN': [0.40672500537632994, 0.42829032416229895, 0.39331840468605667], + 'PRIOR_STD': [0.029498464618176873, 0.027740088491668233, 0.028246722411879095], + # # # WHU Dataset + # 'PRIOR_MEAN': [0.4352682576428411, 0.44523221318154493, 0.41307610541534784], + # 'PRIOR_STD': [0.026973196780331585, 0.026424642808887323, 0.02791246590291434], + + # load state dict path + 'load_checkpoint_path': r'E:\BuildingExtractionDataset\INRIA_ckpt_latest.pt', + } + if model_infos['IS_VAL']: + os.makedirs(model_infos['log_path']+'/val', exist_ok=True) + if model_infos['IS_TEST']: + os.makedirs(model_infos['log_path']+'/test', exist_ok=True) + logger = Logger(model_infos['log_path'] + '/log.log') + + data_loaders = get_data_loader(model_infos, test_mode=True) + loss_weight = 0.1 + model = STTNet(**model_infos) + + logger.write(f'load checkpoint from {model_infos["load_checkpoint_path"]}\n') + state_dict = torch.load(model_infos['load_checkpoint_path'], map_location='cpu') + model_dict = state_dict['model_state_dict'] + try: + model_dict = OrderedDict({k.replace('module.', ''): v for k, v in model_dict.items()}) + model.load_state_dict(model_dict) + except Exception as e: + model.load_state_dict(model_dict) + model = model.cuda() + device_ids = range(torch.cuda.device_count()) + if len(device_ids) > 1: + model = torch.nn.DataParallel(model, device_ids=device_ids) + logger.write(f'Use GPUs: {device_ids}\n') + else: + logger.write(f'Use GPUs: 1\n') + + patterns = ['val', 'test'] + for pattern_id, is_pattern in enumerate([model_infos['IS_VAL'], model_infos['IS_TEST']]): + if is_pattern: + # pred: logits, tensor, nBatch * nClass * W * H + # target: labels, tensor, nBatch * nClass * W * H + # output, batch['label'] + collect_result = {'pred': [], 'target': []} + pattern = patterns[pattern_id] + model.eval() + for batch_id, batch in enumerate(data_loaders[pattern]): + # Get data + img_batch = batch['img'].cuda() + label_batch = batch['label'].cuda() + img_names = batch['img_name'] + collect_result['target'].append(label_batch.data.cpu()) + + # inference + with torch.no_grad(): + logits, att_branch_output = model(img_batch) + + collect_result['pred'].append(logits.data.cpu()) + # get segmentation result, when the phase is test. + pred_label = torch.argmax(logits, 1) + pred_label *= 255 + + # output the segmentation result + if pattern == 'test' or batch_id % 5 == 1: + batch_size = pred_label.size(0) + # k = np.clip(int(0.3 * batch_size), a_min=1, a_max=batch_size) + # ids = np.random.choice(range(batch_size), k, replace=False) + ids = range(batch_size) + for img_id in ids: + img = img_batch[img_id].detach().cpu() + target = label_batch[img_id].detach().cpu() + pred = pred_label[img_id].detach().cpu() + img_name = img_names[img_id] + + img = make_numpy_img( + inv_normalize_img(img, model_infos['PRIOR_MEAN'], model_infos['PRIOR_STD'])) + target = make_numpy_img(encode_onehot_to_mask(target)) * 255 + pred = make_numpy_img(pred) + + vis = np.concatenate([img / 255., target / 255., pred / 255.], axis=0) + vis = np.clip(vis, a_min=0, a_max=1) + file_name = os.path.join(model_infos['log_path'], pattern, f'{img_name.split(".")[0]}.png') + plt.imsave(file_name, vis) + + collect_result['pred'] = torch.cat(collect_result['pred'], dim=0) + collect_result['target'] = torch.cat(collect_result['target'], dim=0) + IoU, OA, F1_score = get_metrics('seg', **collect_result) + logger.write(f'{pattern}: Iou:{IoU[-1]:.4f} OA:{OA[-1]:.4f} F1:{F1_score[-1]:.4f}\n') + diff --git a/Tools/CutImgSegWithLabel.py b/Tools/CutImgSegWithLabel.py new file mode 100644 index 0000000000000000000000000000000000000000..7f12a40947787dad7fcb6f28f7f5e1fb567045fd --- /dev/null +++ b/Tools/CutImgSegWithLabel.py @@ -0,0 +1,44 @@ +import os +import glob +from skimage import io +import tqdm +img_piece_size = (512, 512) + + +def get_pieces(img_path, label_path, img_format): + pieces_folder = os.path.abspath(img_path + '/..') + if not os.path.exists(pieces_folder + '/img_pieces'): + os.makedirs(pieces_folder + '/img_pieces') + if not os.path.exists(pieces_folder + '/label_pieces'): + os.makedirs(pieces_folder + '/label_pieces') + + img_path_list = glob.glob(img_path+'/austin31.%s' % img_format) + for idx in tqdm.tqdm(range(len(img_path_list))): + img = io.imread(img_path_list[idx]) + label = io.imread(label_path + '/' + os.path.basename(img_path_list[idx]).replace(img_format, img_format)) + h, w, c = img.shape + h_list = list(range(0, h-img_piece_size[1], int(0.9 * img_piece_size[1]))) + h_list = h_list + [h - img_piece_size[1]] + # h_list[-1] = h - img_piece_size[1] + w_list = list(range(0, w-img_piece_size[0], int(0.9 * img_piece_size[0]))) + # w_list[-1] = w - img_piece_size[0] + w_list = w_list + [w - img_piece_size[0]] + for h_step in h_list: + for w_step in w_list: + img_piece = img[h_step:h_step+img_piece_size[1], w_step:w_step+img_piece_size[0]] + label_piece = label[h_step:h_step + img_piece_size[1], w_step:w_step + img_piece_size[0]] + assert label_piece.shape[0] == img_piece_size[1] and label_piece.shape[1] == img_piece_size[0], 'shape error' + io.imsave(pieces_folder + '/img_pieces%s_%d_%d.png' % + (img_path_list[idx].replace(img_path, '').replace('.' + img_format, ''), w_step, h_step), img_piece, check_contrast=False) + io.imsave(pieces_folder + '/label_pieces%s_%d_%d.png' % + (img_path_list[idx].replace(img_path, '').replace('.' + img_format, ''), w_step, h_step), label_piece, check_contrast=False) + + +if __name__ == "__main__": + parent_path = r'J:\20200923-建筑提取数据集\InriaAerialImageDataset\train' + for i in ['train', 'val', 'test']: + img_path = parent_path + '/' + i + '/img' + label_path = parent_path + '/' + i + '/gt' + img_format = 'tif' + get_pieces(img_path, label_path, img_format) + diff --git a/Tools/GetImgMeanStd.py b/Tools/GetImgMeanStd.py new file mode 100644 index 0000000000000000000000000000000000000000..ea4b9e26336ce2cc340e10c6e8d8258b6305a886 --- /dev/null +++ b/Tools/GetImgMeanStd.py @@ -0,0 +1,52 @@ +import os +import pandas as pd +from skimage import io +import numpy as np +import json +import tqdm + + +# R, G, B +class GetImgMeanStd: + def __init__(self, data_file): + assert os.path.exists(data_file), 'train.csv dose not exist!' + self.data_info = pd.read_csv(data_file, index_col=0) + self.save_path_mean_std_info = 'generate_dep_info' + self.mean = None + self.std = None + + def get_img_mean_std(self): + means = [] + stds = [] + bar = tqdm.tqdm(total=len(self.data_info)) + for row in self.data_info.iterrows(): + bar.update(1) + img_name = row[1]['img'] + # print(img_name) + img = io.imread(img_name) + img = img / 255. + assert img is not None, img_name + 'is not valid' + # height*width*channels, axis=0 is the first dim + mean = np.mean(np.mean(img, axis=0), axis=0) + means.append(mean) + std = np.std(np.std(img, axis=0), axis=0) + stds.append(std) + bar.close() + self.mean = np.mean(np.array(means), axis=0).tolist() + self.std = np.mean(np.array(stds), axis=0).tolist() + return {'mean': self.mean, 'std': self.std} + + def write_mean_std_information(self): + info = self.get_img_mean_std() + writer = os.path.join(self.save_path_mean_std_info, 'mean_std_info_test.json') + with open(writer, 'w') as f_writer: + json.dump(info, f_writer) + print('\'PRIOR_MEAN\': %s\n\'PRIOR_STD\': %s\n' % (info['mean'], info['std'])) + + +if __name__ == '__main__': + data_file = r'generate_dep_info/train_data.csv' + getImgMeanStd = GetImgMeanStd(data_file) + getImgMeanStd.write_mean_std_information() + + diff --git a/Tools/GetTrainValTestCSV.py b/Tools/GetTrainValTestCSV.py new file mode 100644 index 0000000000000000000000000000000000000000..a4b14ec4e550d6a98e4f7465fe366889b74c6683 --- /dev/null +++ b/Tools/GetTrainValTestCSV.py @@ -0,0 +1,128 @@ +import os +import glob +import random + +import pandas as pd +import cv2 +import tqdm +import numpy as np + + +class GetTrainTestCSV: + def __init__(self, dataset_path_list, csv_name, img_format_list, negative_keep_rate=0.1): + self.data_path_list = dataset_path_list + self.img_format_list = img_format_list + self.negative_keep_rate = negative_keep_rate + self.save_path_csv = r'generate_dep_info' + os.makedirs(self.save_path_csv, exist_ok=True) + self.csv_name = csv_name + + def get_csv(self, pattern): + def get_data_infos(img_path, img_format): + data_info = {'img': [], 'label': []} + img_file_list = glob.glob(img_path + '/*%s' % img_format) + assert len(img_file_list), 'No data in DATASET_PATH!' + for img_file in tqdm.tqdm(img_file_list): + label_file = img_file.replace(img_format, 'png').replace('imgs', 'labels') + if not os.path.exists(label_file): + label_file = 'None' + # if os.path.getsize(label_file) == 0: + # if np.random.random() < self.negative_keep_rate: + # data_info['img'].append(img_file) + # data_info['label'].append(label_file) + # continue + if pattern == 'test': + label_file = 'None' + data_info['img'].append(img_file) + data_info['label'].append(label_file) + + return data_info + + data_information = {'img': [], 'label': []} + for idx, data_dir in enumerate(self.data_path_list): + if len(self.data_path_list) == len(self.img_format_list): + img_format = self.img_format_list[idx] + else: + img_format = self.img_format_list[0] + assert os.path.exists(data_dir), 'No dir: ' + data_dir + img_path_list = glob.glob(data_dir+'/*{0}'.format(img_format)) + # img folder + if len(img_path_list) == 0: + img_path_list = glob.glob(data_dir+'/*') + for img_path in img_path_list: + if os.path.isdir(img_path): + data_info = get_data_infos(img_path, img_format) + data_information['img'].extend(data_info['img']) + data_information['label'].extend(data_info['label']) + + else: + data_info = get_data_infos(data_dir, img_format) + data_information['img'].extend(data_info['img']) + data_information['label'].extend(data_info['label']) + + data_annotation = pd.DataFrame(data_information) + writer_name = self.save_path_csv + '/' + self.csv_name + data_annotation.to_csv(writer_name, index_label=False) + print(os.path.basename(writer_name) + ' file saves successfully!') + + def generate_val_data_from_train_data(self, frac=0.1): + if os.path.exists(self.save_path_csv + '/' + self.csv_name): + data = pd.read_csv(self.save_path_csv + '/' + self.csv_name) + else: + raise Exception('no train data') + + val_data = data.sample(frac=frac, replace=False) + train_data = data.drop(val_data.index) + val_data = val_data.reset_index(drop=True) + train_data = train_data.reset_index(drop=True) + writer_name = self.save_path_csv + '/' + self.csv_name + train_data.to_csv(writer_name, index_label=False) + writer_name = self.save_path_csv + '/' + self.csv_name.replace('train', 'val') + val_data.to_csv(writer_name, index_label=False) + + def _get_file(self, in_path_list): + file_list = [] + for file in in_path_list: + if os.path.isdir(os.path.abspath(file)): + files = glob.glob(file + '/*') + file_list.extend(self._get_file(files)) + else: + file_list += [file] + return file_list + + def get_csv_file(self, phase): + phases = ['seg', 'flow', 'od'] + assert phase in phases, f'{phase} should in {phases}!' + + file_list = self._get_file(self.data_path_list) + file_list = [x for x in file_list if x.split('.')[-1] in self.img_format_list] + assert len(file_list), 'No data in data_path_list!' + random.shuffle(file_list) + data_information = {} + if phase == 'seg': + data_information['img'] = file_list + data_information['label'] = [x.replace('img', 'label') for x in file_list] + elif phase == 'flow': + data_information['img1'] = file_list[:-1] + data_information['img2'] = file_list[1:] + elif phase == 'od': + data_information['img'] = file_list + data_information['label'] = [x.replace('tiff', 'txt').replace('jpg', 'txt').replace('png', 'txt') for x in file_list] + + data_annotation = pd.DataFrame(data_information) + writer_name = self.save_path_csv + '/' + self.csv_name + data_annotation.to_csv(writer_name, index_label=False) + print(os.path.basename(writer_name) + ' file saves successfully!') + + +if __name__ == '__main__': + data_path_list = [ + 'D:/Code/ProjectOnGithub/STT/Data/val_samples/img' + ] + csv_name = 'val_data.csv' + img_format_list = ['png'] + + getTrainTestCSV = GetTrainTestCSV(dataset_path_list=data_path_list, csv_name=csv_name, img_format_list=img_format_list) + getTrainTestCSV.get_csv_file(phase='seg') + + diff --git a/Tools/generate_dep_info/mean_std_info_test.json b/Tools/generate_dep_info/mean_std_info_test.json new file mode 100644 index 0000000000000000000000000000000000000000..feb15e2d6a13c7edf972ae561711a73048bd30f6 --- /dev/null +++ b/Tools/generate_dep_info/mean_std_info_test.json @@ -0,0 +1 @@ +{"mean": [0.46278404739026296, 0.469763416147487, 0.44496931596235817], "std": [0.036004664402296035, 0.036798446555721516, 0.038701379834091894]} \ No newline at end of file diff --git a/Tools/generate_dep_info/test_data.csv b/Tools/generate_dep_info/test_data.csv new file mode 100644 index 0000000000000000000000000000000000000000..9171ae847c143814455e3d51fef14095e9b7551d --- /dev/null +++ b/Tools/generate_dep_info/test_data.csv @@ -0,0 +1,12 @@ +img,label +0,D:/Code/ProjectOnGithub/STT/Data/test_samples/img\austin3_00497_432_361.png,D:/Code/ProjectOnGithub/STT/Data/test_samples/label\austin3_00497_432_361.png +1,D:/Code/ProjectOnGithub/STT/Data/test_samples/img\austin4_00133_264_308.png,D:/Code/ProjectOnGithub/STT/Data/test_samples/label\austin4_00133_264_308.png +2,D:/Code/ProjectOnGithub/STT/Data/test_samples/img\austin2_01179_258_289.png,D:/Code/ProjectOnGithub/STT/Data/test_samples/label\austin2_01179_258_289.png +3,D:/Code/ProjectOnGithub/STT/Data/test_samples/img\austin7_00079_495_463.png,D:/Code/ProjectOnGithub/STT/Data/test_samples/label\austin7_00079_495_463.png +4,D:/Code/ProjectOnGithub/STT/Data/test_samples/img\austin3_01452_334_335.png,D:/Code/ProjectOnGithub/STT/Data/test_samples/label\austin3_01452_334_335.png +5,D:/Code/ProjectOnGithub/STT/Data/test_samples/img\austin4_00205_468_478.png,D:/Code/ProjectOnGithub/STT/Data/test_samples/label\austin4_00205_468_478.png +6,D:/Code/ProjectOnGithub/STT/Data/test_samples/img\austin7_00166_430_319.png,D:/Code/ProjectOnGithub/STT/Data/test_samples/label\austin7_00166_430_319.png +7,D:/Code/ProjectOnGithub/STT/Data/test_samples/img\austin4_00061_481_276.png,D:/Code/ProjectOnGithub/STT/Data/test_samples/label\austin4_00061_481_276.png +8,D:/Code/ProjectOnGithub/STT/Data/test_samples/img\austin7_00100_300_263.png,D:/Code/ProjectOnGithub/STT/Data/test_samples/label\austin7_00100_300_263.png +9,D:/Code/ProjectOnGithub/STT/Data/test_samples/img\austin4_00163_356_367.png,D:/Code/ProjectOnGithub/STT/Data/test_samples/label\austin4_00163_356_367.png +10,D:/Code/ProjectOnGithub/STT/Data/test_samples/img\austin3_00294_332_293.png,D:/Code/ProjectOnGithub/STT/Data/test_samples/label\austin3_00294_332_293.png diff --git a/Tools/generate_dep_info/train_data.csv b/Tools/generate_dep_info/train_data.csv new file mode 100644 index 0000000000000000000000000000000000000000..53feb1f8be1982cbd5fc2c0d6bada1fb4bb30d70 --- /dev/null +++ b/Tools/generate_dep_info/train_data.csv @@ -0,0 +1,12 @@ +img,label +0,D:/Code/ProjectOnGithub/STT/Data/train_samples/img\austin7_00166_430_319.png,D:/Code/ProjectOnGithub/STT/Data/train_samples/label\austin7_00166_430_319.png +1,D:/Code/ProjectOnGithub/STT/Data/train_samples/img\austin3_00497_432_361.png,D:/Code/ProjectOnGithub/STT/Data/train_samples/label\austin3_00497_432_361.png +2,D:/Code/ProjectOnGithub/STT/Data/train_samples/img\austin3_00294_332_293.png,D:/Code/ProjectOnGithub/STT/Data/train_samples/label\austin3_00294_332_293.png +3,D:/Code/ProjectOnGithub/STT/Data/train_samples/img\austin4_00205_468_478.png,D:/Code/ProjectOnGithub/STT/Data/train_samples/label\austin4_00205_468_478.png +4,D:/Code/ProjectOnGithub/STT/Data/train_samples/img\austin4_00133_264_308.png,D:/Code/ProjectOnGithub/STT/Data/train_samples/label\austin4_00133_264_308.png +5,D:/Code/ProjectOnGithub/STT/Data/train_samples/img\austin7_00100_300_263.png,D:/Code/ProjectOnGithub/STT/Data/train_samples/label\austin7_00100_300_263.png +6,D:/Code/ProjectOnGithub/STT/Data/train_samples/img\austin4_00061_481_276.png,D:/Code/ProjectOnGithub/STT/Data/train_samples/label\austin4_00061_481_276.png +7,D:/Code/ProjectOnGithub/STT/Data/train_samples/img\austin3_01452_334_335.png,D:/Code/ProjectOnGithub/STT/Data/train_samples/label\austin3_01452_334_335.png +8,D:/Code/ProjectOnGithub/STT/Data/train_samples/img\austin4_00163_356_367.png,D:/Code/ProjectOnGithub/STT/Data/train_samples/label\austin4_00163_356_367.png +9,D:/Code/ProjectOnGithub/STT/Data/train_samples/img\austin2_01179_258_289.png,D:/Code/ProjectOnGithub/STT/Data/train_samples/label\austin2_01179_258_289.png +10,D:/Code/ProjectOnGithub/STT/Data/train_samples/img\austin7_00079_495_463.png,D:/Code/ProjectOnGithub/STT/Data/train_samples/label\austin7_00079_495_463.png diff --git a/Tools/generate_dep_info/val_data.csv b/Tools/generate_dep_info/val_data.csv new file mode 100644 index 0000000000000000000000000000000000000000..2278258af29d573a77c1572ce725c449e0c1c053 --- /dev/null +++ b/Tools/generate_dep_info/val_data.csv @@ -0,0 +1,12 @@ +img,label +0,D:/Code/ProjectOnGithub/STT/Data/val_samples/img\austin7_00166_430_319.png,D:/Code/ProjectOnGithub/STT/Data/val_samples/label\austin7_00166_430_319.png +1,D:/Code/ProjectOnGithub/STT/Data/val_samples/img\austin3_00497_432_361.png,D:/Code/ProjectOnGithub/STT/Data/val_samples/label\austin3_00497_432_361.png +2,D:/Code/ProjectOnGithub/STT/Data/val_samples/img\austin2_01179_258_289.png,D:/Code/ProjectOnGithub/STT/Data/val_samples/label\austin2_01179_258_289.png +3,D:/Code/ProjectOnGithub/STT/Data/val_samples/img\austin4_00163_356_367.png,D:/Code/ProjectOnGithub/STT/Data/val_samples/label\austin4_00163_356_367.png +4,D:/Code/ProjectOnGithub/STT/Data/val_samples/img\austin7_00100_300_263.png,D:/Code/ProjectOnGithub/STT/Data/val_samples/label\austin7_00100_300_263.png +5,D:/Code/ProjectOnGithub/STT/Data/val_samples/img\austin3_01452_334_335.png,D:/Code/ProjectOnGithub/STT/Data/val_samples/label\austin3_01452_334_335.png +6,D:/Code/ProjectOnGithub/STT/Data/val_samples/img\austin7_00079_495_463.png,D:/Code/ProjectOnGithub/STT/Data/val_samples/label\austin7_00079_495_463.png +7,D:/Code/ProjectOnGithub/STT/Data/val_samples/img\austin3_00294_332_293.png,D:/Code/ProjectOnGithub/STT/Data/val_samples/label\austin3_00294_332_293.png +8,D:/Code/ProjectOnGithub/STT/Data/val_samples/img\austin4_00061_481_276.png,D:/Code/ProjectOnGithub/STT/Data/val_samples/label\austin4_00061_481_276.png +9,D:/Code/ProjectOnGithub/STT/Data/val_samples/img\austin4_00205_468_478.png,D:/Code/ProjectOnGithub/STT/Data/val_samples/label\austin4_00205_468_478.png +10,D:/Code/ProjectOnGithub/STT/Data/val_samples/img\austin4_00133_264_308.png,D:/Code/ProjectOnGithub/STT/Data/val_samples/label\austin4_00133_264_308.png diff --git a/Train.py b/Train.py new file mode 100644 index 0000000000000000000000000000000000000000..8a34e02dafd1ed55a686e90e43dc0e031cd35986 --- /dev/null +++ b/Train.py @@ -0,0 +1,185 @@ +import os +# Change the numbers when you want to train with specific gpus +# os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1, 2, 3' +import torch +from STTNet import STTNet +import torch.nn.functional as F +from Utils.Datasets import get_data_loader +from Utils.Utils import make_numpy_img, inv_normalize_img, encode_onehot_to_mask, get_metrics, Logger +import matplotlib.pyplot as plt +import numpy as np +from collections import OrderedDict +from torch.optim.lr_scheduler import MultiStepLR + +if __name__ == '__main__': + model_infos = { + # vgg16_bn, resnet50, resnet18 + 'backbone': 'resnet50', + 'pretrained': True, + 'out_keys': ['block4'], + 'in_channel': 3, + 'n_classes': 2, + 'top_k_s': 64, + 'top_k_c': 16, + 'encoder_pos': True, + 'decoder_pos': True, + 'model_pattern': ['X', 'A', 'S', 'C'], + + 'BATCH_SIZE': 8, + 'IS_SHUFFLE': True, + 'NUM_WORKERS': 0, + 'DATASET': 'Tools/generate_dep_info/train_data.csv', + 'model_path': 'Checkpoints', + 'log_path': 'Results', + # if you need the validation process. + 'IS_VAL': True, + 'VAL_BATCH_SIZE': 4, + 'VAL_DATASET': 'Tools/generate_dep_info/val_data.csv', + # if you need the test process. + 'IS_TEST': True, + 'TEST_DATASET': 'Tools/generate_dep_info/test_data.csv', + 'IMG_SIZE': [512, 512], + 'PHASE': 'seg', + + # INRIA Dataset + 'PRIOR_MEAN': [0.40672500537632994, 0.42829032416229895, 0.39331840468605667], + 'PRIOR_STD': [0.029498464618176873, 0.027740088491668233, 0.028246722411879095], + # # # WHU Dataset + # 'PRIOR_MEAN': [0.4352682576428411, 0.44523221318154493, 0.41307610541534784], + # 'PRIOR_STD': [0.026973196780331585, 0.026424642808887323, 0.02791246590291434], + + # if you want to load state dict + 'load_checkpoint_path': r'E:\BuildingExtractionDataset\INRIA_ckpt_latest.pt', + # if you want to resume a checkpoint + 'resume_checkpoint_path': '', + + } + os.makedirs(model_infos['model_path'], exist_ok=True) + if model_infos['IS_VAL']: + os.makedirs(model_infos['log_path']+'/val', exist_ok=True) + if model_infos['IS_TEST']: + os.makedirs(model_infos['log_path']+'/test', exist_ok=True) + logger = Logger(model_infos['log_path'] + '/log.log') + + data_loaders = get_data_loader(model_infos) + loss_weight = 0.1 + model = STTNet(**model_infos) + + epoch_start = 0 + if model_infos['load_checkpoint_path'] is not None and os.path.exists(model_infos['load_checkpoint_path']): + logger.write(f'load checkpoint from {model_infos["load_checkpoint_path"]}\n') + state_dict = torch.load(model_infos['load_checkpoint_path'], map_location='cpu') + model_dict = state_dict['model_state_dict'] + try: + model_dict = OrderedDict({k.replace('module.', ''): v for k, v in model_dict.items()}) + model.load_state_dict(model_dict) + except Exception as e: + model.load_state_dict(model_dict) + if model_infos['resume_checkpoint_path'] is not None and os.path.exists(model_infos['resume_checkpoint_path']): + logger.write(f'resume checkpoint path from {model_infos["resume_checkpoint_path"]}\n') + state_dict = torch.load(model_infos['resume_checkpoint_path'], map_location='cpu') + epoch_start = state_dict['epoch_id'] + model_dict = state_dict['model_state_dict'] + logger.write(f'resume checkpoint from epoch {epoch_start}\n') + try: + model_dict = OrderedDict({k.replace('module.', ''): v for k, v in model_dict.items()}) + model.load_state_dict(model_dict) + except Exception as e: + model.load_state_dict(model_dict) + model = model.cuda() + device_ids = range(torch.cuda.device_count()) + if len(device_ids) > 1: + model = torch.nn.DataParallel(model, device_ids=device_ids) + logger.write(f'Use GPUs: {device_ids}\n') + else: + logger.write(f'Use GPUs: 1\n') + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) + max_epoch = 300 + scheduler = MultiStepLR(optimizer, [int(max_epoch*2/3), int(max_epoch*5/6)], 0.5) + + for epoch_id in range(epoch_start, max_epoch): + pattern = 'train' + model.train() # Set model to training mode + for batch_id, batch in enumerate(data_loaders[pattern]): + # Get data + img_batch = batch['img'].cuda() + label_batch = batch['label'].cuda() + + # inference + optimizer.zero_grad() + logits, att_branch_output = model(img_batch) + + # compute loss + label_downs = F.interpolate(label_batch, att_branch_output.size()[2:], mode='nearest') + loss_branch = F.binary_cross_entropy_with_logits(att_branch_output, label_downs) + loss_master = F.binary_cross_entropy_with_logits(logits, label_batch) + loss = loss_master + loss_weight * loss_branch + # loss backward + loss.backward() + optimizer.step() + + if batch_id % 20 == 1: + logger.write( + f'{pattern}: {epoch_id}/{max_epoch} {batch_id}/{len(data_loaders[pattern])} loss: {loss.item():.4f}\n') + + scheduler.step() + patterns = ['val', 'test'] + for pattern_id, is_pattern in enumerate([model_infos['IS_VAL'], model_infos['IS_TEST']]): + if is_pattern: + # pred: logits, tensor, nBatch * nClass * W * H + # target: labels, tensor, nBatch * nClass * W * H + # output, batch['label'] + collect_result = {'pred': [], 'target': []} + pattern = patterns[pattern_id] + model.eval() + for batch_id, batch in enumerate(data_loaders[pattern]): + # Get data + img_batch = batch['img'].cuda() + label_batch = batch['label'].cuda() + img_names = batch['img_name'] + collect_result['target'].append(label_batch.data.cpu()) + + # inference + with torch.no_grad(): + logits, att_branch_output = model(img_batch) + + collect_result['pred'].append(logits.data.cpu()) + # get segmentation result, when the phase is test. + pred_label = torch.argmax(logits, 1) + pred_label *= 255 + + if pattern == 'test' or batch_id % 5 == 1: + batch_size = pred_label.size(0) + # k = np.clip(int(0.3 * batch_size), a_min=1, a_max=batch_size) + # ids = np.random.choice(range(batch_size), k, replace=False) + ids = range(batch_size) + for img_id in ids: + img = img_batch[img_id].detach().cpu() + target = label_batch[img_id].detach().cpu() + pred = pred_label[img_id].detach().cpu() + img_name = img_names[img_id] + + img = make_numpy_img( + inv_normalize_img(img, model_infos['PRIOR_MEAN'], model_infos['PRIOR_STD'])) + target = make_numpy_img(encode_onehot_to_mask(target)) * 255 + pred = make_numpy_img(pred) + + vis = np.concatenate([img / 255., target / 255., pred / 255.], axis=0) + vis = np.clip(vis, a_min=0, a_max=1) + file_name = os.path.join(model_infos['log_path'], pattern, f'Epoch_{epoch_id}_{img_name.split(".")[0]}.png') + plt.imsave(file_name, vis) + + collect_result['pred'] = torch.cat(collect_result['pred'], dim=0) + collect_result['target'] = torch.cat(collect_result['target'], dim=0) + IoU, OA, F1_score = get_metrics('seg', **collect_result) + logger.write(f'{pattern}: {epoch_id}/{max_epoch} Iou:{IoU[-1]:.4f} OA:{OA[-1]:.4f} F1:{F1_score[-1]:.4f}\n') + if epoch_id % 20 == 1: + torch.save({ + 'epoch_id': epoch_id, + 'model_state_dict': model.state_dict() + }, os.path.join(model_infos['model_path'], f'ckpt_{epoch_id}.pt')) + torch.save({ + 'epoch_id': epoch_id, + 'model_state_dict': model.state_dict() + }, os.path.join(model_infos['model_path'], f'ckpt_latest.pt')) + diff --git a/Utils/Augmentations.py b/Utils/Augmentations.py new file mode 100644 index 0000000000000000000000000000000000000000..6664b9450045318a9125e22cdcbcae5bfd24b14a --- /dev/null +++ b/Utils/Augmentations.py @@ -0,0 +1,606 @@ +import numpy as np +import cv2 +import torch + +class Compose(object): + """Composes several transforms together. + + Args: + transforms (list of ``Transform`` objects): list of transforms to compose. + + Example: + >>> transforms.Compose([ + >>> transforms.CenterCrop(10), + >>> transforms.ToTensor(), + >>> ]) + """ + + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, data): + for t in self.transforms: + data = t(data) + return data + + def __repr__(self): + format_string = self.__class__.__name__ + '(' + for t in self.transforms: + format_string += '\n' + format_string += ' {0}'.format(t) + format_string += '\n)' + return format_string + + +class ConvertUcharToFloat(object): + """ + Convert img form uchar to float32 + """ + + def __call__(self, data): + data = [x.astype(np.float32) for x in data] + return data + + +class RandomContrast(object): + """ + Get random contrast img + """ + def __init__(self, phase, lower=0.8, upper=1.2, prob=0.5): + self.phase = phase + self.lower = lower + self.upper = upper + self.prob = prob + assert self.upper >= self.lower, "contrast upper must be >= lower!" + assert self.lower > 0, "contrast lower must be non-negative!" + + def __call__(self, data): + if self.phase in ['od', 'seg']: + img, _ = data + if torch.rand(1) < self.prob: + alpha = torch.FloatTensor(1).uniform_(self.lower, self.upper) + img *= alpha.numpy() + return_data = img, _ + elif self.phase == 'cd': + img1, label1, img2, label2 = data + if torch.rand(1) < self.prob: + alpha = torch.FloatTensor(1).uniform_(self.lower, self.upper) + img1 *= alpha.numpy() + if torch.rand(1) < self.prob: + alpha = torch.FloatTensor(1).uniform_(self.lower, self.upper) + img2 *= alpha.numpy() + return_data = img1, label1, img2, label2 + return return_data + + +class RandomBrightness(object): + """ + Get random brightness img + """ + def __init__(self, phase, delta=10, prob=0.5): + self.phase = phase + self.delta = delta + self.prob = prob + assert 0. <= self.delta < 255., "brightness delta must between 0 to 255" + + def __call__(self, data): + if self.phase in ['od', 'seg']: + img, _ = data + if torch.rand(1) < self.prob: + delta = torch.FloatTensor(1).uniform_(- self.delta, self.delta) + img += delta.numpy() + return_data = img, _ + + elif self.phase == 'cd': + img1, label1, img2, label2 = data + if torch.rand(1) < self.prob: + delta = torch.FloatTensor(1).uniform_(- self.delta, self.delta) + img1 += delta.numpy() + if torch.rand(1) < self.prob: + delta = torch.FloatTensor(1).uniform_(- self.delta, self.delta) + img2 += delta.numpy() + return_data = img1, label1, img2, label2 + + return return_data + + +class ConvertColor(object): + """ + Convert img color BGR to HSV or HSV to BGR for later img distortion. + """ + def __init__(self, phase, current='RGB', target='HSV'): + self.phase = phase + self.current = current + self.target = target + + def __call__(self, data): + + if self.phase in ['od', 'seg']: + img, _ = data + if self.current == 'RGB' and self.target == 'HSV': + img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) + elif self.current == 'HSV' and self.target == 'RGB': + img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB) + else: + raise NotImplementedError("Convert color fail!") + return_data = img, _ + + elif self.phase == 'cd': + img1, label1, img2, label2 = data + if self.current == 'RGB' and self.target == 'HSV': + img1 = cv2.cvtColor(img1, cv2.COLOR_RGB2HSV) + img2 = cv2.cvtColor(img2, cv2.COLOR_RGB2HSV) + elif self.current == 'HSV' and self.target == 'RGB': + img1 = cv2.cvtColor(img1, cv2.COLOR_HSV2RGB) + img2 = cv2.cvtColor(img2, cv2.COLOR_HSV2RGB) + else: + raise NotImplementedError("Convert color fail!") + return_data = img1, label1, img2, label2 + + return return_data + + +class RandomSaturation(object): + """ + get random saturation img + apply the restriction on saturation S + """ + def __init__(self, phase, lower=0.8, upper=1.2, prob=0.5): + self.phase = phase + self.lower = lower + self.upper = upper + self.prob = prob + assert self.upper >= self.lower, "saturation upper must be >= lower!" + assert self.lower > 0, "saturation lower must be non-negative!" + + def __call__(self, data): + if self.phase in ['od', 'seg']: + img, _ = data + if torch.rand(1) < self.prob: + alpha = torch.FloatTensor(1).uniform_(self.lower, self.upper) + img[:, :, 1] *= alpha.numpy() + return_data = img, _ + elif self.phase == 'cd': + img1, label1, img2, label2 = data + if torch.rand(1) < self.prob: + alpha = torch.FloatTensor(1).uniform_(self.lower, self.upper) + img1[:, :, 1] *= alpha.numpy() + if torch.rand(1) < self.prob: + alpha = torch.FloatTensor(1).uniform_(self.lower, self.upper) + img2[:, :, 1] *= alpha.numpy() + return_data = img1, label1, img2, label2 + return return_data + + +class RandomHue(object): + """ + get random Hue img + apply the restriction on Hue H + """ + def __init__(self, phase, delta=10., prob=0.5): + self.phase = phase + self.delta = delta + self.prob = prob + assert 0 <= self.delta < 360, "Hue delta must between 0 to 360!" + + def __call__(self, data): + if self.phase in ['od', 'seg']: + img, _ = data + if torch.rand(1) < self.prob: + alpha = torch.FloatTensor(1).uniform_(-self.delta, self.delta) + img[:, :, 0] += alpha.numpy() + img[:, :, 0][img[:, :, 0] > 360.0] -= 360.0 + img[:, :, 0][img[:, :, 0] < 0.0] += 360.0 + return_data = img, _ + + elif self.phase == 'cd': + img1, label1, img2, label2 = data + if torch.rand(1) < self.prob: + alpha = torch.FloatTensor(1).uniform_(-self.delta, self.delta) + img1[:, :, 0] += alpha.numpy() + img1[:, :, 0][img1[:, :, 0] > 360.0] -= 360.0 + img1[:, :, 0][img1[:, :, 0] < 0.0] += 360.0 + if torch.rand(1) < self.prob: + alpha = torch.FloatTensor(1).uniform_(-self.delta, self.delta) + img2[:, :, 0] += alpha.numpy() + img2[:, :, 0][img2[:, :, 0] > 360.0] -= 360.0 + img2[:, :, 0][img2[:, :, 0] < 0.0] += 360.0 + + return_data = img1, label1, img2, label2 + + return return_data + + +class RandomChannelNoise(object): + """ + Get random shuffle channels + """ + def __init__(self, phase, prob=0.4): + self.phase = phase + self.prob = prob + self.perms = ((0, 1, 2), (0, 2, 1), + (1, 0, 2), (1, 2, 0), + (2, 0, 1), (2, 1, 0)) + + def __call__(self, data): + if self.phase in ['od', 'seg']: + img, _ = data + if torch.rand(1) < self.prob: + shuffle_factor = self.perms[torch.randint(0, len(self.perms), size=[])] + img = img[:, :, shuffle_factor] + return_data = img, _ + + elif self.phase == 'cd': + img1, label1, img2, label2 = data + if torch.rand(1) < self.prob: + shuffle_factor = self.perms[torch.randint(0, len(self.perms), size=[])] + img1 = img1[:, :, shuffle_factor] + if torch.rand(1) < self.prob: + shuffle_factor = self.perms[torch.randint(0, len(self.perms), size=[])] + img2 = img2[:, :, shuffle_factor] + return_data = img1, label1, img2, label2 + + return return_data + + +class ImgDistortion(object): + """ + Change img by distortion + """ + def __init__(self, phase, prob=0.5): + self.phase = phase + self.prob = prob + self.operation = [ + RandomContrast(phase), + ConvertColor(phase, current='RGB', target='HSV'), + RandomSaturation(phase), + RandomHue(phase), + ConvertColor(phase, current='HSV', target='RGB'), + RandomContrast(phase) + ] + self.random_brightness = RandomBrightness(phase) + self.random_light_noise = RandomChannelNoise(phase) + + def __call__(self, data): + if torch.rand(1) < self.prob: + data = self.random_brightness(data) + if torch.rand(1) < self.prob: + distort = Compose(self.operation[:-1]) + else: + distort = Compose(self.operation[1:]) + data = distort(data) + data = self.random_light_noise(data) + return data + + +class ExpandImg(object): + """ + Get expand img + """ + def __init__(self, phase, prior_mean, prob=0.5, expand_ratio=0.2): + self.phase = phase + self.prior_mean = np.array(prior_mean) * 255 + self.prob = prob + self.expand_ratio = expand_ratio + + def __call__(self, data): + if self.phase == 'seg': + img, label = data + if torch.rand(1) < self.prob: + return data + height, width, channels = img.shape + ratio_width = self.expand_ratio * torch.rand([]) + ratio_height = self.expand_ratio * torch.rand([]) + left, right = torch.randint(high=int(max(1, width * ratio_width)), size=[2]) + top, bottom = torch.randint(high=int(max(1, width * ratio_height)), size=[2]) + img = cv2.copyMakeBorder( + img, int(top), int(bottom), int(left), int(right), cv2.BORDER_CONSTANT, value=self.prior_mean) + label = cv2.copyMakeBorder( + label, int(top), int(bottom), int(left), int(right), cv2.BORDER_CONSTANT, value=0) + return img, label + elif self.phase == 'cd': + img1, label1, img2, label2 = data + if torch.rand(1) < self.prob: + return data + height, width, channels = img1.shape + ratio_width = self.expand_ratio * torch.rand([]) + ratio_height = self.expand_ratio * torch.rand([]) + left, right = torch.randint(high=int(max(1, width * ratio_width)), size=[2]) + top, bottom = torch.randint(high=int(max(1, width * ratio_height)), size=[2]) + img1 = cv2.copyMakeBorder( + img1, int(top), int(bottom), int(left), int(right), cv2.BORDER_CONSTANT, value=self.prior_mean) + label1 = cv2.copyMakeBorder( + label1, int(top), int(bottom), int(left), int(right), cv2.BORDER_CONSTANT, value=0) + img2 = cv2.copyMakeBorder( + img2, int(top), int(bottom), int(left), int(right), cv2.BORDER_CONSTANT, value=self.prior_mean) + label2 = cv2.copyMakeBorder( + label2, int(top), int(bottom), int(left), int(right), cv2.BORDER_CONSTANT, value=0) + return img1, label1, img2, label2 + + elif self.phase == 'od': + if torch.rand(1) < self.prob: + return data + img, label = data + height, width, channels = img.shape + ratio_width = self.expand_ratio * torch.rand([]) + ratio_height = self.expand_ratio * torch.rand([]) + left, right = torch.randint(high=int(max(1, width * ratio_width)), size=[2]) + top, bottom = torch.randint(high=int(max(1, width * ratio_height)), size=[2]) + left = int(left) + right = int(right) + top = int(top) + bottom = int(bottom) + img = cv2.copyMakeBorder( + img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=self.prior_mean) + + label[:, 1::2] += left + label[:, 2::2] += top + return img, label + + +class RandomSampleCrop(object): + """ + Crop + Arguments: + img (Image): the image being input during training + boxes (Tensor): the original bounding boxes in pt form + label (Tensor): the class label for each bbox + mode (float tuple): the min and max jaccard overlaps + Return: + (img, boxes, classes) + img (Image): the cropped image + boxes (Tensor): the adjusted bounding boxes in pt form + label (Tensor): the class label for each bbox + """ + def __init__(self, + phase, + original_size=[512, 512], + prob=0.5, + crop_scale_ratios_range=[0.8, 1.2], + aspect_ratio_range=[4./5, 5./4]): + self.phase = phase + self.prob = prob + self.scale_range = crop_scale_ratios_range + self.original_size = original_size + self.aspect_ratio_range = aspect_ratio_range # h/w + self.max_try_times = 500 + + def __call__(self, data): + if self.phase == 'seg': + img, label = data + w, h, c = img.shape + if torch.rand(1) < self.prob: + return data + else: + try_times = 0 + while try_times < self.max_try_times: + crop_w = torch.randint( + min(w, int(self.scale_range[0] * self.original_size[0])), + min(w + 1, int(self.scale_range[1] * self.original_size[0])), + size=[] + ) + crop_h = torch.randint( + min(h, int(self.scale_range[0] * self.original_size[1])), + min(h + 1, int(self.scale_range[1] * self.original_size[1])), + size=[] + ) + # aspect ratio constraint + if self.aspect_ratio_range[0] < crop_h / crop_w < self.aspect_ratio_range[1]: + break + else: + try_times += 1 + if try_times >= self.max_try_times: + print("try times over max threshold!", flush=True) + return img, label + + left = torch.randint(0, w - crop_w + 1, size=[]) + top = torch.randint(0, h - crop_h + 1, size=[]) + img = img[top:(top + crop_h), left:(left + crop_w), :] + label = label[top:(top + crop_h), left:(left + crop_w)] + return img, label + + elif self.phase == 'od': + if torch.rand(1) < self.prob: + return data + img, label = data + w, h, c = img.shape + + while True: + crop_w = torch.randint( + min(w, int(self.scale_range[0] * self.original_size[0])), + min(w + 1, int(self.scale_range[1] * self.original_size[0])), + size=[] + ) + crop_h = torch.randint( + min(h, int(self.scale_range[0] * self.original_size[1])), + min(h + 1, int(self.scale_range[1] * self.original_size[1])), + size=[] + ) + + # aspect ratio constraint + if self.aspect_ratio_range[0] < crop_h / crop_w < self.aspect_ratio_range[1]: + break + + left = torch.randint(0, w - crop_w + 1, size=[]) + top = torch.randint(0, h - crop_h + 1, size=[]) + left = left.numpy() + top = top.numpy() + crop_h = crop_h.numpy() + crop_w = crop_w.numpy() + img = img[top:(top + crop_h), left:(left + crop_w), :] + if len(label): + # keep overlap with gt box IF center in sampled patch + centers = (label[:, 1:3] + label[:, 3:]) / 2.0 + # mask in all gt boxes that above and to the left of centers + m1 = (left <= centers[:, 0]) * (top <= centers[:, 1]) + # mask in all gt boxes that under and to the right of centers + m2 = ((left + crop_w) >= centers[:, 0]) * ((top + crop_h) > centers[:, 1]) + # mask in that both m1 and m2 are true + mask = m1 * m2 + + # take only matching gt boxes + current_label = label[mask, :] + + # adjust to crop (by substracting crop's left,top) + current_label[:, 1::2] -= left + current_label[:, 2::2] -= top + label = current_label + return img, label + + +class RandomMirror(object): + def __init__(self, phase, prob=0.5): + self.phase = phase + self.prob = prob + + def __call__(self, data): + if self.phase == 'seg': + img, label = data + if torch.rand(1) < self.prob: + img = img[:, ::-1] + label = label[:, ::-1] + return img, label + elif self.phase == 'cd': + img1, label1, img2, label2 = data + if torch.rand(1) < self.prob: + img1 = img1[:, ::-1] + label1 = label1[:, ::-1] + img2 = img2[:, ::-1] + label2 = label2[:, ::-1] + return img1, label1, img2, label2 + elif self.phase == 'od': + img, label = data + if torch.rand(1) < self.prob: + _, width, _ = img.shape + img = img[:, ::-1] + label[:, 1::2] = width - label[:, 3::-2] + return img, label + + +class RandomFlipV(object): + def __init__(self, phase, prob=0.5): + self.phase = phase + self.prob = prob + + def __call__(self, data): + if self.phase == 'seg': + img, label = data + if torch.rand(1) < self.prob: + img = img[::-1, :] + label = label[::-1, :] + return img, label + elif self.phase == 'cd': + img1, label1, img2, label2 = data + if torch.rand(1) < self.prob: + img1 = img1[::-1, :] + label1 = label1[::-1, :] + img2 = img2[::-1, :] + label2 = label2[::-1, :] + return img1, label1, img2, label2 + elif self.phase == 'od': + img, label = data + if torch.rand(1) < self.prob: + height, _, _ = img.shape + img = img[::-1, :] + label[:, 2::2] = height - label[:, 4:1:-2] + return img, label + + +class Resize(object): + def __init__(self, phase, size): + self.phase = phase + self.size = size + + def __call__(self, data): + if self.phase == 'seg': + img, label = data + img = cv2.resize(img, self.size, interpolation=cv2.INTER_LINEAR) + # for label + label = cv2.resize(label, self.size, interpolation=cv2.INTER_NEAREST) + return img, label + elif self.phase == 'cd': + img1, label1, img2, label2 = data + img1 = cv2.resize(img1, self.size, interpolation=cv2.INTER_LINEAR) + img2 = cv2.resize(img2, self.size, interpolation=cv2.INTER_LINEAR) + # for label + label1 = cv2.resize(label1, self.size, interpolation=cv2.INTER_NEAREST) + label2 = cv2.resize(label2, self.size, interpolation=cv2.INTER_NEAREST) + return img1, label1, img2, label2 + elif self.phase == 'od': + img, label = data + height, width, _ = img.shape + img = cv2.resize(img, self.size, interpolation=cv2.INTER_LINEAR) + label[:, 1::2] = label[:, 1::2] / width * self.size[0] + label[:, 2::2] = label[:, 2::2] / height * self.size[1] + return img, label + + +class Normalize(object): + def __init__(self, phase, prior_mean, prior_std): + self.phase = phase + self.prior_mean = np.array([[prior_mean]], dtype=np.float32) + self.prior_std = np.array([[prior_std]], dtype=np.float32) + + def __call__(self, data): + if self.phase in ['od', 'seg']: + img, _ = data + img = img / 255. + img = (img - self.prior_mean) / (self.prior_std + 1e-10) + + return img, _ + elif self.phase == 'cd': + img1, label1, img2, label2 = data + img1 = img1 / 255. + img1 = (img1 - self.prior_mean) / (self.prior_std + 1e-10) + img2 = img2 / 255. + img2 = (img2 - self.prior_mean) / (self.prior_std + 1e-10) + + return img1, label1, img2, label2 + + +class InvNormalize(object): + def __init__(self, prior_mean, prior_std): + self.prior_mean = np.array([[prior_mean]], dtype=np.float32) + self.prior_std = np.array([[prior_std]], dtype=np.float32) + + def __call__(self, img): + img = img * self.prior_std + self.prior_mean + img = img * 255. + img = np.clip(img, a_min=0, a_max=255) + return img + + +class Augmentations(object): + def __init__(self, size, prior_mean=0, prior_std=1, pattern='train', phase='seg', *args, **kwargs): + self.size = size + self.prior_mean = prior_mean + self.prior_std = prior_std + self.phase = phase + + augments = { + 'train': Compose([ + ConvertUcharToFloat(), + ImgDistortion(self.phase), + ExpandImg(self.phase, self.prior_mean), + RandomSampleCrop(self.phase, original_size=self.size), + RandomMirror(self.phase), + RandomFlipV(self.phase), + Resize(self.phase, self.size), + Normalize(self.phase, self.prior_mean, self.prior_std), + ]), + 'val': Compose([ + ConvertUcharToFloat(), + Resize(self.phase, self.size), + Normalize(self.phase, self.prior_mean, self.prior_std), + ]), + 'test': Compose([ + ConvertUcharToFloat(), + Resize(self.phase, self.size), + Normalize(self.phase, self.prior_mean, self.prior_std), + ]) + } + self.augment = augments[pattern] + + def __call__(self, data): + return self.augment(data) + diff --git a/Utils/Datasets.py b/Utils/Datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..a9ec3573ae62d0361ce7a9015389c1a44b4957cd --- /dev/null +++ b/Utils/Datasets.py @@ -0,0 +1,143 @@ +import os.path + +from torch.utils.data import Dataset, DataLoader +import torch +import numpy as np +import pandas as pd +from skimage import io +from Utils.Augmentations import Augmentations, Resize + + +class Datasets(Dataset): + def __init__(self, data_file, transform=None, phase='train', *args, **kwargs): + self.transform = transform + self.data_info = pd.read_csv(data_file, index_col=0) + self.phase = phase + + def __len__(self): + return len(self.data_info) + + def __getitem__(self, index): + data = self.pull_item_seg(index) + return data + + def pull_item_seg(self, index): + """ + :param index: image index + """ + data = self.data_info.iloc[index] + img_name = data['img'] + label_name = data['label'] + + ori_img = io.imread(img_name, as_gray=False) + ori_label = io.imread(label_name, as_gray=True) + assert (ori_img is not None and ori_label is not None), f'{img_name} or {label_name} is not valid' + + if self.transform is not None: + img, label = self.transform((ori_img, ori_label)) + + one_hot_label = np.zeros([2] + list(label.shape), dtype=np.float) + one_hot_label[0] = label == 0 + one_hot_label[1] = label > 0 + return_dict = { + 'img': torch.from_numpy(img).permute(2, 0, 1), + 'label': torch.from_numpy(one_hot_label), + 'img_name': os.path.basename(img_name) + } + return return_dict + + +def get_data_loader(config, test_mode=False): + if not test_mode: + train_params = { + 'batch_size': config['BATCH_SIZE'], + 'shuffle': config['IS_SHUFFLE'], + 'drop_last': False, + 'collate_fn': collate_fn, + 'num_workers': config['NUM_WORKERS'], + 'pin_memory': False + } + # data_file, config, transform=None + train_set = Datasets( + config['DATASET'], + Augmentations( + config['IMG_SIZE'], config['PRIOR_MEAN'], config['PRIOR_STD'], 'train', config['PHASE'], config + ), + config['PHASE'], + config + ) + patterns = ['train'] + else: + patterns = [] + + if config['IS_VAL']: + val_params = { + 'batch_size': config['VAL_BATCH_SIZE'], + 'shuffle': False, + 'drop_last': False, + 'collate_fn': collate_fn, + 'num_workers': config['NUM_WORKERS'], + 'pin_memory': False + } + val_set = Datasets( + config['VAL_DATASET'], + Augmentations( + config['IMG_SIZE'], config['PRIOR_MEAN'], config['PRIOR_STD'], 'val', config['PHASE'], config + ), + config['PHASE'], + config + ) + patterns += ['val'] + + if config['IS_TEST']: + test_params = { + 'batch_size': config['VAL_BATCH_SIZE'], + 'shuffle': False, + 'drop_last': False, + 'collate_fn': collate_fn, + 'num_workers': config['NUM_WORKERS'], + 'pin_memory': False + } + test_set = Datasets( + config['TEST_DATASET'], + Augmentations( + config['IMG_SIZE'], config['PRIOR_MEAN'], config['PRIOR_STD'], 'test', config['PHASE'], config + ), + config['PHASE'], + config + ) + patterns += ['test'] + + data_loaders = {} + for x in patterns: + data_loaders[x] = DataLoader(eval(x+'_set'), **eval(x+'_params')) + return data_loaders + + +def collate_fn(batch): + def to_tensor(item): + if torch.is_tensor(item): + return item + elif isinstance(item, type(np.array(0))): + return torch.from_numpy(item).float() + elif isinstance(item, type('0')): + return item + elif isinstance(item, list): + return item + elif isinstance(item, dict): + return item + + return_data = {} + for key in batch[0].keys(): + return_data[key] = [] + + for sample in batch: + for key, value in sample.items(): + return_data[key].append(to_tensor(value)) + + keys = set(batch[0].keys()) - {'img_name'} + for key in keys: + return_data[key] = torch.stack(return_data[key], dim=0) + + return return_data + diff --git a/Utils/Utils.py b/Utils/Utils.py new file mode 100644 index 0000000000000000000000000000000000000000..af272cc343ede0274a5f4ad7779efa6386597d92 --- /dev/null +++ b/Utils/Utils.py @@ -0,0 +1,619 @@ +import yaml +import torch +import random +import numpy as np +import os +import sys +import matplotlib.pyplot as plt +from einops import repeat +import cv2 +import time +import torch.nn.functional as F + + +__all__ = ["decode_mask_to_onehot", + "encode_onehot_to_mask", + 'Logger', + 'get_coords_grid', + 'get_coords_grid_float', + 'draw_bboxes', + 'Infos', + 'inv_normalize_img', + 'make_numpy_img', + 'get_metrics' + ] + + +class Infos(object): + def __init__(self, phase, class_names=None): + assert phase in ['od'], "Error in Infos" + self.phase = phase + self.class_names = class_names + self.register() + self.pattern = 'train' + self.epoch_id = 0 + self.max_epoch = 0 + self.batch_id = 0 + self.batch_num = 0 + self.lr = 0 + self.fps_data_load = 0 + self.fps = 0 + self.val_metric = 0 + + # 'running_acc': {'loss': [], 'mIoU': [], 'OA': [], 'F1_score': []}, + # 'epoch_metrics': {'loss': 1e10, 'mIoU': 0, 'OA': 0, 'F1_score': 0}, + # 'best_val_metrics': {'epoch_id': 0, 'loss': 1e10, 'mIoU': 0, 'OA': 0, 'F1_score': 0}, + def set_epoch_training_time(self, data): + self.epoch_training_time = data + + def set_pattern(self, data): + self.pattern = data + def set_epoch_id(self, data): + self.epoch_id = data + def set_max_epoch(self, data): + self.max_epoch = data + def set_batch_id(self, data): + self.batch_id = data + def set_batch_num(self, data): + self.batch_num = data + def set_lr(self, data): + self.lr = data + def set_fps_data_load(self, data): + self.fps_data_load = data + def set_fps(self, data): + self.fps = data + def clear_cache(self): + self.register() + + def get_val_metric(self): + return self.val_metric + + def cal_metrics(self): + if self.phase == 'od': + coco_api_gt = COCO() + coco_api_gt.dataset['images'] = [] + coco_api_gt.dataset['annotations'] = [] + ann_id = 0 + for i, targets_per_image in enumerate(self.result_all['target_all']): + for j in range(targets_per_image.shape[0]): + coco_api_gt.dataset['images'].append({'id': i}) + coco_api_gt.dataset['annotations'].append({ + 'image_id': i, + "category_id": int(targets_per_image[j, 0]), + "bbox": np.hstack([targets_per_image[j, 1:3], targets_per_image[j, 3:5] - targets_per_image[j, 1:3]]), + "area": np.prod(targets_per_image[j, 3:5] - targets_per_image[j, 1:3]), + "id": ann_id, + "iscrowd": 0 + }) + ann_id += 1 + coco_api_gt.dataset['categories'] = [{"id": i, "supercategory": c, "name": c} for i, c in + enumerate(self.class_names)] + coco_api_gt.createIndex() + + coco_api_pred = COCO() + coco_api_pred.dataset['images'] = [] + coco_api_pred.dataset['annotations'] = [] + ann_id = 0 + for i, preds_per_image in enumerate(self.result_all['pred_all']): + for j in range(preds_per_image.shape[0]): + coco_api_pred.dataset['images'].append({'id': i}) + coco_api_pred.dataset['annotations'].append({ + 'image_id': i, + "category_id": int(preds_per_image[j, 0]), + 'score': preds_per_image[j, 1], + "bbox": np.hstack( + [preds_per_image[j, 2:4], preds_per_image[j, 4:6] - preds_per_image[j, 2:4]]), + "area": np.prod(preds_per_image[j, 4:6] - preds_per_image[j, 2:4]), + "id": ann_id, + "iscrowd": 0 + }) + ann_id += 1 + coco_api_pred.dataset['categories'] = [{"id": i, "supercategory": c, "name": c} for i, c in + enumerate(self.class_names)] + coco_api_pred.createIndex() + + coco_eval = COCOeval(coco_api_gt, coco_api_pred, "bbox") + coco_eval.params.imgIds = coco_api_gt.getImgIds() + coco_eval.evaluate() + coco_eval.accumulate() + self.metrics = coco_eval.summarize() + self.val_metric = self.metrics[1] + + def print_epoch_state_infos(self, logger): + infos_str = 'Pattern: %s Epoch [%d,%d], time: %d loss: %.4f' % \ + (self.pattern, self.epoch_id, self.max_epoch, self.epoch_training_time, np.mean(self.loss_all['loss'])) + logger.write(infos_str + '\n') + time_start = time.time() + self.cal_metrics() + time_end = time.time() + logger.write('Pattern: %s Epoch Eval_time: %d\n' % (self.pattern, (time_end - time_start))) + + if self.phase == 'od': + titleStr = 6 * ['Average Precision'] + 6 * ['Average Recall'] + typeStr = 6 * ['(AP)'] + 6 * ['(AR)'] + iouStr = 12 * ['0.50:0.95'] + iouStr[1] = '0.50' + iouStr[2] = '0.75' + areaRng = 3 * ['all'] + ['small', 'medium', 'large'] + 3 * ['all'] + ['small', 'medium', 'large'] + maxDets = 6 * [100] + [1, 10, 100] + 3 * [100] + for i in range(12): + infos_str = '{:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}\n' + logger.write(infos_str.format(titleStr[i], typeStr[i], iouStr[i], areaRng[i], maxDets[i], self.metrics[i])) + + + def save_epoch_state_infos(self, writer): + iter = self.epoch_id + keys = [ + 'AP_m_all_100', + 'AP_50_all_100', + 'AP_75_all_100', + 'AP_m_small_100', + 'AP_m_medium_100', + 'AP_m_large_100', + 'AR_m_all_1', + 'AR_m_all_10', + 'AR_m_all_100', + 'AR_m_small_100', + 'AR_m_medium_100', + 'AR_m_large_100', + ] + for i, key in enumerate(keys): + writer.add_scalar(f'%s/epoch/%s' % (self.pattern, key), self.metrics[i], iter) + + def print_batch_state_infos(self, logger): + infos_str = 'Pattern: %s [%d,%d][%d,%d], lr: %5f, fps_data_load: %.2f, fps: %.2f' % \ + (self.pattern, self.epoch_id, self.max_epoch, self.batch_id, + self.batch_num, self.lr, self.fps_data_load, self.fps) + # add loss + infos_str += ', loss: %.4f' % self.loss_all['loss'][-1] + logger.write(infos_str + '\n') + + def save_batch_state_infos(self, writer): + iter = self.epoch_id * self.batch_num + self.batch_id + writer.add_scalar('%s/lr' % self.pattern, self.lr, iter) + for key, value in self.loss_all.items(): + writer.add_scalar(f'%s/%s' % (self.pattern, key), value[-1], iter) + + def save_results(self, img_batch, prior_mean, prior_std, vis_dir, *args, **kwargs): + batch_size = img_batch.size(0) + k = np.clip(int(0.3 * batch_size), a_min=1, a_max=batch_size) + ids = np.random.choice(range(batch_size), k, replace=False) + for img_id in ids: + img = img_batch[img_id].detach().cpu() + pred = self.result_all['pred_all'][img_id - batch_size] + target = self.result_all['target_all'][img_id - batch_size] + + img = make_numpy_img(inv_normalize_img(img, prior_mean, prior_std)) + pred_draw = draw_bboxes(img, pred, self.class_names, (255, 0, 0)) + target_draw = draw_bboxes(img, target, self.class_names, (0, 255, 0)) + # target = make_numpy_img(encode_onehot_to_mask(target)) + # pred = make_numpy_img(pred_label[img_id]) + + vis = np.concatenate([img/255., pred_draw/255., target_draw/255.], axis=0) + vis = np.clip(vis, a_min=0, a_max=1) + file_name = os.path.join(vis_dir, self.pattern, f'{self.epoch_id}_{self.batch_id}_{img_id}.png') + plt.imsave(file_name, vis) + + def register(self): + self.is_registered_result = False + self.result_all = {} + + self.is_registered_loss = False + self.loss_all = {} + + def register_result(self, data: dict): + for key in data.keys(): + self.result_all[key] = [] + self.is_registered_result = True + + def append_result(self, data: dict): + if not self.is_registered_result: + self.register_result(data) + for key, value in data.items(): + self.result_all[key] += value + + def register_loss(self, data: dict): + for key in data.keys(): + self.loss_all[key] = [] + self.is_registered_loss = True + + def append_loss(self, data: dict): + if not self.is_registered_loss: + self.register_loss(data) + for key, value in data.items(): + self.loss_all[key].append(value.detach().cpu().numpy()) + + +# draw bboxes on image, bboxes with classID +def draw_bboxes(img, bboxes, color=(255, 0, 0), class_names=None, is_show_score=True): + ''' + Args: + img: + bboxes: [n, 5], class_idx, l, t, r, b + [n, 6], class_idx, score, l, t, r, b + Returns: + ''' + assert img is not None, "In draw_bboxes, img is None" + if torch.is_tensor(img): + img = img.cpu().numpy() + img = img.astype(np.uint8).copy() + + if torch.is_tensor(bboxes): + bboxes = bboxes.cpu().numpy() + for bbox in bboxes: + if class_names: + class_name = class_names[int(bbox[0])] + bbox_coordinate = bbox[1:] + if len(bbox) == 6: + score = bbox[1] + bbox_coordinate = bbox[2:] + bbox_coordinate = bbox_coordinate.astype(np.int) + if is_show_score: + cv2.rectangle(img, pt1=tuple(bbox_coordinate[0:2] - np.array([2, 15])), + pt2=tuple(bbox_coordinate[0:2] + np.array([15, 1])), color=(0, 0, 255), thickness=-1) + if len(bbox) == 6: + cv2.putText(img, text='%s:%.2f' % (class_name, score), + org=tuple(bbox_coordinate[0:2] - np.array([1, 7])), fontFace=cv2.FONT_HERSHEY_SIMPLEX, + fontScale=0.2, color=(255, 255, 255), thickness=1) + else: + cv2.putText(img, text='%s' % class_name, + org=tuple(bbox_coordinate[0:2] - np.array([1, 7])), fontFace=cv2.FONT_HERSHEY_SIMPLEX, + fontScale=0.2, color=(255, 255, 255), thickness=1) + cv2.rectangle(img, pt1=tuple(bbox_coordinate[0:2]), pt2=tuple(bbox_coordinate[2:4]), color=color, thickness=2) + return img + + +def get_coords_grid(h_end, w_end, h_start=0, w_start=0, h_steps=None, w_steps=None, is_normalize=False): + if h_steps is None: + h_steps = int(h_end - h_start) + 1 + if w_steps is None: + w_steps = int(w_end - w_start) + 1 + + y = torch.linspace(h_start, h_end, h_steps) + x = torch.linspace(w_start, w_end, w_steps) + if is_normalize: + y = y / h_end + x = x / w_end + coords = torch.meshgrid(y, x) + coords = torch.stack(coords[::-1], dim=0) + return coords + + +def get_coords_grid_float(ht, wd, scale, is_normalize=False): + y = torch.linspace(0, scale, ht + 2) + x = torch.linspace(0, scale, wd + 2) + if is_normalize: + y = y/scale + x = x/scale + coords = torch.meshgrid(y[1:-1], x[1:-1]) + coords = torch.stack(coords[::-1], dim=0) + return coords + + +def get_coords_vector_float(len, scale, is_normalize=False): + x = torch.linspace(0, scale, len+2) + if is_normalize: + x = x/scale + coords = torch.meshgrid(x[1:-1], torch.tensor([0.])) + coords = torch.stack(coords[::-1], dim=0) + return coords + + +class Logger(object): + def __init__(self, filename="Default.log", is_terminal_show=True): + self.is_terminal_show = is_terminal_show + if self.is_terminal_show: + self.terminal = sys.stdout + self.log = open(filename, "a") + + def write(self, message): + if self.is_terminal_show: + self.terminal.write(message) + self.log.write(message) + self.flush() + + def flush(self): + if self.is_terminal_show: + self.terminal.flush() + self.log.flush() + + +class ParamsParser: + def __init__(self, project_file): + self.params = yaml.safe_load(open(project_file).read()) + + def __getattr__(self, item): + return self.params.get(item, None) + + +def get_all_dict(dict_infos: dict) -> dict: + return_dict = {} + for key, value in dict_infos.items(): + if not isinstance(value, dict): + return_dict[key] = value + else: + return_dict = dict(return_dict.items(), **get_all_dict(value)) + return return_dict + + +def make_numpy_img(tensor_data): + if len(tensor_data.shape) == 2: + tensor_data = tensor_data.unsqueeze(2) + tensor_data = torch.cat((tensor_data, tensor_data, tensor_data), dim=2) + elif tensor_data.size(0) == 1: + tensor_data = tensor_data.permute((1, 2, 0)) + tensor_data = torch.cat((tensor_data, tensor_data, tensor_data), dim=2) + elif tensor_data.size(0) == 3: + tensor_data = tensor_data.permute((1, 2, 0)) + elif tensor_data.size(2) == 3: + pass + else: + raise Exception('tensor_data apply to make_numpy_img error') + vis_img = tensor_data.detach().cpu().numpy() + + return vis_img + + +def print_infos(logger, writer, infos: dict): + keys = list(infos.keys()) + values = list(infos.values()) + infos_str = 'Pattern: %s [%d,%d][%d,%d], lr: %5f, fps_data_load: %.2f, fps: %.2f' % tuple(values[:8]) + if len(values) > 8: + extra_infos = [f', {x}: {y:.4f}' for x, y in zip(keys[8:], values[8:])] + infos_str = infos_str + ''.join(extra_infos) + + logger.write(infos_str + '\n') + + writer.add_scalar('%s/lr' % infos['pattern'], infos['lr'], + infos['epoch_id'] * infos['batch_num'] + infos['batch_id']) + for key, value in zip(keys[8:], values[8:]): + writer.add_scalar(f'%s/%s' % (infos['pattern'], key), value, + infos['epoch_id'] * infos['batch_num'] + infos['batch_id']) + + +def invert_affine(origin_imgs, preds, pattern='train'): + if pattern == 'val': + for i in range(len(preds)): + if len(preds[i]['rois']) == 0: + continue + else: + old_h, old_w, _ = origin_imgs[i].shape + preds[i]['rois'][:, [0, 2]] = preds[i]['rois'][:, [0, 2]] / (512 / old_w) + preds[i]['rois'][:, [1, 3]] = preds[i]['rois'][:, [1, 3]] / (512 / old_h) + return preds + + +def save_output_infos(input, output, vis_dir, pattern, epoch_id, batch_id): + flows, pf1s, pf2s = output + k = np.clip(int(0.2 * len(flows[0])), a_min=2, a_max=len(flows[0])) + ids = np.random.choice(range(len(flows[0])), k, replace=False) + for img_id in ids: + img1, img2 = input['ori_img1'][img_id:img_id+1].to(flows[0].device), input['ori_img2'][img_id:img_id+1].to(flows[0].device) + # call the network with image pair batches and actions + flow = flows[0][img_id:img_id+1] + warps = flow_to_warp(flow) + + warped_img2 = resample(img2, warps) + + ori_img1 = make_numpy_img(img1[0]) / 255. + ori_img2 = make_numpy_img(img2[0]) / 255. + warped_img2 = make_numpy_img(warped_img2[0]) / 255. + flow_amplitude = torch.sqrt(flow[0, 0:1, ...] ** 2 + flow[0, 1:2, ...] ** 2) + flow_amplitude = make_numpy_img(flow_amplitude) + flow_amplitude = (flow_amplitude - np.min(flow_amplitude)) / (np.max(flow_amplitude) - np.min(flow_amplitude) + 1e-10) + u = make_numpy_img(flow[0, 0:1, ...]) + v = make_numpy_img(flow[0, 1:2, ...]) + + vis = np.concatenate([ori_img1, ori_img2, warped_img2, flow_amplitude], axis=0) + vis = np.clip(vis, a_min=0, a_max=1) + file_name = os.path.join(vis_dir, pattern, str(epoch_id) + '_' + str(batch_id) + '.jpg') + plt.imsave(file_name, vis) + + +def inv_normalize_img(img, prior_mean=[0, 0, 0], prior_std=[1, 1, 1]): + prior_mean = torch.tensor(prior_mean, dtype=torch.float).to(img.device).view(img.size(0), 1, 1) + prior_std = torch.tensor(prior_std, dtype=torch.float).to(img.device).view(img.size(0), 1, 1) + img = img * prior_std + prior_mean + img = img * 255. + img = torch.clamp(img, min=0, max=255) + return img + + +def save_seg_output_infos(input, output, vis_dir, pattern, epoch_id, batch_id, prior_mean, prior_std): + pred_label = torch.argmax(output, 1) + k = np.clip(int(0.2 * len(pred_label)), a_min=1, a_max=len(pred_label[0])) + ids = np.random.choice(range(len(pred_label)), k, replace=False) + for img_id in ids: + img = input['img'][img_id].to(pred_label.device) + target = input['label'][img_id].to(pred_label.device) + + img = make_numpy_img(inv_normalize_img(img, prior_mean, prior_std)) / 255. + target = make_numpy_img(encode_onehot_to_mask(target)) + pred = make_numpy_img(pred_label[img_id]) + + vis = np.concatenate([img, pred, target], axis=0) + vis = np.clip(vis, a_min=0, a_max=1) + file_name = os.path.join(vis_dir, pattern, str(epoch_id) + '_' + str(batch_id) + '.jpg') + plt.imsave(file_name, vis) + + +def set_requires_grad(nets, requires_grad=False): + """Set requies_grad=Fasle for all the networks to avoid unnecessary computations + Parameters: + nets (network list) -- a list of networks + requires_grad (bool) -- whether the networks require gradients or not + """ + if not isinstance(nets, list): + nets = [nets] + for net in nets: + if net is not None: + for param in net.parameters(): + param.requires_grad = requires_grad + + +def boolean_string(s): + if s not in {'False', 'True'}: + raise ValueError('Not a valid boolean string') + return s == 'True' + + +def cpt_pxl_cls_acc(pred_idx, target): + pred_idx = torch.reshape(pred_idx, [-1]) + target = torch.reshape(target, [-1]) + return torch.mean((pred_idx.int() == target.int()).float()) + + +def cpt_batch_psnr(img, img_gt, PIXEL_MAX): + mse = torch.mean((img - img_gt) ** 2, dim=[1, 2, 3]) + psnr = 20 * torch.log10(PIXEL_MAX / torch.sqrt(mse)) + return torch.mean(psnr) + + +def cpt_psnr(img, img_gt, PIXEL_MAX): + mse = np.mean((img - img_gt) ** 2) + psnr = 20 * np.log10(PIXEL_MAX / np.sqrt(mse)) + return psnr + + +def cpt_rgb_ssim(img, img_gt): + img = clip_01(img) + img_gt = clip_01(img_gt) + SSIM = 0 + for i in range(3): + tmp = img[:, :, i] + tmp_gt = img_gt[:, :, i] + ssim = sk_cpt_ssim(tmp, tmp_gt) + SSIM = SSIM + ssim + return SSIM / 3.0 + + +def cpt_ssim(img, img_gt): + img = clip_01(img) + img_gt = clip_01(img_gt) + return sk_cpt_ssim(img, img_gt) + + +def decode_mask_to_onehot(mask, n_class): + ''' + mask : BxWxH or WxH + n_class : n + return : BxnxWxH or nxWxH + ''' + assert len(mask.shape) in [2, 3], "decode_mask_to_onehot error!" + if len(mask.shape) == 2: + mask = mask.unsqueeze(0) + onehot = torch.zeros((mask.size(0), n_class, mask.size(1), mask.size(2))).to(mask.device) + for i in range(n_class): + onehot[:, i, ...] = mask == i + if len(mask.shape) == 2: + onehot = onehot.squeeze(0) + return onehot + + +def encode_onehot_to_mask(onehot): + ''' + onehot: tensor, BxnxWxH or nxWxH + output: tensor, BxWxH or WxH + ''' + assert len(onehot.shape) in [3, 4], "encode_onehot_to_mask error!" + mask = torch.argmax(onehot, dim=len(onehot.shape)-3) + return mask + + +def decode(pred, target=None, *args, **kwargs): + """ + + Args: + phase: 'od' + pred: big_cls_1(0), big_reg_1, small_cls_1(2), small_reg_1, big_cls_2(4), big_reg_2, small_cls_2(6), small_reg_2 + target: [[n,5], [n,5]] list of tensor + + Returns: + + """ + phase = kwargs['phase'] + img_size = kwargs['img_size'] + if phase == 'od': + prior_box_wh = kwargs['prior_box_wh'] + conf_thres = kwargs['conf_thres'] + iou_thres = kwargs['iou_thres'] + conf_type = kwargs['conf_type'] + pred_conf_32_2 = F.softmax(pred[4], dim=1)[:, 1, ...] # B H W + pred_conf_64_2 = F.softmax(pred[6], dim=1)[:, 1, ...] # B H W + obj_mask_32_2 = pred_conf_32_2 > conf_thres # B H W + obj_mask_64_2 = pred_conf_64_2 > conf_thres # B H W + + pre_loc_32_2 = pred[1] + pred[5] # B 4 H W + pre_loc_32_2[:, 0::2, ...] *= prior_box_wh[0] + pre_loc_32_2[:, 1::2, ...] *= prior_box_wh[1] + x_y_grid = get_coords_grid(31, 31, 0, 0) + x_y_grid *= 8 + x_y_grid = torch.cat([x_y_grid, x_y_grid], dim=0) + pre_loc_32_2 += x_y_grid.to(pre_loc_32_2.device) + + pre_loc_64_2 = pred[3] + pred[7] # B 4 H W + pre_loc_64_2[:, 0::2, ...] *= prior_box_wh[0] + pre_loc_64_2[:, 1::2, ...] *= prior_box_wh[1] + x_y_grid_2 = get_coords_grid(63, 63, 0, 0) + x_y_grid_2 *= 4 + x_y_grid_2 = torch.cat([x_y_grid_2, x_y_grid_2], dim=0) + pre_loc_64_2 += x_y_grid_2.to(pre_loc_32_2.device) + + pred_all = [] + for i in range(pre_loc_32_2.size(0)): + score_32 = pred_conf_32_2[i][obj_mask_32_2[i]] # N + score_64 = pred_conf_64_2[i][obj_mask_64_2[i]] # M + + loc_32 = pre_loc_32_2[i].permute((1, 2, 0))[obj_mask_32_2[i]] # Nx4 + loc_64 = pre_loc_64_2[i].permute((1, 2, 0))[obj_mask_64_2[i]] # Mx4 + + score_list = torch.cat((score_32, score_64), dim=0).detach().cpu().numpy() + boxes_list = torch.cat((loc_32, loc_64), dim=0).detach().cpu().numpy() + boxes_list[:, 0::2] /= img_size[0] + boxes_list[:, 1::2] /= img_size[1] + label_list = np.ones_like(score_list) + # 目标预设150 + boxes_list = boxes_list[:150, :] + score_list = score_list[:150] + label_list = label_list[:150] + boxes, scores, labels = weighted_boxes_fusion([boxes_list], [score_list], [label_list], weights=None, + iou_thr=iou_thres, conf_type=conf_type) + boxes[:, 0::2] *= img_size[0] + boxes[:, 1::2] *= img_size[1] + pred_boxes = np.concatenate((labels.reshape(-1, 1), scores.reshape(-1, 1), boxes), axis=1) + pred_all.append(pred_boxes) + if target is not None: + target_all = [x.cpu().numpy() for x in target] + else: + target_all = None + return {"pred_all": pred_all, "target_all": target_all} + + + +def get_metrics(phase, pred, target): + + ''' + pred: logits, tensor, nBatch*nClass*W*H + target: labels, tensor, nBatch*nClass*W*H + ''' + if phase == 'seg': + pred = torch.argmax(pred.detach(), dim=1) + pred = decode_mask_to_onehot(pred, target.size(1)) + # positive samples in ground truth + gt_pos_sum = torch.sum(target == 1, dim=(0, 2, 3)) + # positive prediction in predict mask + pred_pos_sum = torch.sum(pred == 1, dim=(0, 2, 3)) + # cal true positive sample + true_pos_sum = torch.sum((target == 1) * (pred == 1), dim=(0, 2, 3)) + # Precision + precision = true_pos_sum / (pred_pos_sum + 1e-15) + # Recall + recall = true_pos_sum / (gt_pos_sum + 1e-15) + # IoU + IoU = true_pos_sum / (pred_pos_sum + gt_pos_sum - true_pos_sum + 1e-15) + # OA + OA = 1 - (pred_pos_sum + gt_pos_sum - 2 * true_pos_sum) / torch.sum(target >= 0, dim=(0, 2, 3)) + # F1-score + F1_score = 2 * precision * recall / (precision + recall + 1e-15) + return IoU, OA, F1_score +