DuyTa commited on
Commit
44168e6
1 Parent(s): a84cb9d

Upload folder using huggingface_hub (#1)

Browse files

- d2f9d5b1696d723607b469880b3d5616ee5b225a5296662d3b494a9f93762c27 (864c14f7179bf9dcdf7c90dbed7a41ee5e839da2)
- 72712ee369aab2d20658d111064b5f0d7a1499fc75277a6606fda5c45e0621b4 (62ad3c18015dcb4dc60ff7e6e468ff15c0334876)
- 343785fe216bc214dd04cc3e9c10d02e7a94143f99508481cd37022497f4783f (ea980d56daed988b608e7f7f5d95c5c04d4ea697)
- 22e429e0f0ccd12e262eff7dd5f76de81fa895e252330ff2269d3b0a71d57919 (13fdcc6ad6afb851b5358aba34dcab1cf21d6767)
- 4dc2a277835db86d453a6e9dbd5e071290d13d21e944b285578b6730ed36adcd (de0a1295bf46d3cd023a5c387424e6e3de075548)
- 2aad947094aa55aba4cb28a3be8165f085f0bc0ea622bfc1455937cd076e3e3c (ca28d59b4827feb16274699cd99de89ece6aa1ab)
- 9e97593500c52ee9f4ebe3dbc3b2d3f8ccf2806573bc32848b7002ed330eabfe (f00e6a25ea16d066400536988df75e970f7b4b09)
- 6932ae5e409108b07a5f2e0032f8d1426c5b1d8b45c336198a9a3d24e32bf27a (3518a2576bcc63d5a312d8257279ef3541937813)
- 9f813b4cf8b1b68c2650dd47e56fdf0f4fa11f0b634b34e113fe049555cf6b0b (80075679b76b6d7a771dee7bfc9e525c8541bdfd)
- d25a8a27251818db766711c75332797dec2e5ed154c67d95a710789abc38f1ae (0353ff565da94cf44d164cee53bedd43c546c686)
- 94fb5f9916f447a06693ce9c08cae6053e3a53b089203b07a77e111c7a48b7c5 (b1e50c8ba84454750277304e2add63747f969c24)
- 9d957872a8d1cfe72442bd80af4bdabc68981732ae7b7ef38873f94c71a6e38d (78524c180ca1df1858325b21b1dcd51a48d0b832)
- 7b42a8465af5b9c3bbbd35e9ee8f6eabdb767b82963b9a1edd54c0d0db78f519 (56ce03842edf47a35ff25011e68d7d33a91b9fd4)
- f6a04718ee1cee63d5f53af5947287d0e196e596e2c9bd2073739f83e55a2151 (45df139ec12e1336e6508de6ea5dccb3b72dce21)
- 2a2858e318f8a24f63e670f274755189120791789d2895e55be73e616158ec99 (9410754a0b758ea43963ddc4bf4575258e30b991)
- eafc0cef804e25c9a59bb93bf2ce5891c01e84b35c44e013fb0fcb88c091ed8a (84ceb6e6903ed2bb5717a9880353c9d0d63cfd14)
- 125c0488499d6143b105d0abc6a2f4cad87f10ad297c5bd54f0b64b882043133 (ed9d66f64a08d44ed8311fe81d7aa8050a6bd555)
- d79e8401776d882537e4f9e8b5e318b6017157ad821c12616710c8ead8005262 (051417feb00ee958d6fbaa78b8b1f5723593047a)
- d4428c1827509739b091f111ac62cbc9473f2af30cd19df41af86ce4f0aac3ac (9385b680a01363dadff5316da1b45b872d1c78da)
- f96abbc6a4b337e150b0a6f74d41fc562fa78901362eff8f935f731ee83c8d32 (fc44d6f314c4fe9414424fed7dbdd107b2157b0c)
- ab5595192559ac1a498ecc905c693ee5da489683cb13d9e4f70c8e4787a08e84 (d82ce16a74aaf42e3b5197dc20ae786ee12b56da)
- 691a76ad38e7f277af60b80f88e313c3b87a5ac30404aa5e9ccd1a821900f690 (03d5d897031cb74afc71644c7992cdd100f32609)
- 5f6ec4caedf83a3b9abd85036479f89d1fbe048f32bf90b09ed9dbff1434f198 (18700e75e8dc8bab031935b7ebbd733381bde785)
- ad3cefbea0cffd469a20aec62ec552f05c4aeec849036134b0774f527363947c (0ac1b1d0362aeacba7440ed5c46b41f460c00887)
- 1fd5cf87b52feeeb909bc84cc1c095d45a7fe8daadbed67fa2759f70108ac491 (d4c05cda6218d732d66a7f5ea51107e929f94c34)
- c6dfaa42dbb815ebb99abbb8cf8fe821fc45ef8208f95baf5f87b2ea762cfdee (173b2db59c6f2136a07d67506f3b70220d1a5928)
- 327a839011e474979217760c545d723bbb788967992f0ac692e4b3c48d9dd9b4 (269e400d6d434f77586d58bf853059a3aedfde52)
- 62e8bae67a6a2e2bfa75077f462b466439d3b603753904ec73a0a30250993387 (8a678657f6d88989dfe54e39aac033d2e70d7d28)
- fe2f79481822147e39f82eea4b911b45eb34b349efe358d2c523b96a3fd80571 (8537475fa7301119c409ec1adc68d9fd70043798)
- b8b16f6cd140bd482475a53ddaf68e0469a9011a7779f955a3c51a0103feb946 (03d5ebf0693a875f460168561bf5b95975d101c9)
- a46cc9965b4d399b3760d64abf4a24f368d62d727ba52c0c14045ba05b207b8f (82c0f6c9cce2ace10ee83e66b2e2a00a68144240)
- cb1175da0a7309cae6b5b239d54f8922be35970697125993c879327f694d03e1 (713781f6b3cc63827644f0ed02f6039fe61df6a5)
- 5081d0a653640f56b2f3eb3b1036115a02dbe906c74d4f1c818560c6b73d8a70 (e4f7258fba290f62526114265327167565b7d6fa)
- 8c5129423a57ac8c4dca1e65620c05adf6f1cbd624645a01dd03c5e5bc058d8b (fcfbafd24a19af4913dbd928a36f9d952f58c652)
- f35cb770179564361359a27b337e92fc9fd5a30b6ebdce95650998efb5c8a9c6 (61cb7b39db104df88a26ceb308d387c65ee5dfe0)
- fc8cc9b183f7e8ff0520eeb4452013ff0c1da97a790bb63e27be39002d0dce04 (6265ee5f69b2ed60a5e785ee765dfb979462730c)
- ee820fb0057f7294c3a40d195ba1b6c55e5aa7ed299f4a43190900a4043f92b7 (a737bd9a5af19cb5a1acf412ffa89efb1eca9e09)
- 66c2c9caf035f38a3689f26af8f9bb8b3b638d49459c292809b22417aacab0d6 (9cf38ebab276f2bbfa08008e1f01c9859b2b9a91)
- 882cd5887028c73f672dfc62dce8f2caccf201beab3a2ccbec9835932993e229 (71d39bbe94a51963f1f40343cf09b5e3de094e99)
- bd10a2393aebf0d10a96aa795e603fa632555e088590388b7ae13f36dedfd3a0 (ea91aa2508545f805f2bcfa4cda1c496fcaa7f40)
- 8bb324f87650de78aff21b8630cdb906f4ecf8e89b0b58d903f81632883d92fd (99a4cfecdbe7cd81ace3416e91a9bce6ffb6cc47)
- a1c56cbf345b93554b5c0a0d189e694601913359a6387424bc4a9d53971d6d27 (71f6d19caf6d5bfb2481f228f8e0ccdd694bf4c2)
- 15d12d23663866ea9352d97c0bc14cc1af0664bbd7640feebb3bebaa5181645a (d5447884e6009d5613092f75b7d9c7be03b3f39b)
- 547784d9ad8a093d6c7ef661df30ee89480ad1b53d4aabf0c289810aad4bb200 (1e0a71b4430e0fb4603669d20c1adead564014bb)
- e6dbb1fef7a153779585553550c0165b8b4f0713b3f8f69dcb954dac2b8abf23 (787d9a8cd403d8ab32b0d10a4906c117fbf2cc7e)
- 1f3e00d60711eea36626167c2a4e62ffa55b899e5f1b19a5aed0291a513db8ed (9f6d2471a7cc23a4c0b021b6f7f0ba5d20f17bd1)
- 91d8ba13af8d3fc37ddf0df4fd21a975bab53ce8df4c6e4c0cbe898f928fef0b (434a797575d14b302f78128d853b25121e78ac1f)
- 61838a48aa440bbb958ddb1960152f470e106a8642a927573921f5a14eb395a0 (a9a3f7373637f499dbad0bdd9886cc941895c93c)
- 933e5f6ad3865341fbe5f6fc68c602f1329c75a4b501a4d8de35292cf28fdfca (12f72b70f59948c2ab995e43a8b32da6448cf7a4)
- 6921ac405dc2a28b3898aea80ef36b99eadb8ea960d230b913e1a6cca56bbfe5 (26189145f005f38c8cdd6d45002c3771633a5292)
- 7763265e7706f2bce75530b24b89637a4effb2217ea0d5ccd5b8ad4be21057a8 (4461db0a4c205b3f7b7371e049d0014a1c6bd25d)
- d87f84dd7b77f46fe6f5baf4823a194f7ce7939556ff0369e9dd38c3c059d9c5 (471e459a4cac7804ec99fb28097b3a9fb350e666)
- 2081b42ff277f260028bf8d71da124f2d941925c2fac45dd2494a9eae4dc7f73 (0bcdbfca59152312578d8d615f1f874f867442b7)
- 4cdee9c9b19362e8c3e5788c9f7229cbfd0e7b42b2390406219f047de62f2111 (2c6d105de52a97c8f80d95a29906354af2cd5b18)
- 4e4b7b69db35bb5c463f1503ba2ddd7a30ade2565ee9418749951fb270d7f63a (7f75dd1f3a3fe08c385b925ca3ce98ca6000091b)
- 915698ad5bce7f3c862e93266a918d8c921fb5612a6adba2c4983a36e914f580 (e156afc4f8dfcc09cdbad8a096d18e833f59377e)
- dcbf7733203f7fa96b3758df957db6c9c5df613e8f3c1adbe524f0ec9e30aec6 (8835d4a648e01873b7225a98c04b0cf9b1057030)
- 4b7f12f6f9b8d96202a9940d257b76effbc443fc5ed3026f42b211f22ef0e412 (c59b859c39d259f794a5f557698926c8c5f256bf)
- 280683d634d93e294ede67c37f2c78e9833a2d642cf297d0b164969f3c4ca84f (40c1498d5ca5373090951c9a64dfc2d06f0fb708)
- 87ad050f1c232289d79f96521c1a8b2f1e85f1b8c7c610802852ac0eddb0e533 (0902e558132e0f7eceb2dc74a57252b80ca8de75)
- 522d040300b5ee19e831bf5edb526b71afa6cdae37f73e9e86eae6ce080b1d0f (c9338e9dcaaa12aff40ae47ed431098c207e9050)
- 71aa1e18dc90216ec14cf5e45f7107f04706961ed229f4b33e24afc2acc445a9 (bdeac5594b1a2b2187b5d8049ee17eaa5cc4efb3)
- 66e5c610f3dc2ce92aa5594ed50d01f995e0418cbb916a7cd88907cc20945e1f (781875521e92d6784dc8a9bd212c03a4a9e2e3fe)
- e2ec34a652004fa63d71c6e540533af9d01c6d0bfd324e9a4348adadb440141e (ef0758d3bd0c0f063663512b9357cedfb6ee950e)
- 06c6b6e3b67360a3de065ba177977c07047dfee394ee3a58c4b008d62ec5a46c (603de8e3f336d361f584646ac11dd51ffb8047af)
- c9971ce5330913107bfb5e4f45fca3f9a5c57e9eac9b99a5bdcdaddd3b74c2f9 (d60c8e433079603eca5d681564e6e0758d04c3a0)
- 777437eeca03308b6027523feea509ceb4546c4495d61f877e9ada901ad303ab (50fdebc04bb05dd92cf908333a714117d0f9784c)
- 535abdb9083eaed7ae028e53cd9450a43798ae202570f710bdbe57a724b12e47 (723c3daf6fe21a67031655f60738a39682fecc7c)
- e13a39e7cf8697228652e34866a8546e37ac97a39ca9cd45afbcba4a15c33f60 (2bc54c71c738d6f6158b8269d54cd1a790b2b322)
- a712baac141437803214ed623433bcd9ec04ca000c7eda39f33f81838c46f3d2 (f49214b9cb5808aacbe1184ee3b11e7484e864b5)
- 9c72f8aa408a516ca83bab7df52c229f4d8e8491b118e2887d3f94462af70ed8 (d51330ea04b497391d1b42defb1e91d7a8f62874)
- d7c1bfb4298b6a78d441e891b8012e2169359bf1956b3095f59971691df224b7 (d148a75472db16f33266832278c1ffda78551cee)
- d3ce69adf84eb6ba8269dd1d1fda8baa26a46afa2f4dabe3287c198c3f72bd56 (4b45703bf2c46f335510b6db07a6f98e90d96e4a)
- 4023396a8ca8105460f07fdf79fc343d55600c8e897fae0f15ef9447e938f978 (b356c44c85d9d654476bcce191756cfc457fd20f)
- 3553c5aa50058c1726ced12ff1d60903b92a9840e78d7f3a72af7cfe0f7015a2 (ad891a1810c210daf9e9ae9edd499f3a5e270555)
- bdd3a489b5ffccf9a584e67062f90c5c2e713d5b8900ceb8668b848919193e67 (e489981a15e70ba487f1b64153f73b864e53e3a9)
- 08546b8d569655c81daa3d8df7879c65a03e74f35448f5b154465a7a11d4066f (fa4944dc9e4138338e8da6337bafce44b180e7b8)
- ce7110a255ce401609500a35ef8e3d06db706c68999c31ac18b5bf7ffb9cb16e (74c140cbe35d5bde83d7d4b6a5b142d3276d6ad0)
- 98af78a4f3c383ce8ddbd71089f447b2d034c77cefd3289b7a516b3e8fedb6ae (ca73a0d823510168c97b7e84a0a7e8bfaf2d0f07)
- 7e1541c4e59861bcdc7ea662b8634b067d2e330613aadcee07a80e2250c51ac7 (c65d93e73ce371bec992974e399fadb75849d7f1)
- 85ab2f14035c69cc871c66640f36fb9ce2c44b18a12c99eaed5529ac25b263d4 (bf220846cddab7c5a1aa4644585b599949339ebc)
- e5388fa1e1aaf90e6b45d03a5406cc34376a6f0f3af85440148c19703f079525 (0541e2bf4e03da3c13e1fdc41c8527ea6318204c)
- 07f2d5d3f62ae322a9599f19f4af52522f2c98eebf001360b8c7427af4c8f6cb (13065a1b4444b37904739de54633b25a2a413870)
- 21da302dccb6e851a71f0abbe1a7a317befbd2f905e1b2cd5f94afa84a20bc8d (35234fb13998eebd67a9d456de9a377516ac1270)
- ab793765b3d74e07d56258a7a531d7f64a36b2b98521782000fb9a610c934cc3 (51949dcfcb54eef8eac9fde406c0ffa5168ca8af)
- 481db94ba864163d857b0d1c846d0fb7d19855504d4c9925143816d8a864373d (0bd0d2d4b31465e0f746d607a8de09183b9ee9d8)
- 3cfbbc1eb6228763b97f48964f310be97730ae77b7aa1853ac570a89ff1c1c8d (99531e2247bfbcb8843425b312ec227df68de975)
- 606e600402bbef8e2a3d565cb11e156adc6b29f232788e841298d7fd8bead5c2 (5e40247447d89a72bf1b20f7903df98e53dcb0f9)
- 33de2a8fb960aea29c487739242d50bda1e964636a9d60763bf0a0a9adaab34e (e38e50ca7707548cc3bbacf10fdb0f1f8f64bb09)
- 45d2303d438b69d811587ad4f88d372a63dcf4983b9c287434e5a01ebe5d2c12 (3955d7a0f27f2a887a4c6a7565cbbeb1f32d83ec)
- 2652710804a318b68c69757f8879b0a1e81f3608d487651ce4290c0bf0cacf77 (37ef98bb91015b442e1bfc775d1bb08a1cc61e6e)
- eab8b9631435f3fc0285f8b2dd758c188510898b0b1199888a13d91d6ff366a6 (5525ee7a7e137045e3816b99a7e99de3dbfd9f4d)
- a5cef1a1c11c3729674467378e568b54d6282fa9cbd5bc5d009c9959f20d3fcc (69a7927b6

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +0 -0
  2. 3DUnet_Like/__pycache__/trainer.cpython-39.pyc +0 -0
  3. 3DUnet_Like/dataset/__pycache__/utils.cpython-39.pyc +0 -0
  4. 3DUnet_Like/dataset/brats.py +31 -0
  5. 3DUnet_Like/dataset/utils.py +100 -0
  6. 3DUnet_Like/logs/SegTransVAE/version_0/events.out.tfevents.1710047381.speech-demo.148199.0 +0 -0
  7. 3DUnet_Like/logs/SegTransVAE/version_0/hparams.yaml +1 -0
  8. 3DUnet_Like/logs/SegTransVAE/version_0/metric_log.csv +2 -0
  9. 3DUnet_Like/loss/__init__.py +0 -0
  10. 3DUnet_Like/loss/__pycache__/__init__.cpython-39.pyc +0 -0
  11. 3DUnet_Like/loss/__pycache__/loss.cpython-39.pyc +0 -0
  12. 3DUnet_Like/loss/loss.py +55 -0
  13. 3DUnet_Like/models/SegTranVAE/SegTranVAE.py +538 -0
  14. 3DUnet_Like/models/SegTranVAE/__init__.py +0 -0
  15. 3DUnet_Like/models/SegTranVAE/__pycache__/SegTranVAE.cpython-39.pyc +0 -0
  16. 3DUnet_Like/models/SegTranVAE/__pycache__/__init__.cpython-39.pyc +0 -0
  17. 3DUnet_Like/train.py +69 -0
  18. 3DUnet_Like/trainer.py +233 -0
  19. brats_2021_task1/BraTS2021_Training_Data/.DS_Store +0 -0
  20. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00000/BraTS2021_00000_flair.nii.gz +3 -0
  21. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00000/BraTS2021_00000_seg.nii.gz +0 -0
  22. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00000/BraTS2021_00000_t1.nii.gz +3 -0
  23. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00000/BraTS2021_00000_t1ce.nii.gz +3 -0
  24. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00000/BraTS2021_00000_t2.nii.gz +3 -0
  25. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00002/BraTS2021_00002_flair.nii.gz +3 -0
  26. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00002/BraTS2021_00002_seg.nii.gz +0 -0
  27. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00002/BraTS2021_00002_t1.nii.gz +3 -0
  28. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00002/BraTS2021_00002_t1ce.nii.gz +3 -0
  29. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00002/BraTS2021_00002_t2.nii.gz +3 -0
  30. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00003/BraTS2021_00003_flair.nii.gz +3 -0
  31. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00003/BraTS2021_00003_seg.nii.gz +0 -0
  32. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00003/BraTS2021_00003_t1.nii.gz +3 -0
  33. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00003/BraTS2021_00003_t1ce.nii.gz +3 -0
  34. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00003/BraTS2021_00003_t2.nii.gz +3 -0
  35. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00005/BraTS2021_00005_flair.nii.gz +3 -0
  36. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00005/BraTS2021_00005_seg.nii.gz +0 -0
  37. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00005/BraTS2021_00005_t1.nii.gz +3 -0
  38. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00005/BraTS2021_00005_t1ce.nii.gz +3 -0
  39. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00005/BraTS2021_00005_t2.nii.gz +3 -0
  40. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00006/BraTS2021_00006_flair.nii.gz +3 -0
  41. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00006/BraTS2021_00006_seg.nii.gz +0 -0
  42. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00006/BraTS2021_00006_t1.nii.gz +3 -0
  43. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00006/BraTS2021_00006_t1ce.nii.gz +3 -0
  44. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00006/BraTS2021_00006_t2.nii.gz +3 -0
  45. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00008/BraTS2021_00008_flair.nii.gz +3 -0
  46. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00008/BraTS2021_00008_seg.nii.gz +0 -0
  47. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00008/BraTS2021_00008_t1.nii.gz +3 -0
  48. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00008/BraTS2021_00008_t1ce.nii.gz +3 -0
  49. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00008/BraTS2021_00008_t2.nii.gz +3 -0
  50. brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00009/BraTS2021_00009_flair.nii.gz +3 -0
.gitignore ADDED
File without changes
3DUnet_Like/__pycache__/trainer.cpython-39.pyc ADDED
Binary file (7.55 kB). View file
 
3DUnet_Like/dataset/__pycache__/utils.cpython-39.pyc ADDED
Binary file (2.66 kB). View file
 
3DUnet_Like/dataset/brats.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from monai.transforms import MapTransform
3
+
4
+
5
+ class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform):
6
+ """
7
+ Convert labels to multi channels based on brats classes:
8
+ label 1 is the necrotic and non-enhancing tumor core
9
+ label 2 is the peritumoral edema
10
+ label 4 is the GD-enhancing tumor
11
+ The possible classes are TC (Tumor core), WT (Whole tumor)
12
+ and ET (Enhancing tumor).
13
+
14
+ """
15
+
16
+ def __call__(self, data):
17
+ d = dict(data)
18
+ for key in self.keys:
19
+ result = []
20
+ # merge label 1 and label 4 to construct TC
21
+ result.append(np.logical_or(d[key] == 1, d[key] == 4))
22
+ # merge labels 1, 2 and 4 to construct WT
23
+ result.append(
24
+ np.logical_or(
25
+ np.logical_or(d[key] == 1, d[key] == 4), d[key] == 2
26
+ )
27
+ )
28
+ # label 4 is ET
29
+ result.append(d[key] == 4)
30
+ d[key] = np.stack(result, axis=0).astype(np.float32)
31
+ return d
3DUnet_Like/dataset/utils.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from sklearn.model_selection import train_test_split
4
+
5
+ from monai.data import DataLoader, Dataset
6
+ from monai import transforms
7
+
8
+ def datafold_read(datalist, basedir, fold=0, key="training"):
9
+ with open(datalist) as f:
10
+ json_data = json.load(f)
11
+
12
+ json_data = json_data[key]
13
+
14
+ for d in json_data:
15
+ for k in d:
16
+ if isinstance(d[k], list):
17
+ d[k] = [os.path.join(basedir, iv) for iv in d[k]]
18
+ elif isinstance(d[k], str):
19
+ d[k] = os.path.join(basedir, d[k]) if len(d[k]) > 0 else d[k]
20
+
21
+ tr = []
22
+ val = []
23
+ for d in json_data:
24
+ if "fold" in d and d["fold"] == fold:
25
+ val.append(d)
26
+ else:
27
+ tr.append(d)
28
+
29
+ return tr, val
30
+
31
+
32
+ def split_train_test(datalist, basedir, fold,test_size = 0.2, volume : float = None) :
33
+ train_files, _ = datafold_read(datalist=datalist, basedir=basedir, fold=fold)
34
+ if volume != None :
35
+ train_files, _ = train_test_split(train_files,test_size=volume,random_state=42)
36
+
37
+ train_files,validation_files = train_test_split(train_files,test_size=test_size, random_state=42)
38
+
39
+ validation_files,test_files = train_test_split(validation_files,test_size=test_size, random_state=42)
40
+ return train_files, validation_files, test_files
41
+
42
+
43
+ def get_loader(batch_size, data_dir, json_list, fold, roi,volume :float = None,test_size = 0.2):
44
+ train_files,validation_files,test_files = split_train_test(datalist = json_list,basedir = data_dir,test_size=test_size,fold = fold,volume= volume)
45
+
46
+ train_transform = transforms.Compose(
47
+ [
48
+ transforms.LoadImaged(keys=["image", "label"]),
49
+ transforms.ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
50
+ transforms.CropForegroundd(
51
+ keys=["image", "label"],
52
+ source_key="image",
53
+ k_divisible=[roi[0], roi[1], roi[2]],
54
+ ),
55
+ transforms.RandSpatialCropd(
56
+ keys=["image", "label"],
57
+ roi_size=[roi[0], roi[1], roi[2]],
58
+ random_size=False,
59
+ ),
60
+ transforms.RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
61
+ transforms.RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
62
+ transforms.RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
63
+ transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
64
+ transforms.RandScaleIntensityd(keys="image", factors=0.1, prob=1.0),
65
+ transforms.RandShiftIntensityd(keys="image", offsets=0.1, prob=1.0),
66
+ ]
67
+ )
68
+ val_transform = transforms.Compose(
69
+ [
70
+ transforms.LoadImaged(keys=["image", "label"]),
71
+ transforms.ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
72
+ transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
73
+ ]
74
+ )
75
+
76
+ train_ds = Dataset(data=train_files, transform=train_transform)
77
+ train_loader = DataLoader(
78
+ train_ds,
79
+ batch_size=batch_size,
80
+ shuffle=True,
81
+ num_workers=2,
82
+ pin_memory=True,
83
+ )
84
+ val_ds = Dataset(data=validation_files, transform=val_transform)
85
+ val_loader = DataLoader(
86
+ val_ds,
87
+ batch_size=1,
88
+ shuffle=False,
89
+ num_workers=2,
90
+ pin_memory=True,
91
+ )
92
+ test_ds = Dataset(data=test_files, transform=val_transform)
93
+ test_loader = DataLoader(
94
+ test_ds,
95
+ batch_size=1,
96
+ shuffle=False,
97
+ num_workers=2,
98
+ pin_memory=True,
99
+ )
100
+ return train_loader, val_loader,test_loader
3DUnet_Like/logs/SegTransVAE/version_0/events.out.tfevents.1710047381.speech-demo.148199.0 ADDED
Binary file (117 kB). View file
 
3DUnet_Like/logs/SegTransVAE/version_0/hparams.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ {}
3DUnet_Like/logs/SegTransVAE/version_0/metric_log.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ Epoch,Mean Dice Score,Dice TC,Dice WT,Dice ET
2
+ 0,0.004601036664098501,0.0006361556006595492,0.012770041823387146,0.0003969123645219952
3DUnet_Like/loss/__init__.py ADDED
File without changes
3DUnet_Like/loss/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (124 Bytes). View file
 
3DUnet_Like/loss/__pycache__/loss.cpython-39.pyc ADDED
Binary file (2.16 kB). View file
 
3DUnet_Like/loss/loss.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ class Loss_VAE(nn.Module):
5
+ def __init__(self):
6
+ super().__init__()
7
+ self.mse = nn.MSELoss(reduction='sum')
8
+
9
+ def forward(self, recon_x, x, mu, log_var):
10
+ mse = self.mse(recon_x, x)
11
+ kld = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
12
+ loss = mse + kld
13
+ return loss
14
+
15
+
16
+ def DiceScore(
17
+ y_pred: torch.Tensor,
18
+ y: torch.Tensor,
19
+ include_background: bool = True,
20
+ ) -> torch.Tensor:
21
+ """Computes Dice score metric from full size Tensor and collects average.
22
+ Args:
23
+ y_pred: input data to compute, typical segmentation model output.
24
+ It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values
25
+ should be binarized.
26
+ y: ground truth to compute mean dice metric. It must be one-hot format and first dim is batch.
27
+ The values should be binarized.
28
+ include_background: whether to skip Dice computation on the first channel of
29
+ the predicted output. Defaults to True.
30
+ Returns:
31
+ Dice scores per batch and per class, (shape [batch_size, num_classes]).
32
+ Raises:
33
+ ValueError: when `y_pred` and `y` have different shapes.
34
+ """
35
+
36
+ y = y.float()
37
+ y_pred = y_pred.float()
38
+
39
+ if y.shape != y_pred.shape:
40
+ raise ValueError("y_pred and y should have same shapes.")
41
+
42
+ # reducing only spatial dimensions (not batch nor channels)
43
+ n_len = len(y_pred.shape)
44
+ reduce_axis = list(range(2, n_len))
45
+ intersection = torch.sum(y * y_pred, dim=reduce_axis)
46
+
47
+ y_o = torch.sum(y, reduce_axis)
48
+ y_pred_o = torch.sum(y_pred, dim=reduce_axis)
49
+ denominator = y_o + y_pred_o
50
+
51
+ return torch.where(
52
+ denominator > 0,
53
+ (2.0 * intersection) / denominator,
54
+ torch.tensor(float("1"), device=y_o.device),
55
+ )
3DUnet_Like/models/SegTranVAE/SegTranVAE.py ADDED
@@ -0,0 +1,538 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+ import torch
5
+ from einops import rearrange
6
+ import torch
7
+ import torch.nn as nn
8
+ ###########Resnet Block############
9
+ def normalization(planes, norm = 'instance'):
10
+ if norm == 'bn':
11
+ m = nn.BatchNorm3d(planes)
12
+ elif norm == 'gn':
13
+ m = nn.GroupNorm(8, planes)
14
+ elif norm == 'instance':
15
+ m = nn.InstanceNorm3d(planes)
16
+ else:
17
+ raise ValueError("Does not support this kind of norm.")
18
+ return m
19
+ class ResNetBlock(nn.Module):
20
+ def __init__(self, in_channels, norm = 'instance'):
21
+ super().__init__()
22
+ self.resnetblock = nn.Sequential(
23
+ normalization(in_channels, norm = norm),
24
+ nn.LeakyReLU(0.2, inplace=True),
25
+ nn.Conv3d(in_channels, in_channels, kernel_size = 3, padding = 1),
26
+ normalization(in_channels, norm = norm),
27
+ nn.LeakyReLU(0.2, inplace=True),
28
+ nn.Conv3d(in_channels, in_channels, kernel_size = 3, padding = 1)
29
+ )
30
+
31
+ def forward(self, x):
32
+ y = self.resnetblock(x)
33
+ return y + x
34
+
35
+
36
+ ##############VAE###############
37
+ def calculate_total_dimension(a):
38
+ res = 1
39
+ for x in a:
40
+ res *= x
41
+ return res
42
+
43
+ class VAE(nn.Module):
44
+ def __init__(self, input_shape, latent_dim, num_channels):
45
+ super().__init__()
46
+ self.input_shape = input_shape
47
+ self.in_channels = input_shape[1] #input_shape[0] is batch size
48
+ self.latent_dim = latent_dim
49
+ self.encoder_channels = self.in_channels // 16
50
+
51
+ #Encoder
52
+ self.VAE_reshape = nn.Conv3d(self.in_channels, self.encoder_channels,
53
+ kernel_size = 3, stride = 2, padding=1)
54
+ # self.VAE_reshape = nn.Sequential(
55
+ # nn.GroupNorm(8, self.in_channels),
56
+ # nn.ReLU(),
57
+ # nn.Conv3d(self.in_channels, self.encoder_channels,
58
+ # kernel_size = 3, stride = 2, padding=1),
59
+ # )
60
+
61
+ flatten_input_shape = calculate_total_dimension(input_shape)
62
+ flatten_input_shape_after_vae_reshape = \
63
+ flatten_input_shape * self.encoder_channels // (8 * self.in_channels)
64
+
65
+ #Convert from total dimension to latent space
66
+ self.to_latent_space = nn.Linear(
67
+ flatten_input_shape_after_vae_reshape // self.in_channels, 1)
68
+
69
+ self.mean = nn.Linear(self.in_channels, self.latent_dim)
70
+ self.logvar = nn.Linear(self.in_channels, self.latent_dim)
71
+ # self.epsilon = nn.Parameter(torch.randn(1, latent_dim))
72
+
73
+ #Decoder
74
+ self.to_original_dimension = nn.Linear(self.latent_dim, flatten_input_shape_after_vae_reshape)
75
+ self.Reconstruct = nn.Sequential(
76
+ nn.LeakyReLU(0.2, inplace=True),
77
+ nn.Conv3d(
78
+ self.encoder_channels, self.in_channels,
79
+ stride = 1, kernel_size = 1),
80
+ nn.Upsample(scale_factor=2, mode = 'nearest'),
81
+
82
+ nn.Conv3d(
83
+ self.in_channels, self.in_channels // 2,
84
+ stride = 1, kernel_size = 1),
85
+ nn.Upsample(scale_factor=2, mode = 'nearest'),
86
+ ResNetBlock(self.in_channels // 2),
87
+
88
+ nn.Conv3d(
89
+ self.in_channels // 2, self.in_channels // 4,
90
+ stride = 1, kernel_size = 1),
91
+ nn.Upsample(scale_factor=2, mode = 'nearest'),
92
+ ResNetBlock(self.in_channels // 4),
93
+
94
+ nn.Conv3d(
95
+ self.in_channels // 4, self.in_channels // 8,
96
+ stride = 1, kernel_size = 1),
97
+ nn.Upsample(scale_factor=2, mode = 'nearest'),
98
+ ResNetBlock(self.in_channels // 8),
99
+
100
+ nn.InstanceNorm3d(self.in_channels // 8),
101
+ nn.LeakyReLU(0.2, inplace=True),
102
+ nn.Conv3d(
103
+ self.in_channels // 8, num_channels,
104
+ kernel_size = 3, padding = 1),
105
+ # nn.Sigmoid()
106
+ )
107
+
108
+
109
+ def forward(self, x): #x has shape = input_shape
110
+ #Encoder
111
+ # print(x.shape)
112
+ x = self.VAE_reshape(x)
113
+ shape = x.shape
114
+
115
+ x = x.view(self.in_channels, -1)
116
+ x = self.to_latent_space(x)
117
+ x = x.view(1, self.in_channels)
118
+
119
+ mean = self.mean(x)
120
+ logvar = self.logvar(x)
121
+ # sigma = torch.exp(0.5 * logvar)
122
+ # Reparameter
123
+ epsilon = torch.randn_like(logvar)
124
+ sample = mean + epsilon * torch.exp(0.5*logvar)
125
+
126
+ #Decoder
127
+ y = self.to_original_dimension(sample)
128
+ y = y.view(*shape)
129
+ return self.Reconstruct(y), mean, logvar
130
+ def total_params(self):
131
+ total = sum(p.numel() for p in self.parameters())
132
+ return format(total, ',')
133
+
134
+ def total_trainable_params(self):
135
+ total_trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
136
+ return format(total_trainable, ',')
137
+
138
+
139
+ # x = torch.rand((1, 256, 16, 16, 16))
140
+ # vae = VAE(input_shape = x.shape, latent_dim = 256, num_channels = 4)
141
+ # y = vae(x)
142
+ # print(y[0].shape, y[1].shape, y[2].shape)
143
+ # print(vae.total_trainable_params())
144
+
145
+
146
+ ### Decoder ####
147
+
148
+
149
+
150
+ class Upsample(nn.Module):
151
+ def __init__(self, in_channel, out_channel):
152
+ super().__init__()
153
+ self.conv1 = nn.Conv3d(in_channel, out_channel, kernel_size = 1)
154
+ self.deconv = nn.ConvTranspose3d(out_channel, out_channel, kernel_size = 2, stride = 2)
155
+ self.conv2 = nn.Conv3d(out_channel * 2, out_channel, kernel_size = 1)
156
+
157
+ def forward(self, prev, x):
158
+ x = self.deconv(self.conv1(x))
159
+ y = torch.cat((prev, x), dim = 1)
160
+ return self.conv2(y)
161
+
162
+ class FinalConv(nn.Module): # Input channels are equal to output channels
163
+ def __init__(self, in_channels, out_channels=32, norm="instance"):
164
+ super(FinalConv, self).__init__()
165
+ if norm == "batch":
166
+ norm_layer = nn.BatchNorm3d(num_features=in_channels)
167
+ elif norm == "group":
168
+ norm_layer = nn.GroupNorm(num_groups=8, num_channels=in_channels)
169
+ elif norm == 'instance':
170
+ norm_layer = nn.InstanceNorm3d(in_channels)
171
+
172
+ self.layer = nn.Sequential(
173
+ norm_layer,
174
+ nn.LeakyReLU(0.2, inplace=True),
175
+ nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
176
+ )
177
+ def forward(self, x):
178
+ return self.layer(x)
179
+
180
+ class Decoder(nn.Module):
181
+ def __init__(self, img_dim, patch_dim, embedding_dim, num_classes = 3):
182
+ super().__init__()
183
+ self.img_dim = img_dim
184
+ self.patch_dim = patch_dim
185
+ self.embedding_dim = embedding_dim
186
+
187
+ self.decoder_upsample_1 = Upsample(128, 64)
188
+ self.decoder_block_1 = ResNetBlock(64)
189
+
190
+ self.decoder_upsample_2 = Upsample(64, 32)
191
+ self.decoder_block_2 = ResNetBlock(32)
192
+
193
+ self.decoder_upsample_3 = Upsample(32, 16)
194
+ self.decoder_block_3 = ResNetBlock(16)
195
+
196
+ self.endconv = FinalConv(16, num_classes)
197
+ # self.normalize = nn.Sigmoid()
198
+
199
+ def forward(self, x1, x2, x3, x):
200
+ x = self.decoder_upsample_1(x3, x)
201
+ x = self.decoder_block_1(x)
202
+
203
+ x = self.decoder_upsample_2(x2, x)
204
+ x = self.decoder_block_2(x)
205
+
206
+ x = self.decoder_upsample_3(x1, x)
207
+ x = self.decoder_block_3(x)
208
+
209
+ y = self.endconv(x)
210
+ return y
211
+
212
+
213
+
214
+ ###############Encoder##############
215
+ class InitConv(nn.Module):
216
+ def __init__(self, in_channels = 4, out_channels = 16, dropout = 0.2):
217
+ super().__init__()
218
+ self.layer = nn.Sequential(
219
+ nn.Conv3d(in_channels, out_channels, kernel_size = 3, padding = 1),
220
+ nn.Dropout3d(dropout)
221
+ )
222
+ def forward(self, x):
223
+ y = self.layer(x)
224
+ return y
225
+
226
+
227
+ class DownSample(nn.Module):
228
+ def __init__(self, in_channels, out_channels):
229
+ super().__init__()
230
+ self.conv = nn.Conv3d(in_channels, out_channels, kernel_size = 3, stride = 2, padding = 1)
231
+ def forward(self, x):
232
+ return self.conv(x)
233
+
234
+ class Encoder(nn.Module):
235
+ def __init__(self, in_channels, base_channels, dropout = 0.2):
236
+ super().__init__()
237
+
238
+ self.init_conv = InitConv(in_channels, base_channels, dropout = dropout)
239
+ self.encoder_block1 = ResNetBlock(in_channels = base_channels)
240
+ self.encoder_down1 = DownSample(base_channels, base_channels * 2)
241
+
242
+ self.encoder_block2_1 = ResNetBlock(base_channels * 2)
243
+ self.encoder_block2_2 = ResNetBlock(base_channels * 2)
244
+ self.encoder_down2 = DownSample(base_channels * 2, base_channels * 4)
245
+
246
+ self.encoder_block3_1 = ResNetBlock(base_channels * 4)
247
+ self.encoder_block3_2 = ResNetBlock(base_channels * 4)
248
+ self.encoder_down3 = DownSample(base_channels * 4, base_channels * 8)
249
+
250
+ self.encoder_block4_1 = ResNetBlock(base_channels * 8)
251
+ self.encoder_block4_2 = ResNetBlock(base_channels * 8)
252
+ self.encoder_block4_3 = ResNetBlock(base_channels * 8)
253
+ self.encoder_block4_4 = ResNetBlock(base_channels * 8)
254
+ # self.encoder_down3 = EncoderDown(base_channels * 8, base_channels * 16)
255
+ def forward(self, x):
256
+ x = self.init_conv(x) #(1, 16, 128, 128, 128)
257
+
258
+ x1 = self.encoder_block1(x)
259
+ x1_down = self.encoder_down1(x1) #(1, 32, 64, 64, 64)
260
+
261
+ x2 = self.encoder_block2_2(self.encoder_block2_1(x1_down))
262
+ x2_down = self.encoder_down2(x2) #(1, 64, 32, 32, 32)
263
+
264
+ x3 = self.encoder_block3_2(self.encoder_block3_1(x2_down))
265
+ x3_down = self.encoder_down3(x3) #(1, 128, 16, 16, 16)
266
+
267
+ output = self.encoder_block4_4(
268
+ self.encoder_block4_3(
269
+ self.encoder_block4_2(
270
+ self.encoder_block4_1(x3_down)))) #(1, 256, 16, 16, 16)
271
+ return x1, x2, x3, output
272
+
273
+ # x = torch.rand((1, 4, 128, 128, 128))
274
+ # Enc = Encoder(4, 32)
275
+ # _, _, _, y = Enc(x)
276
+ # print(y.shape) (1,256,16,16)
277
+
278
+
279
+ ###############FeatureMapping###############
280
+
281
+ class FeatureMapping(nn.Module):
282
+ def __init__(self, in_channel, out_channel, norm = 'instance'):
283
+ super().__init__()
284
+ if norm == 'bn':
285
+ norm_layer_1 = nn.BatchNorm3d(out_channel)
286
+ norm_layer_2 = nn.BatchNorm3d(out_channel)
287
+ elif norm == 'gn':
288
+ norm_layer_1 = nn.GroupNorm(8, out_channel)
289
+ norm_layer_2 = nn.GroupNorm(8, out_channel)
290
+ elif norm == 'instance':
291
+ norm_layer_1 = nn.InstanceNorm3d(out_channel)
292
+ norm_layer_2 = nn.InstanceNorm3d(out_channel)
293
+ self.feature_mapping = nn.Sequential(
294
+ nn.Conv3d(in_channel, out_channel, kernel_size = 3, padding = 1),
295
+ norm_layer_1,
296
+ nn.LeakyReLU(0.2, inplace=True),
297
+ nn.Conv3d(out_channel, out_channel, kernel_size = 3, padding = 1),
298
+ norm_layer_2,
299
+ nn.LeakyReLU(0.2, inplace=True)
300
+ )
301
+
302
+ def forward(self, x):
303
+ return self.feature_mapping(x)
304
+
305
+
306
+ class FeatureMapping1(nn.Module):
307
+ def __init__(self, in_channel, norm = 'instance'):
308
+ super().__init__()
309
+ if norm == 'bn':
310
+ norm_layer_1 = nn.BatchNorm3d(in_channel)
311
+ norm_layer_2 = nn.BatchNorm3d(in_channel)
312
+ elif norm == 'gn':
313
+ norm_layer_1 = nn.GroupNorm(8, in_channel)
314
+ norm_layer_2 = nn.GroupNorm(8, in_channel)
315
+ elif norm == 'instance':
316
+ norm_layer_1 = nn.InstanceNorm3d(in_channel)
317
+ norm_layer_2 = nn.InstanceNorm3d(in_channel)
318
+ self.feature_mapping1 = nn.Sequential(
319
+ nn.Conv3d(in_channel, in_channel, kernel_size = 3, padding = 1),
320
+ norm_layer_1,
321
+ nn.LeakyReLU(0.2, inplace=True),
322
+ nn.Conv3d(in_channel, in_channel, kernel_size = 3, padding = 1),
323
+ norm_layer_2,
324
+ nn.LeakyReLU(0.2, inplace=True)
325
+ )
326
+ def forward(self, x):
327
+ y = self.feature_mapping1(x)
328
+ return x + y #Resnet Like
329
+
330
+ ################Transformer#######################
331
+
332
+
333
+ def pair(t):
334
+ return t if isinstance(t, tuple) else (t, t)
335
+
336
+
337
+ class PreNorm(nn.Module):
338
+ def __init__(self, dim, function):
339
+ super().__init__()
340
+ self.norm = nn.LayerNorm(dim)
341
+ self.function = function
342
+
343
+ def forward(self, x):
344
+ return self.function(self.norm(x))
345
+
346
+
347
+ class FeedForward(nn.Module):
348
+ def __init__(self, dim, hidden_dim, dropout = 0.0):
349
+ super().__init__()
350
+ self.net = nn.Sequential(
351
+ nn.Linear(dim, hidden_dim),
352
+ nn.GELU(),
353
+ nn.Dropout(dropout),
354
+ nn.Linear(hidden_dim, dim),
355
+ nn.Dropout(dropout)
356
+ )
357
+
358
+ def forward(self, x):
359
+ return self.net(x)
360
+
361
+ class Attention(nn.Module):
362
+ def __init__(self, dim, heads, dim_head, dropout = 0.0):
363
+ super().__init__()
364
+ all_head_size = heads * dim_head
365
+ project_out = not (heads == 1 and dim_head == dim)
366
+
367
+ self.heads = heads
368
+ self.scale = dim_head ** -0.5
369
+
370
+ self.softmax = nn.Softmax(dim = -1)
371
+ self.to_qkv = nn.Linear(dim, all_head_size * 3, bias = False)
372
+
373
+ self.to_out = nn.Sequential(
374
+ nn.Linear(all_head_size, dim),
375
+ nn.Dropout(dropout)
376
+ ) if project_out else nn.Identity()
377
+
378
+ def forward(self, x):
379
+ qkv = self.to_qkv(x).chunk(3, dim = -1)
380
+ #(batch, heads * dim_head) -> (batch, all_head_size)
381
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
382
+
383
+ dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
384
+
385
+ atten = self.softmax(dots)
386
+
387
+ out = torch.matmul(atten, v)
388
+ out = rearrange(out, 'b h n d -> b n (h d)')
389
+ return self.to_out(out)
390
+
391
+ class Transformer(nn.Module):
392
+ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.0):
393
+ super().__init__()
394
+ self.layers = nn.ModuleList([])
395
+ for _ in range(depth):
396
+ self.layers.append(nn.ModuleList([
397
+ PreNorm(dim, Attention(dim, heads, dim_head, dropout)),
398
+ PreNorm(dim, FeedForward(dim, mlp_dim, dropout))
399
+ ]))
400
+ def forward(self, x):
401
+ for attention, feedforward in self.layers:
402
+ x = attention(x) + x
403
+ x = feedforward(x) + x
404
+ return x
405
+
406
+ class FixedPositionalEncoding(nn.Module):
407
+ def __init__(self, embedding_dim, max_length=768):
408
+ super(FixedPositionalEncoding, self).__init__()
409
+
410
+ pe = torch.zeros(max_length, embedding_dim)
411
+ position = torch.arange(0, max_length, dtype=torch.float).unsqueeze(1)
412
+ div_term = torch.exp(
413
+ torch.arange(0, embedding_dim, 2).float()
414
+ * (-torch.log(torch.tensor(10000.0)) / embedding_dim)
415
+ )
416
+ pe[:, 0::2] = torch.sin(position * div_term)
417
+ pe[:, 1::2] = torch.cos(position * div_term)
418
+ pe = pe.unsqueeze(0).transpose(0, 1)
419
+ self.register_buffer('pe', pe)
420
+
421
+ def forward(self, x):
422
+ x = x + self.pe[: x.size(0), :]
423
+ return x
424
+
425
+
426
+ class LearnedPositionalEncoding(nn.Module):
427
+ def __init__(self, embedding_dim, seq_length):
428
+ super(LearnedPositionalEncoding, self).__init__()
429
+ self.seq_length = seq_length
430
+ self.position_embeddings = nn.Parameter(torch.zeros(1, seq_length, embedding_dim)) #8x
431
+
432
+ def forward(self, x, position_ids=None):
433
+ position_embeddings = self.position_embeddings
434
+ # print(x.shape, self.position_embeddings.shape)
435
+ return x + position_embeddings
436
+
437
+
438
+
439
+
440
+
441
+ ###############Main model#################
442
+
443
+ class SegTransVAE(nn.Module):
444
+ def __init__(self, img_dim, patch_dim, num_channels, num_classes,
445
+ embedding_dim, num_heads, num_layers, hidden_dim, in_channels_vae,
446
+ dropout = 0.0, attention_dropout = 0.0,
447
+ conv_patch_representation = True, positional_encoding = 'learned',
448
+ use_VAE = False):
449
+
450
+ super().__init__()
451
+ assert embedding_dim % num_heads == 0
452
+ assert img_dim[0] % patch_dim == 0 and img_dim[1] % patch_dim == 0 and img_dim[2] % patch_dim == 0
453
+
454
+ self.img_dim = img_dim
455
+ self.embedding_dim = embedding_dim
456
+ self.num_heads = num_heads
457
+ self.num_classes = num_classes
458
+ self.patch_dim = patch_dim
459
+ self.num_channels = num_channels
460
+ self.in_channels_vae = in_channels_vae
461
+ self.dropout = dropout
462
+ self.attention_dropout = attention_dropout
463
+ self.conv_patch_representation = conv_patch_representation
464
+ self.use_VAE = use_VAE
465
+
466
+ self.num_patches = int((img_dim[0] // patch_dim) * (img_dim[1] // patch_dim) * (img_dim[2] // patch_dim))
467
+ self.seq_length = self.num_patches
468
+ self.flatten_dim = 128 * num_channels
469
+
470
+ self.linear_encoding = nn.Linear(self.flatten_dim, self.embedding_dim)
471
+ if positional_encoding == "learned":
472
+ self.position_encoding = LearnedPositionalEncoding(
473
+ self.embedding_dim, self.seq_length
474
+ )
475
+ elif positional_encoding == "fixed":
476
+ self.position_encoding = FixedPositionalEncoding(
477
+ self.embedding_dim,
478
+ )
479
+ self.pe_dropout = nn.Dropout(self.dropout)
480
+
481
+ self.transformer = Transformer(
482
+ embedding_dim, num_layers, num_heads, embedding_dim // num_heads, hidden_dim, dropout
483
+ )
484
+ self.pre_head_ln = nn.LayerNorm(embedding_dim)
485
+
486
+ if self.conv_patch_representation:
487
+ self.conv_x = nn.Conv3d(128, self.embedding_dim, kernel_size=3, stride=1, padding=1)
488
+ self.encoder = Encoder(self.num_channels, 16)
489
+ self.bn = nn.InstanceNorm3d(128)
490
+ self.relu = nn.LeakyReLU(0.2, inplace=True)
491
+ self.FeatureMapping = FeatureMapping(in_channel = self.embedding_dim, out_channel= self.in_channels_vae)
492
+ self.FeatureMapping1 = FeatureMapping1(in_channel = self.in_channels_vae)
493
+ self.decoder = Decoder(self.img_dim, self.patch_dim, self.embedding_dim, num_classes)
494
+
495
+ self.vae_input = (1, self.in_channels_vae, img_dim[0] // 8, img_dim[1] // 8, img_dim[2] // 8)
496
+ if use_VAE:
497
+ self.vae = VAE(input_shape = self.vae_input , latent_dim= 256, num_channels= self.num_channels)
498
+ def encode(self, x):
499
+ if self.conv_patch_representation:
500
+ x1, x2, x3, x = self.encoder(x)
501
+ x = self.bn(x)
502
+ x = self.relu(x)
503
+ x = self.conv_x(x)
504
+ x = x.permute(0, 2, 3, 4, 1).contiguous()
505
+ x = x.view(x.size(0), -1, self.embedding_dim)
506
+ x = self.position_encoding(x)
507
+ x = self.pe_dropout(x)
508
+ x = self.transformer(x)
509
+ x = self.pre_head_ln(x)
510
+
511
+ return x1, x2, x3, x
512
+
513
+ def decode(self, x1, x2, x3, x):
514
+ #x: (1, 4096, 512) -> (1, 16, 16, 16, 512)
515
+ # print("In decode...")
516
+ # print(" x1: {} \n x2: {} \n x3: {} \n x: {}".format( x1.shape, x2.shape, x3.shape, x.shape))
517
+ # break
518
+ return self.decoder(x1, x2, x3, x)
519
+
520
+ def forward(self, x, is_validation = True):
521
+ x1, x2, x3, x = self.encode(x)
522
+ x = x.view( x.size(0),
523
+ self.img_dim[0] // self.patch_dim,
524
+ self.img_dim[1] // self.patch_dim,
525
+ self.img_dim[2] // self.patch_dim,
526
+ self.embedding_dim)
527
+ x = x.permute(0, 4, 1, 2, 3).contiguous()
528
+ x = self.FeatureMapping(x)
529
+ x = self.FeatureMapping1(x)
530
+ if self.use_VAE and not is_validation:
531
+ vae_out, mu, sigma = self.vae(x)
532
+ y = self.decode(x1, x2, x3, x)
533
+ if self.use_VAE and not is_validation:
534
+ return y, vae_out, mu, sigma
535
+ else:
536
+ return y
537
+
538
+
3DUnet_Like/models/SegTranVAE/__init__.py ADDED
File without changes
3DUnet_Like/models/SegTranVAE/__pycache__/SegTranVAE.cpython-39.pyc ADDED
Binary file (16.3 kB). View file
 
3DUnet_Like/models/SegTranVAE/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (137 Bytes). View file
 
3DUnet_Like/train.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ from monai.utils import set_determinism
4
+ from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
5
+ import os
6
+ from pytorch_lightning.loggers import TensorBoardLogger
7
+ from trainer import BRATS
8
+ from dataset.utils import get_loader
9
+ import pytorch_lightning as pl
10
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+
13
+ set_determinism(seed=0)
14
+
15
+ os.system('cls||clear')
16
+ print("Training ...")
17
+
18
+ data_dir = "/app/brats_2021_task1"
19
+ json_list = "/app/info.json"
20
+ roi = (128, 128, 128)
21
+ batch_size = 1
22
+ fold = 1
23
+ max_epochs = 500
24
+ val_every = 10
25
+ train_loader, val_loader,test_loader = get_loader(batch_size, data_dir, json_list, fold, roi, volume=1, test_size=0.2)
26
+ print("Done initialize dataloader !! ")
27
+
28
+ model = BRATS(use_VAE = True, train_loader = train_loader,val_loader = val_loader, test_loader=test_loader )
29
+ checkpoint_callback = ModelCheckpoint(
30
+ monitor='val/MeanDiceScore',
31
+ dirpath='./checkpoints/{}'.format("SegTransVAE"),
32
+ filename='Epoch{epoch:3d}-MeanDiceScore{val/MeanDiceScore:.4f}',
33
+ save_top_k=3,
34
+ mode='max',
35
+ save_last= True,
36
+ auto_insert_metric_name=False
37
+ )
38
+ early_stop_callback = EarlyStopping(
39
+ monitor='val/MeanDiceScore',
40
+ min_delta=0.0001,
41
+ patience=15,
42
+ verbose=False,
43
+ mode='max'
44
+ )
45
+ tensorboardlogger = TensorBoardLogger(
46
+ 'logs',
47
+ name = "SegTransVAE",
48
+ default_hp_metric = None
49
+ )
50
+ trainer = pl.Trainer(#fast_dev_run = 10,
51
+ # accelerator='ddp',
52
+ #overfit_batches=5,
53
+ devices = [0],
54
+ precision=16,
55
+ max_epochs = max_epochs,
56
+ enable_progress_bar=True,
57
+ callbacks=[checkpoint_callback, early_stop_callback],
58
+ # auto_lr_find=True,
59
+ num_sanity_val_steps=1,
60
+ logger = tensorboardlogger,
61
+ check_val_every_n_epoch = 10,
62
+ # limit_train_batches=0.01,
63
+ # limit_val_batches=0.01
64
+ )
65
+ # trainer.tune(model)
66
+ trainer.fit(model)
67
+
68
+
69
+
3DUnet_Like/trainer.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pytorch_lightning as pl
3
+ import matplotlib.pyplot as plt
4
+ import csv
5
+ import torch
6
+ from monai.transforms import AsDiscrete, Activations, Compose, EnsureType
7
+ from models.SegTranVAE.SegTranVAE import SegTransVAE
8
+ from loss.loss import Loss_VAE, DiceScore
9
+ from monai.losses import DiceLoss
10
+ import pytorch_lightning as pl
11
+ from monai.inferers import sliding_window_inference
12
+
13
+
14
+
15
+
16
+
17
+ class BRATS(pl.LightningModule):
18
+ def __init__(self,train_loader,val_loader,test_loader, use_VAE = True, lr = 1e-4 ):
19
+ super().__init__()
20
+ self.train_loader = train_loader
21
+ self.val_loader = val_loader
22
+ self.test_loader = test_loader
23
+ self.use_vae = use_VAE
24
+ self.lr = lr
25
+ self.model = SegTransVAE((128, 128, 128), 8, 4, 3, 768, 8, 4, 3072, in_channels_vae=128, use_VAE = use_VAE)
26
+
27
+ self.loss_vae = Loss_VAE()
28
+ self.dice_loss = DiceLoss(to_onehot_y=False, sigmoid=True, squared_pred=True)
29
+ self.post_trans_images = Compose(
30
+ [EnsureType(),
31
+ Activations(sigmoid=True),
32
+ AsDiscrete(threshold_values=True),
33
+ ]
34
+ )
35
+
36
+ self.best_val_dice = 0
37
+
38
+ self.training_step_outputs = []
39
+ self.val_step_loss = []
40
+ self.val_step_dice = []
41
+ self.val_step_dice_tc = []
42
+ self.val_step_dice_wt = []
43
+ self.val_step_dice_et = []
44
+ self.test_step_loss = []
45
+ self.test_step_dice = []
46
+ self.test_step_dice_tc = []
47
+ self.test_step_dice_wt = []
48
+ self.test_step_dice_et = []
49
+
50
+ def forward(self, x, is_validation = True):
51
+ return self.model(x, is_validation)
52
+ def training_step(self, batch, batch_index):
53
+ inputs, labels = (batch['image'], batch['label'])
54
+
55
+ if not self.use_vae:
56
+ outputs = self.forward(inputs, is_validation=False)
57
+ loss = self.dice_loss(outputs, labels)
58
+ else:
59
+ outputs, recon_batch, mu, sigma = self.forward(inputs, is_validation=False)
60
+
61
+ vae_loss = self.loss_vae(recon_batch, inputs, mu, sigma)
62
+ dice_loss = self.dice_loss(outputs, labels)
63
+ loss = dice_loss + 1/(4 * 128 * 128 * 128) * vae_loss
64
+ self.training_step_outputs.append(loss)
65
+ self.log('train/vae_loss', vae_loss)
66
+ self.log('train/dice_loss', dice_loss)
67
+ if batch_index == 10:
68
+
69
+ tensorboard = self.logger.experiment
70
+ fig, ax = plt.subplots(nrows=1, ncols=6, figsize=(10, 5))
71
+
72
+
73
+ ax[0].imshow(inputs.detach().cpu()[0][0][:, :, 80], cmap='gray')
74
+ ax[0].set_title("Input")
75
+
76
+ ax[1].imshow(recon_batch.detach().cpu().float()[0][0][:,:, 80], cmap='gray')
77
+ ax[1].set_title("Reconstruction")
78
+
79
+ ax[2].imshow(labels.detach().cpu().float()[0][0][:,:, 80], cmap='gray')
80
+ ax[2].set_title("Labels TC")
81
+
82
+ ax[3].imshow(outputs.sigmoid().detach().cpu().float()[0][0][:,:, 80], cmap='gray')
83
+ ax[3].set_title("TC")
84
+
85
+ ax[4].imshow(labels.detach().cpu().float()[0][2][:,:, 80], cmap='gray')
86
+ ax[4].set_title("Labels ET")
87
+
88
+ ax[5].imshow(outputs.sigmoid().detach().cpu().float()[0][2][:,:, 80], cmap='gray')
89
+ ax[5].set_title("ET")
90
+
91
+
92
+ tensorboard.add_figure('train_visualize', fig, self.current_epoch)
93
+
94
+ self.log('train/loss', loss)
95
+
96
+ return loss
97
+
98
+ def on_train_epoch_end(self):
99
+ ## F1 Macro all epoch saving outputs and target per batch
100
+
101
+ # free up the memory
102
+ # --> HERE STEP 3 <--
103
+ epoch_average = torch.stack(self.training_step_outputs).mean()
104
+ self.log("training_epoch_average", epoch_average)
105
+ self.training_step_outputs.clear() # free memory
106
+
107
+ def validation_step(self, batch, batch_index):
108
+ inputs, labels = (batch['image'], batch['label'])
109
+ roi_size = (128, 128, 128)
110
+ sw_batch_size = 1
111
+ outputs = sliding_window_inference(
112
+ inputs, roi_size, sw_batch_size, self.model, overlap = 0.5)
113
+ loss = self.dice_loss(outputs, labels)
114
+
115
+
116
+ val_outputs = self.post_trans_images(outputs)
117
+
118
+
119
+ metric_tc = DiceScore(y_pred=val_outputs[:, 0:1], y=labels[:, 0:1], include_background = True)
120
+ metric_wt = DiceScore(y_pred=val_outputs[:, 1:2], y=labels[:, 1:2], include_background = True)
121
+ metric_et = DiceScore(y_pred=val_outputs[:, 2:3], y=labels[:, 2:3], include_background = True)
122
+ mean_val_dice = (metric_tc + metric_wt + metric_et)/3
123
+ self.val_step_loss.append(loss)
124
+ self.val_step_dice.append(mean_val_dice)
125
+ self.val_step_dice_tc.append(metric_tc)
126
+ self.val_step_dice_wt.append(metric_wt)
127
+ self.val_step_dice_et.append(metric_et)
128
+ return {'val_loss': loss, 'val_mean_dice': mean_val_dice, 'val_dice_tc': metric_tc,
129
+ 'val_dice_wt': metric_wt, 'val_dice_et': metric_et}
130
+
131
+ def on_validation_epoch_end(self):
132
+
133
+ loss = torch.stack(self.val_step_loss).mean()
134
+ mean_val_dice = torch.stack(self.val_step_dice).mean()
135
+ metric_tc = torch.stack(self.val_step_dice_tc).mean()
136
+ metric_wt = torch.stack(self.val_step_dice_wt).mean()
137
+ metric_et = torch.stack(self.val_step_dice_et).mean()
138
+ self.log('val/Loss', loss)
139
+ self.log('val/MeanDiceScore', mean_val_dice)
140
+ self.log('val/DiceTC', metric_tc)
141
+ self.log('val/DiceWT', metric_wt)
142
+ self.log('val/DiceET', metric_et)
143
+ os.makedirs(self.logger.log_dir, exist_ok=True)
144
+ if self.current_epoch == 0:
145
+ with open('{}/metric_log.csv'.format(self.logger.log_dir), 'w') as f:
146
+ writer = csv.writer(f)
147
+ writer.writerow(['Epoch', 'Mean Dice Score', 'Dice TC', 'Dice WT', 'Dice ET'])
148
+ with open('{}/metric_log.csv'.format(self.logger.log_dir), 'a') as f:
149
+ writer = csv.writer(f)
150
+ writer.writerow([self.current_epoch, mean_val_dice.item(), metric_tc.item(), metric_wt.item(), metric_et.item()])
151
+
152
+ if mean_val_dice > self.best_val_dice:
153
+ self.best_val_dice = mean_val_dice
154
+ self.best_val_epoch = self.current_epoch
155
+ print(
156
+ f"\n Current epoch: {self.current_epoch} Current mean dice: {mean_val_dice:.4f}"
157
+ f" tc: {metric_tc:.4f} wt: {metric_wt:.4f} et: {metric_et:.4f}"
158
+ f"\n Best mean dice: {self.best_val_dice}"
159
+ f" at epoch: {self.best_val_epoch}"
160
+ )
161
+
162
+ self.val_step_loss.clear()
163
+ self.val_step_dice.clear()
164
+ self.val_step_dice_tc.clear()
165
+ self.val_step_dice_wt.clear()
166
+ self.val_step_dice_et.clear()
167
+ return {'val_MeanDiceScore': mean_val_dice}
168
+ def test_step(self, batch, batch_index):
169
+ inputs, labels = (batch['image'], batch['label'])
170
+
171
+ roi_size = (128, 128, 128)
172
+ sw_batch_size = 1
173
+ test_outputs = sliding_window_inference(
174
+ inputs, roi_size, sw_batch_size, self.forward, overlap = 0.5)
175
+ loss = self.dice_loss(test_outputs, labels)
176
+ test_outputs = self.post_trans_images(test_outputs)
177
+ metric_tc = DiceScore(y_pred=test_outputs[:, 0:1], y=labels[:, 0:1], include_background = True)
178
+ metric_wt = DiceScore(y_pred=test_outputs[:, 1:2], y=labels[:, 1:2], include_background = True)
179
+ metric_et = DiceScore(y_pred=test_outputs[:, 2:3], y=labels[:, 2:3], include_background = True)
180
+ mean_test_dice = (metric_tc + metric_wt + metric_et)/3
181
+
182
+ self.test_step_loss.append(loss)
183
+ self.test_step_dice.append(mean_test_dice)
184
+ self.test_step_dice_tc.append(metric_tc)
185
+ self.test_step_dice_wt.append(metric_wt)
186
+ self.test_step_dice_et.append(metric_et)
187
+
188
+ return {'test_loss': loss, 'test_mean_dice': mean_test_dice, 'test_dice_tc': metric_tc,
189
+ 'test_dice_wt': metric_wt, 'test_dice_et': metric_et}
190
+
191
+ def test_epoch_end(self):
192
+ loss = torch.stack(self.test_step_loss).mean()
193
+ mean_test_dice = torch.stack(self.test_step_dice).mean()
194
+ metric_tc = torch.stack(self.test_step_dice_tc).mean()
195
+ metric_wt = torch.stack(self.test_step_dice_wt).mean()
196
+ metric_et = torch.stack(self.test_step_dice_et).mean()
197
+ self.log('test/Loss', loss)
198
+ self.log('test/MeanDiceScore', mean_test_dice)
199
+ self.log('test/DiceTC', metric_tc)
200
+ self.log('test/DiceWT', metric_wt)
201
+ self.log('test/DiceET', metric_et)
202
+
203
+ with open('{}/test_log.csv'.format(self.logger.log_dir), 'w') as f:
204
+ writer = csv.writer(f)
205
+ writer.writerow(["Mean Test Dice", "Dice TC", "Dice WT", "Dice ET"])
206
+ writer.writerow([mean_test_dice, metric_tc, metric_wt, metric_et])
207
+
208
+ self.test_step_loss.clear()
209
+ self.test_step_dice.clear()
210
+ self.test_step_dice_tc.clear()
211
+ self.test_step_dice_wt.clear()
212
+ self.test_step_dice_et.clear()
213
+ return {'test_MeanDiceScore': mean_test_dice}
214
+
215
+
216
+ def configure_optimizers(self):
217
+ optimizer = torch.optim.Adam(
218
+ self.model.parameters(), self.lr, weight_decay=1e-5, amsgrad=True
219
+ )
220
+ # optimizer = AdaBelief(self.model.parameters(),
221
+ # lr=self.lr, eps=1e-16,
222
+ # betas=(0.9,0.999), weight_decouple = True,
223
+ # rectify = False)
224
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 200)
225
+ return [optimizer], [scheduler]
226
+
227
+ def train_dataloader(self):
228
+ return self.train_loader
229
+ def val_dataloader(self):
230
+ return self.val_loader
231
+
232
+ def test_dataloader(self):
233
+ return self.test_loader
brats_2021_task1/BraTS2021_Training_Data/.DS_Store ADDED
Binary file (6.15 kB). View file
 
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00000/BraTS2021_00000_flair.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bb899a83627591e55cada00b2c6d5402199832b717c8b9f90bb550fe35d971ff
3
+ size 2532638
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00000/BraTS2021_00000_seg.nii.gz ADDED
Binary file (57.9 kB). View file
 
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00000/BraTS2021_00000_t1.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e92b72ee221624c36cf89ac826deceeee7097f46dd66a5d218d18b7916ebd67d
3
+ size 2332393
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00000/BraTS2021_00000_t1ce.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:735c27fd7a17b1702875837bcc843eedc88ced6ac2cb0e73cdf995e3e64ba82f
3
+ size 2643179
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00000/BraTS2021_00000_t2.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:955ff59d053e87153bb7c809235743ec904817727ec02c630f3141e191d6f452
3
+ size 2432699
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00002/BraTS2021_00002_flair.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:208e4e8fadbdf2b1c1a87f8bad8854bc7d4becd604bc01eefad725d19b43c6ef
3
+ size 2331912
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00002/BraTS2021_00002_seg.nii.gz ADDED
Binary file (78 kB). View file
 
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00002/BraTS2021_00002_t1.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4b76c5e6326f0f89e2f6ee243473060390982fc83e9bceb27a4b94899b2b0df1
3
+ size 2170543
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00002/BraTS2021_00002_t1ce.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4f49aa266cd907ce890bcc9c96f82534810f85bb98a76f8a092a5b529d3b6b6e
3
+ size 2486326
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00002/BraTS2021_00002_t2.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b5c9da53d4a573fe37fa8a8075b8220d4dfe3bbc377c3e1943fbd3ef86d3b118
3
+ size 2303833
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00003/BraTS2021_00003_flair.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f68b6d8026f2c6ac097ea26a4dd727571025e26b69b485f9a1a84e244222721
3
+ size 2719582
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00003/BraTS2021_00003_seg.nii.gz ADDED
Binary file (63.4 kB). View file
 
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00003/BraTS2021_00003_t1.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:90997c21753a6d49bdb8629f5b2c6ddb61ef7c2f86f977b18001ce6c5e3161ed
3
+ size 2488450
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00003/BraTS2021_00003_t1ce.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a211cd8c26dd8d396341399d0a13fb888d5df6252bd9c5aa01156680e97c5577
3
+ size 2834759
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00003/BraTS2021_00003_t2.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5c440a9ef1ef3b494d74e647fab9e41eb48f1fc40c8eabffd55863f46779ed5d
3
+ size 2635293
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00005/BraTS2021_00005_flair.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c5c3ed117922199afb26cae01b090400d6ca90f06825c12f5a5a3ae7ced098e3
3
+ size 2265964
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00005/BraTS2021_00005_seg.nii.gz ADDED
Binary file (70.7 kB). View file
 
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00005/BraTS2021_00005_t1.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:54f109386253cb60694dec35e10e1e8f2ee02a2671faa860ac212cce8997b434
3
+ size 2085481
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00005/BraTS2021_00005_t1ce.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2e074611e7388f7f7fefb5bbe39abb09f16f0dbeb0a35fbd661a0bbcd14d4b5a
3
+ size 2323871
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00005/BraTS2021_00005_t2.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d1689213fb1bb46330f16d30489a9875e2b5e9f30c8c5659a6b689011bd69e9
3
+ size 2144371
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00006/BraTS2021_00006_flair.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0f42d02ba6637b8d98c0c44869ccfa444935dedb8f5ada29a3db65d33c878c0b
3
+ size 2588071
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00006/BraTS2021_00006_seg.nii.gz ADDED
Binary file (70.5 kB). View file
 
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00006/BraTS2021_00006_t1.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5f9e9ff16cdffd375f31a969994e31725c329950e7c3e64789238afc3faadead
3
+ size 2386395
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00006/BraTS2021_00006_t1ce.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:20847fc162fe9193fb2d5a08e7b9009d16603eae268f769bf6d7bbffb0d79c42
3
+ size 2705917
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00006/BraTS2021_00006_t2.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0a18fb6644ca828ef43067264a631d297d5aeaad4ebae05bce4ea09c9c76f898
3
+ size 2479204
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00008/BraTS2021_00008_flair.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c5dbec780a624127d829b4b7165ff6afe6258e9224a529abf8ddaf25474888c5
3
+ size 2452120
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00008/BraTS2021_00008_seg.nii.gz ADDED
Binary file (45.5 kB). View file
 
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00008/BraTS2021_00008_t1.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f46e84f3f81a10b11b994fca26ae7f2bf537a6dfc0cb60853c097243924e39e8
3
+ size 2397406
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00008/BraTS2021_00008_t1ce.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3b92f907a585a13ae99b0c79f315d396a3dd4524f1a35001212509957e264abf
3
+ size 2639147
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00008/BraTS2021_00008_t2.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:545a1519ae217731184caad17d6474e3d62736efbad11326b5c21e63124ace12
3
+ size 2368904
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00009/BraTS2021_00009_flair.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0f407ffd3d48d678b63de7158fa566626a0ad917e3d523d7d6a3cdd6f7e99789
3
+ size 2233856