Upload folder using huggingface_hub (#1)
Browse files- d2f9d5b1696d723607b469880b3d5616ee5b225a5296662d3b494a9f93762c27 (864c14f7179bf9dcdf7c90dbed7a41ee5e839da2)
- 72712ee369aab2d20658d111064b5f0d7a1499fc75277a6606fda5c45e0621b4 (62ad3c18015dcb4dc60ff7e6e468ff15c0334876)
- 343785fe216bc214dd04cc3e9c10d02e7a94143f99508481cd37022497f4783f (ea980d56daed988b608e7f7f5d95c5c04d4ea697)
- 22e429e0f0ccd12e262eff7dd5f76de81fa895e252330ff2269d3b0a71d57919 (13fdcc6ad6afb851b5358aba34dcab1cf21d6767)
- 4dc2a277835db86d453a6e9dbd5e071290d13d21e944b285578b6730ed36adcd (de0a1295bf46d3cd023a5c387424e6e3de075548)
- 2aad947094aa55aba4cb28a3be8165f085f0bc0ea622bfc1455937cd076e3e3c (ca28d59b4827feb16274699cd99de89ece6aa1ab)
- 9e97593500c52ee9f4ebe3dbc3b2d3f8ccf2806573bc32848b7002ed330eabfe (f00e6a25ea16d066400536988df75e970f7b4b09)
- 6932ae5e409108b07a5f2e0032f8d1426c5b1d8b45c336198a9a3d24e32bf27a (3518a2576bcc63d5a312d8257279ef3541937813)
- 9f813b4cf8b1b68c2650dd47e56fdf0f4fa11f0b634b34e113fe049555cf6b0b (80075679b76b6d7a771dee7bfc9e525c8541bdfd)
- d25a8a27251818db766711c75332797dec2e5ed154c67d95a710789abc38f1ae (0353ff565da94cf44d164cee53bedd43c546c686)
- 94fb5f9916f447a06693ce9c08cae6053e3a53b089203b07a77e111c7a48b7c5 (b1e50c8ba84454750277304e2add63747f969c24)
- 9d957872a8d1cfe72442bd80af4bdabc68981732ae7b7ef38873f94c71a6e38d (78524c180ca1df1858325b21b1dcd51a48d0b832)
- 7b42a8465af5b9c3bbbd35e9ee8f6eabdb767b82963b9a1edd54c0d0db78f519 (56ce03842edf47a35ff25011e68d7d33a91b9fd4)
- f6a04718ee1cee63d5f53af5947287d0e196e596e2c9bd2073739f83e55a2151 (45df139ec12e1336e6508de6ea5dccb3b72dce21)
- 2a2858e318f8a24f63e670f274755189120791789d2895e55be73e616158ec99 (9410754a0b758ea43963ddc4bf4575258e30b991)
- eafc0cef804e25c9a59bb93bf2ce5891c01e84b35c44e013fb0fcb88c091ed8a (84ceb6e6903ed2bb5717a9880353c9d0d63cfd14)
- 125c0488499d6143b105d0abc6a2f4cad87f10ad297c5bd54f0b64b882043133 (ed9d66f64a08d44ed8311fe81d7aa8050a6bd555)
- d79e8401776d882537e4f9e8b5e318b6017157ad821c12616710c8ead8005262 (051417feb00ee958d6fbaa78b8b1f5723593047a)
- d4428c1827509739b091f111ac62cbc9473f2af30cd19df41af86ce4f0aac3ac (9385b680a01363dadff5316da1b45b872d1c78da)
- f96abbc6a4b337e150b0a6f74d41fc562fa78901362eff8f935f731ee83c8d32 (fc44d6f314c4fe9414424fed7dbdd107b2157b0c)
- ab5595192559ac1a498ecc905c693ee5da489683cb13d9e4f70c8e4787a08e84 (d82ce16a74aaf42e3b5197dc20ae786ee12b56da)
- 691a76ad38e7f277af60b80f88e313c3b87a5ac30404aa5e9ccd1a821900f690 (03d5d897031cb74afc71644c7992cdd100f32609)
- 5f6ec4caedf83a3b9abd85036479f89d1fbe048f32bf90b09ed9dbff1434f198 (18700e75e8dc8bab031935b7ebbd733381bde785)
- ad3cefbea0cffd469a20aec62ec552f05c4aeec849036134b0774f527363947c (0ac1b1d0362aeacba7440ed5c46b41f460c00887)
- 1fd5cf87b52feeeb909bc84cc1c095d45a7fe8daadbed67fa2759f70108ac491 (d4c05cda6218d732d66a7f5ea51107e929f94c34)
- c6dfaa42dbb815ebb99abbb8cf8fe821fc45ef8208f95baf5f87b2ea762cfdee (173b2db59c6f2136a07d67506f3b70220d1a5928)
- 327a839011e474979217760c545d723bbb788967992f0ac692e4b3c48d9dd9b4 (269e400d6d434f77586d58bf853059a3aedfde52)
- 62e8bae67a6a2e2bfa75077f462b466439d3b603753904ec73a0a30250993387 (8a678657f6d88989dfe54e39aac033d2e70d7d28)
- fe2f79481822147e39f82eea4b911b45eb34b349efe358d2c523b96a3fd80571 (8537475fa7301119c409ec1adc68d9fd70043798)
- b8b16f6cd140bd482475a53ddaf68e0469a9011a7779f955a3c51a0103feb946 (03d5ebf0693a875f460168561bf5b95975d101c9)
- a46cc9965b4d399b3760d64abf4a24f368d62d727ba52c0c14045ba05b207b8f (82c0f6c9cce2ace10ee83e66b2e2a00a68144240)
- cb1175da0a7309cae6b5b239d54f8922be35970697125993c879327f694d03e1 (713781f6b3cc63827644f0ed02f6039fe61df6a5)
- 5081d0a653640f56b2f3eb3b1036115a02dbe906c74d4f1c818560c6b73d8a70 (e4f7258fba290f62526114265327167565b7d6fa)
- 8c5129423a57ac8c4dca1e65620c05adf6f1cbd624645a01dd03c5e5bc058d8b (fcfbafd24a19af4913dbd928a36f9d952f58c652)
- f35cb770179564361359a27b337e92fc9fd5a30b6ebdce95650998efb5c8a9c6 (61cb7b39db104df88a26ceb308d387c65ee5dfe0)
- fc8cc9b183f7e8ff0520eeb4452013ff0c1da97a790bb63e27be39002d0dce04 (6265ee5f69b2ed60a5e785ee765dfb979462730c)
- ee820fb0057f7294c3a40d195ba1b6c55e5aa7ed299f4a43190900a4043f92b7 (a737bd9a5af19cb5a1acf412ffa89efb1eca9e09)
- 66c2c9caf035f38a3689f26af8f9bb8b3b638d49459c292809b22417aacab0d6 (9cf38ebab276f2bbfa08008e1f01c9859b2b9a91)
- 882cd5887028c73f672dfc62dce8f2caccf201beab3a2ccbec9835932993e229 (71d39bbe94a51963f1f40343cf09b5e3de094e99)
- bd10a2393aebf0d10a96aa795e603fa632555e088590388b7ae13f36dedfd3a0 (ea91aa2508545f805f2bcfa4cda1c496fcaa7f40)
- 8bb324f87650de78aff21b8630cdb906f4ecf8e89b0b58d903f81632883d92fd (99a4cfecdbe7cd81ace3416e91a9bce6ffb6cc47)
- a1c56cbf345b93554b5c0a0d189e694601913359a6387424bc4a9d53971d6d27 (71f6d19caf6d5bfb2481f228f8e0ccdd694bf4c2)
- 15d12d23663866ea9352d97c0bc14cc1af0664bbd7640feebb3bebaa5181645a (d5447884e6009d5613092f75b7d9c7be03b3f39b)
- 547784d9ad8a093d6c7ef661df30ee89480ad1b53d4aabf0c289810aad4bb200 (1e0a71b4430e0fb4603669d20c1adead564014bb)
- e6dbb1fef7a153779585553550c0165b8b4f0713b3f8f69dcb954dac2b8abf23 (787d9a8cd403d8ab32b0d10a4906c117fbf2cc7e)
- 1f3e00d60711eea36626167c2a4e62ffa55b899e5f1b19a5aed0291a513db8ed (9f6d2471a7cc23a4c0b021b6f7f0ba5d20f17bd1)
- 91d8ba13af8d3fc37ddf0df4fd21a975bab53ce8df4c6e4c0cbe898f928fef0b (434a797575d14b302f78128d853b25121e78ac1f)
- 61838a48aa440bbb958ddb1960152f470e106a8642a927573921f5a14eb395a0 (a9a3f7373637f499dbad0bdd9886cc941895c93c)
- 933e5f6ad3865341fbe5f6fc68c602f1329c75a4b501a4d8de35292cf28fdfca (12f72b70f59948c2ab995e43a8b32da6448cf7a4)
- 6921ac405dc2a28b3898aea80ef36b99eadb8ea960d230b913e1a6cca56bbfe5 (26189145f005f38c8cdd6d45002c3771633a5292)
- 7763265e7706f2bce75530b24b89637a4effb2217ea0d5ccd5b8ad4be21057a8 (4461db0a4c205b3f7b7371e049d0014a1c6bd25d)
- d87f84dd7b77f46fe6f5baf4823a194f7ce7939556ff0369e9dd38c3c059d9c5 (471e459a4cac7804ec99fb28097b3a9fb350e666)
- 2081b42ff277f260028bf8d71da124f2d941925c2fac45dd2494a9eae4dc7f73 (0bcdbfca59152312578d8d615f1f874f867442b7)
- 4cdee9c9b19362e8c3e5788c9f7229cbfd0e7b42b2390406219f047de62f2111 (2c6d105de52a97c8f80d95a29906354af2cd5b18)
- 4e4b7b69db35bb5c463f1503ba2ddd7a30ade2565ee9418749951fb270d7f63a (7f75dd1f3a3fe08c385b925ca3ce98ca6000091b)
- 915698ad5bce7f3c862e93266a918d8c921fb5612a6adba2c4983a36e914f580 (e156afc4f8dfcc09cdbad8a096d18e833f59377e)
- dcbf7733203f7fa96b3758df957db6c9c5df613e8f3c1adbe524f0ec9e30aec6 (8835d4a648e01873b7225a98c04b0cf9b1057030)
- 4b7f12f6f9b8d96202a9940d257b76effbc443fc5ed3026f42b211f22ef0e412 (c59b859c39d259f794a5f557698926c8c5f256bf)
- 280683d634d93e294ede67c37f2c78e9833a2d642cf297d0b164969f3c4ca84f (40c1498d5ca5373090951c9a64dfc2d06f0fb708)
- 87ad050f1c232289d79f96521c1a8b2f1e85f1b8c7c610802852ac0eddb0e533 (0902e558132e0f7eceb2dc74a57252b80ca8de75)
- 522d040300b5ee19e831bf5edb526b71afa6cdae37f73e9e86eae6ce080b1d0f (c9338e9dcaaa12aff40ae47ed431098c207e9050)
- 71aa1e18dc90216ec14cf5e45f7107f04706961ed229f4b33e24afc2acc445a9 (bdeac5594b1a2b2187b5d8049ee17eaa5cc4efb3)
- 66e5c610f3dc2ce92aa5594ed50d01f995e0418cbb916a7cd88907cc20945e1f (781875521e92d6784dc8a9bd212c03a4a9e2e3fe)
- e2ec34a652004fa63d71c6e540533af9d01c6d0bfd324e9a4348adadb440141e (ef0758d3bd0c0f063663512b9357cedfb6ee950e)
- 06c6b6e3b67360a3de065ba177977c07047dfee394ee3a58c4b008d62ec5a46c (603de8e3f336d361f584646ac11dd51ffb8047af)
- c9971ce5330913107bfb5e4f45fca3f9a5c57e9eac9b99a5bdcdaddd3b74c2f9 (d60c8e433079603eca5d681564e6e0758d04c3a0)
- 777437eeca03308b6027523feea509ceb4546c4495d61f877e9ada901ad303ab (50fdebc04bb05dd92cf908333a714117d0f9784c)
- 535abdb9083eaed7ae028e53cd9450a43798ae202570f710bdbe57a724b12e47 (723c3daf6fe21a67031655f60738a39682fecc7c)
- e13a39e7cf8697228652e34866a8546e37ac97a39ca9cd45afbcba4a15c33f60 (2bc54c71c738d6f6158b8269d54cd1a790b2b322)
- a712baac141437803214ed623433bcd9ec04ca000c7eda39f33f81838c46f3d2 (f49214b9cb5808aacbe1184ee3b11e7484e864b5)
- 9c72f8aa408a516ca83bab7df52c229f4d8e8491b118e2887d3f94462af70ed8 (d51330ea04b497391d1b42defb1e91d7a8f62874)
- d7c1bfb4298b6a78d441e891b8012e2169359bf1956b3095f59971691df224b7 (d148a75472db16f33266832278c1ffda78551cee)
- d3ce69adf84eb6ba8269dd1d1fda8baa26a46afa2f4dabe3287c198c3f72bd56 (4b45703bf2c46f335510b6db07a6f98e90d96e4a)
- 4023396a8ca8105460f07fdf79fc343d55600c8e897fae0f15ef9447e938f978 (b356c44c85d9d654476bcce191756cfc457fd20f)
- 3553c5aa50058c1726ced12ff1d60903b92a9840e78d7f3a72af7cfe0f7015a2 (ad891a1810c210daf9e9ae9edd499f3a5e270555)
- bdd3a489b5ffccf9a584e67062f90c5c2e713d5b8900ceb8668b848919193e67 (e489981a15e70ba487f1b64153f73b864e53e3a9)
- 08546b8d569655c81daa3d8df7879c65a03e74f35448f5b154465a7a11d4066f (fa4944dc9e4138338e8da6337bafce44b180e7b8)
- ce7110a255ce401609500a35ef8e3d06db706c68999c31ac18b5bf7ffb9cb16e (74c140cbe35d5bde83d7d4b6a5b142d3276d6ad0)
- 98af78a4f3c383ce8ddbd71089f447b2d034c77cefd3289b7a516b3e8fedb6ae (ca73a0d823510168c97b7e84a0a7e8bfaf2d0f07)
- 7e1541c4e59861bcdc7ea662b8634b067d2e330613aadcee07a80e2250c51ac7 (c65d93e73ce371bec992974e399fadb75849d7f1)
- 85ab2f14035c69cc871c66640f36fb9ce2c44b18a12c99eaed5529ac25b263d4 (bf220846cddab7c5a1aa4644585b599949339ebc)
- e5388fa1e1aaf90e6b45d03a5406cc34376a6f0f3af85440148c19703f079525 (0541e2bf4e03da3c13e1fdc41c8527ea6318204c)
- 07f2d5d3f62ae322a9599f19f4af52522f2c98eebf001360b8c7427af4c8f6cb (13065a1b4444b37904739de54633b25a2a413870)
- 21da302dccb6e851a71f0abbe1a7a317befbd2f905e1b2cd5f94afa84a20bc8d (35234fb13998eebd67a9d456de9a377516ac1270)
- ab793765b3d74e07d56258a7a531d7f64a36b2b98521782000fb9a610c934cc3 (51949dcfcb54eef8eac9fde406c0ffa5168ca8af)
- 481db94ba864163d857b0d1c846d0fb7d19855504d4c9925143816d8a864373d (0bd0d2d4b31465e0f746d607a8de09183b9ee9d8)
- 3cfbbc1eb6228763b97f48964f310be97730ae77b7aa1853ac570a89ff1c1c8d (99531e2247bfbcb8843425b312ec227df68de975)
- 606e600402bbef8e2a3d565cb11e156adc6b29f232788e841298d7fd8bead5c2 (5e40247447d89a72bf1b20f7903df98e53dcb0f9)
- 33de2a8fb960aea29c487739242d50bda1e964636a9d60763bf0a0a9adaab34e (e38e50ca7707548cc3bbacf10fdb0f1f8f64bb09)
- 45d2303d438b69d811587ad4f88d372a63dcf4983b9c287434e5a01ebe5d2c12 (3955d7a0f27f2a887a4c6a7565cbbeb1f32d83ec)
- 2652710804a318b68c69757f8879b0a1e81f3608d487651ce4290c0bf0cacf77 (37ef98bb91015b442e1bfc775d1bb08a1cc61e6e)
- eab8b9631435f3fc0285f8b2dd758c188510898b0b1199888a13d91d6ff366a6 (5525ee7a7e137045e3816b99a7e99de3dbfd9f4d)
- a5cef1a1c11c3729674467378e568b54d6282fa9cbd5bc5d009c9959f20d3fcc (69a7927b6
- .gitignore +0 -0
- 3DUnet_Like/__pycache__/trainer.cpython-39.pyc +0 -0
- 3DUnet_Like/dataset/__pycache__/utils.cpython-39.pyc +0 -0
- 3DUnet_Like/dataset/brats.py +31 -0
- 3DUnet_Like/dataset/utils.py +100 -0
- 3DUnet_Like/logs/SegTransVAE/version_0/events.out.tfevents.1710047381.speech-demo.148199.0 +0 -0
- 3DUnet_Like/logs/SegTransVAE/version_0/hparams.yaml +1 -0
- 3DUnet_Like/logs/SegTransVAE/version_0/metric_log.csv +2 -0
- 3DUnet_Like/loss/__init__.py +0 -0
- 3DUnet_Like/loss/__pycache__/__init__.cpython-39.pyc +0 -0
- 3DUnet_Like/loss/__pycache__/loss.cpython-39.pyc +0 -0
- 3DUnet_Like/loss/loss.py +55 -0
- 3DUnet_Like/models/SegTranVAE/SegTranVAE.py +538 -0
- 3DUnet_Like/models/SegTranVAE/__init__.py +0 -0
- 3DUnet_Like/models/SegTranVAE/__pycache__/SegTranVAE.cpython-39.pyc +0 -0
- 3DUnet_Like/models/SegTranVAE/__pycache__/__init__.cpython-39.pyc +0 -0
- 3DUnet_Like/train.py +69 -0
- 3DUnet_Like/trainer.py +233 -0
- brats_2021_task1/BraTS2021_Training_Data/.DS_Store +0 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00000/BraTS2021_00000_flair.nii.gz +3 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00000/BraTS2021_00000_seg.nii.gz +0 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00000/BraTS2021_00000_t1.nii.gz +3 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00000/BraTS2021_00000_t1ce.nii.gz +3 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00000/BraTS2021_00000_t2.nii.gz +3 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00002/BraTS2021_00002_flair.nii.gz +3 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00002/BraTS2021_00002_seg.nii.gz +0 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00002/BraTS2021_00002_t1.nii.gz +3 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00002/BraTS2021_00002_t1ce.nii.gz +3 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00002/BraTS2021_00002_t2.nii.gz +3 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00003/BraTS2021_00003_flair.nii.gz +3 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00003/BraTS2021_00003_seg.nii.gz +0 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00003/BraTS2021_00003_t1.nii.gz +3 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00003/BraTS2021_00003_t1ce.nii.gz +3 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00003/BraTS2021_00003_t2.nii.gz +3 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00005/BraTS2021_00005_flair.nii.gz +3 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00005/BraTS2021_00005_seg.nii.gz +0 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00005/BraTS2021_00005_t1.nii.gz +3 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00005/BraTS2021_00005_t1ce.nii.gz +3 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00005/BraTS2021_00005_t2.nii.gz +3 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00006/BraTS2021_00006_flair.nii.gz +3 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00006/BraTS2021_00006_seg.nii.gz +0 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00006/BraTS2021_00006_t1.nii.gz +3 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00006/BraTS2021_00006_t1ce.nii.gz +3 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00006/BraTS2021_00006_t2.nii.gz +3 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00008/BraTS2021_00008_flair.nii.gz +3 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00008/BraTS2021_00008_seg.nii.gz +0 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00008/BraTS2021_00008_t1.nii.gz +3 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00008/BraTS2021_00008_t1ce.nii.gz +3 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00008/BraTS2021_00008_t2.nii.gz +3 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00009/BraTS2021_00009_flair.nii.gz +3 -0
File without changes
|
Binary file (7.55 kB). View file
|
|
Binary file (2.66 kB). View file
|
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from monai.transforms import MapTransform
|
3 |
+
|
4 |
+
|
5 |
+
class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform):
|
6 |
+
"""
|
7 |
+
Convert labels to multi channels based on brats classes:
|
8 |
+
label 1 is the necrotic and non-enhancing tumor core
|
9 |
+
label 2 is the peritumoral edema
|
10 |
+
label 4 is the GD-enhancing tumor
|
11 |
+
The possible classes are TC (Tumor core), WT (Whole tumor)
|
12 |
+
and ET (Enhancing tumor).
|
13 |
+
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __call__(self, data):
|
17 |
+
d = dict(data)
|
18 |
+
for key in self.keys:
|
19 |
+
result = []
|
20 |
+
# merge label 1 and label 4 to construct TC
|
21 |
+
result.append(np.logical_or(d[key] == 1, d[key] == 4))
|
22 |
+
# merge labels 1, 2 and 4 to construct WT
|
23 |
+
result.append(
|
24 |
+
np.logical_or(
|
25 |
+
np.logical_or(d[key] == 1, d[key] == 4), d[key] == 2
|
26 |
+
)
|
27 |
+
)
|
28 |
+
# label 4 is ET
|
29 |
+
result.append(d[key] == 4)
|
30 |
+
d[key] = np.stack(result, axis=0).astype(np.float32)
|
31 |
+
return d
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
from sklearn.model_selection import train_test_split
|
4 |
+
|
5 |
+
from monai.data import DataLoader, Dataset
|
6 |
+
from monai import transforms
|
7 |
+
|
8 |
+
def datafold_read(datalist, basedir, fold=0, key="training"):
|
9 |
+
with open(datalist) as f:
|
10 |
+
json_data = json.load(f)
|
11 |
+
|
12 |
+
json_data = json_data[key]
|
13 |
+
|
14 |
+
for d in json_data:
|
15 |
+
for k in d:
|
16 |
+
if isinstance(d[k], list):
|
17 |
+
d[k] = [os.path.join(basedir, iv) for iv in d[k]]
|
18 |
+
elif isinstance(d[k], str):
|
19 |
+
d[k] = os.path.join(basedir, d[k]) if len(d[k]) > 0 else d[k]
|
20 |
+
|
21 |
+
tr = []
|
22 |
+
val = []
|
23 |
+
for d in json_data:
|
24 |
+
if "fold" in d and d["fold"] == fold:
|
25 |
+
val.append(d)
|
26 |
+
else:
|
27 |
+
tr.append(d)
|
28 |
+
|
29 |
+
return tr, val
|
30 |
+
|
31 |
+
|
32 |
+
def split_train_test(datalist, basedir, fold,test_size = 0.2, volume : float = None) :
|
33 |
+
train_files, _ = datafold_read(datalist=datalist, basedir=basedir, fold=fold)
|
34 |
+
if volume != None :
|
35 |
+
train_files, _ = train_test_split(train_files,test_size=volume,random_state=42)
|
36 |
+
|
37 |
+
train_files,validation_files = train_test_split(train_files,test_size=test_size, random_state=42)
|
38 |
+
|
39 |
+
validation_files,test_files = train_test_split(validation_files,test_size=test_size, random_state=42)
|
40 |
+
return train_files, validation_files, test_files
|
41 |
+
|
42 |
+
|
43 |
+
def get_loader(batch_size, data_dir, json_list, fold, roi,volume :float = None,test_size = 0.2):
|
44 |
+
train_files,validation_files,test_files = split_train_test(datalist = json_list,basedir = data_dir,test_size=test_size,fold = fold,volume= volume)
|
45 |
+
|
46 |
+
train_transform = transforms.Compose(
|
47 |
+
[
|
48 |
+
transforms.LoadImaged(keys=["image", "label"]),
|
49 |
+
transforms.ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
|
50 |
+
transforms.CropForegroundd(
|
51 |
+
keys=["image", "label"],
|
52 |
+
source_key="image",
|
53 |
+
k_divisible=[roi[0], roi[1], roi[2]],
|
54 |
+
),
|
55 |
+
transforms.RandSpatialCropd(
|
56 |
+
keys=["image", "label"],
|
57 |
+
roi_size=[roi[0], roi[1], roi[2]],
|
58 |
+
random_size=False,
|
59 |
+
),
|
60 |
+
transforms.RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
|
61 |
+
transforms.RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
|
62 |
+
transforms.RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
|
63 |
+
transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
|
64 |
+
transforms.RandScaleIntensityd(keys="image", factors=0.1, prob=1.0),
|
65 |
+
transforms.RandShiftIntensityd(keys="image", offsets=0.1, prob=1.0),
|
66 |
+
]
|
67 |
+
)
|
68 |
+
val_transform = transforms.Compose(
|
69 |
+
[
|
70 |
+
transforms.LoadImaged(keys=["image", "label"]),
|
71 |
+
transforms.ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
|
72 |
+
transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
|
73 |
+
]
|
74 |
+
)
|
75 |
+
|
76 |
+
train_ds = Dataset(data=train_files, transform=train_transform)
|
77 |
+
train_loader = DataLoader(
|
78 |
+
train_ds,
|
79 |
+
batch_size=batch_size,
|
80 |
+
shuffle=True,
|
81 |
+
num_workers=2,
|
82 |
+
pin_memory=True,
|
83 |
+
)
|
84 |
+
val_ds = Dataset(data=validation_files, transform=val_transform)
|
85 |
+
val_loader = DataLoader(
|
86 |
+
val_ds,
|
87 |
+
batch_size=1,
|
88 |
+
shuffle=False,
|
89 |
+
num_workers=2,
|
90 |
+
pin_memory=True,
|
91 |
+
)
|
92 |
+
test_ds = Dataset(data=test_files, transform=val_transform)
|
93 |
+
test_loader = DataLoader(
|
94 |
+
test_ds,
|
95 |
+
batch_size=1,
|
96 |
+
shuffle=False,
|
97 |
+
num_workers=2,
|
98 |
+
pin_memory=True,
|
99 |
+
)
|
100 |
+
return train_loader, val_loader,test_loader
|
Binary file (117 kB). View file
|
|
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{}
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
Epoch,Mean Dice Score,Dice TC,Dice WT,Dice ET
|
2 |
+
0,0.004601036664098501,0.0006361556006595492,0.012770041823387146,0.0003969123645219952
|
File without changes
|
Binary file (124 Bytes). View file
|
|
Binary file (2.16 kB). View file
|
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
class Loss_VAE(nn.Module):
|
5 |
+
def __init__(self):
|
6 |
+
super().__init__()
|
7 |
+
self.mse = nn.MSELoss(reduction='sum')
|
8 |
+
|
9 |
+
def forward(self, recon_x, x, mu, log_var):
|
10 |
+
mse = self.mse(recon_x, x)
|
11 |
+
kld = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
|
12 |
+
loss = mse + kld
|
13 |
+
return loss
|
14 |
+
|
15 |
+
|
16 |
+
def DiceScore(
|
17 |
+
y_pred: torch.Tensor,
|
18 |
+
y: torch.Tensor,
|
19 |
+
include_background: bool = True,
|
20 |
+
) -> torch.Tensor:
|
21 |
+
"""Computes Dice score metric from full size Tensor and collects average.
|
22 |
+
Args:
|
23 |
+
y_pred: input data to compute, typical segmentation model output.
|
24 |
+
It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values
|
25 |
+
should be binarized.
|
26 |
+
y: ground truth to compute mean dice metric. It must be one-hot format and first dim is batch.
|
27 |
+
The values should be binarized.
|
28 |
+
include_background: whether to skip Dice computation on the first channel of
|
29 |
+
the predicted output. Defaults to True.
|
30 |
+
Returns:
|
31 |
+
Dice scores per batch and per class, (shape [batch_size, num_classes]).
|
32 |
+
Raises:
|
33 |
+
ValueError: when `y_pred` and `y` have different shapes.
|
34 |
+
"""
|
35 |
+
|
36 |
+
y = y.float()
|
37 |
+
y_pred = y_pred.float()
|
38 |
+
|
39 |
+
if y.shape != y_pred.shape:
|
40 |
+
raise ValueError("y_pred and y should have same shapes.")
|
41 |
+
|
42 |
+
# reducing only spatial dimensions (not batch nor channels)
|
43 |
+
n_len = len(y_pred.shape)
|
44 |
+
reduce_axis = list(range(2, n_len))
|
45 |
+
intersection = torch.sum(y * y_pred, dim=reduce_axis)
|
46 |
+
|
47 |
+
y_o = torch.sum(y, reduce_axis)
|
48 |
+
y_pred_o = torch.sum(y_pred, dim=reduce_axis)
|
49 |
+
denominator = y_o + y_pred_o
|
50 |
+
|
51 |
+
return torch.where(
|
52 |
+
denominator > 0,
|
53 |
+
(2.0 * intersection) / denominator,
|
54 |
+
torch.tensor(float("1"), device=y_o.device),
|
55 |
+
)
|
@@ -0,0 +1,538 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from torch import nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
import torch
|
5 |
+
from einops import rearrange
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
###########Resnet Block############
|
9 |
+
def normalization(planes, norm = 'instance'):
|
10 |
+
if norm == 'bn':
|
11 |
+
m = nn.BatchNorm3d(planes)
|
12 |
+
elif norm == 'gn':
|
13 |
+
m = nn.GroupNorm(8, planes)
|
14 |
+
elif norm == 'instance':
|
15 |
+
m = nn.InstanceNorm3d(planes)
|
16 |
+
else:
|
17 |
+
raise ValueError("Does not support this kind of norm.")
|
18 |
+
return m
|
19 |
+
class ResNetBlock(nn.Module):
|
20 |
+
def __init__(self, in_channels, norm = 'instance'):
|
21 |
+
super().__init__()
|
22 |
+
self.resnetblock = nn.Sequential(
|
23 |
+
normalization(in_channels, norm = norm),
|
24 |
+
nn.LeakyReLU(0.2, inplace=True),
|
25 |
+
nn.Conv3d(in_channels, in_channels, kernel_size = 3, padding = 1),
|
26 |
+
normalization(in_channels, norm = norm),
|
27 |
+
nn.LeakyReLU(0.2, inplace=True),
|
28 |
+
nn.Conv3d(in_channels, in_channels, kernel_size = 3, padding = 1)
|
29 |
+
)
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
y = self.resnetblock(x)
|
33 |
+
return y + x
|
34 |
+
|
35 |
+
|
36 |
+
##############VAE###############
|
37 |
+
def calculate_total_dimension(a):
|
38 |
+
res = 1
|
39 |
+
for x in a:
|
40 |
+
res *= x
|
41 |
+
return res
|
42 |
+
|
43 |
+
class VAE(nn.Module):
|
44 |
+
def __init__(self, input_shape, latent_dim, num_channels):
|
45 |
+
super().__init__()
|
46 |
+
self.input_shape = input_shape
|
47 |
+
self.in_channels = input_shape[1] #input_shape[0] is batch size
|
48 |
+
self.latent_dim = latent_dim
|
49 |
+
self.encoder_channels = self.in_channels // 16
|
50 |
+
|
51 |
+
#Encoder
|
52 |
+
self.VAE_reshape = nn.Conv3d(self.in_channels, self.encoder_channels,
|
53 |
+
kernel_size = 3, stride = 2, padding=1)
|
54 |
+
# self.VAE_reshape = nn.Sequential(
|
55 |
+
# nn.GroupNorm(8, self.in_channels),
|
56 |
+
# nn.ReLU(),
|
57 |
+
# nn.Conv3d(self.in_channels, self.encoder_channels,
|
58 |
+
# kernel_size = 3, stride = 2, padding=1),
|
59 |
+
# )
|
60 |
+
|
61 |
+
flatten_input_shape = calculate_total_dimension(input_shape)
|
62 |
+
flatten_input_shape_after_vae_reshape = \
|
63 |
+
flatten_input_shape * self.encoder_channels // (8 * self.in_channels)
|
64 |
+
|
65 |
+
#Convert from total dimension to latent space
|
66 |
+
self.to_latent_space = nn.Linear(
|
67 |
+
flatten_input_shape_after_vae_reshape // self.in_channels, 1)
|
68 |
+
|
69 |
+
self.mean = nn.Linear(self.in_channels, self.latent_dim)
|
70 |
+
self.logvar = nn.Linear(self.in_channels, self.latent_dim)
|
71 |
+
# self.epsilon = nn.Parameter(torch.randn(1, latent_dim))
|
72 |
+
|
73 |
+
#Decoder
|
74 |
+
self.to_original_dimension = nn.Linear(self.latent_dim, flatten_input_shape_after_vae_reshape)
|
75 |
+
self.Reconstruct = nn.Sequential(
|
76 |
+
nn.LeakyReLU(0.2, inplace=True),
|
77 |
+
nn.Conv3d(
|
78 |
+
self.encoder_channels, self.in_channels,
|
79 |
+
stride = 1, kernel_size = 1),
|
80 |
+
nn.Upsample(scale_factor=2, mode = 'nearest'),
|
81 |
+
|
82 |
+
nn.Conv3d(
|
83 |
+
self.in_channels, self.in_channels // 2,
|
84 |
+
stride = 1, kernel_size = 1),
|
85 |
+
nn.Upsample(scale_factor=2, mode = 'nearest'),
|
86 |
+
ResNetBlock(self.in_channels // 2),
|
87 |
+
|
88 |
+
nn.Conv3d(
|
89 |
+
self.in_channels // 2, self.in_channels // 4,
|
90 |
+
stride = 1, kernel_size = 1),
|
91 |
+
nn.Upsample(scale_factor=2, mode = 'nearest'),
|
92 |
+
ResNetBlock(self.in_channels // 4),
|
93 |
+
|
94 |
+
nn.Conv3d(
|
95 |
+
self.in_channels // 4, self.in_channels // 8,
|
96 |
+
stride = 1, kernel_size = 1),
|
97 |
+
nn.Upsample(scale_factor=2, mode = 'nearest'),
|
98 |
+
ResNetBlock(self.in_channels // 8),
|
99 |
+
|
100 |
+
nn.InstanceNorm3d(self.in_channels // 8),
|
101 |
+
nn.LeakyReLU(0.2, inplace=True),
|
102 |
+
nn.Conv3d(
|
103 |
+
self.in_channels // 8, num_channels,
|
104 |
+
kernel_size = 3, padding = 1),
|
105 |
+
# nn.Sigmoid()
|
106 |
+
)
|
107 |
+
|
108 |
+
|
109 |
+
def forward(self, x): #x has shape = input_shape
|
110 |
+
#Encoder
|
111 |
+
# print(x.shape)
|
112 |
+
x = self.VAE_reshape(x)
|
113 |
+
shape = x.shape
|
114 |
+
|
115 |
+
x = x.view(self.in_channels, -1)
|
116 |
+
x = self.to_latent_space(x)
|
117 |
+
x = x.view(1, self.in_channels)
|
118 |
+
|
119 |
+
mean = self.mean(x)
|
120 |
+
logvar = self.logvar(x)
|
121 |
+
# sigma = torch.exp(0.5 * logvar)
|
122 |
+
# Reparameter
|
123 |
+
epsilon = torch.randn_like(logvar)
|
124 |
+
sample = mean + epsilon * torch.exp(0.5*logvar)
|
125 |
+
|
126 |
+
#Decoder
|
127 |
+
y = self.to_original_dimension(sample)
|
128 |
+
y = y.view(*shape)
|
129 |
+
return self.Reconstruct(y), mean, logvar
|
130 |
+
def total_params(self):
|
131 |
+
total = sum(p.numel() for p in self.parameters())
|
132 |
+
return format(total, ',')
|
133 |
+
|
134 |
+
def total_trainable_params(self):
|
135 |
+
total_trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
136 |
+
return format(total_trainable, ',')
|
137 |
+
|
138 |
+
|
139 |
+
# x = torch.rand((1, 256, 16, 16, 16))
|
140 |
+
# vae = VAE(input_shape = x.shape, latent_dim = 256, num_channels = 4)
|
141 |
+
# y = vae(x)
|
142 |
+
# print(y[0].shape, y[1].shape, y[2].shape)
|
143 |
+
# print(vae.total_trainable_params())
|
144 |
+
|
145 |
+
|
146 |
+
### Decoder ####
|
147 |
+
|
148 |
+
|
149 |
+
|
150 |
+
class Upsample(nn.Module):
|
151 |
+
def __init__(self, in_channel, out_channel):
|
152 |
+
super().__init__()
|
153 |
+
self.conv1 = nn.Conv3d(in_channel, out_channel, kernel_size = 1)
|
154 |
+
self.deconv = nn.ConvTranspose3d(out_channel, out_channel, kernel_size = 2, stride = 2)
|
155 |
+
self.conv2 = nn.Conv3d(out_channel * 2, out_channel, kernel_size = 1)
|
156 |
+
|
157 |
+
def forward(self, prev, x):
|
158 |
+
x = self.deconv(self.conv1(x))
|
159 |
+
y = torch.cat((prev, x), dim = 1)
|
160 |
+
return self.conv2(y)
|
161 |
+
|
162 |
+
class FinalConv(nn.Module): # Input channels are equal to output channels
|
163 |
+
def __init__(self, in_channels, out_channels=32, norm="instance"):
|
164 |
+
super(FinalConv, self).__init__()
|
165 |
+
if norm == "batch":
|
166 |
+
norm_layer = nn.BatchNorm3d(num_features=in_channels)
|
167 |
+
elif norm == "group":
|
168 |
+
norm_layer = nn.GroupNorm(num_groups=8, num_channels=in_channels)
|
169 |
+
elif norm == 'instance':
|
170 |
+
norm_layer = nn.InstanceNorm3d(in_channels)
|
171 |
+
|
172 |
+
self.layer = nn.Sequential(
|
173 |
+
norm_layer,
|
174 |
+
nn.LeakyReLU(0.2, inplace=True),
|
175 |
+
nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
|
176 |
+
)
|
177 |
+
def forward(self, x):
|
178 |
+
return self.layer(x)
|
179 |
+
|
180 |
+
class Decoder(nn.Module):
|
181 |
+
def __init__(self, img_dim, patch_dim, embedding_dim, num_classes = 3):
|
182 |
+
super().__init__()
|
183 |
+
self.img_dim = img_dim
|
184 |
+
self.patch_dim = patch_dim
|
185 |
+
self.embedding_dim = embedding_dim
|
186 |
+
|
187 |
+
self.decoder_upsample_1 = Upsample(128, 64)
|
188 |
+
self.decoder_block_1 = ResNetBlock(64)
|
189 |
+
|
190 |
+
self.decoder_upsample_2 = Upsample(64, 32)
|
191 |
+
self.decoder_block_2 = ResNetBlock(32)
|
192 |
+
|
193 |
+
self.decoder_upsample_3 = Upsample(32, 16)
|
194 |
+
self.decoder_block_3 = ResNetBlock(16)
|
195 |
+
|
196 |
+
self.endconv = FinalConv(16, num_classes)
|
197 |
+
# self.normalize = nn.Sigmoid()
|
198 |
+
|
199 |
+
def forward(self, x1, x2, x3, x):
|
200 |
+
x = self.decoder_upsample_1(x3, x)
|
201 |
+
x = self.decoder_block_1(x)
|
202 |
+
|
203 |
+
x = self.decoder_upsample_2(x2, x)
|
204 |
+
x = self.decoder_block_2(x)
|
205 |
+
|
206 |
+
x = self.decoder_upsample_3(x1, x)
|
207 |
+
x = self.decoder_block_3(x)
|
208 |
+
|
209 |
+
y = self.endconv(x)
|
210 |
+
return y
|
211 |
+
|
212 |
+
|
213 |
+
|
214 |
+
###############Encoder##############
|
215 |
+
class InitConv(nn.Module):
|
216 |
+
def __init__(self, in_channels = 4, out_channels = 16, dropout = 0.2):
|
217 |
+
super().__init__()
|
218 |
+
self.layer = nn.Sequential(
|
219 |
+
nn.Conv3d(in_channels, out_channels, kernel_size = 3, padding = 1),
|
220 |
+
nn.Dropout3d(dropout)
|
221 |
+
)
|
222 |
+
def forward(self, x):
|
223 |
+
y = self.layer(x)
|
224 |
+
return y
|
225 |
+
|
226 |
+
|
227 |
+
class DownSample(nn.Module):
|
228 |
+
def __init__(self, in_channels, out_channels):
|
229 |
+
super().__init__()
|
230 |
+
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size = 3, stride = 2, padding = 1)
|
231 |
+
def forward(self, x):
|
232 |
+
return self.conv(x)
|
233 |
+
|
234 |
+
class Encoder(nn.Module):
|
235 |
+
def __init__(self, in_channels, base_channels, dropout = 0.2):
|
236 |
+
super().__init__()
|
237 |
+
|
238 |
+
self.init_conv = InitConv(in_channels, base_channels, dropout = dropout)
|
239 |
+
self.encoder_block1 = ResNetBlock(in_channels = base_channels)
|
240 |
+
self.encoder_down1 = DownSample(base_channels, base_channels * 2)
|
241 |
+
|
242 |
+
self.encoder_block2_1 = ResNetBlock(base_channels * 2)
|
243 |
+
self.encoder_block2_2 = ResNetBlock(base_channels * 2)
|
244 |
+
self.encoder_down2 = DownSample(base_channels * 2, base_channels * 4)
|
245 |
+
|
246 |
+
self.encoder_block3_1 = ResNetBlock(base_channels * 4)
|
247 |
+
self.encoder_block3_2 = ResNetBlock(base_channels * 4)
|
248 |
+
self.encoder_down3 = DownSample(base_channels * 4, base_channels * 8)
|
249 |
+
|
250 |
+
self.encoder_block4_1 = ResNetBlock(base_channels * 8)
|
251 |
+
self.encoder_block4_2 = ResNetBlock(base_channels * 8)
|
252 |
+
self.encoder_block4_3 = ResNetBlock(base_channels * 8)
|
253 |
+
self.encoder_block4_4 = ResNetBlock(base_channels * 8)
|
254 |
+
# self.encoder_down3 = EncoderDown(base_channels * 8, base_channels * 16)
|
255 |
+
def forward(self, x):
|
256 |
+
x = self.init_conv(x) #(1, 16, 128, 128, 128)
|
257 |
+
|
258 |
+
x1 = self.encoder_block1(x)
|
259 |
+
x1_down = self.encoder_down1(x1) #(1, 32, 64, 64, 64)
|
260 |
+
|
261 |
+
x2 = self.encoder_block2_2(self.encoder_block2_1(x1_down))
|
262 |
+
x2_down = self.encoder_down2(x2) #(1, 64, 32, 32, 32)
|
263 |
+
|
264 |
+
x3 = self.encoder_block3_2(self.encoder_block3_1(x2_down))
|
265 |
+
x3_down = self.encoder_down3(x3) #(1, 128, 16, 16, 16)
|
266 |
+
|
267 |
+
output = self.encoder_block4_4(
|
268 |
+
self.encoder_block4_3(
|
269 |
+
self.encoder_block4_2(
|
270 |
+
self.encoder_block4_1(x3_down)))) #(1, 256, 16, 16, 16)
|
271 |
+
return x1, x2, x3, output
|
272 |
+
|
273 |
+
# x = torch.rand((1, 4, 128, 128, 128))
|
274 |
+
# Enc = Encoder(4, 32)
|
275 |
+
# _, _, _, y = Enc(x)
|
276 |
+
# print(y.shape) (1,256,16,16)
|
277 |
+
|
278 |
+
|
279 |
+
###############FeatureMapping###############
|
280 |
+
|
281 |
+
class FeatureMapping(nn.Module):
|
282 |
+
def __init__(self, in_channel, out_channel, norm = 'instance'):
|
283 |
+
super().__init__()
|
284 |
+
if norm == 'bn':
|
285 |
+
norm_layer_1 = nn.BatchNorm3d(out_channel)
|
286 |
+
norm_layer_2 = nn.BatchNorm3d(out_channel)
|
287 |
+
elif norm == 'gn':
|
288 |
+
norm_layer_1 = nn.GroupNorm(8, out_channel)
|
289 |
+
norm_layer_2 = nn.GroupNorm(8, out_channel)
|
290 |
+
elif norm == 'instance':
|
291 |
+
norm_layer_1 = nn.InstanceNorm3d(out_channel)
|
292 |
+
norm_layer_2 = nn.InstanceNorm3d(out_channel)
|
293 |
+
self.feature_mapping = nn.Sequential(
|
294 |
+
nn.Conv3d(in_channel, out_channel, kernel_size = 3, padding = 1),
|
295 |
+
norm_layer_1,
|
296 |
+
nn.LeakyReLU(0.2, inplace=True),
|
297 |
+
nn.Conv3d(out_channel, out_channel, kernel_size = 3, padding = 1),
|
298 |
+
norm_layer_2,
|
299 |
+
nn.LeakyReLU(0.2, inplace=True)
|
300 |
+
)
|
301 |
+
|
302 |
+
def forward(self, x):
|
303 |
+
return self.feature_mapping(x)
|
304 |
+
|
305 |
+
|
306 |
+
class FeatureMapping1(nn.Module):
|
307 |
+
def __init__(self, in_channel, norm = 'instance'):
|
308 |
+
super().__init__()
|
309 |
+
if norm == 'bn':
|
310 |
+
norm_layer_1 = nn.BatchNorm3d(in_channel)
|
311 |
+
norm_layer_2 = nn.BatchNorm3d(in_channel)
|
312 |
+
elif norm == 'gn':
|
313 |
+
norm_layer_1 = nn.GroupNorm(8, in_channel)
|
314 |
+
norm_layer_2 = nn.GroupNorm(8, in_channel)
|
315 |
+
elif norm == 'instance':
|
316 |
+
norm_layer_1 = nn.InstanceNorm3d(in_channel)
|
317 |
+
norm_layer_2 = nn.InstanceNorm3d(in_channel)
|
318 |
+
self.feature_mapping1 = nn.Sequential(
|
319 |
+
nn.Conv3d(in_channel, in_channel, kernel_size = 3, padding = 1),
|
320 |
+
norm_layer_1,
|
321 |
+
nn.LeakyReLU(0.2, inplace=True),
|
322 |
+
nn.Conv3d(in_channel, in_channel, kernel_size = 3, padding = 1),
|
323 |
+
norm_layer_2,
|
324 |
+
nn.LeakyReLU(0.2, inplace=True)
|
325 |
+
)
|
326 |
+
def forward(self, x):
|
327 |
+
y = self.feature_mapping1(x)
|
328 |
+
return x + y #Resnet Like
|
329 |
+
|
330 |
+
################Transformer#######################
|
331 |
+
|
332 |
+
|
333 |
+
def pair(t):
|
334 |
+
return t if isinstance(t, tuple) else (t, t)
|
335 |
+
|
336 |
+
|
337 |
+
class PreNorm(nn.Module):
|
338 |
+
def __init__(self, dim, function):
|
339 |
+
super().__init__()
|
340 |
+
self.norm = nn.LayerNorm(dim)
|
341 |
+
self.function = function
|
342 |
+
|
343 |
+
def forward(self, x):
|
344 |
+
return self.function(self.norm(x))
|
345 |
+
|
346 |
+
|
347 |
+
class FeedForward(nn.Module):
|
348 |
+
def __init__(self, dim, hidden_dim, dropout = 0.0):
|
349 |
+
super().__init__()
|
350 |
+
self.net = nn.Sequential(
|
351 |
+
nn.Linear(dim, hidden_dim),
|
352 |
+
nn.GELU(),
|
353 |
+
nn.Dropout(dropout),
|
354 |
+
nn.Linear(hidden_dim, dim),
|
355 |
+
nn.Dropout(dropout)
|
356 |
+
)
|
357 |
+
|
358 |
+
def forward(self, x):
|
359 |
+
return self.net(x)
|
360 |
+
|
361 |
+
class Attention(nn.Module):
|
362 |
+
def __init__(self, dim, heads, dim_head, dropout = 0.0):
|
363 |
+
super().__init__()
|
364 |
+
all_head_size = heads * dim_head
|
365 |
+
project_out = not (heads == 1 and dim_head == dim)
|
366 |
+
|
367 |
+
self.heads = heads
|
368 |
+
self.scale = dim_head ** -0.5
|
369 |
+
|
370 |
+
self.softmax = nn.Softmax(dim = -1)
|
371 |
+
self.to_qkv = nn.Linear(dim, all_head_size * 3, bias = False)
|
372 |
+
|
373 |
+
self.to_out = nn.Sequential(
|
374 |
+
nn.Linear(all_head_size, dim),
|
375 |
+
nn.Dropout(dropout)
|
376 |
+
) if project_out else nn.Identity()
|
377 |
+
|
378 |
+
def forward(self, x):
|
379 |
+
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
380 |
+
#(batch, heads * dim_head) -> (batch, all_head_size)
|
381 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
382 |
+
|
383 |
+
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
384 |
+
|
385 |
+
atten = self.softmax(dots)
|
386 |
+
|
387 |
+
out = torch.matmul(atten, v)
|
388 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
389 |
+
return self.to_out(out)
|
390 |
+
|
391 |
+
class Transformer(nn.Module):
|
392 |
+
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.0):
|
393 |
+
super().__init__()
|
394 |
+
self.layers = nn.ModuleList([])
|
395 |
+
for _ in range(depth):
|
396 |
+
self.layers.append(nn.ModuleList([
|
397 |
+
PreNorm(dim, Attention(dim, heads, dim_head, dropout)),
|
398 |
+
PreNorm(dim, FeedForward(dim, mlp_dim, dropout))
|
399 |
+
]))
|
400 |
+
def forward(self, x):
|
401 |
+
for attention, feedforward in self.layers:
|
402 |
+
x = attention(x) + x
|
403 |
+
x = feedforward(x) + x
|
404 |
+
return x
|
405 |
+
|
406 |
+
class FixedPositionalEncoding(nn.Module):
|
407 |
+
def __init__(self, embedding_dim, max_length=768):
|
408 |
+
super(FixedPositionalEncoding, self).__init__()
|
409 |
+
|
410 |
+
pe = torch.zeros(max_length, embedding_dim)
|
411 |
+
position = torch.arange(0, max_length, dtype=torch.float).unsqueeze(1)
|
412 |
+
div_term = torch.exp(
|
413 |
+
torch.arange(0, embedding_dim, 2).float()
|
414 |
+
* (-torch.log(torch.tensor(10000.0)) / embedding_dim)
|
415 |
+
)
|
416 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
417 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
418 |
+
pe = pe.unsqueeze(0).transpose(0, 1)
|
419 |
+
self.register_buffer('pe', pe)
|
420 |
+
|
421 |
+
def forward(self, x):
|
422 |
+
x = x + self.pe[: x.size(0), :]
|
423 |
+
return x
|
424 |
+
|
425 |
+
|
426 |
+
class LearnedPositionalEncoding(nn.Module):
|
427 |
+
def __init__(self, embedding_dim, seq_length):
|
428 |
+
super(LearnedPositionalEncoding, self).__init__()
|
429 |
+
self.seq_length = seq_length
|
430 |
+
self.position_embeddings = nn.Parameter(torch.zeros(1, seq_length, embedding_dim)) #8x
|
431 |
+
|
432 |
+
def forward(self, x, position_ids=None):
|
433 |
+
position_embeddings = self.position_embeddings
|
434 |
+
# print(x.shape, self.position_embeddings.shape)
|
435 |
+
return x + position_embeddings
|
436 |
+
|
437 |
+
|
438 |
+
|
439 |
+
|
440 |
+
|
441 |
+
###############Main model#################
|
442 |
+
|
443 |
+
class SegTransVAE(nn.Module):
|
444 |
+
def __init__(self, img_dim, patch_dim, num_channels, num_classes,
|
445 |
+
embedding_dim, num_heads, num_layers, hidden_dim, in_channels_vae,
|
446 |
+
dropout = 0.0, attention_dropout = 0.0,
|
447 |
+
conv_patch_representation = True, positional_encoding = 'learned',
|
448 |
+
use_VAE = False):
|
449 |
+
|
450 |
+
super().__init__()
|
451 |
+
assert embedding_dim % num_heads == 0
|
452 |
+
assert img_dim[0] % patch_dim == 0 and img_dim[1] % patch_dim == 0 and img_dim[2] % patch_dim == 0
|
453 |
+
|
454 |
+
self.img_dim = img_dim
|
455 |
+
self.embedding_dim = embedding_dim
|
456 |
+
self.num_heads = num_heads
|
457 |
+
self.num_classes = num_classes
|
458 |
+
self.patch_dim = patch_dim
|
459 |
+
self.num_channels = num_channels
|
460 |
+
self.in_channels_vae = in_channels_vae
|
461 |
+
self.dropout = dropout
|
462 |
+
self.attention_dropout = attention_dropout
|
463 |
+
self.conv_patch_representation = conv_patch_representation
|
464 |
+
self.use_VAE = use_VAE
|
465 |
+
|
466 |
+
self.num_patches = int((img_dim[0] // patch_dim) * (img_dim[1] // patch_dim) * (img_dim[2] // patch_dim))
|
467 |
+
self.seq_length = self.num_patches
|
468 |
+
self.flatten_dim = 128 * num_channels
|
469 |
+
|
470 |
+
self.linear_encoding = nn.Linear(self.flatten_dim, self.embedding_dim)
|
471 |
+
if positional_encoding == "learned":
|
472 |
+
self.position_encoding = LearnedPositionalEncoding(
|
473 |
+
self.embedding_dim, self.seq_length
|
474 |
+
)
|
475 |
+
elif positional_encoding == "fixed":
|
476 |
+
self.position_encoding = FixedPositionalEncoding(
|
477 |
+
self.embedding_dim,
|
478 |
+
)
|
479 |
+
self.pe_dropout = nn.Dropout(self.dropout)
|
480 |
+
|
481 |
+
self.transformer = Transformer(
|
482 |
+
embedding_dim, num_layers, num_heads, embedding_dim // num_heads, hidden_dim, dropout
|
483 |
+
)
|
484 |
+
self.pre_head_ln = nn.LayerNorm(embedding_dim)
|
485 |
+
|
486 |
+
if self.conv_patch_representation:
|
487 |
+
self.conv_x = nn.Conv3d(128, self.embedding_dim, kernel_size=3, stride=1, padding=1)
|
488 |
+
self.encoder = Encoder(self.num_channels, 16)
|
489 |
+
self.bn = nn.InstanceNorm3d(128)
|
490 |
+
self.relu = nn.LeakyReLU(0.2, inplace=True)
|
491 |
+
self.FeatureMapping = FeatureMapping(in_channel = self.embedding_dim, out_channel= self.in_channels_vae)
|
492 |
+
self.FeatureMapping1 = FeatureMapping1(in_channel = self.in_channels_vae)
|
493 |
+
self.decoder = Decoder(self.img_dim, self.patch_dim, self.embedding_dim, num_classes)
|
494 |
+
|
495 |
+
self.vae_input = (1, self.in_channels_vae, img_dim[0] // 8, img_dim[1] // 8, img_dim[2] // 8)
|
496 |
+
if use_VAE:
|
497 |
+
self.vae = VAE(input_shape = self.vae_input , latent_dim= 256, num_channels= self.num_channels)
|
498 |
+
def encode(self, x):
|
499 |
+
if self.conv_patch_representation:
|
500 |
+
x1, x2, x3, x = self.encoder(x)
|
501 |
+
x = self.bn(x)
|
502 |
+
x = self.relu(x)
|
503 |
+
x = self.conv_x(x)
|
504 |
+
x = x.permute(0, 2, 3, 4, 1).contiguous()
|
505 |
+
x = x.view(x.size(0), -1, self.embedding_dim)
|
506 |
+
x = self.position_encoding(x)
|
507 |
+
x = self.pe_dropout(x)
|
508 |
+
x = self.transformer(x)
|
509 |
+
x = self.pre_head_ln(x)
|
510 |
+
|
511 |
+
return x1, x2, x3, x
|
512 |
+
|
513 |
+
def decode(self, x1, x2, x3, x):
|
514 |
+
#x: (1, 4096, 512) -> (1, 16, 16, 16, 512)
|
515 |
+
# print("In decode...")
|
516 |
+
# print(" x1: {} \n x2: {} \n x3: {} \n x: {}".format( x1.shape, x2.shape, x3.shape, x.shape))
|
517 |
+
# break
|
518 |
+
return self.decoder(x1, x2, x3, x)
|
519 |
+
|
520 |
+
def forward(self, x, is_validation = True):
|
521 |
+
x1, x2, x3, x = self.encode(x)
|
522 |
+
x = x.view( x.size(0),
|
523 |
+
self.img_dim[0] // self.patch_dim,
|
524 |
+
self.img_dim[1] // self.patch_dim,
|
525 |
+
self.img_dim[2] // self.patch_dim,
|
526 |
+
self.embedding_dim)
|
527 |
+
x = x.permute(0, 4, 1, 2, 3).contiguous()
|
528 |
+
x = self.FeatureMapping(x)
|
529 |
+
x = self.FeatureMapping1(x)
|
530 |
+
if self.use_VAE and not is_validation:
|
531 |
+
vae_out, mu, sigma = self.vae(x)
|
532 |
+
y = self.decode(x1, x2, x3, x)
|
533 |
+
if self.use_VAE and not is_validation:
|
534 |
+
return y, vae_out, mu, sigma
|
535 |
+
else:
|
536 |
+
return y
|
537 |
+
|
538 |
+
|
File without changes
|
Binary file (16.3 kB). View file
|
|
Binary file (137 Bytes). View file
|
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import os
|
3 |
+
from monai.utils import set_determinism
|
4 |
+
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
5 |
+
import os
|
6 |
+
from pytorch_lightning.loggers import TensorBoardLogger
|
7 |
+
from trainer import BRATS
|
8 |
+
from dataset.utils import get_loader
|
9 |
+
import pytorch_lightning as pl
|
10 |
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
11 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
12 |
+
|
13 |
+
set_determinism(seed=0)
|
14 |
+
|
15 |
+
os.system('cls||clear')
|
16 |
+
print("Training ...")
|
17 |
+
|
18 |
+
data_dir = "/app/brats_2021_task1"
|
19 |
+
json_list = "/app/info.json"
|
20 |
+
roi = (128, 128, 128)
|
21 |
+
batch_size = 1
|
22 |
+
fold = 1
|
23 |
+
max_epochs = 500
|
24 |
+
val_every = 10
|
25 |
+
train_loader, val_loader,test_loader = get_loader(batch_size, data_dir, json_list, fold, roi, volume=1, test_size=0.2)
|
26 |
+
print("Done initialize dataloader !! ")
|
27 |
+
|
28 |
+
model = BRATS(use_VAE = True, train_loader = train_loader,val_loader = val_loader, test_loader=test_loader )
|
29 |
+
checkpoint_callback = ModelCheckpoint(
|
30 |
+
monitor='val/MeanDiceScore',
|
31 |
+
dirpath='./checkpoints/{}'.format("SegTransVAE"),
|
32 |
+
filename='Epoch{epoch:3d}-MeanDiceScore{val/MeanDiceScore:.4f}',
|
33 |
+
save_top_k=3,
|
34 |
+
mode='max',
|
35 |
+
save_last= True,
|
36 |
+
auto_insert_metric_name=False
|
37 |
+
)
|
38 |
+
early_stop_callback = EarlyStopping(
|
39 |
+
monitor='val/MeanDiceScore',
|
40 |
+
min_delta=0.0001,
|
41 |
+
patience=15,
|
42 |
+
verbose=False,
|
43 |
+
mode='max'
|
44 |
+
)
|
45 |
+
tensorboardlogger = TensorBoardLogger(
|
46 |
+
'logs',
|
47 |
+
name = "SegTransVAE",
|
48 |
+
default_hp_metric = None
|
49 |
+
)
|
50 |
+
trainer = pl.Trainer(#fast_dev_run = 10,
|
51 |
+
# accelerator='ddp',
|
52 |
+
#overfit_batches=5,
|
53 |
+
devices = [0],
|
54 |
+
precision=16,
|
55 |
+
max_epochs = max_epochs,
|
56 |
+
enable_progress_bar=True,
|
57 |
+
callbacks=[checkpoint_callback, early_stop_callback],
|
58 |
+
# auto_lr_find=True,
|
59 |
+
num_sanity_val_steps=1,
|
60 |
+
logger = tensorboardlogger,
|
61 |
+
check_val_every_n_epoch = 10,
|
62 |
+
# limit_train_batches=0.01,
|
63 |
+
# limit_val_batches=0.01
|
64 |
+
)
|
65 |
+
# trainer.tune(model)
|
66 |
+
trainer.fit(model)
|
67 |
+
|
68 |
+
|
69 |
+
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pytorch_lightning as pl
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import csv
|
5 |
+
import torch
|
6 |
+
from monai.transforms import AsDiscrete, Activations, Compose, EnsureType
|
7 |
+
from models.SegTranVAE.SegTranVAE import SegTransVAE
|
8 |
+
from loss.loss import Loss_VAE, DiceScore
|
9 |
+
from monai.losses import DiceLoss
|
10 |
+
import pytorch_lightning as pl
|
11 |
+
from monai.inferers import sliding_window_inference
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
class BRATS(pl.LightningModule):
|
18 |
+
def __init__(self,train_loader,val_loader,test_loader, use_VAE = True, lr = 1e-4 ):
|
19 |
+
super().__init__()
|
20 |
+
self.train_loader = train_loader
|
21 |
+
self.val_loader = val_loader
|
22 |
+
self.test_loader = test_loader
|
23 |
+
self.use_vae = use_VAE
|
24 |
+
self.lr = lr
|
25 |
+
self.model = SegTransVAE((128, 128, 128), 8, 4, 3, 768, 8, 4, 3072, in_channels_vae=128, use_VAE = use_VAE)
|
26 |
+
|
27 |
+
self.loss_vae = Loss_VAE()
|
28 |
+
self.dice_loss = DiceLoss(to_onehot_y=False, sigmoid=True, squared_pred=True)
|
29 |
+
self.post_trans_images = Compose(
|
30 |
+
[EnsureType(),
|
31 |
+
Activations(sigmoid=True),
|
32 |
+
AsDiscrete(threshold_values=True),
|
33 |
+
]
|
34 |
+
)
|
35 |
+
|
36 |
+
self.best_val_dice = 0
|
37 |
+
|
38 |
+
self.training_step_outputs = []
|
39 |
+
self.val_step_loss = []
|
40 |
+
self.val_step_dice = []
|
41 |
+
self.val_step_dice_tc = []
|
42 |
+
self.val_step_dice_wt = []
|
43 |
+
self.val_step_dice_et = []
|
44 |
+
self.test_step_loss = []
|
45 |
+
self.test_step_dice = []
|
46 |
+
self.test_step_dice_tc = []
|
47 |
+
self.test_step_dice_wt = []
|
48 |
+
self.test_step_dice_et = []
|
49 |
+
|
50 |
+
def forward(self, x, is_validation = True):
|
51 |
+
return self.model(x, is_validation)
|
52 |
+
def training_step(self, batch, batch_index):
|
53 |
+
inputs, labels = (batch['image'], batch['label'])
|
54 |
+
|
55 |
+
if not self.use_vae:
|
56 |
+
outputs = self.forward(inputs, is_validation=False)
|
57 |
+
loss = self.dice_loss(outputs, labels)
|
58 |
+
else:
|
59 |
+
outputs, recon_batch, mu, sigma = self.forward(inputs, is_validation=False)
|
60 |
+
|
61 |
+
vae_loss = self.loss_vae(recon_batch, inputs, mu, sigma)
|
62 |
+
dice_loss = self.dice_loss(outputs, labels)
|
63 |
+
loss = dice_loss + 1/(4 * 128 * 128 * 128) * vae_loss
|
64 |
+
self.training_step_outputs.append(loss)
|
65 |
+
self.log('train/vae_loss', vae_loss)
|
66 |
+
self.log('train/dice_loss', dice_loss)
|
67 |
+
if batch_index == 10:
|
68 |
+
|
69 |
+
tensorboard = self.logger.experiment
|
70 |
+
fig, ax = plt.subplots(nrows=1, ncols=6, figsize=(10, 5))
|
71 |
+
|
72 |
+
|
73 |
+
ax[0].imshow(inputs.detach().cpu()[0][0][:, :, 80], cmap='gray')
|
74 |
+
ax[0].set_title("Input")
|
75 |
+
|
76 |
+
ax[1].imshow(recon_batch.detach().cpu().float()[0][0][:,:, 80], cmap='gray')
|
77 |
+
ax[1].set_title("Reconstruction")
|
78 |
+
|
79 |
+
ax[2].imshow(labels.detach().cpu().float()[0][0][:,:, 80], cmap='gray')
|
80 |
+
ax[2].set_title("Labels TC")
|
81 |
+
|
82 |
+
ax[3].imshow(outputs.sigmoid().detach().cpu().float()[0][0][:,:, 80], cmap='gray')
|
83 |
+
ax[3].set_title("TC")
|
84 |
+
|
85 |
+
ax[4].imshow(labels.detach().cpu().float()[0][2][:,:, 80], cmap='gray')
|
86 |
+
ax[4].set_title("Labels ET")
|
87 |
+
|
88 |
+
ax[5].imshow(outputs.sigmoid().detach().cpu().float()[0][2][:,:, 80], cmap='gray')
|
89 |
+
ax[5].set_title("ET")
|
90 |
+
|
91 |
+
|
92 |
+
tensorboard.add_figure('train_visualize', fig, self.current_epoch)
|
93 |
+
|
94 |
+
self.log('train/loss', loss)
|
95 |
+
|
96 |
+
return loss
|
97 |
+
|
98 |
+
def on_train_epoch_end(self):
|
99 |
+
## F1 Macro all epoch saving outputs and target per batch
|
100 |
+
|
101 |
+
# free up the memory
|
102 |
+
# --> HERE STEP 3 <--
|
103 |
+
epoch_average = torch.stack(self.training_step_outputs).mean()
|
104 |
+
self.log("training_epoch_average", epoch_average)
|
105 |
+
self.training_step_outputs.clear() # free memory
|
106 |
+
|
107 |
+
def validation_step(self, batch, batch_index):
|
108 |
+
inputs, labels = (batch['image'], batch['label'])
|
109 |
+
roi_size = (128, 128, 128)
|
110 |
+
sw_batch_size = 1
|
111 |
+
outputs = sliding_window_inference(
|
112 |
+
inputs, roi_size, sw_batch_size, self.model, overlap = 0.5)
|
113 |
+
loss = self.dice_loss(outputs, labels)
|
114 |
+
|
115 |
+
|
116 |
+
val_outputs = self.post_trans_images(outputs)
|
117 |
+
|
118 |
+
|
119 |
+
metric_tc = DiceScore(y_pred=val_outputs[:, 0:1], y=labels[:, 0:1], include_background = True)
|
120 |
+
metric_wt = DiceScore(y_pred=val_outputs[:, 1:2], y=labels[:, 1:2], include_background = True)
|
121 |
+
metric_et = DiceScore(y_pred=val_outputs[:, 2:3], y=labels[:, 2:3], include_background = True)
|
122 |
+
mean_val_dice = (metric_tc + metric_wt + metric_et)/3
|
123 |
+
self.val_step_loss.append(loss)
|
124 |
+
self.val_step_dice.append(mean_val_dice)
|
125 |
+
self.val_step_dice_tc.append(metric_tc)
|
126 |
+
self.val_step_dice_wt.append(metric_wt)
|
127 |
+
self.val_step_dice_et.append(metric_et)
|
128 |
+
return {'val_loss': loss, 'val_mean_dice': mean_val_dice, 'val_dice_tc': metric_tc,
|
129 |
+
'val_dice_wt': metric_wt, 'val_dice_et': metric_et}
|
130 |
+
|
131 |
+
def on_validation_epoch_end(self):
|
132 |
+
|
133 |
+
loss = torch.stack(self.val_step_loss).mean()
|
134 |
+
mean_val_dice = torch.stack(self.val_step_dice).mean()
|
135 |
+
metric_tc = torch.stack(self.val_step_dice_tc).mean()
|
136 |
+
metric_wt = torch.stack(self.val_step_dice_wt).mean()
|
137 |
+
metric_et = torch.stack(self.val_step_dice_et).mean()
|
138 |
+
self.log('val/Loss', loss)
|
139 |
+
self.log('val/MeanDiceScore', mean_val_dice)
|
140 |
+
self.log('val/DiceTC', metric_tc)
|
141 |
+
self.log('val/DiceWT', metric_wt)
|
142 |
+
self.log('val/DiceET', metric_et)
|
143 |
+
os.makedirs(self.logger.log_dir, exist_ok=True)
|
144 |
+
if self.current_epoch == 0:
|
145 |
+
with open('{}/metric_log.csv'.format(self.logger.log_dir), 'w') as f:
|
146 |
+
writer = csv.writer(f)
|
147 |
+
writer.writerow(['Epoch', 'Mean Dice Score', 'Dice TC', 'Dice WT', 'Dice ET'])
|
148 |
+
with open('{}/metric_log.csv'.format(self.logger.log_dir), 'a') as f:
|
149 |
+
writer = csv.writer(f)
|
150 |
+
writer.writerow([self.current_epoch, mean_val_dice.item(), metric_tc.item(), metric_wt.item(), metric_et.item()])
|
151 |
+
|
152 |
+
if mean_val_dice > self.best_val_dice:
|
153 |
+
self.best_val_dice = mean_val_dice
|
154 |
+
self.best_val_epoch = self.current_epoch
|
155 |
+
print(
|
156 |
+
f"\n Current epoch: {self.current_epoch} Current mean dice: {mean_val_dice:.4f}"
|
157 |
+
f" tc: {metric_tc:.4f} wt: {metric_wt:.4f} et: {metric_et:.4f}"
|
158 |
+
f"\n Best mean dice: {self.best_val_dice}"
|
159 |
+
f" at epoch: {self.best_val_epoch}"
|
160 |
+
)
|
161 |
+
|
162 |
+
self.val_step_loss.clear()
|
163 |
+
self.val_step_dice.clear()
|
164 |
+
self.val_step_dice_tc.clear()
|
165 |
+
self.val_step_dice_wt.clear()
|
166 |
+
self.val_step_dice_et.clear()
|
167 |
+
return {'val_MeanDiceScore': mean_val_dice}
|
168 |
+
def test_step(self, batch, batch_index):
|
169 |
+
inputs, labels = (batch['image'], batch['label'])
|
170 |
+
|
171 |
+
roi_size = (128, 128, 128)
|
172 |
+
sw_batch_size = 1
|
173 |
+
test_outputs = sliding_window_inference(
|
174 |
+
inputs, roi_size, sw_batch_size, self.forward, overlap = 0.5)
|
175 |
+
loss = self.dice_loss(test_outputs, labels)
|
176 |
+
test_outputs = self.post_trans_images(test_outputs)
|
177 |
+
metric_tc = DiceScore(y_pred=test_outputs[:, 0:1], y=labels[:, 0:1], include_background = True)
|
178 |
+
metric_wt = DiceScore(y_pred=test_outputs[:, 1:2], y=labels[:, 1:2], include_background = True)
|
179 |
+
metric_et = DiceScore(y_pred=test_outputs[:, 2:3], y=labels[:, 2:3], include_background = True)
|
180 |
+
mean_test_dice = (metric_tc + metric_wt + metric_et)/3
|
181 |
+
|
182 |
+
self.test_step_loss.append(loss)
|
183 |
+
self.test_step_dice.append(mean_test_dice)
|
184 |
+
self.test_step_dice_tc.append(metric_tc)
|
185 |
+
self.test_step_dice_wt.append(metric_wt)
|
186 |
+
self.test_step_dice_et.append(metric_et)
|
187 |
+
|
188 |
+
return {'test_loss': loss, 'test_mean_dice': mean_test_dice, 'test_dice_tc': metric_tc,
|
189 |
+
'test_dice_wt': metric_wt, 'test_dice_et': metric_et}
|
190 |
+
|
191 |
+
def test_epoch_end(self):
|
192 |
+
loss = torch.stack(self.test_step_loss).mean()
|
193 |
+
mean_test_dice = torch.stack(self.test_step_dice).mean()
|
194 |
+
metric_tc = torch.stack(self.test_step_dice_tc).mean()
|
195 |
+
metric_wt = torch.stack(self.test_step_dice_wt).mean()
|
196 |
+
metric_et = torch.stack(self.test_step_dice_et).mean()
|
197 |
+
self.log('test/Loss', loss)
|
198 |
+
self.log('test/MeanDiceScore', mean_test_dice)
|
199 |
+
self.log('test/DiceTC', metric_tc)
|
200 |
+
self.log('test/DiceWT', metric_wt)
|
201 |
+
self.log('test/DiceET', metric_et)
|
202 |
+
|
203 |
+
with open('{}/test_log.csv'.format(self.logger.log_dir), 'w') as f:
|
204 |
+
writer = csv.writer(f)
|
205 |
+
writer.writerow(["Mean Test Dice", "Dice TC", "Dice WT", "Dice ET"])
|
206 |
+
writer.writerow([mean_test_dice, metric_tc, metric_wt, metric_et])
|
207 |
+
|
208 |
+
self.test_step_loss.clear()
|
209 |
+
self.test_step_dice.clear()
|
210 |
+
self.test_step_dice_tc.clear()
|
211 |
+
self.test_step_dice_wt.clear()
|
212 |
+
self.test_step_dice_et.clear()
|
213 |
+
return {'test_MeanDiceScore': mean_test_dice}
|
214 |
+
|
215 |
+
|
216 |
+
def configure_optimizers(self):
|
217 |
+
optimizer = torch.optim.Adam(
|
218 |
+
self.model.parameters(), self.lr, weight_decay=1e-5, amsgrad=True
|
219 |
+
)
|
220 |
+
# optimizer = AdaBelief(self.model.parameters(),
|
221 |
+
# lr=self.lr, eps=1e-16,
|
222 |
+
# betas=(0.9,0.999), weight_decouple = True,
|
223 |
+
# rectify = False)
|
224 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 200)
|
225 |
+
return [optimizer], [scheduler]
|
226 |
+
|
227 |
+
def train_dataloader(self):
|
228 |
+
return self.train_loader
|
229 |
+
def val_dataloader(self):
|
230 |
+
return self.val_loader
|
231 |
+
|
232 |
+
def test_dataloader(self):
|
233 |
+
return self.test_loader
|
Binary file (6.15 kB). View file
|
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bb899a83627591e55cada00b2c6d5402199832b717c8b9f90bb550fe35d971ff
|
3 |
+
size 2532638
|
Binary file (57.9 kB). View file
|
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e92b72ee221624c36cf89ac826deceeee7097f46dd66a5d218d18b7916ebd67d
|
3 |
+
size 2332393
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:735c27fd7a17b1702875837bcc843eedc88ced6ac2cb0e73cdf995e3e64ba82f
|
3 |
+
size 2643179
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:955ff59d053e87153bb7c809235743ec904817727ec02c630f3141e191d6f452
|
3 |
+
size 2432699
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:208e4e8fadbdf2b1c1a87f8bad8854bc7d4becd604bc01eefad725d19b43c6ef
|
3 |
+
size 2331912
|
Binary file (78 kB). View file
|
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4b76c5e6326f0f89e2f6ee243473060390982fc83e9bceb27a4b94899b2b0df1
|
3 |
+
size 2170543
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4f49aa266cd907ce890bcc9c96f82534810f85bb98a76f8a092a5b529d3b6b6e
|
3 |
+
size 2486326
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b5c9da53d4a573fe37fa8a8075b8220d4dfe3bbc377c3e1943fbd3ef86d3b118
|
3 |
+
size 2303833
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8f68b6d8026f2c6ac097ea26a4dd727571025e26b69b485f9a1a84e244222721
|
3 |
+
size 2719582
|
Binary file (63.4 kB). View file
|
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:90997c21753a6d49bdb8629f5b2c6ddb61ef7c2f86f977b18001ce6c5e3161ed
|
3 |
+
size 2488450
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a211cd8c26dd8d396341399d0a13fb888d5df6252bd9c5aa01156680e97c5577
|
3 |
+
size 2834759
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5c440a9ef1ef3b494d74e647fab9e41eb48f1fc40c8eabffd55863f46779ed5d
|
3 |
+
size 2635293
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c5c3ed117922199afb26cae01b090400d6ca90f06825c12f5a5a3ae7ced098e3
|
3 |
+
size 2265964
|
Binary file (70.7 kB). View file
|
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:54f109386253cb60694dec35e10e1e8f2ee02a2671faa860ac212cce8997b434
|
3 |
+
size 2085481
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2e074611e7388f7f7fefb5bbe39abb09f16f0dbeb0a35fbd661a0bbcd14d4b5a
|
3 |
+
size 2323871
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8d1689213fb1bb46330f16d30489a9875e2b5e9f30c8c5659a6b689011bd69e9
|
3 |
+
size 2144371
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0f42d02ba6637b8d98c0c44869ccfa444935dedb8f5ada29a3db65d33c878c0b
|
3 |
+
size 2588071
|
Binary file (70.5 kB). View file
|
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5f9e9ff16cdffd375f31a969994e31725c329950e7c3e64789238afc3faadead
|
3 |
+
size 2386395
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:20847fc162fe9193fb2d5a08e7b9009d16603eae268f769bf6d7bbffb0d79c42
|
3 |
+
size 2705917
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0a18fb6644ca828ef43067264a631d297d5aeaad4ebae05bce4ea09c9c76f898
|
3 |
+
size 2479204
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c5dbec780a624127d829b4b7165ff6afe6258e9224a529abf8ddaf25474888c5
|
3 |
+
size 2452120
|
Binary file (45.5 kB). View file
|
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f46e84f3f81a10b11b994fca26ae7f2bf537a6dfc0cb60853c097243924e39e8
|
3 |
+
size 2397406
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3b92f907a585a13ae99b0c79f315d396a3dd4524f1a35001212509957e264abf
|
3 |
+
size 2639147
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:545a1519ae217731184caad17d6474e3d62736efbad11326b5c21e63124ace12
|
3 |
+
size 2368904
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0f407ffd3d48d678b63de7158fa566626a0ad917e3d523d7d6a3cdd6f7e99789
|
3 |
+
size 2233856
|